From 8972e47a2edb01d66c2fc6373a5663b68e3da82c Mon Sep 17 00:00:00 2001 From: Chris Kuiper Date: Thu, 2 May 2019 19:39:55 -0700 Subject: Support reception of multicast data on more than one socket This requires two changes: 1) Support for more than one socket to join a given multicast group. 2) Duplicate delivery of incoming multicast packets to all sockets listening for it. In addition, I tweaked the code (and added a test) to disallow duplicates IP_ADD_MEMBERSHIP calls for the same group and NIC. This is how Linux does it. PiperOrigin-RevId: 246437315 Change-Id: Icad8300b4a8c3f501d9b4cd283bd3beabef88b72 --- pkg/tcpip/stack/nic.go | 80 +++++++++++++++++++++++++------ pkg/tcpip/stack/stack.go | 16 ++++++- pkg/tcpip/stack/transport_demuxer.go | 6 +-- pkg/tcpip/transport/udp/endpoint.go | 39 ++++++++++----- pkg/tcpip/transport/udp/endpoint_state.go | 12 ++--- 5 files changed, 116 insertions(+), 37 deletions(-) (limited to 'pkg') diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 8008d9870..a4117d98e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -42,6 +42,7 @@ type NIC struct { primary map[tcpip.NetworkProtocolNumber]*ilist.List endpoints map[NetworkEndpointID]*referencedNetworkEndpoint subnets []tcpip.Subnet + mcastJoins map[NetworkEndpointID]int32 stats NICStats } @@ -79,14 +80,15 @@ const ( func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC { return &NIC{ - stack: stack, - id: id, - name: name, - linkEP: ep, - loopback: loopback, - demux: newTransportDemuxer(stack), - primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List), - endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), + stack: stack, + id: id, + name: name, + linkEP: ep, + loopback: loopback, + demux: newTransportDemuxer(stack), + primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List), + endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), + mcastJoins: make(map[NetworkEndpointID]int32), stats: NICStats{ Tx: DirectionStats{ Packets: &tcpip.StatCounter{}, @@ -384,20 +386,62 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { n.mu.Unlock() } -// RemoveAddress removes an address from n. -func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error { - n.mu.Lock() +func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error { r := n.endpoints[NetworkEndpointID{addr}] if r == nil || !r.holdsInsertRef { - n.mu.Unlock() return tcpip.ErrBadLocalAddress } r.holdsInsertRef = false - n.mu.Unlock() - r.decRef() + r.decRefLocked() + + return nil +} + +// RemoveAddress removes an address from n. +func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error { + n.mu.Lock() + defer n.mu.Unlock() + return n.removeAddressLocked(addr) +} + +// joinGroup adds a new endpoint for the given multicast address, if none +// exists yet. Otherwise it just increments its count. +func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { + n.mu.Lock() + defer n.mu.Unlock() + id := NetworkEndpointID{addr} + joins := n.mcastJoins[id] + if joins == 0 { + if _, err := n.addAddressLocked(protocol, addr, NeverPrimaryEndpoint, false); err != nil { + return err + } + } + n.mcastJoins[id] = joins + 1 + return nil +} + +// leaveGroup decrements the count for the given multicast address, and when it +// reaches zero removes the endpoint for this address. +func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { + n.mu.Lock() + defer n.mu.Unlock() + + id := NetworkEndpointID{addr} + joins := n.mcastJoins[id] + switch joins { + case 0: + // There are no joins with this address on this NIC. + return tcpip.ErrBadLocalAddress + case 1: + // This is the last one, clean up. + if err := n.removeAddressLocked(addr); err != nil { + return err + } + } + n.mcastJoins[id] = joins - 1 return nil } @@ -644,6 +688,14 @@ func (r *referencedNetworkEndpoint) decRef() { } } +// decRefLocked is the same as decRef but assumes that the NIC.mu mutex is +// locked. +func (r *referencedNetworkEndpoint) decRefLocked() { + if atomic.AddInt32(&r.refs, -1) == 0 { + r.nic.removeEndpointLocked(r) + } +} + // incRef increments the ref count. It must only be called when the caller is // known to be holding a reference to the endpoint, otherwise tryIncRef should // be used. diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index f204ca790..c82822ee2 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1062,10 +1062,22 @@ func (s *Stack) RemoveTCPProbe() { // JoinGroup joins the given multicast group on the given NIC. func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error { // TODO: notify network of subscription via igmp protocol. - return s.AddAddressWithOptions(nicID, protocol, multicastAddr, NeverPrimaryEndpoint) + s.mu.RLock() + defer s.mu.RUnlock() + + if nic, ok := s.nics[nicID]; ok { + return nic.joinGroup(protocol, multicastAddr) + } + return tcpip.ErrUnknownNICID } // LeaveGroup leaves the given multicast group on the given NIC. func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error { - return s.RemoveAddress(nicID, multicastAddr) + s.mu.RLock() + defer s.mu.RUnlock() + + if nic, ok := s.nics[nicID]; ok { + return nic.leaveGroup(multicastAddr) + } + return tcpip.ErrUnknownNICID } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index e8b562ad9..66c564613 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -141,9 +141,9 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { - // If this is a broadcast datagram, deliver the datagram to all endpoints - // managed by ep. - if id.LocalAddress == header.IPv4Broadcast { + // If this is a broadcast or multicast datagram, deliver the datagram to all + // endpoints managed by ep. + if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) { for i, endpoint := range ep.endpointsArr { // HandlePacket modifies vv, so each endpoint needs its own copy. if i == len(ep.endpointsArr)-1 { diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index db65a4e88..0ed0902b0 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -458,14 +458,22 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrUnknownDevice } - if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil { - return err - } + memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} e.mu.Lock() defer e.mu.Unlock() - e.multicastMemberships = append(e.multicastMemberships, multicastMembership{nicID, v.MulticastAddr}) + for _, mem := range e.multicastMemberships { + if mem == memToInsert { + return tcpip.ErrPortInUse + } + } + + if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + return err + } + + e.multicastMemberships = append(e.multicastMemberships, memToInsert) case tcpip.RemoveMembershipOption: if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { @@ -488,21 +496,28 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrUnknownDevice } - if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil { - return err - } + memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} + memToRemoveIndex := -1 e.mu.Lock() defer e.mu.Unlock() + for i, mem := range e.multicastMemberships { - if mem.nicID == nicID && mem.multicastAddr == v.MulticastAddr { - // Only remove the first match, so that each added membership above is - // paired with exactly 1 removal. - e.multicastMemberships[i] = e.multicastMemberships[len(e.multicastMemberships)-1] - e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1] + if mem == memToRemove { + memToRemoveIndex = i break } } + if memToRemoveIndex == -1 { + return tcpip.ErrBadLocalAddress + } + + if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + return err + } + + e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1] + e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1] case tcpip.MulticastLoopOption: e.mu.Lock() diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 163dcbc13..74e8e9fd5 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -66,6 +66,12 @@ func (e *endpoint) loadRcvBufSizeMax(max int) { func (e *endpoint) afterLoad() { e.stack = stack.StackFromEnv + for _, m := range e.multicastMemberships { + if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil { + panic(err) + } + } + if e.state != stateBound && e.state != stateConnected { return } @@ -103,10 +109,4 @@ func (e *endpoint) afterLoad() { if err != nil { panic(*err) } - - for _, m := range e.multicastMemberships { - if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil { - panic(err) - } - } } -- cgit v1.2.3