summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/ndp_test.go8
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go22
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go32
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go102
-rw-r--r--pkg/tcpip/stack/nic.go60
-rw-r--r--pkg/tcpip/stack/nud.go2
-rw-r--r--pkg/tcpip/stack/stack.go8
-rw-r--r--pkg/tcpip/stack/stack_test.go193
8 files changed, 303 insertions, 124 deletions
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 644ba7c33..5d286ccbc 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -1689,13 +1689,7 @@ func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix)
AddressWithPrefix: item,
}
- for _, i := range list {
- if i == protocolAddress {
- return true
- }
- }
-
- return false
+ return containsAddr(list, protocolAddress)
}
// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to.
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
index 1d37716c2..27e1feec0 100644
--- a/pkg/tcpip/stack/neighbor_cache.go
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -115,17 +115,15 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, li
// channel is returned for the top level caller to block. Channel is closed
// once address resolution is complete (success or not).
func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) {
- if linkRes != nil {
- if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok {
- e := NeighborEntry{
- Addr: remoteAddr,
- LocalAddr: localAddr,
- LinkAddr: linkAddr,
- State: Static,
- UpdatedAt: time.Now(),
- }
- return e, nil, nil
+ if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok {
+ e := NeighborEntry{
+ Addr: remoteAddr,
+ LocalAddr: localAddr,
+ LinkAddr: linkAddr,
+ State: Static,
+ UpdatedAt: time.Now(),
}
+ return e, nil, nil
}
entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes)
@@ -289,8 +287,8 @@ func (n *neighborCache) setConfig(config NUDConfigurations) {
// HandleProbe implements NUDHandler.HandleProbe by following the logic defined
// in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled
// by the caller.
-func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress) {
- entry := n.getOrCreateEntry(remoteAddr, localAddr, nil)
+func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) {
+ entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes)
entry.mu.Lock()
entry.handleProbeLocked(remoteLinkAddr)
entry.mu.Unlock()
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index 4cb2c9c6b..b4fa69e3e 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -335,32 +335,6 @@ func TestNeighborCacheEntry(t *testing.T) {
}
}
-// TestNeighborCacheEntryNoLinkAddress verifies calling entry() without a
-// LinkAddressResolver returns ErrNoLinkAddress.
-func TestNeighborCacheEntryNoLinkAddress(t *testing.T) {
- nudDisp := testNUDDispatcher{}
- c := DefaultNUDConfigurations()
- clock := newFakeClock()
- neigh := newTestNeighborCache(&nudDisp, c, clock)
- store := newTestEntryStore()
-
- entry, ok := store.entry(0)
- if !ok {
- t.Fatalf("store.entry(0) not found")
- }
- _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, nil, nil)
- if err != tcpip.ErrNoLinkAddress {
- t.Errorf("got neigh.entry(%s, %s, nil, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress)
- }
-
- // No events should have been dispatched.
- nudDisp.mu.Lock()
- defer nudDisp.mu.Unlock()
- if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
- }
-}
-
func TestNeighborCacheRemoveEntry(t *testing.T) {
config := DefaultNUDConfigurations()
@@ -1048,9 +1022,9 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
t.Fatalf("c.store.entry(0) not found")
}
c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr)
- e, _, err := c.neigh.entry(entry.Addr, "", nil, nil)
+ e, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil)
if err != nil {
- t.Errorf("unexpected error from c.neigh.entry(%s, \"\", nil nil): %s", entry.Addr, err)
+ t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
Addr: entry.Addr,
@@ -1059,7 +1033,7 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
State: Static,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("c.neigh.entry(%s, \"\", nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
wantEvents := []testEntryEventInfo{
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index 08c9ccd25..b769fb2fa 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -236,7 +236,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
nudState := NewNUDState(c, rng)
linkRes := entryTestLinkResolver{}
- entry := newNeighborEntry(&nic, entryTestAddr1, entryTestAddr2, nudState, &linkRes)
+ entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, entryTestAddr2 /* localAddr */, nudState, &linkRes)
// Stub out ndpState to verify modification of default routers.
nic.mu.ndp = ndpState{
@@ -2344,6 +2344,106 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
nudDisp.mu.Unlock()
}
+// TestEntryUnknownToStaleToProbeToReachable exercises the following scenario:
+// 1. Probe is received
+// 2. Entry is created in Stale
+// 3. Packet is queued on the entry
+// 4. Entry transitions to Delay then Probe
+// 5. Probe is sent
+func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Probe to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handleProbeLocked(entryTestLinkAddr1)
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ clock.advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // Probe caused by the Delay-to-Probe transition
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
c := DefaultNUDConfigurations()
// Eliminate random factors from ReachableTime computation so the transition
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index f21066fce..eaaf756cd 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -217,6 +217,11 @@ func (n *NIC) disableLocked() *tcpip.Error {
}
if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
+ // The NIC may have already left the multicast group.
+ if err := n.leaveGroupLocked(header.IPv4AllSystems, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress {
+ return err
+ }
+
// The address may have already been removed.
if err := n.removePermanentAddressLocked(ipv4BroadcastAddr.AddressWithPrefix.Address); err != nil && err != tcpip.ErrBadLocalAddress {
return err
@@ -255,6 +260,13 @@ func (n *NIC) enable() *tcpip.Error {
if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil {
return err
}
+
+ // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts
+ // multicast group. Note, the IANA calls the all-hosts multicast group the
+ // all-systems multicast group.
+ if err := n.joinGroupLocked(header.IPv4ProtocolNumber, header.IPv4AllSystems); err != nil {
+ return err
+ }
}
// Join the IPv6 All-Nodes Multicast group if the stack is configured to
@@ -609,6 +621,9 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A
// If none exists a temporary one may be created if we are in promiscuous mode
// or spoofing. Promiscuous mode will only be checked if promiscuous is true.
// Similarly, spoofing will only be checked if spoofing is true.
+//
+// If the address is the IPv4 broadcast address for an endpoint's network, that
+// endpoint will be returned.
func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint {
n.mu.RLock()
@@ -633,6 +648,16 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
}
}
+ // Check if address is a broadcast address for the endpoint's network.
+ //
+ // Only IPv4 has a notion of broadcast addresses.
+ if protocol == header.IPv4ProtocolNumber {
+ if ref := n.getRefForBroadcastRLocked(address); ref != nil {
+ n.mu.RUnlock()
+ return ref
+ }
+ }
+
// A usable reference was not found, create a temporary one if requested by
// the caller or if the address is found in the NIC's subnets.
createTempEP := spoofingOrPromiscuous
@@ -670,8 +695,34 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
return ref
}
+// getRefForBroadcastLocked returns an endpoint where address is the IPv4
+// broadcast address for the endpoint's network.
+//
+// n.mu MUST be read locked.
+func (n *NIC) getRefForBroadcastRLocked(address tcpip.Address) *referencedNetworkEndpoint {
+ for _, ref := range n.mu.endpoints {
+ // Only IPv4 has a notion of broadcast addresses.
+ if ref.protocol != header.IPv4ProtocolNumber {
+ continue
+ }
+
+ addr := ref.addrWithPrefix()
+ subnet := addr.Subnet()
+ if subnet.IsBroadcast(address) && ref.tryIncRef() {
+ return ref
+ }
+ }
+
+ return nil
+}
+
/// getRefOrCreateTempLocked returns an existing endpoint for address or creates
/// and returns a temporary endpoint.
+//
+// If the address is the IPv4 broadcast address for an endpoint's network, that
+// endpoint will be returned.
+//
+// n.mu must be write locked.
func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok {
// No need to check the type as we are ok with expired endpoints at this
@@ -685,6 +736,15 @@ func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, add
n.removeEndpointLocked(ref)
}
+ // Check if address is a broadcast address for an endpoint's network.
+ //
+ // Only IPv4 has a notion of broadcast addresses.
+ if protocol == header.IPv4ProtocolNumber {
+ if ref := n.getRefForBroadcastRLocked(address); ref != nil {
+ return ref
+ }
+ }
+
// Add a new temporary endpoint.
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go
index f848d50ad..e1ec15487 100644
--- a/pkg/tcpip/stack/nud.go
+++ b/pkg/tcpip/stack/nud.go
@@ -177,7 +177,7 @@ type NUDHandler interface {
// Neighbor Solicitation for ARP or NDP, respectively). Validation of the
// probe needs to be performed before calling this function since the
// Neighbor Cache doesn't have access to view the NIC's assigned addresses.
- HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress)
+ HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver)
// HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP
// reply or Neighbor Advertisement for ARP or NDP, respectively).
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 7189e8e7e..5b19c5d59 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -1985,8 +1985,8 @@ func generateRandInt64() int64 {
// FindNetworkEndpoint returns the network endpoint for the given address.
func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, *tcpip.Error) {
- s.mu.Lock()
- defer s.mu.Unlock()
+ s.mu.RLock()
+ defer s.mu.RUnlock()
for _, nic := range s.nics {
id := NetworkEndpointID{address}
@@ -2005,8 +2005,8 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres
// FindNICNameFromID returns the name of the nic for the given NICID.
func (s *Stack) FindNICNameFromID(id tcpip.NICID) string {
- s.mu.Lock()
- defer s.mu.Unlock()
+ s.mu.RLock()
+ defer s.mu.RUnlock()
nic, ok := s.nics[id]
if !ok {
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index f22062889..0b6deda02 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -277,6 +277,17 @@ func (l *linkEPWithMockedAttach) isAttached() bool {
return l.attached
}
+// Checks to see if list contains an address.
+func containsAddr(list []tcpip.ProtocolAddress, item tcpip.ProtocolAddress) bool {
+ for _, i := range list {
+ if i == item {
+ return true
+ }
+ }
+
+ return false
+}
+
func TestNetworkReceive(t *testing.T) {
// Create a stack with the fake network protocol, one nic, and two
// addresses attached to it: 1 & 2.
@@ -1704,7 +1715,7 @@ func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, sub
// Trying the next address should always fail since it is outside the range.
if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addrBytes)); gotNicID != 0 {
- t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addrBytes), gotNicID, 0)
+ t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = 0", fakeNetNumber, tcpip.Address(addrBytes), gotNicID)
}
}
@@ -3089,6 +3100,13 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) {
const nicID = 1
+ broadcastAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: header.IPv4Broadcast,
+ PrefixLen: 32,
+ },
+ }
e := loopback.New()
s := stack.New(stack.Options{
@@ -3099,49 +3117,41 @@ func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) {
t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
}
- allStackAddrs := s.AllAddresses()
- allNICAddrs, ok := allStackAddrs[nicID]
- if !ok {
- t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
- }
- if l := len(allNICAddrs); l != 0 {
- t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
+ {
+ allStackAddrs := s.AllAddresses()
+ if allNICAddrs, ok := allStackAddrs[nicID]; !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ } else if containsAddr(allNICAddrs, broadcastAddr) {
+ t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr)
+ }
}
// Enabling the NIC should add the IPv4 broadcast address.
if err := s.EnableNIC(nicID); err != nil {
t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
}
- allStackAddrs = s.AllAddresses()
- allNICAddrs, ok = allStackAddrs[nicID]
- if !ok {
- t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
- }
- if l := len(allNICAddrs); l != 1 {
- t.Fatalf("got len(allNICAddrs) = %d, want = 1", l)
- }
- want := tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: header.IPv4Broadcast,
- PrefixLen: 32,
- },
- }
- if allNICAddrs[0] != want {
- t.Fatalf("got allNICAddrs[0] = %+v, want = %+v", allNICAddrs[0], want)
+
+ {
+ allStackAddrs := s.AllAddresses()
+ if allNICAddrs, ok := allStackAddrs[nicID]; !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ } else if !containsAddr(allNICAddrs, broadcastAddr) {
+ t.Fatalf("got allNICAddrs = %+v, want = %+v", allNICAddrs, broadcastAddr)
+ }
}
// Disabling the NIC should remove the IPv4 broadcast address.
if err := s.DisableNIC(nicID); err != nil {
t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
}
- allStackAddrs = s.AllAddresses()
- allNICAddrs, ok = allStackAddrs[nicID]
- if !ok {
- t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
- }
- if l := len(allNICAddrs); l != 0 {
- t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
+
+ {
+ allStackAddrs := s.AllAddresses()
+ if allNICAddrs, ok := allStackAddrs[nicID]; !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ } else if containsAddr(allNICAddrs, broadcastAddr) {
+ t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr)
+ }
}
}
@@ -3189,50 +3199,93 @@ func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) {
}
}
-func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) {
+func TestJoinLeaveMulticastOnNICEnableDisable(t *testing.T) {
const nicID = 1
- e := loopback.New()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- })
- nicOpts := stack.NICOptions{Disabled: true}
- if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
- t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
+ tests := []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ addr tcpip.Address
+ }{
+ {
+ name: "IPv6 All-Nodes",
+ proto: header.IPv6ProtocolNumber,
+ addr: header.IPv6AllNodesMulticastAddress,
+ },
+ {
+ name: "IPv4 All-Systems",
+ proto: header.IPv4ProtocolNumber,
+ addr: header.IPv4AllSystems,
+ },
}
- // Should not be in the IPv6 all-nodes multicast group yet because the NIC has
- // not been enabled yet.
- isInGroup, err := s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
- if err != nil {
- t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
- }
- if isInGroup {
- t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress)
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := loopback.New()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ })
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
- // The all-nodes multicast group should be joined when the NIC is enabled.
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
- }
- isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
- if err != nil {
- t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
- }
- if !isInGroup {
- t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, header.IPv6AllNodesMulticastAddress)
- }
+ // Should not be in the multicast group yet because the NIC has not been
+ // enabled yet.
+ if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err)
+ } else if isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr)
+ }
- // The all-nodes multicast group should be left when the NIC is disabled.
- if err := s.DisableNIC(nicID); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
- }
- isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
- if err != nil {
- t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
- }
- if isInGroup {
- t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress)
+ // The all-nodes multicast group should be joined when the NIC is enabled.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+
+ if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err)
+ } else if !isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr)
+ }
+
+ // The multicast group should be left when the NIC is disabled.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+
+ if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err)
+ } else if isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr)
+ }
+
+ // The all-nodes multicast group should be joined when the NIC is enabled.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+
+ if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err)
+ } else if !isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr)
+ }
+
+ // Leaving the group before disabling the NIC should not cause an error.
+ if err := s.LeaveGroup(test.proto, nicID, test.addr); err != nil {
+ t.Fatalf("s.LeaveGroup(%d, %d, %s): %s", test.proto, nicID, test.addr, err)
+ }
+
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+
+ if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err)
+ } else if isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr)
+ }
+ })
}
}