summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/stack/nic.go80
-rw-r--r--pkg/tcpip/stack/stack.go16
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go6
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go39
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go12
5 files changed, 116 insertions, 37 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 8008d9870..a4117d98e 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -42,6 +42,7 @@ type NIC struct {
primary map[tcpip.NetworkProtocolNumber]*ilist.List
endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
subnets []tcpip.Subnet
+ mcastJoins map[NetworkEndpointID]int32
stats NICStats
}
@@ -79,14 +80,15 @@ const (
func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC {
return &NIC{
- stack: stack,
- id: id,
- name: name,
- linkEP: ep,
- loopback: loopback,
- demux: newTransportDemuxer(stack),
- primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
- endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
+ stack: stack,
+ id: id,
+ name: name,
+ linkEP: ep,
+ loopback: loopback,
+ demux: newTransportDemuxer(stack),
+ primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
+ endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
+ mcastJoins: make(map[NetworkEndpointID]int32),
stats: NICStats{
Tx: DirectionStats{
Packets: &tcpip.StatCounter{},
@@ -384,20 +386,62 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
n.mu.Unlock()
}
-// RemoveAddress removes an address from n.
-func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
- n.mu.Lock()
+func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error {
r := n.endpoints[NetworkEndpointID{addr}]
if r == nil || !r.holdsInsertRef {
- n.mu.Unlock()
return tcpip.ErrBadLocalAddress
}
r.holdsInsertRef = false
- n.mu.Unlock()
- r.decRef()
+ r.decRefLocked()
+
+ return nil
+}
+
+// RemoveAddress removes an address from n.
+func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+ return n.removeAddressLocked(addr)
+}
+
+// joinGroup adds a new endpoint for the given multicast address, if none
+// exists yet. Otherwise it just increments its count.
+func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+ id := NetworkEndpointID{addr}
+ joins := n.mcastJoins[id]
+ if joins == 0 {
+ if _, err := n.addAddressLocked(protocol, addr, NeverPrimaryEndpoint, false); err != nil {
+ return err
+ }
+ }
+ n.mcastJoins[id] = joins + 1
+ return nil
+}
+
+// leaveGroup decrements the count for the given multicast address, and when it
+// reaches zero removes the endpoint for this address.
+func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ id := NetworkEndpointID{addr}
+ joins := n.mcastJoins[id]
+ switch joins {
+ case 0:
+ // There are no joins with this address on this NIC.
+ return tcpip.ErrBadLocalAddress
+ case 1:
+ // This is the last one, clean up.
+ if err := n.removeAddressLocked(addr); err != nil {
+ return err
+ }
+ }
+ n.mcastJoins[id] = joins - 1
return nil
}
@@ -644,6 +688,14 @@ func (r *referencedNetworkEndpoint) decRef() {
}
}
+// decRefLocked is the same as decRef but assumes that the NIC.mu mutex is
+// locked.
+func (r *referencedNetworkEndpoint) decRefLocked() {
+ if atomic.AddInt32(&r.refs, -1) == 0 {
+ r.nic.removeEndpointLocked(r)
+ }
+}
+
// incRef increments the ref count. It must only be called when the caller is
// known to be holding a reference to the endpoint, otherwise tryIncRef should
// be used.
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index f204ca790..c82822ee2 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -1062,10 +1062,22 @@ func (s *Stack) RemoveTCPProbe() {
// JoinGroup joins the given multicast group on the given NIC.
func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
// TODO: notify network of subscription via igmp protocol.
- return s.AddAddressWithOptions(nicID, protocol, multicastAddr, NeverPrimaryEndpoint)
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[nicID]; ok {
+ return nic.joinGroup(protocol, multicastAddr)
+ }
+ return tcpip.ErrUnknownNICID
}
// LeaveGroup leaves the given multicast group on the given NIC.
func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
- return s.RemoveAddress(nicID, multicastAddr)
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[nicID]; ok {
+ return nic.leaveGroup(multicastAddr)
+ }
+ return tcpip.ErrUnknownNICID
}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index e8b562ad9..66c564613 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -141,9 +141,9 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
- // If this is a broadcast datagram, deliver the datagram to all endpoints
- // managed by ep.
- if id.LocalAddress == header.IPv4Broadcast {
+ // If this is a broadcast or multicast datagram, deliver the datagram to all
+ // endpoints managed by ep.
+ if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) {
for i, endpoint := range ep.endpointsArr {
// HandlePacket modifies vv, so each endpoint needs its own copy.
if i == len(ep.endpointsArr)-1 {
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index db65a4e88..0ed0902b0 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -458,14 +458,22 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrUnknownDevice
}
- if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
- return err
- }
+ memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
e.mu.Lock()
defer e.mu.Unlock()
- e.multicastMemberships = append(e.multicastMemberships, multicastMembership{nicID, v.MulticastAddr})
+ for _, mem := range e.multicastMemberships {
+ if mem == memToInsert {
+ return tcpip.ErrPortInUse
+ }
+ }
+
+ if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ e.multicastMemberships = append(e.multicastMemberships, memToInsert)
case tcpip.RemoveMembershipOption:
if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
@@ -488,21 +496,28 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrUnknownDevice
}
- if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
- return err
- }
+ memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
+ memToRemoveIndex := -1
e.mu.Lock()
defer e.mu.Unlock()
+
for i, mem := range e.multicastMemberships {
- if mem.nicID == nicID && mem.multicastAddr == v.MulticastAddr {
- // Only remove the first match, so that each added membership above is
- // paired with exactly 1 removal.
- e.multicastMemberships[i] = e.multicastMemberships[len(e.multicastMemberships)-1]
- e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
+ if mem == memToRemove {
+ memToRemoveIndex = i
break
}
}
+ if memToRemoveIndex == -1 {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1]
+ e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
case tcpip.MulticastLoopOption:
e.mu.Lock()
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 163dcbc13..74e8e9fd5 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -66,6 +66,12 @@ func (e *endpoint) loadRcvBufSizeMax(max int) {
func (e *endpoint) afterLoad() {
e.stack = stack.StackFromEnv
+ for _, m := range e.multicastMemberships {
+ if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
+ panic(err)
+ }
+ }
+
if e.state != stateBound && e.state != stateConnected {
return
}
@@ -103,10 +109,4 @@ func (e *endpoint) afterLoad() {
if err != nil {
panic(*err)
}
-
- for _, m := range e.multicastMemberships {
- if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
- panic(err)
- }
- }
}