summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
authorChris Kuiper <ckuiper@google.com>2019-05-02 19:39:55 -0700
committerShentubot <shentubot@google.com>2019-05-02 19:41:00 -0700
commit8972e47a2edb01d66c2fc6373a5663b68e3da82c (patch)
tree28195396bd4e663cfd2886ffbfd93f8d66af24dd /pkg/tcpip/stack
parent5f8225c009fcf297139c54c7b329da4aff679ece (diff)
Support reception of multicast data on more than one socket
This requires two changes: 1) Support for more than one socket to join a given multicast group. 2) Duplicate delivery of incoming multicast packets to all sockets listening for it. In addition, I tweaked the code (and added a test) to disallow duplicates IP_ADD_MEMBERSHIP calls for the same group and NIC. This is how Linux does it. PiperOrigin-RevId: 246437315 Change-Id: Icad8300b4a8c3f501d9b4cd283bd3beabef88b72
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/nic.go80
-rw-r--r--pkg/tcpip/stack/stack.go16
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go6
3 files changed, 83 insertions, 19 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 8008d9870..a4117d98e 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -42,6 +42,7 @@ type NIC struct {
primary map[tcpip.NetworkProtocolNumber]*ilist.List
endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
subnets []tcpip.Subnet
+ mcastJoins map[NetworkEndpointID]int32
stats NICStats
}
@@ -79,14 +80,15 @@ const (
func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC {
return &NIC{
- stack: stack,
- id: id,
- name: name,
- linkEP: ep,
- loopback: loopback,
- demux: newTransportDemuxer(stack),
- primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
- endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
+ stack: stack,
+ id: id,
+ name: name,
+ linkEP: ep,
+ loopback: loopback,
+ demux: newTransportDemuxer(stack),
+ primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
+ endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
+ mcastJoins: make(map[NetworkEndpointID]int32),
stats: NICStats{
Tx: DirectionStats{
Packets: &tcpip.StatCounter{},
@@ -384,20 +386,62 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
n.mu.Unlock()
}
-// RemoveAddress removes an address from n.
-func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
- n.mu.Lock()
+func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error {
r := n.endpoints[NetworkEndpointID{addr}]
if r == nil || !r.holdsInsertRef {
- n.mu.Unlock()
return tcpip.ErrBadLocalAddress
}
r.holdsInsertRef = false
- n.mu.Unlock()
- r.decRef()
+ r.decRefLocked()
+
+ return nil
+}
+
+// RemoveAddress removes an address from n.
+func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+ return n.removeAddressLocked(addr)
+}
+
+// joinGroup adds a new endpoint for the given multicast address, if none
+// exists yet. Otherwise it just increments its count.
+func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+ id := NetworkEndpointID{addr}
+ joins := n.mcastJoins[id]
+ if joins == 0 {
+ if _, err := n.addAddressLocked(protocol, addr, NeverPrimaryEndpoint, false); err != nil {
+ return err
+ }
+ }
+ n.mcastJoins[id] = joins + 1
+ return nil
+}
+
+// leaveGroup decrements the count for the given multicast address, and when it
+// reaches zero removes the endpoint for this address.
+func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ id := NetworkEndpointID{addr}
+ joins := n.mcastJoins[id]
+ switch joins {
+ case 0:
+ // There are no joins with this address on this NIC.
+ return tcpip.ErrBadLocalAddress
+ case 1:
+ // This is the last one, clean up.
+ if err := n.removeAddressLocked(addr); err != nil {
+ return err
+ }
+ }
+ n.mcastJoins[id] = joins - 1
return nil
}
@@ -644,6 +688,14 @@ func (r *referencedNetworkEndpoint) decRef() {
}
}
+// decRefLocked is the same as decRef but assumes that the NIC.mu mutex is
+// locked.
+func (r *referencedNetworkEndpoint) decRefLocked() {
+ if atomic.AddInt32(&r.refs, -1) == 0 {
+ r.nic.removeEndpointLocked(r)
+ }
+}
+
// incRef increments the ref count. It must only be called when the caller is
// known to be holding a reference to the endpoint, otherwise tryIncRef should
// be used.
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index f204ca790..c82822ee2 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -1062,10 +1062,22 @@ func (s *Stack) RemoveTCPProbe() {
// JoinGroup joins the given multicast group on the given NIC.
func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
// TODO: notify network of subscription via igmp protocol.
- return s.AddAddressWithOptions(nicID, protocol, multicastAddr, NeverPrimaryEndpoint)
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[nicID]; ok {
+ return nic.joinGroup(protocol, multicastAddr)
+ }
+ return tcpip.ErrUnknownNICID
}
// LeaveGroup leaves the given multicast group on the given NIC.
func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
- return s.RemoveAddress(nicID, multicastAddr)
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[nicID]; ok {
+ return nic.leaveGroup(multicastAddr)
+ }
+ return tcpip.ErrUnknownNICID
}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index e8b562ad9..66c564613 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -141,9 +141,9 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
- // If this is a broadcast datagram, deliver the datagram to all endpoints
- // managed by ep.
- if id.LocalAddress == header.IPv4Broadcast {
+ // If this is a broadcast or multicast datagram, deliver the datagram to all
+ // endpoints managed by ep.
+ if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) {
for i, endpoint := range ep.endpointsArr {
// HandlePacket modifies vv, so each endpoint needs its own copy.
if i == len(ep.endpointsArr)-1 {