summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/stack/nic.go80
-rw-r--r--pkg/tcpip/stack/stack.go16
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go6
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go39
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go12
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound.cc219
6 files changed, 335 insertions, 37 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 8008d9870..a4117d98e 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -42,6 +42,7 @@ type NIC struct {
primary map[tcpip.NetworkProtocolNumber]*ilist.List
endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
subnets []tcpip.Subnet
+ mcastJoins map[NetworkEndpointID]int32
stats NICStats
}
@@ -79,14 +80,15 @@ const (
func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC {
return &NIC{
- stack: stack,
- id: id,
- name: name,
- linkEP: ep,
- loopback: loopback,
- demux: newTransportDemuxer(stack),
- primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
- endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
+ stack: stack,
+ id: id,
+ name: name,
+ linkEP: ep,
+ loopback: loopback,
+ demux: newTransportDemuxer(stack),
+ primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
+ endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
+ mcastJoins: make(map[NetworkEndpointID]int32),
stats: NICStats{
Tx: DirectionStats{
Packets: &tcpip.StatCounter{},
@@ -384,20 +386,62 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
n.mu.Unlock()
}
-// RemoveAddress removes an address from n.
-func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
- n.mu.Lock()
+func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error {
r := n.endpoints[NetworkEndpointID{addr}]
if r == nil || !r.holdsInsertRef {
- n.mu.Unlock()
return tcpip.ErrBadLocalAddress
}
r.holdsInsertRef = false
- n.mu.Unlock()
- r.decRef()
+ r.decRefLocked()
+
+ return nil
+}
+
+// RemoveAddress removes an address from n.
+func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+ return n.removeAddressLocked(addr)
+}
+
+// joinGroup adds a new endpoint for the given multicast address, if none
+// exists yet. Otherwise it just increments its count.
+func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+ id := NetworkEndpointID{addr}
+ joins := n.mcastJoins[id]
+ if joins == 0 {
+ if _, err := n.addAddressLocked(protocol, addr, NeverPrimaryEndpoint, false); err != nil {
+ return err
+ }
+ }
+ n.mcastJoins[id] = joins + 1
+ return nil
+}
+
+// leaveGroup decrements the count for the given multicast address, and when it
+// reaches zero removes the endpoint for this address.
+func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ id := NetworkEndpointID{addr}
+ joins := n.mcastJoins[id]
+ switch joins {
+ case 0:
+ // There are no joins with this address on this NIC.
+ return tcpip.ErrBadLocalAddress
+ case 1:
+ // This is the last one, clean up.
+ if err := n.removeAddressLocked(addr); err != nil {
+ return err
+ }
+ }
+ n.mcastJoins[id] = joins - 1
return nil
}
@@ -644,6 +688,14 @@ func (r *referencedNetworkEndpoint) decRef() {
}
}
+// decRefLocked is the same as decRef but assumes that the NIC.mu mutex is
+// locked.
+func (r *referencedNetworkEndpoint) decRefLocked() {
+ if atomic.AddInt32(&r.refs, -1) == 0 {
+ r.nic.removeEndpointLocked(r)
+ }
+}
+
// incRef increments the ref count. It must only be called when the caller is
// known to be holding a reference to the endpoint, otherwise tryIncRef should
// be used.
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index f204ca790..c82822ee2 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -1062,10 +1062,22 @@ func (s *Stack) RemoveTCPProbe() {
// JoinGroup joins the given multicast group on the given NIC.
func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
// TODO: notify network of subscription via igmp protocol.
- return s.AddAddressWithOptions(nicID, protocol, multicastAddr, NeverPrimaryEndpoint)
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[nicID]; ok {
+ return nic.joinGroup(protocol, multicastAddr)
+ }
+ return tcpip.ErrUnknownNICID
}
// LeaveGroup leaves the given multicast group on the given NIC.
func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
- return s.RemoveAddress(nicID, multicastAddr)
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[nicID]; ok {
+ return nic.leaveGroup(multicastAddr)
+ }
+ return tcpip.ErrUnknownNICID
}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index e8b562ad9..66c564613 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -141,9 +141,9 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
- // If this is a broadcast datagram, deliver the datagram to all endpoints
- // managed by ep.
- if id.LocalAddress == header.IPv4Broadcast {
+ // If this is a broadcast or multicast datagram, deliver the datagram to all
+ // endpoints managed by ep.
+ if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) {
for i, endpoint := range ep.endpointsArr {
// HandlePacket modifies vv, so each endpoint needs its own copy.
if i == len(ep.endpointsArr)-1 {
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index db65a4e88..0ed0902b0 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -458,14 +458,22 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrUnknownDevice
}
- if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
- return err
- }
+ memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
e.mu.Lock()
defer e.mu.Unlock()
- e.multicastMemberships = append(e.multicastMemberships, multicastMembership{nicID, v.MulticastAddr})
+ for _, mem := range e.multicastMemberships {
+ if mem == memToInsert {
+ return tcpip.ErrPortInUse
+ }
+ }
+
+ if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ e.multicastMemberships = append(e.multicastMemberships, memToInsert)
case tcpip.RemoveMembershipOption:
if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
@@ -488,21 +496,28 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrUnknownDevice
}
- if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
- return err
- }
+ memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
+ memToRemoveIndex := -1
e.mu.Lock()
defer e.mu.Unlock()
+
for i, mem := range e.multicastMemberships {
- if mem.nicID == nicID && mem.multicastAddr == v.MulticastAddr {
- // Only remove the first match, so that each added membership above is
- // paired with exactly 1 removal.
- e.multicastMemberships[i] = e.multicastMemberships[len(e.multicastMemberships)-1]
- e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
+ if mem == memToRemove {
+ memToRemoveIndex = i
break
}
}
+ if memToRemoveIndex == -1 {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1]
+ e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
case tcpip.MulticastLoopOption:
e.mu.Lock()
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 163dcbc13..74e8e9fd5 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -66,6 +66,12 @@ func (e *endpoint) loadRcvBufSizeMax(max int) {
func (e *endpoint) afterLoad() {
e.stack = stack.StackFromEnv
+ for _, m := range e.multicastMemberships {
+ if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
+ panic(err)
+ }
+ }
+
if e.state != stateBound && e.state != stateConnected {
return
}
@@ -103,10 +109,4 @@ func (e *endpoint) afterLoad() {
if err != nil {
panic(*err)
}
-
- for _, m := range e.multicastMemberships {
- if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
- panic(err)
- }
- }
}
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
index 709172580..0ec828d8d 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
@@ -1207,5 +1207,224 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestJoinGroupInvalidIf) {
SyscallFailsWithErrno(ENODEV));
}
+// Check that multiple memberships are not allowed on the same socket.
+TEST_P(IPv4UDPUnboundSocketPairTest, TestMultipleJoinsOnSingleSocket) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ auto fd = sockets->first_fd();
+ ip_mreqn group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+
+ EXPECT_THAT(
+ setsockopt(fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, sizeof(group)),
+ SyscallSucceeds());
+
+ EXPECT_THAT(
+ setsockopt(fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, sizeof(group)),
+ SyscallFailsWithErrno(EADDRINUSE));
+}
+
+// Check that two sockets can join the same multicast group at the same time.
+TEST_P(IPv4UDPUnboundSocketPairTest, TestTwoSocketsJoinSameMulticastGroup) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ip_mreqn group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
+ &group, sizeof(group)),
+ SyscallSucceeds());
+ EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
+ &group, sizeof(group)),
+ SyscallSucceeds());
+
+ // Drop the membership twice on each socket, the second call for each socket
+ // should fail.
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP,
+ &group, sizeof(group)),
+ SyscallSucceeds());
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP,
+ &group, sizeof(group)),
+ SyscallFailsWithErrno(EADDRNOTAVAIL));
+ EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP,
+ &group, sizeof(group)),
+ SyscallSucceeds());
+ EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP,
+ &group, sizeof(group)),
+ SyscallFailsWithErrno(EADDRNOTAVAIL));
+}
+
+// Check that two sockets can join the same multicast group at the same time,
+// and both will receive data on it.
+TEST_P(IPv4UDPUnboundSocketPairTest, TestMcastReceptionOnTwoSockets) {
+ std::unique_ptr<SocketPair> socket_pairs[2] = {
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()),
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair())};
+
+ ip_mreq iface = {}, group = {};
+ iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ auto receiver_addr = V4Any();
+ int bound_port = 0;
+
+ // Create two socketpairs with the exact same configuration.
+ for (auto& sockets : socket_pairs) {
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF,
+ &iface, sizeof(iface)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(sockets->second_fd(), SOL_SOCKET, SO_REUSEPORT,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
+ &group, sizeof(group)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(sockets->second_fd(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ // Get the port assigned.
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(sockets->second_fd(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+ // On the first iteration, save the port we are bound to. On the second
+ // iteration, verify the port is the same as the one from the first
+ // iteration. In other words, both sockets listen on the same port.
+ if (bound_port == 0) {
+ bound_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ } else {
+ EXPECT_EQ(bound_port,
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port);
+ }
+ }
+
+ // Send a multicast packet to the group from two different sockets and verify
+ // it is received by both sockets that joined that group.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = bound_port;
+ for (auto& sockets : socket_pairs) {
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we received the multicast packet on both sockets.
+ for (auto& sockets : socket_pairs) {
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(
+ RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+ }
+ }
+}
+
+// Check that on two sockets that joined a group and listen on ANY, dropping
+// memberships one by one will continue to deliver packets to both sockets until
+// both memberships have been dropped.
+TEST_P(IPv4UDPUnboundSocketPairTest,
+ TestMcastReceptionWhenDroppingMemberships) {
+ std::unique_ptr<SocketPair> socket_pairs[2] = {
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()),
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair())};
+
+ ip_mreq iface = {}, group = {};
+ iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
+ auto receiver_addr = V4Any();
+ int bound_port = 0;
+
+ // Create two socketpairs with the exact same configuration.
+ for (auto& sockets : socket_pairs) {
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF,
+ &iface, sizeof(iface)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(sockets->second_fd(), SOL_SOCKET, SO_REUSEPORT,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
+ &group, sizeof(group)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(sockets->second_fd(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ // Get the port assigned.
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(sockets->second_fd(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+ // On the first iteration, save the port we are bound to. On the second
+ // iteration, verify the port is the same as the one from the first
+ // iteration. In other words, both sockets listen on the same port.
+ if (bound_port == 0) {
+ bound_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ } else {
+ EXPECT_EQ(bound_port,
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port);
+ }
+ }
+
+ // Drop the membership of the first socket pair and verify data is still
+ // received.
+ ASSERT_THAT(setsockopt(socket_pairs[0]->second_fd(), IPPROTO_IP,
+ IP_DROP_MEMBERSHIP, &group, sizeof(group)),
+ SyscallSucceeds());
+ // Send a packet from each socket_pair.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = bound_port;
+ for (auto& sockets : socket_pairs) {
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we received the multicast packet on both sockets.
+ for (auto& sockets : socket_pairs) {
+ char recv_buf[sizeof(send_buf)] = {};
+ ASSERT_THAT(
+ RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+ }
+ }
+
+ // Drop the membership of the second socket pair and verify data stops being
+ // received.
+ ASSERT_THAT(setsockopt(socket_pairs[1]->second_fd(), IPPROTO_IP,
+ IP_DROP_MEMBERSHIP, &group, sizeof(group)),
+ SyscallSucceeds());
+ // Send a packet from each socket_pair.
+ for (auto& sockets : socket_pairs) {
+ char send_buf[200];
+ ASSERT_THAT(
+ RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ char recv_buf[sizeof(send_buf)] = {};
+ for (auto& sockets : socket_pairs) {
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf,
+ sizeof(recv_buf), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+ }
+ }
+}
+
} // namespace testing
} // namespace gvisor