summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/stack/ndp_test.go75
-rw-r--r--pkg/tcpip/stack/nic.go52
2 files changed, 91 insertions, 36 deletions
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 58f1ebf60..ae326b3ab 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -421,28 +421,52 @@ func TestDADResolve(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
+ // We add a default route so the call to FindRoute below will succeed
+ // once we have an assigned address.
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv6EmptySubnet,
+ Gateway: addr3,
+ NIC: nicID,
+ }})
+
if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
}
// Address should not be considered bound to the NIC yet (DAD ongoing).
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
+ if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ } else if want := (tcpip.AddressWithPrefix{}); addr != want {
t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
}
// Make sure the address does not resolve before the resolution time has
// passed.
time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncEventTimeout)
- addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ } else if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
+ // Should not get a route even if we specify the local address as the
+ // tentative address.
+ {
+ r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false)
+ if err != tcpip.ErrNoRoute {
+ t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
+ }
+ r.Release()
}
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ {
+ r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false)
+ if err != tcpip.ErrNoRoute {
+ t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
+ }
+ r.Release()
+ }
+
+ if t.Failed() {
+ t.FailNow()
}
// Wait for DAD to resolve.
@@ -454,12 +478,33 @@ func TestDADResolve(t *testing.T) {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
- addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ } else if addr.Address != addr1 {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1)
}
- if addr.Address != addr1 {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1)
+ // Should get a route using the address now that it is resolved.
+ {
+ r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false)
+ if err != nil {
+ t.Errorf("got FindRoute(%d, '', %s, %d, false): %s", nicID, addr2, header.IPv6ProtocolNumber, err)
+ } else if r.LocalAddress != addr1 {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1)
+ }
+ r.Release()
+ }
+ {
+ r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false)
+ if err != nil {
+ t.Errorf("got FindRoute(%d, %s, %s, %d, false): %s", nicID, addr1, addr2, header.IPv6ProtocolNumber, err)
+ } else if r.LocalAddress != addr1 {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1)
+ }
+ r.Release()
+ }
+
+ if t.Failed() {
+ t.FailNow()
}
// Should not have sent any more NS messages.
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 644c0d437..afb7dfeaf 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -610,18 +610,14 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok {
// An endpoint with this id exists, check if it can be used and return it.
- switch ref.getKind() {
- case permanentExpired:
- if !spoofingOrPromiscuous {
- n.mu.RUnlock()
- return nil
- }
- fallthrough
- case temporary, permanent:
- if ref.tryIncRef() {
- n.mu.RUnlock()
- return ref
- }
+ if !ref.isAssignedRLocked(spoofingOrPromiscuous) {
+ n.mu.RUnlock()
+ return nil
+ }
+
+ if ref.tryIncRef() {
+ n.mu.RUnlock()
+ return ref
}
}
@@ -689,7 +685,6 @@ func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, add
PrefixLen: netProto.DefaultPrefixLen(),
},
}, peb, temporary, static, false)
-
return ref
}
@@ -1660,8 +1655,8 @@ func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) {
}
// isValidForOutgoing returns true if the endpoint can be used to send out a
-// packet. It requires the endpoint to not be marked expired (i.e., its address
-// has been removed), or the NIC to be in spoofing mode.
+// packet. It requires the endpoint to not be marked expired (i.e., its address)
+// has been removed) unless the NIC is in spoofing mode, or temporary.
func (r *referencedNetworkEndpoint) isValidForOutgoing() bool {
r.nic.mu.RLock()
defer r.nic.mu.RUnlock()
@@ -1669,13 +1664,28 @@ func (r *referencedNetworkEndpoint) isValidForOutgoing() bool {
return r.isValidForOutgoingRLocked()
}
-// isValidForOutgoingRLocked returns true if the endpoint can be used to send
-// out a packet. It requires the endpoint to not be marked expired (i.e., its
-// address has been removed), or the NIC to be in spoofing mode.
-//
-// r's NIC must be read locked.
+// isValidForOutgoingRLocked is the same as isValidForOutgoing but requires
+// r.nic.mu to be read locked.
func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool {
- return r.nic.mu.enabled && (r.getKind() != permanentExpired || r.nic.mu.spoofing)
+ if !r.nic.mu.enabled {
+ return false
+ }
+
+ return r.isAssignedRLocked(r.nic.mu.spoofing)
+}
+
+// isAssignedRLocked returns true if r is considered to be assigned to the NIC.
+//
+// r.nic.mu must be read locked.
+func (r *referencedNetworkEndpoint) isAssignedRLocked(spoofingOrPromiscuous bool) bool {
+ switch r.getKind() {
+ case permanentTentative:
+ return false
+ case permanentExpired:
+ return spoofingOrPromiscuous
+ default:
+ return true
+ }
}
// expireLocked decrements the reference count and marks the permanent endpoint