summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack/nic.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack/nic.go')
-rw-r--r--pkg/tcpip/stack/nic.go80
1 files changed, 66 insertions, 14 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 8008d9870..a4117d98e 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -42,6 +42,7 @@ type NIC struct {
primary map[tcpip.NetworkProtocolNumber]*ilist.List
endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
subnets []tcpip.Subnet
+ mcastJoins map[NetworkEndpointID]int32
stats NICStats
}
@@ -79,14 +80,15 @@ const (
func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC {
return &NIC{
- stack: stack,
- id: id,
- name: name,
- linkEP: ep,
- loopback: loopback,
- demux: newTransportDemuxer(stack),
- primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
- endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
+ stack: stack,
+ id: id,
+ name: name,
+ linkEP: ep,
+ loopback: loopback,
+ demux: newTransportDemuxer(stack),
+ primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
+ endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
+ mcastJoins: make(map[NetworkEndpointID]int32),
stats: NICStats{
Tx: DirectionStats{
Packets: &tcpip.StatCounter{},
@@ -384,20 +386,62 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
n.mu.Unlock()
}
-// RemoveAddress removes an address from n.
-func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
- n.mu.Lock()
+func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error {
r := n.endpoints[NetworkEndpointID{addr}]
if r == nil || !r.holdsInsertRef {
- n.mu.Unlock()
return tcpip.ErrBadLocalAddress
}
r.holdsInsertRef = false
- n.mu.Unlock()
- r.decRef()
+ r.decRefLocked()
+
+ return nil
+}
+
+// RemoveAddress removes an address from n.
+func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+ return n.removeAddressLocked(addr)
+}
+
+// joinGroup adds a new endpoint for the given multicast address, if none
+// exists yet. Otherwise it just increments its count.
+func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+ id := NetworkEndpointID{addr}
+ joins := n.mcastJoins[id]
+ if joins == 0 {
+ if _, err := n.addAddressLocked(protocol, addr, NeverPrimaryEndpoint, false); err != nil {
+ return err
+ }
+ }
+ n.mcastJoins[id] = joins + 1
+ return nil
+}
+
+// leaveGroup decrements the count for the given multicast address, and when it
+// reaches zero removes the endpoint for this address.
+func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ id := NetworkEndpointID{addr}
+ joins := n.mcastJoins[id]
+ switch joins {
+ case 0:
+ // There are no joins with this address on this NIC.
+ return tcpip.ErrBadLocalAddress
+ case 1:
+ // This is the last one, clean up.
+ if err := n.removeAddressLocked(addr); err != nil {
+ return err
+ }
+ }
+ n.mcastJoins[id] = joins - 1
return nil
}
@@ -644,6 +688,14 @@ func (r *referencedNetworkEndpoint) decRef() {
}
}
+// decRefLocked is the same as decRef but assumes that the NIC.mu mutex is
+// locked.
+func (r *referencedNetworkEndpoint) decRefLocked() {
+ if atomic.AddInt32(&r.refs, -1) == 0 {
+ r.nic.removeEndpointLocked(r)
+ }
+}
+
// incRef increments the ref count. It must only be called when the caller is
// known to be holding a reference to the endpoint, otherwise tryIncRef should
// be used.