diff options
-rw-r--r-- | pkg/tcpip/stack/nic.go | 80 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 16 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 39 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint_state.go | 12 | ||||
-rw-r--r-- | test/syscalls/linux/socket_ipv4_udp_unbound.cc | 219 |
6 files changed, 335 insertions, 37 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 8008d9870..a4117d98e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -42,6 +42,7 @@ type NIC struct { primary map[tcpip.NetworkProtocolNumber]*ilist.List endpoints map[NetworkEndpointID]*referencedNetworkEndpoint subnets []tcpip.Subnet + mcastJoins map[NetworkEndpointID]int32 stats NICStats } @@ -79,14 +80,15 @@ const ( func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC { return &NIC{ - stack: stack, - id: id, - name: name, - linkEP: ep, - loopback: loopback, - demux: newTransportDemuxer(stack), - primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List), - endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), + stack: stack, + id: id, + name: name, + linkEP: ep, + loopback: loopback, + demux: newTransportDemuxer(stack), + primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List), + endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), + mcastJoins: make(map[NetworkEndpointID]int32), stats: NICStats{ Tx: DirectionStats{ Packets: &tcpip.StatCounter{}, @@ -384,20 +386,62 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { n.mu.Unlock() } -// RemoveAddress removes an address from n. -func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error { - n.mu.Lock() +func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error { r := n.endpoints[NetworkEndpointID{addr}] if r == nil || !r.holdsInsertRef { - n.mu.Unlock() return tcpip.ErrBadLocalAddress } r.holdsInsertRef = false - n.mu.Unlock() - r.decRef() + r.decRefLocked() + + return nil +} + +// RemoveAddress removes an address from n. +func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error { + n.mu.Lock() + defer n.mu.Unlock() + return n.removeAddressLocked(addr) +} + +// joinGroup adds a new endpoint for the given multicast address, if none +// exists yet. Otherwise it just increments its count. +func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { + n.mu.Lock() + defer n.mu.Unlock() + id := NetworkEndpointID{addr} + joins := n.mcastJoins[id] + if joins == 0 { + if _, err := n.addAddressLocked(protocol, addr, NeverPrimaryEndpoint, false); err != nil { + return err + } + } + n.mcastJoins[id] = joins + 1 + return nil +} + +// leaveGroup decrements the count for the given multicast address, and when it +// reaches zero removes the endpoint for this address. +func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { + n.mu.Lock() + defer n.mu.Unlock() + + id := NetworkEndpointID{addr} + joins := n.mcastJoins[id] + switch joins { + case 0: + // There are no joins with this address on this NIC. + return tcpip.ErrBadLocalAddress + case 1: + // This is the last one, clean up. + if err := n.removeAddressLocked(addr); err != nil { + return err + } + } + n.mcastJoins[id] = joins - 1 return nil } @@ -644,6 +688,14 @@ func (r *referencedNetworkEndpoint) decRef() { } } +// decRefLocked is the same as decRef but assumes that the NIC.mu mutex is +// locked. +func (r *referencedNetworkEndpoint) decRefLocked() { + if atomic.AddInt32(&r.refs, -1) == 0 { + r.nic.removeEndpointLocked(r) + } +} + // incRef increments the ref count. It must only be called when the caller is // known to be holding a reference to the endpoint, otherwise tryIncRef should // be used. diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index f204ca790..c82822ee2 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1062,10 +1062,22 @@ func (s *Stack) RemoveTCPProbe() { // JoinGroup joins the given multicast group on the given NIC. func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error { // TODO: notify network of subscription via igmp protocol. - return s.AddAddressWithOptions(nicID, protocol, multicastAddr, NeverPrimaryEndpoint) + s.mu.RLock() + defer s.mu.RUnlock() + + if nic, ok := s.nics[nicID]; ok { + return nic.joinGroup(protocol, multicastAddr) + } + return tcpip.ErrUnknownNICID } // LeaveGroup leaves the given multicast group on the given NIC. func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error { - return s.RemoveAddress(nicID, multicastAddr) + s.mu.RLock() + defer s.mu.RUnlock() + + if nic, ok := s.nics[nicID]; ok { + return nic.leaveGroup(multicastAddr) + } + return tcpip.ErrUnknownNICID } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index e8b562ad9..66c564613 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -141,9 +141,9 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { - // If this is a broadcast datagram, deliver the datagram to all endpoints - // managed by ep. - if id.LocalAddress == header.IPv4Broadcast { + // If this is a broadcast or multicast datagram, deliver the datagram to all + // endpoints managed by ep. + if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) { for i, endpoint := range ep.endpointsArr { // HandlePacket modifies vv, so each endpoint needs its own copy. if i == len(ep.endpointsArr)-1 { diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index db65a4e88..0ed0902b0 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -458,14 +458,22 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrUnknownDevice } - if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil { - return err - } + memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} e.mu.Lock() defer e.mu.Unlock() - e.multicastMemberships = append(e.multicastMemberships, multicastMembership{nicID, v.MulticastAddr}) + for _, mem := range e.multicastMemberships { + if mem == memToInsert { + return tcpip.ErrPortInUse + } + } + + if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + return err + } + + e.multicastMemberships = append(e.multicastMemberships, memToInsert) case tcpip.RemoveMembershipOption: if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { @@ -488,21 +496,28 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrUnknownDevice } - if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil { - return err - } + memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} + memToRemoveIndex := -1 e.mu.Lock() defer e.mu.Unlock() + for i, mem := range e.multicastMemberships { - if mem.nicID == nicID && mem.multicastAddr == v.MulticastAddr { - // Only remove the first match, so that each added membership above is - // paired with exactly 1 removal. - e.multicastMemberships[i] = e.multicastMemberships[len(e.multicastMemberships)-1] - e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1] + if mem == memToRemove { + memToRemoveIndex = i break } } + if memToRemoveIndex == -1 { + return tcpip.ErrBadLocalAddress + } + + if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + return err + } + + e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1] + e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1] case tcpip.MulticastLoopOption: e.mu.Lock() diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 163dcbc13..74e8e9fd5 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -66,6 +66,12 @@ func (e *endpoint) loadRcvBufSizeMax(max int) { func (e *endpoint) afterLoad() { e.stack = stack.StackFromEnv + for _, m := range e.multicastMemberships { + if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil { + panic(err) + } + } + if e.state != stateBound && e.state != stateConnected { return } @@ -103,10 +109,4 @@ func (e *endpoint) afterLoad() { if err != nil { panic(*err) } - - for _, m := range e.multicastMemberships { - if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil { - panic(err) - } - } } diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index 709172580..0ec828d8d 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -1207,5 +1207,224 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestJoinGroupInvalidIf) { SyscallFailsWithErrno(ENODEV)); } +// Check that multiple memberships are not allowed on the same socket. +TEST_P(IPv4UDPUnboundSocketPairTest, TestMultipleJoinsOnSingleSocket) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + auto fd = sockets->first_fd(); + ip_mreqn group = {}; + group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); + group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); + + EXPECT_THAT( + setsockopt(fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, sizeof(group)), + SyscallSucceeds()); + + EXPECT_THAT( + setsockopt(fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, sizeof(group)), + SyscallFailsWithErrno(EADDRINUSE)); +} + +// Check that two sockets can join the same multicast group at the same time. +TEST_P(IPv4UDPUnboundSocketPairTest, TestTwoSocketsJoinSameMulticastGroup) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ip_mreqn group = {}; + group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); + group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, + &group, sizeof(group)), + SyscallSucceeds()); + EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, + &group, sizeof(group)), + SyscallSucceeds()); + + // Drop the membership twice on each socket, the second call for each socket + // should fail. + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP, + &group, sizeof(group)), + SyscallSucceeds()); + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP, + &group, sizeof(group)), + SyscallFailsWithErrno(EADDRNOTAVAIL)); + EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP, + &group, sizeof(group)), + SyscallSucceeds()); + EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP, + &group, sizeof(group)), + SyscallFailsWithErrno(EADDRNOTAVAIL)); +} + +// Check that two sockets can join the same multicast group at the same time, +// and both will receive data on it. +TEST_P(IPv4UDPUnboundSocketPairTest, TestMcastReceptionOnTwoSockets) { + std::unique_ptr<SocketPair> socket_pairs[2] = { + ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()), + ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair())}; + + ip_mreq iface = {}, group = {}; + iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); + group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); + group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); + auto receiver_addr = V4Any(); + int bound_port = 0; + + // Create two socketpairs with the exact same configuration. + for (auto& sockets : socket_pairs) { + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, + &iface, sizeof(iface)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(sockets->second_fd(), SOL_SOCKET, SO_REUSEPORT, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, + &group, sizeof(group)), + SyscallSucceeds()); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + // Get the port assigned. + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + // On the first iteration, save the port we are bound to. On the second + // iteration, verify the port is the same as the one from the first + // iteration. In other words, both sockets listen on the same port. + if (bound_port == 0) { + bound_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + } else { + EXPECT_EQ(bound_port, + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port); + } + } + + // Send a multicast packet to the group from two different sockets and verify + // it is received by both sockets that joined that group. + auto send_addr = V4Multicast(); + reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = bound_port; + for (auto& sockets : socket_pairs) { + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we received the multicast packet on both sockets. + for (auto& sockets : socket_pairs) { + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT( + RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); + } + } +} + +// Check that on two sockets that joined a group and listen on ANY, dropping +// memberships one by one will continue to deliver packets to both sockets until +// both memberships have been dropped. +TEST_P(IPv4UDPUnboundSocketPairTest, + TestMcastReceptionWhenDroppingMemberships) { + std::unique_ptr<SocketPair> socket_pairs[2] = { + ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()), + ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair())}; + + ip_mreq iface = {}, group = {}; + iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); + group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); + group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); + auto receiver_addr = V4Any(); + int bound_port = 0; + + // Create two socketpairs with the exact same configuration. + for (auto& sockets : socket_pairs) { + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, + &iface, sizeof(iface)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(sockets->second_fd(), SOL_SOCKET, SO_REUSEPORT, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, + &group, sizeof(group)), + SyscallSucceeds()); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + // Get the port assigned. + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + // On the first iteration, save the port we are bound to. On the second + // iteration, verify the port is the same as the one from the first + // iteration. In other words, both sockets listen on the same port. + if (bound_port == 0) { + bound_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + } else { + EXPECT_EQ(bound_port, + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port); + } + } + + // Drop the membership of the first socket pair and verify data is still + // received. + ASSERT_THAT(setsockopt(socket_pairs[0]->second_fd(), IPPROTO_IP, + IP_DROP_MEMBERSHIP, &group, sizeof(group)), + SyscallSucceeds()); + // Send a packet from each socket_pair. + auto send_addr = V4Multicast(); + reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = bound_port; + for (auto& sockets : socket_pairs) { + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we received the multicast packet on both sockets. + for (auto& sockets : socket_pairs) { + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT( + RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); + } + } + + // Drop the membership of the second socket pair and verify data stops being + // received. + ASSERT_THAT(setsockopt(socket_pairs[1]->second_fd(), IPPROTO_IP, + IP_DROP_MEMBERSHIP, &group, sizeof(group)), + SyscallSucceeds()); + // Send a packet from each socket_pair. + for (auto& sockets : socket_pairs) { + char send_buf[200]; + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + char recv_buf[sizeof(send_buf)] = {}; + for (auto& sockets : socket_pairs) { + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, + sizeof(recv_buf), MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); + } + } +} + } // namespace testing } // namespace gvisor |