diff options
author | Chris Kuiper <ckuiper@google.com> | 2019-05-02 19:39:55 -0700 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2019-05-02 19:41:00 -0700 |
commit | 8972e47a2edb01d66c2fc6373a5663b68e3da82c (patch) | |
tree | 28195396bd4e663cfd2886ffbfd93f8d66af24dd /pkg/tcpip/stack | |
parent | 5f8225c009fcf297139c54c7b329da4aff679ece (diff) |
Support reception of multicast data on more than one socket
This requires two changes:
1) Support for more than one socket to join a given multicast group.
2) Duplicate delivery of incoming multicast packets to all sockets listening
for it.
In addition, I tweaked the code (and added a test) to disallow duplicates
IP_ADD_MEMBERSHIP calls for the same group and NIC. This is how Linux does
it.
PiperOrigin-RevId: 246437315
Change-Id: Icad8300b4a8c3f501d9b4cd283bd3beabef88b72
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/nic.go | 80 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 16 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 6 |
3 files changed, 83 insertions, 19 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 8008d9870..a4117d98e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -42,6 +42,7 @@ type NIC struct { primary map[tcpip.NetworkProtocolNumber]*ilist.List endpoints map[NetworkEndpointID]*referencedNetworkEndpoint subnets []tcpip.Subnet + mcastJoins map[NetworkEndpointID]int32 stats NICStats } @@ -79,14 +80,15 @@ const ( func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC { return &NIC{ - stack: stack, - id: id, - name: name, - linkEP: ep, - loopback: loopback, - demux: newTransportDemuxer(stack), - primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List), - endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), + stack: stack, + id: id, + name: name, + linkEP: ep, + loopback: loopback, + demux: newTransportDemuxer(stack), + primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List), + endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), + mcastJoins: make(map[NetworkEndpointID]int32), stats: NICStats{ Tx: DirectionStats{ Packets: &tcpip.StatCounter{}, @@ -384,20 +386,62 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { n.mu.Unlock() } -// RemoveAddress removes an address from n. -func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error { - n.mu.Lock() +func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error { r := n.endpoints[NetworkEndpointID{addr}] if r == nil || !r.holdsInsertRef { - n.mu.Unlock() return tcpip.ErrBadLocalAddress } r.holdsInsertRef = false - n.mu.Unlock() - r.decRef() + r.decRefLocked() + + return nil +} + +// RemoveAddress removes an address from n. +func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error { + n.mu.Lock() + defer n.mu.Unlock() + return n.removeAddressLocked(addr) +} + +// joinGroup adds a new endpoint for the given multicast address, if none +// exists yet. Otherwise it just increments its count. +func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { + n.mu.Lock() + defer n.mu.Unlock() + id := NetworkEndpointID{addr} + joins := n.mcastJoins[id] + if joins == 0 { + if _, err := n.addAddressLocked(protocol, addr, NeverPrimaryEndpoint, false); err != nil { + return err + } + } + n.mcastJoins[id] = joins + 1 + return nil +} + +// leaveGroup decrements the count for the given multicast address, and when it +// reaches zero removes the endpoint for this address. +func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { + n.mu.Lock() + defer n.mu.Unlock() + + id := NetworkEndpointID{addr} + joins := n.mcastJoins[id] + switch joins { + case 0: + // There are no joins with this address on this NIC. + return tcpip.ErrBadLocalAddress + case 1: + // This is the last one, clean up. + if err := n.removeAddressLocked(addr); err != nil { + return err + } + } + n.mcastJoins[id] = joins - 1 return nil } @@ -644,6 +688,14 @@ func (r *referencedNetworkEndpoint) decRef() { } } +// decRefLocked is the same as decRef but assumes that the NIC.mu mutex is +// locked. +func (r *referencedNetworkEndpoint) decRefLocked() { + if atomic.AddInt32(&r.refs, -1) == 0 { + r.nic.removeEndpointLocked(r) + } +} + // incRef increments the ref count. It must only be called when the caller is // known to be holding a reference to the endpoint, otherwise tryIncRef should // be used. diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index f204ca790..c82822ee2 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1062,10 +1062,22 @@ func (s *Stack) RemoveTCPProbe() { // JoinGroup joins the given multicast group on the given NIC. func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error { // TODO: notify network of subscription via igmp protocol. - return s.AddAddressWithOptions(nicID, protocol, multicastAddr, NeverPrimaryEndpoint) + s.mu.RLock() + defer s.mu.RUnlock() + + if nic, ok := s.nics[nicID]; ok { + return nic.joinGroup(protocol, multicastAddr) + } + return tcpip.ErrUnknownNICID } // LeaveGroup leaves the given multicast group on the given NIC. func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error { - return s.RemoveAddress(nicID, multicastAddr) + s.mu.RLock() + defer s.mu.RUnlock() + + if nic, ok := s.nics[nicID]; ok { + return nic.leaveGroup(multicastAddr) + } + return tcpip.ErrUnknownNICID } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index e8b562ad9..66c564613 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -141,9 +141,9 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { - // If this is a broadcast datagram, deliver the datagram to all endpoints - // managed by ep. - if id.LocalAddress == header.IPv4Broadcast { + // If this is a broadcast or multicast datagram, deliver the datagram to all + // endpoints managed by ep. + if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) { for i, endpoint := range ep.endpointsArr { // HandlePacket modifies vv, so each endpoint needs its own copy. if i == len(ep.endpointsArr)-1 { |