summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/BUILD6
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go191
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state_test.go22
-rw-r--r--pkg/tcpip/stack/conntrack.go14
-rw-r--r--pkg/tcpip/stack/forwarding_test.go62
-rw-r--r--pkg/tcpip/stack/iptables.go75
-rw-r--r--pkg/tcpip/stack/iptables_targets.go102
-rw-r--r--pkg/tcpip/stack/iptables_types.go41
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go8
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go4
-rw-r--r--pkg/tcpip/stack/ndp_test.go259
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go49
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go538
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go248
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go2609
-rw-r--r--pkg/tcpip/stack/nic.go297
-rw-r--r--pkg/tcpip/stack/nic_test.go5
-rw-r--r--pkg/tcpip/stack/nud.go8
-rw-r--r--pkg/tcpip/stack/packet_buffer.go60
-rw-r--r--pkg/tcpip/stack/pending_packets.go2
-rw-r--r--pkg/tcpip/stack/registration.go70
-rw-r--r--pkg/tcpip/stack/route.go370
-rw-r--r--pkg/tcpip/stack/stack.go425
-rw-r--r--pkg/tcpip/stack/stack_test.go869
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go63
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go15
-rw-r--r--pkg/tcpip/stack/transport_test.go156
27 files changed, 4041 insertions, 2527 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index d09ebe7fa..9cc6074da 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test", "most_shards")
load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
@@ -112,7 +112,7 @@ go_test(
"transport_demuxer_test.go",
"transport_test.go",
],
- shard_count = 20,
+ shard_count = most_shards,
deps = [
":stack",
"//pkg/rand",
@@ -120,6 +120,7 @@ go_test(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
@@ -131,7 +132,6 @@ go_test(
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
- "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
)
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index 4d3acab96..cd423bf71 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
-var _ GroupAddressableEndpoint = (*AddressableEndpointState)(nil)
var _ AddressableEndpoint = (*AddressableEndpointState)(nil)
// AddressableEndpointState is an implementation of an AddressableEndpoint.
@@ -37,10 +36,6 @@ type AddressableEndpointState struct {
endpoints map[tcpip.Address]*addressState
primary []*addressState
-
- // groups holds the mapping between group addresses and the number of times
- // they have been joined.
- groups map[tcpip.Address]uint32
}
}
@@ -53,65 +48,33 @@ func (a *AddressableEndpointState) Init(networkEndpoint NetworkEndpoint) {
a.mu.Lock()
defer a.mu.Unlock()
a.mu.endpoints = make(map[tcpip.Address]*addressState)
- a.mu.groups = make(map[tcpip.Address]uint32)
-}
-
-// ReadOnlyAddressableEndpointState provides read-only access to an
-// AddressableEndpointState.
-type ReadOnlyAddressableEndpointState struct {
- inner *AddressableEndpointState
}
-// AddrOrMatching returns an endpoint for the passed address that is consisdered
-// bound to the wrapped AddressableEndpointState.
+// GetAddress returns the AddressEndpoint for the passed address.
//
-// If addr is an exact match with an existing address, that address is returned.
-// Otherwise, f is called with each address and the address that f returns true
-// for is returned.
+// GetAddress does not increment the address's reference count or check if the
+// address is considered bound to the endpoint.
//
-// Returns nil of no address matches.
-func (m ReadOnlyAddressableEndpointState) AddrOrMatching(addr tcpip.Address, spoofingOrPrimiscuous bool, f func(AddressEndpoint) bool) AddressEndpoint {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
-
- if ep, ok := m.inner.mu.endpoints[addr]; ok {
- if ep.IsAssigned(spoofingOrPrimiscuous) && ep.IncRef() {
- return ep
- }
- }
-
- for _, ep := range m.inner.mu.endpoints {
- if ep.IsAssigned(spoofingOrPrimiscuous) && f(ep) && ep.IncRef() {
- return ep
- }
- }
-
- return nil
-}
-
-// Lookup returns the AddressEndpoint for the passed address.
-//
-// Returns nil if the passed address is not associated with the
-// AddressableEndpointState.
-func (m ReadOnlyAddressableEndpointState) Lookup(addr tcpip.Address) AddressEndpoint {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
+// Returns nil if the passed address is not associated with the endpoint.
+func (a *AddressableEndpointState) GetAddress(addr tcpip.Address) AddressEndpoint {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
- ep, ok := m.inner.mu.endpoints[addr]
+ ep, ok := a.mu.endpoints[addr]
if !ok {
return nil
}
return ep
}
-// ForEach calls f for each address pair.
+// ForEachEndpoint calls f for each address.
//
-// If f returns false, f is no longer be called.
-func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool) {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
+// Once f returns false, f will no longer be called.
+func (a *AddressableEndpointState) ForEachEndpoint(f func(AddressEndpoint) bool) {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
- for _, ep := range m.inner.mu.endpoints {
+ for _, ep := range a.mu.endpoints {
if !f(ep) {
return
}
@@ -120,18 +83,16 @@ func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool)
// ForEachPrimaryEndpoint calls f for each primary address.
//
-// If f returns false, f is no longer be called.
-func (m ReadOnlyAddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
- for _, ep := range m.inner.mu.primary {
- f(ep)
- }
-}
+// Once f returns false, f will no longer be called.
+func (a *AddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint) bool) {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
-// ReadOnly returns a readonly reference to a.
-func (a *AddressableEndpointState) ReadOnly() ReadOnlyAddressableEndpointState {
- return ReadOnlyAddressableEndpointState{inner: a}
+ for _, ep := range a.mu.primary {
+ if !f(ep) {
+ return
+ }
+ }
}
func (a *AddressableEndpointState) releaseAddressState(addrState *addressState) {
@@ -272,6 +233,9 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
addrState = &addressState{
addressableEndpointState: a,
addr: addr,
+ // Cache the subnet in addrState to avoid calls to addr.Subnet() as that
+ // results in allocations on every call.
+ subnet: addr.Subnet(),
}
a.mu.endpoints[addr.Address] = addrState
addrState.mu.Lock()
@@ -332,11 +296,6 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
a.mu.Lock()
defer a.mu.Unlock()
-
- if _, ok := a.mu.groups[addr]; ok {
- panic(fmt.Sprintf("group address = %s must be removed with LeaveGroup", addr))
- }
-
return a.removePermanentAddressLocked(addr)
}
@@ -361,6 +320,8 @@ func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) *
return tcpip.ErrInvalidEndpointState
}
+ a.mu.Lock()
+ defer a.mu.Unlock()
return a.removePermanentEndpointLocked(addrState)
}
@@ -466,8 +427,19 @@ func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*ad
return deprecatedEndpoint
}
-// AcquireAssignedAddress implements AddressableEndpoint.
-func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
+// AcquireAssignedAddressOrMatching returns an address endpoint that is
+// considered assigned to the addressable endpoint.
+//
+// If the address is an exact match with an existing address, that address is
+// returned. Otherwise, if f is provided, f is called with each address and
+// the address that f returns true for is returned.
+//
+// If there is no matching address, a temporary address will be returned if
+// allowTemp is true.
+//
+// Regardless how the address was obtained, it will be acquired before it is
+// returned.
+func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tcpip.Address, f func(AddressEndpoint) bool, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
a.mu.Lock()
defer a.mu.Unlock()
@@ -483,6 +455,14 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres
return addrState
}
+ if f != nil {
+ for _, addrState := range a.mu.endpoints {
+ if addrState.IsAssigned(allowTemp) && f(addrState) && addrState.IncRef() {
+ return addrState
+ }
+ }
+ }
+
if !allowTemp {
return nil
}
@@ -515,6 +495,11 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres
return ep
}
+// AcquireAssignedAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
+ return a.AcquireAssignedAddressOrMatching(localAddr, nil, allowTemp, tempPEB)
+}
+
// AcquireOutgoingPrimaryAddress implements AddressableEndpoint.
func (a *AddressableEndpointState) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint {
a.mu.RLock()
@@ -583,72 +568,11 @@ func (a *AddressableEndpointState) PermanentAddresses() []tcpip.AddressWithPrefi
return addrs
}
-// JoinGroup implements GroupAddressableEndpoint.
-func (a *AddressableEndpointState) JoinGroup(group tcpip.Address) (bool, *tcpip.Error) {
- a.mu.Lock()
- defer a.mu.Unlock()
-
- joins, ok := a.mu.groups[group]
- if !ok {
- ep, err := a.addAndAcquireAddressLocked(group.WithPrefix(), NeverPrimaryEndpoint, AddressConfigStatic, false /* deprecated */, true /* permanent */)
- if err != nil {
- return false, err
- }
- // We have no need for the address endpoint.
- a.decAddressRefLocked(ep)
- }
-
- a.mu.groups[group] = joins + 1
- return !ok, nil
-}
-
-// LeaveGroup implements GroupAddressableEndpoint.
-func (a *AddressableEndpointState) LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) {
- a.mu.Lock()
- defer a.mu.Unlock()
-
- joins, ok := a.mu.groups[group]
- if !ok {
- return false, tcpip.ErrBadLocalAddress
- }
-
- if joins == 1 {
- a.removeGroupAddressLocked(group)
- delete(a.mu.groups, group)
- return true, nil
- }
-
- a.mu.groups[group] = joins - 1
- return false, nil
-}
-
-// IsInGroup implements GroupAddressableEndpoint.
-func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool {
- a.mu.RLock()
- defer a.mu.RUnlock()
- _, ok := a.mu.groups[group]
- return ok
-}
-
-func (a *AddressableEndpointState) removeGroupAddressLocked(group tcpip.Address) {
- if err := a.removePermanentAddressLocked(group); err != nil {
- // removePermanentEndpointLocked would only return an error if group is
- // not bound to the addressable endpoint, but we know it MUST be assigned
- // since we have group in our map of groups.
- panic(fmt.Sprintf("error removing group address = %s: %s", group, err))
- }
-}
-
// Cleanup forcefully leaves all groups and removes all permanent addresses.
func (a *AddressableEndpointState) Cleanup() {
a.mu.Lock()
defer a.mu.Unlock()
- for group := range a.mu.groups {
- a.removeGroupAddressLocked(group)
- }
- a.mu.groups = make(map[tcpip.Address]uint32)
-
for _, ep := range a.mu.endpoints {
// removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is
// not a permanent address.
@@ -664,7 +588,7 @@ var _ AddressEndpoint = (*addressState)(nil)
type addressState struct {
addressableEndpointState *AddressableEndpointState
addr tcpip.AddressWithPrefix
-
+ subnet tcpip.Subnet
// Lock ordering (from outer to inner lock ordering):
//
// AddressableEndpointState.mu
@@ -684,6 +608,11 @@ func (a *addressState) AddressWithPrefix() tcpip.AddressWithPrefix {
return a.addr
}
+// Subnet implements AddressEndpoint.
+func (a *addressState) Subnet() tcpip.Subnet {
+ return a.subnet
+}
+
// GetKind implements AddressEndpoint.
func (a *addressState) GetKind() AddressKind {
a.mu.RLock()
diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go
index 26787d0a3..140f146f6 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state_test.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go
@@ -53,25 +53,9 @@ func TestAddressableEndpointStateCleanup(t *testing.T) {
ep.DecRef()
}
- group := tcpip.Address("\x02")
- if added, err := s.JoinGroup(group); err != nil {
- t.Fatalf("s.JoinGroup(%s): %s", group, err)
- } else if !added {
- t.Fatalf("got s.JoinGroup(%s) = false, want = true", group)
- }
- if !s.IsInGroup(group) {
- t.Fatalf("got s.IsInGroup(%s) = false, want = true", group)
- }
-
s.Cleanup()
- {
- ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint)
- if ep != nil {
- ep.DecRef()
- t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix())
- }
- }
- if s.IsInGroup(group) {
- t.Fatalf("got s.IsInGroup(%s) = true, want = false", group)
+ if ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint); ep != nil {
+ ep.DecRef()
+ t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix())
}
}
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 0cd1da11f..9a17efcba 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -269,7 +269,7 @@ func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) {
return nil, dirOriginal
}
-func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *RedirectTarget) *conn {
+func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn {
tid, err := packetToTupleID(pkt)
if err != nil {
return nil
@@ -282,8 +282,8 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *Redire
// rule. This tuple will be used to manipulate the packet in
// handlePacket.
replyTID := tid.reply()
- replyTID.srcAddr = rt.Addr
- replyTID.srcPort = rt.Port
+ replyTID.srcAddr = address
+ replyTID.srcPort = port
var manip manipType
switch hook {
case Prerouting:
@@ -401,12 +401,12 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
// Calculate the TCP checksum and set it.
tcpHeader.SetChecksum(0)
- length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
- xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length)
+ length := uint16(len(tcpHeader) + pkt.Data.Size())
+ xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
if gso != nil && gso.NeedsCsum {
tcpHeader.SetChecksum(xsum)
- } else if r.Capabilities()&CapabilityTXChecksumOffload == 0 {
- xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, int(tcpHeader.DataOffset()), pkt.Data.Size())
+ } else if r.RequiresTXTransportChecksum() {
+ xsum = header.ChecksumVV(pkt.Data, xsum)
tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
}
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index cf042309e..5ec9b3411 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -73,9 +73,31 @@ func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
return 123
}
-func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) {
- // Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
+func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) {
+ netHdr := pkt.NetworkHeader().View()
+ _, dst := f.proto.ParseAddresses(netHdr)
+
+ addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), CanBePrimaryEndpoint)
+ if addressEndpoint != nil {
+ addressEndpoint.DecRef()
+ // Dispatch the packet to the transport protocol.
+ f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(netHdr[protocolNumberOffset]), pkt)
+ return
+ }
+
+ r, err := f.proto.stack.FindRoute(0, "", dst, fwdTestNetNumber, false /* multicastLoop */)
+ if err != nil {
+ return
+ }
+ defer r.Release()
+
+ vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
+ pkt = NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: vv.ToView().ToVectorisedView(),
+ })
+ // TODO(b/143425874) Decrease the TTL field in forwarded packets.
+ _ = r.WriteHeaderIncludedPacket(pkt)
}
func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -106,8 +128,13 @@ func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuf
panic("not implemented")
}
-func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error {
+ // The network header should not already be populated.
+ if _, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen); !ok {
+ return tcpip.ErrMalformedHeader
+ }
+
+ return f.nic.WritePacket(r, nil /* gso */, fwdTestNetNumber, pkt)
}
func (f *fwdTestNetworkEndpoint) Close() {
@@ -117,6 +144,8 @@ func (f *fwdTestNetworkEndpoint) Close() {
// fwdTestNetworkProtocol is a network-layer protocol that implements Address
// resolution.
type fwdTestNetworkProtocol struct {
+ stack *Stack
+
addrCache *linkAddrCache
neigh *neighborCache
addrResolveDelay time.Duration
@@ -178,7 +207,7 @@ func (*fwdTestNetworkProtocol) Close() {}
func (*fwdTestNetworkProtocol) Wait() {}
-func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
+func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
if f.onLinkAddressResolved != nil {
time.AfterFunc(f.addrResolveDelay, func() {
f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr)
@@ -280,7 +309,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress {
func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
p := fwdTestPacketInfo{
- RemoteLinkAddress: r.RemoteLinkAddress,
+ RemoteLinkAddress: r.RemoteLinkAddress(),
LocalLinkAddress: r.LocalLinkAddress,
Pkt: pkt,
}
@@ -304,20 +333,6 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer
return n, nil
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- p := fwdTestPacketInfo{
- Pkt: NewPacketBuffer(PacketBufferOptions{Data: vv}),
- }
-
- select {
- case e.C <- p:
- default:
- }
-
- return nil
-}
-
// Wait implements stack.LinkEndpoint.Wait.
func (*fwdTestLinkEndpoint) Wait() {}
@@ -334,7 +349,10 @@ func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protoco
func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) {
// Create a stack with the network protocol and two NICs.
s := New(Options{
- NetworkProtocols: []NetworkProtocolFactory{func(*Stack) NetworkProtocol { return proto }},
+ NetworkProtocols: []NetworkProtocolFactory{func(s *Stack) NetworkProtocol {
+ proto.stack = s
+ return proto
+ }},
UseNeighborCache: useNeighborCache,
})
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 8d6d9a7f1..2d8c883cd 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -22,30 +22,17 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
)
-// tableID is an index into IPTables.tables.
-type tableID int
+// TableID identifies a specific table.
+type TableID int
+// Each value identifies a specific table.
const (
- natID tableID = iota
- mangleID
- filterID
- numTables
+ NATID TableID = iota
+ MangleID
+ FilterID
+ NumTables
)
-// Table names.
-const (
- NATTable = "nat"
- MangleTable = "mangle"
- FilterTable = "filter"
-)
-
-// nameToID is immutable.
-var nameToID = map[string]tableID{
- NATTable: natID,
- MangleTable: mangleID,
- FilterTable: filterID,
-}
-
// HookUnset indicates that there is no hook set for an entrypoint or
// underflow.
const HookUnset = -1
@@ -57,8 +44,8 @@ const reaperDelay = 5 * time.Second
// all packets.
func DefaultTables() *IPTables {
return &IPTables{
- v4Tables: [numTables]Table{
- natID: Table{
+ v4Tables: [NumTables]Table{
+ NATID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
@@ -81,7 +68,7 @@ func DefaultTables() *IPTables {
Postrouting: 3,
},
},
- mangleID: Table{
+ MangleID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
@@ -99,7 +86,7 @@ func DefaultTables() *IPTables {
Postrouting: HookUnset,
},
},
- filterID: Table{
+ FilterID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
@@ -122,8 +109,8 @@ func DefaultTables() *IPTables {
},
},
},
- v6Tables: [numTables]Table{
- natID: Table{
+ v6Tables: [NumTables]Table{
+ NATID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
@@ -146,7 +133,7 @@ func DefaultTables() *IPTables {
Postrouting: 3,
},
},
- mangleID: Table{
+ MangleID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
@@ -164,7 +151,7 @@ func DefaultTables() *IPTables {
Postrouting: HookUnset,
},
},
- filterID: Table{
+ FilterID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
@@ -187,10 +174,10 @@ func DefaultTables() *IPTables {
},
},
},
- priorities: [NumHooks][]tableID{
- Prerouting: []tableID{mangleID, natID},
- Input: []tableID{natID, filterID},
- Output: []tableID{mangleID, natID, filterID},
+ priorities: [NumHooks][]TableID{
+ Prerouting: []TableID{MangleID, NATID},
+ Input: []TableID{NATID, FilterID},
+ Output: []TableID{MangleID, NATID, FilterID},
},
connections: ConnTrack{
seed: generateRandUint32(),
@@ -229,26 +216,20 @@ func EmptyNATTable() Table {
}
}
-// GetTable returns a table by name.
-func (it *IPTables) GetTable(name string, ipv6 bool) (Table, bool) {
- id, ok := nameToID[name]
- if !ok {
- return Table{}, false
- }
+// GetTable returns a table with the given id and IP version. It panics when an
+// invalid id is provided.
+func (it *IPTables) GetTable(id TableID, ipv6 bool) Table {
it.mu.RLock()
defer it.mu.RUnlock()
if ipv6 {
- return it.v6Tables[id], true
+ return it.v6Tables[id]
}
- return it.v4Tables[id], true
+ return it.v4Tables[id]
}
-// ReplaceTable replaces or inserts table by name.
-func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Error {
- id, ok := nameToID[name]
- if !ok {
- return tcpip.ErrInvalidOptionValue
- }
+// ReplaceTable replaces or inserts table by name. It panics when an invalid id
+// is provided.
+func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) *tcpip.Error {
it.mu.Lock()
defer it.mu.Unlock()
// If iptables is being enabled, initialize the conntrack table and
@@ -311,7 +292,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer
for _, tableID := range priorities {
// If handlePacket already NATed the packet, we don't need to
// check the NAT table.
- if tableID == natID && pkt.NatDone {
+ if tableID == NATID && pkt.NatDone {
continue
}
var table Table
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 538c4625d..d63e9757c 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -15,6 +15,8 @@
package stack
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -26,13 +28,6 @@ type AcceptTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (at *AcceptTarget) ID() TargetID {
- return TargetID{
- NetworkProtocol: at.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleAccept, 0
@@ -44,22 +39,11 @@ type DropTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (dt *DropTarget) ID() TargetID {
- return TargetID{
- NetworkProtocol: dt.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleDrop, 0
}
-// ErrorTargetName is used to mark targets as error targets. Error targets
-// shouldn't be reached - an error has occurred if we fall through to one.
-const ErrorTargetName = "ERROR"
-
// ErrorTarget logs an error and drops the packet. It represents a target that
// should be unreachable.
type ErrorTarget struct {
@@ -67,14 +51,6 @@ type ErrorTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (et *ErrorTarget) ID() TargetID {
- return TargetID{
- Name: ErrorTargetName,
- NetworkProtocol: et.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
@@ -90,14 +66,6 @@ type UserChainTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (uc *UserChainTarget) ID() TargetID {
- return TargetID{
- Name: ErrorTargetName,
- NetworkProtocol: uc.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
@@ -110,50 +78,39 @@ type ReturnTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (rt *ReturnTarget) ID() TargetID {
- return TargetID{
- NetworkProtocol: rt.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleReturn, 0
}
-// RedirectTargetName is used to mark targets as redirect targets. Redirect
-// targets should be reached for only NAT and Mangle tables. These targets will
-// change the destination port/destination IP for packets.
-const RedirectTargetName = "REDIRECT"
-
-// RedirectTarget redirects the packet by modifying the destination port/IP.
+// RedirectTarget redirects the packet to this machine by modifying the
+// destination port/IP. Outgoing packets are redirected to the loopback device,
+// and incoming packets are redirected to the incoming interface (rather than
+// forwarded).
+//
// TODO(gvisor.dev/issue/170): Other flags need to be added after we support
// them.
type RedirectTarget struct {
- // Addr indicates address used to redirect.
- Addr tcpip.Address
-
- // Port indicates port used to redirect.
+ // Port indicates port used to redirect. It is immutable.
Port uint16
- // NetworkProtocol is the network protocol the target is used with.
+ // NetworkProtocol is the network protocol the target is used with. It
+ // is immutable.
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (rt *RedirectTarget) ID() TargetID {
- return TargetID{
- Name: RedirectTargetName,
- NetworkProtocol: rt.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
// TODO(gvisor.dev/issue/170): Parse headers without copying. The current
-// implementation only works for PREROUTING and calls pkt.Clone(), neither
+// implementation only works for Prerouting and calls pkt.Clone(), neither
// of which should be the case.
func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
+ // Sanity check.
+ if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
+ panic(fmt.Sprintf(
+ "RedirectTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
+ rt.NetworkProtocol, pkt.NetworkProtocolNumber))
+ }
+
// Packet is already manipulated.
if pkt.NatDone {
return RuleAccept, 0
@@ -164,17 +121,17 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs
return RuleDrop, 0
}
- // Change the address to localhost (127.0.0.1 or ::1) in Output and to
+ // Change the address to loopback (127.0.0.1 or ::1) in Output and to
// the primary address of the incoming interface in Prerouting.
switch hook {
case Output:
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- rt.Addr = tcpip.Address([]byte{127, 0, 0, 1})
+ address = tcpip.Address([]byte{127, 0, 0, 1})
} else {
- rt.Addr = header.IPv6Loopback
+ address = header.IPv6Loopback
}
case Prerouting:
- rt.Addr = address
+ // No-op, as address is already set correctly.
default:
panic("redirect target is supported only on output and prerouting hooks")
}
@@ -189,21 +146,18 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs
// Calculate UDP checksum and set it.
if hook == Output {
udpHeader.SetChecksum(0)
+ netHeader := pkt.Network()
+ netHeader.SetDestinationAddress(address)
// Only calculate the checksum if offloading isn't supported.
- if r.Capabilities()&CapabilityTXChecksumOffload == 0 {
+ if r.RequiresTXTransportChecksum() {
length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
- xsum := r.PseudoHeaderChecksum(protocol, length)
- for _, v := range pkt.Data.Views() {
- xsum = header.Checksum(v, xsum)
- }
- udpHeader.SetChecksum(0)
+ xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
+ xsum = header.ChecksumVV(pkt.Data, xsum)
udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
}
}
- pkt.Network().SetDestinationAddress(rt.Addr)
-
// After modification, IPv4 packets need a valid checksum.
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
netHeader := header.IPv4(pkt.NetworkHeader().View())
@@ -219,7 +173,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs
// Set up conection for matching NAT rule. Only the first
// packet of the connection comes here. Other packets will be
// manipulated in connection tracking.
- if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil {
+ if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil {
ct.handlePacket(pkt, hook, gso, r)
}
default:
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 7b3f3e88b..4b86c1be9 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -37,7 +37,6 @@ import (
// ----->[Prerouting]----->routing----->[Forward]---------[Postrouting]----->
type Hook uint
-// These values correspond to values in include/uapi/linux/netfilter.h.
const (
// Prerouting happens before a packet is routed to applications or to
// be forwarded.
@@ -86,8 +85,8 @@ type IPTables struct {
mu sync.RWMutex
// v4Tables and v6tables map tableIDs to tables. They hold builtin
// tables only, not user tables. mu must be locked for accessing.
- v4Tables [numTables]Table
- v6Tables [numTables]Table
+ v4Tables [NumTables]Table
+ v6Tables [NumTables]Table
// modified is whether tables have been modified at least once. It is
// used to elide the iptables performance overhead for workloads that
// don't utilize iptables.
@@ -96,7 +95,7 @@ type IPTables struct {
// priorities maps each hook to a list of table names. The order of the
// list is the order in which each table should be visited for that
// hook. It is immutable.
- priorities [NumHooks][]tableID
+ priorities [NumHooks][]TableID
connections ConnTrack
@@ -104,6 +103,24 @@ type IPTables struct {
reaperDone chan struct{}
}
+// VisitTargets traverses all the targets of all tables and replaces each with
+// transform(target).
+func (it *IPTables) VisitTargets(transform func(Target) Target) {
+ it.mu.Lock()
+ defer it.mu.Unlock()
+
+ for tid := range it.v4Tables {
+ for i, rule := range it.v4Tables[tid].Rules {
+ it.v4Tables[tid].Rules[i].Target = transform(rule.Target)
+ }
+ }
+ for tid := range it.v6Tables {
+ for i, rule := range it.v6Tables[tid].Rules {
+ it.v6Tables[tid].Rules[i].Target = transform(rule.Target)
+ }
+ }
+}
+
// A Table defines a set of chains and hooks into the network stack.
//
// It is a list of Rules, entry points (BuiltinChains), and error handlers
@@ -169,7 +186,6 @@ type IPHeaderFilter struct {
// CheckProtocol determines whether the Protocol field should be
// checked during matching.
- // TODO(gvisor.dev/issue/3549): Check this field during matching.
CheckProtocol bool
// Dst matches the destination IP address.
@@ -309,23 +325,8 @@ type Matcher interface {
Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
}
-// A TargetID uniquely identifies a target.
-type TargetID struct {
- // Name is the target name as stored in the xt_entry_target struct.
- Name string
-
- // NetworkProtocol is the protocol to which the target applies.
- NetworkProtocol tcpip.NetworkProtocolNumber
-
- // Revision is the version of the target.
- Revision uint8
-}
-
// A Target is the interface for taking an action for a packet.
type Target interface {
- // ID uniquely identifies the Target.
- ID() TargetID
-
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the index of the rule to jump to.
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index 6f73a0ce4..c9b13cd0e 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -180,7 +180,7 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt
}
// get reports any known link address for k.
-func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
if linkRes != nil {
if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok {
return addr, nil, nil
@@ -221,7 +221,7 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo
}
entry.done = make(chan struct{})
- go c.startAddressResolution(k, linkRes, localAddr, linkEP, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
+ go c.startAddressResolution(k, linkRes, localAddr, nic, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
}
return entry.linkAddr, entry.done, tcpip.ErrWouldBlock
@@ -240,11 +240,11 @@ func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) {
}
}
-func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, done <-chan struct{}) {
+func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) {
for i := 0; ; i++ {
// Send link request, then wait for the timeout limit and check
// whether the request succeeded.
- linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, linkEP)
+ linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, nic)
select {
case now := <-time.After(c.resolutionTimeout):
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index 33806340e..d2e37f38d 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -49,8 +49,8 @@ type testLinkAddressResolver struct {
onLinkAddressRequest func()
}
-func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error {
- time.AfterFunc(r.delay, func() { r.fakeRequest(addr) })
+func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
+ time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) })
if f := r.onLinkAddressRequest; f != nil {
f()
}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 73a01c2dd..03d7b4e0d 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
@@ -352,7 +353,7 @@ func TestDADDisabled(t *testing.T) {
}
// We should not have sent any NDP NS messages.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 {
+ if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != 0 {
t.Fatalf("got NeighborSolicit = %d, want = 0", got)
}
}
@@ -465,14 +466,18 @@ func TestDADResolve(t *testing.T) {
if err != tcpip.ErrNoRoute {
t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
}
- r.Release()
+ if r != nil {
+ r.Release()
+ }
}
{
r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false)
if err != tcpip.ErrNoRoute {
t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
}
- r.Release()
+ if r != nil {
+ r.Release()
+ }
}
if t.Failed() {
@@ -510,7 +515,9 @@ func TestDADResolve(t *testing.T) {
} else if r.LocalAddress != addr1 {
t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1)
}
- r.Release()
+ if r != nil {
+ r.Release()
+ }
}
if t.Failed() {
@@ -518,7 +525,7 @@ func TestDADResolve(t *testing.T) {
}
// Should not have sent any more NS messages.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) {
+ if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) {
t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits)
}
@@ -533,8 +540,8 @@ func TestDADResolve(t *testing.T) {
// Make sure the right remote link address is used.
snmc := header.SolicitedNodeAddr(addr1)
- if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want {
- t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(snmc); got != want {
+ t.Errorf("got remote link address = %s, want = %s", got, want)
}
// Check NDP NS packet.
@@ -563,18 +570,18 @@ func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(tgt)
snmc := header.SolicitedNodeAddr(tgt)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{}))
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: 255,
- SrcAddr: header.IPv6Any,
- DstAddr: snmc,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: 255,
+ SrcAddr: header.IPv6Any,
+ DstAddr: snmc,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()}))
}
@@ -605,7 +612,7 @@ func TestDADFail(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
pkt := header.ICMPv6(hdr.Prepend(naSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na := header.NDPNeighborAdvert(pkt.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(true)
na.SetTargetAddress(tgt)
@@ -616,11 +623,11 @@ func TestDADFail(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: 255,
- SrcAddr: tgt,
- DstAddr: header.IPv6AllNodesMulticastAddress,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: 255,
+ SrcAddr: tgt,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()}))
},
@@ -666,7 +673,7 @@ func TestDADFail(t *testing.T) {
// Receive a packet to simulate an address conflict.
test.rxPkt(e, addr1)
- stat := test.getStat(s.Stats().ICMP.V6PacketsReceived)
+ stat := test.getStat(s.Stats().ICMP.V6.PacketsReceived)
if got := stat.Value(); got != 1 {
t.Fatalf("got stat = %d, want = 1", got)
}
@@ -803,7 +810,7 @@ func TestDADStop(t *testing.T) {
}
// Should not have sent more than 1 NS message.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 {
+ if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got > 1 {
t.Errorf("got NeighborSolicit = %d, want <= 1", got)
}
})
@@ -982,7 +989,7 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
pkt.SetType(header.ICMPv6RouterAdvert)
pkt.SetCode(0)
- raPayload := pkt.NDPPayload()
+ raPayload := pkt.MessageBody()
ra := header.NDPRouterAdvert(raPayload)
// Populate the Router Lifetime.
binary.BigEndian.PutUint16(raPayload[2:], rl)
@@ -1004,11 +1011,11 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
payloadLength := hdr.UsedLength()
iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
iph.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: header.NDPHopLimit,
- SrcAddr: ip,
- DstAddr: header.IPv6AllNodesMulticastAddress,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: ip,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
})
return stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -2162,8 +2169,8 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
NDPConfigs: ipv6.NDPConfigurations{
AutoGenTempGlobalAddresses: true,
},
- NDPDisp: &ndpDisp,
- AutoGenIPv6LinkLocal: true,
+ NDPDisp: &ndpDisp,
+ AutoGenLinkLocal: true,
})},
})
@@ -2843,9 +2850,7 @@ func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress
t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
if err := ep.Connect(addr); err != nil {
t.Fatalf("ep.Connect(%+v): %s", addr, err)
}
@@ -2879,9 +2884,7 @@ func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullA
t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
if err := ep.Bind(addr); err != nil {
t.Fatalf("ep.Bind(%+v): %s", addr, err)
}
@@ -3250,9 +3253,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute {
t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, tcpip.ErrNoRoute)
@@ -4044,9 +4045,9 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
ndpConfigs.AutoGenAddressConflictRetries = maxRetries
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
+ AutoGenLinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: func(_ tcpip.NICID, nicName string) string {
return nicName
@@ -4179,9 +4180,9 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
- NDPConfigs: addrType.ndpConfigs,
- NDPDisp: &ndpDisp,
+ AutoGenLinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: addrType.ndpConfigs,
+ NDPDisp: &ndpDisp,
})},
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -4708,7 +4709,7 @@ func TestCleanupNDPState(t *testing.T) {
}
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: true,
+ AutoGenLinkLocal: true,
NDPConfigs: ipv6.NDPConfigurations{
HandleRAs: true,
DiscoverDefaultRouters: true,
@@ -5174,113 +5175,99 @@ func TestRouterSolicitation(t *testing.T) {
},
}
- // This Run will not return until the parallel tests finish.
- //
- // We need this because we need to do some teardown work after the
- // parallel tests complete.
- //
- // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
- // more details.
- t.Run("group", func(t *testing.T) {
- for _, test := range tests {
- test := test
-
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ e := channelLinkWithHeaderLength{
+ Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
+ headerLength: test.linkHeaderLen,
+ }
+ e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ waitForPkt := func(timeout time.Duration) {
+ t.Helper()
- e := channelLinkWithHeaderLength{
- Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
- headerLength: test.linkHeaderLen,
+ clock.Advance(timeout)
+ p, ok := e.Read()
+ if !ok {
+ t.Fatal("expected router solicitation packet")
}
- e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- waitForPkt := func(timeout time.Duration) {
- t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- p, ok := e.ReadContext(ctx)
- if !ok {
- t.Fatal("timed out waiting for packet")
- return
- }
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
- // Make sure the right remote link address is used.
- if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want {
- t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
- }
+ // Make sure the right remote link address is used.
+ if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); got != want {
+ t.Errorf("got remote link address = %s, want = %s", got, want)
+ }
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(test.expectedSrcAddr),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
- )
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(test.expectedSrcAddr),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
+ )
- if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
- t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
- }
- }
- waitForNothing := func(timeout time.Duration) {
- t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- if _, ok := e.ReadContext(ctx); ok {
- t.Fatal("unexpectedly got a packet")
- }
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- MaxRtrSolicitations: test.maxRtrSolicit,
- RtrSolicitationInterval: test.rtrSolicitInt,
- MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
- },
- })},
- })
- if err := s.CreateNIC(nicID, &e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
}
+ }
+ waitForNothing := func(timeout time.Duration) {
+ t.Helper()
- if addr := test.nicAddr; addr != "" {
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
- }
+ clock.Advance(timeout)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("unexpectedly got a packet = %#v", p)
}
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ MaxRtrSolicitations: test.maxRtrSolicit,
+ RtrSolicitationInterval: test.rtrSolicitInt,
+ MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
+ },
+ })},
+ Clock: clock,
+ })
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
- // Make sure each RS is sent at the right time.
- remaining := test.maxRtrSolicit
- if remaining > 0 {
- waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncPositiveEventTimeout)
- remaining--
+ if addr := test.nicAddr; addr != "" {
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
}
+ }
- for ; remaining > 0; remaining-- {
- if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
- waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncNegativeEventTimeout)
- waitForPkt(defaultAsyncPositiveEventTimeout)
- } else {
- waitForPkt(test.effectiveRtrSolicitInt + defaultAsyncPositiveEventTimeout)
- }
- }
+ // Make sure each RS is sent at the right time.
+ remaining := test.maxRtrSolicit
+ if remaining > 0 {
+ waitForPkt(test.effectiveMaxRtrSolicitDelay)
+ remaining--
+ }
- // Make sure no more RS.
- if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
- waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncNegativeEventTimeout)
+ for ; remaining > 0; remaining-- {
+ if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
+ waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond)
+ waitForPkt(time.Nanosecond)
} else {
- waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncNegativeEventTimeout)
+ waitForPkt(test.effectiveRtrSolicitInt)
}
+ }
- // Make sure the counter got properly
- // incremented.
- if got, want := s.Stats().ICMP.V6PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
- t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
- }
- })
- }
- })
+ // Make sure no more RS.
+ if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
+ waitForNothing(test.effectiveRtrSolicitInt)
+ } else {
+ waitForNothing(test.effectiveMaxRtrSolicitDelay)
+ }
+
+ if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
+ t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
+ }
+ })
+ }
}
func TestStopStartSolicitingRouters(t *testing.T) {
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
index 4df288798..317f6871d 100644
--- a/pkg/tcpip/stack/neighbor_cache.go
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -16,7 +16,6 @@ package stack
import (
"fmt"
- "time"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
@@ -25,9 +24,16 @@ import (
const neighborCacheSize = 512 // max entries per interface
+// NeighborStats holds metrics for the neighbor table.
+type NeighborStats struct {
+ // FailedEntryLookups counts the number of lookups performed on an entry in
+ // Failed state.
+ FailedEntryLookups *tcpip.StatCounter
+}
+
// neighborCache maps IP addresses to link addresses. It uses the Least
// Recently Used (LRU) eviction strategy to implement a bounded cache for
-// dynmically acquired entries. It contains the state machine and configuration
+// dynamically acquired entries. It contains the state machine and configuration
// for running Neighbor Unreachability Detection (NUD).
//
// There are two types of entries in the neighbor cache:
@@ -68,7 +74,7 @@ var _ NUDHandler = (*neighborCache)(nil)
// reset to state incomplete, and returned. If no matching entry exists and the
// cache is not full, a new entry with state incomplete is allocated and
// returned.
-func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry {
+func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry {
n.mu.Lock()
defer n.mu.Unlock()
@@ -84,7 +90,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, li
// The entry that needs to be created must be dynamic since all static
// entries are directly added to the cache via addStaticEntry.
- entry := newNeighborEntry(n.nic, remoteAddr, localAddr, n.state, linkRes)
+ entry := newNeighborEntry(n.nic, remoteAddr, n.state, linkRes)
if n.dynamic.count == neighborCacheSize {
e := n.dynamic.lru.Back()
e.mu.Lock()
@@ -111,28 +117,31 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, li
// provided, it will be notified when address resolution is complete (success
// or not).
//
+// If specified, the local address must be an address local to the interface the
+// neighbor cache belongs to. The local address is the source address of a
+// packet prompting NUD/link address resolution.
+//
// If address resolution is required, ErrNoLinkAddress and a notification
// channel is returned for the top level caller to block. Channel is closed
// once address resolution is complete (success or not).
func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) {
if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok {
e := NeighborEntry{
- Addr: remoteAddr,
- LocalAddr: localAddr,
- LinkAddr: linkAddr,
- State: Static,
- UpdatedAt: time.Now(),
+ Addr: remoteAddr,
+ LinkAddr: linkAddr,
+ State: Static,
+ UpdatedAtNanos: 0,
}
return e, nil, nil
}
- entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes)
+ entry := n.getOrCreateEntry(remoteAddr, linkRes)
entry.mu.Lock()
defer entry.mu.Unlock()
switch s := entry.neigh.State; s {
case Stale:
- entry.handlePacketQueuedLocked()
+ entry.handlePacketQueuedLocked(localAddr)
fallthrough
case Reachable, Static, Delay, Probe:
// As per RFC 4861 section 7.3.3:
@@ -152,7 +161,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
entry.done = make(chan struct{})
}
- entry.handlePacketQueuedLocked()
+ entry.handlePacketQueuedLocked(localAddr)
return entry.neigh, entry.done, tcpip.ErrWouldBlock
case Failed:
return entry.neigh, nil, tcpip.ErrNoLinkAddress
@@ -173,14 +182,15 @@ func (n *neighborCache) removeWaker(addr tcpip.Address, waker *sleep.Waker) {
// entries returns all entries in the neighbor cache.
func (n *neighborCache) entries() []NeighborEntry {
- entries := make([]NeighborEntry, 0, len(n.cache))
n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ entries := make([]NeighborEntry, 0, len(n.cache))
for _, entry := range n.cache {
entry.mu.RLock()
entries = append(entries, entry.neigh)
entry.mu.RUnlock()
}
- n.mu.RUnlock()
return entries
}
@@ -207,7 +217,7 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd
} else {
// Static entry found with the same address but different link address.
entry.neigh.LinkAddr = linkAddr
- entry.dispatchChangeEventLocked(entry.neigh.State)
+ entry.dispatchChangeEventLocked()
entry.mu.Unlock()
return
}
@@ -220,11 +230,12 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd
entry.mu.Unlock()
}
- entry := newStaticNeighborEntry(n.nic, addr, linkAddr, n.state)
- n.cache[addr] = entry
+ n.cache[addr] = newStaticNeighborEntry(n.nic, addr, linkAddr, n.state)
}
// removeEntryLocked removes the specified entry from the neighbor cache.
+//
+// Prerequisite: n.mu and entry.mu MUST be locked.
func (n *neighborCache) removeEntryLocked(entry *neighborEntry) {
if entry.neigh.State != Static {
n.dynamic.lru.Remove(entry)
@@ -292,8 +303,8 @@ func (n *neighborCache) setConfig(config NUDConfigurations) {
// HandleProbe implements NUDHandler.HandleProbe by following the logic defined
// in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled
// by the caller.
-func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) {
- entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes)
+func (n *neighborCache) HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) {
+ entry := n.getOrCreateEntry(remoteAddr, linkRes)
entry.mu.Lock()
entry.handleProbeLocked(remoteLinkAddr)
entry.mu.Unlock()
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index fcd54ed83..732a299f7 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -61,39 +61,39 @@ const (
)
// entryDiffOpts returns the options passed to cmp.Diff to compare neighbor
-// entries. The UpdatedAt field is ignored due to a lack of a deterministic
-// method to predict the time that an event will be dispatched.
+// entries. The UpdatedAtNanos field is ignored due to a lack of a
+// deterministic method to predict the time that an event will be dispatched.
func entryDiffOpts() []cmp.Option {
return []cmp.Option{
- cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"),
+ cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAtNanos"),
}
}
// entryDiffOptsWithSort is like entryDiffOpts but also includes an option to
// sort slices of entries for cases where ordering must be ignored.
func entryDiffOptsWithSort() []cmp.Option {
- return []cmp.Option{
- cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"),
- cmpopts.SortSlices(func(a, b NeighborEntry) bool {
- return strings.Compare(string(a.Addr), string(b.Addr)) < 0
- }),
- }
+ return append(entryDiffOpts(), cmpopts.SortSlices(func(a, b NeighborEntry) bool {
+ return strings.Compare(string(a.Addr), string(b.Addr)) < 0
+ }))
}
func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache {
config.resetInvalidFields()
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
- return &neighborCache{
+ neigh := &neighborCache{
nic: &NIC{
stack: &Stack{
clock: clock,
nudDisp: nudDisp,
},
- id: 1,
+ id: 1,
+ stats: makeNICStats(),
},
state: NewNUDState(config, rng),
cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
}
+ neigh.nic.neigh = neigh
+ return neigh
}
// testEntryStore contains a set of IP to NeighborEntry mappings.
@@ -128,9 +128,8 @@ func newTestEntryStore() *testEntryStore {
linkAddr := toLinkAddress(i)
store.entriesMap[addr] = NeighborEntry{
- Addr: addr,
- LocalAddr: testEntryLocalAddr,
- LinkAddr: linkAddr,
+ Addr: addr,
+ LinkAddr: linkAddr,
}
}
return store
@@ -195,10 +194,10 @@ type testNeighborResolver struct {
var _ LinkAddressResolver = (*testNeighborResolver)(nil)
-func (r *testNeighborResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
+func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
// Delay handling the request to emulate network latency.
r.clock.AfterFunc(r.delay, func() {
- r.fakeRequest(addr)
+ r.fakeRequest(targetAddr)
})
// Execute post address resolution action, if available.
@@ -294,9 +293,8 @@ func TestNeighborCacheEntry(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
@@ -305,15 +303,19 @@ func TestNeighborCacheEntry(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -324,8 +326,8 @@ func TestNeighborCacheEntry(t *testing.T) {
t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != nil {
- t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil {
+ t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
}
// No more events should have been dispatched.
@@ -354,9 +356,9 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
@@ -365,15 +367,19 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -391,9 +397,11 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -404,8 +412,8 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
}
}
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
}
@@ -452,8 +460,8 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
if !ok {
return fmt.Errorf("c.store.entry(%d) not found", i)
}
- if _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil); err != tcpip.ErrWouldBlock {
- return fmt.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
+ return fmt.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(c.neigh.config().RetransmitTimer)
@@ -470,23 +478,29 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
wantEvents = append(wantEvents, testEntryEventInfo{
EventType: entryTestRemoved,
NICID: 1,
- Addr: removedEntry.Addr,
- LinkAddr: removedEntry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: removedEntry.Addr,
+ LinkAddr: removedEntry.LinkAddr,
+ State: Reachable,
+ },
})
}
wantEvents = append(wantEvents, testEntryEventInfo{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
}, testEntryEventInfo{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
})
c.nudDisp.mu.Lock()
@@ -508,10 +522,9 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
return fmt.Errorf("c.store.entry(%d) not found", i)
}
wantEntry := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
}
wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
}
@@ -564,24 +577,27 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
if !ok {
t.Fatalf("c.store.entry(0) not found")
}
- _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(c.neigh.config().RetransmitTimer)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -600,9 +616,11 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -640,9 +658,11 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -682,9 +702,11 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T)
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -703,9 +725,11 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T)
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -740,9 +764,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -760,9 +786,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -800,24 +828,27 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
if !ok {
t.Fatalf("c.store.entry(0) not found")
}
- _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -836,16 +867,20 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -861,10 +896,9 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
startAtEntryIndex: 1,
wantStaticEntries: []NeighborEntry{
{
- Addr: entry.Addr,
- LocalAddr: "", // static entries don't need a local address
- LinkAddr: staticLinkAddr,
- State: Static,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
},
},
}
@@ -896,12 +930,12 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w)
+ _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w)
if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, _ = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ t.Fatalf("got neigh.entry(%s, '', _, _ = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
if doneCh == nil {
- t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
+ t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr)
}
clock.Advance(typicalLatency)
@@ -913,7 +947,7 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) {
id, ok := s.Fetch(false /* block */)
if !ok {
- t.Errorf("expected waker to be notified after neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
+ t.Errorf("expected waker to be notified after neigh.entry(%s, '', _, _)", entry.Addr)
}
if id != wakerID {
t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID)
@@ -923,15 +957,19 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -964,12 +1002,12 @@ func TestNeighborCacheRemoveWaker(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w)
+ _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w)
if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, _) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ t.Fatalf("got neigh.entry(%s, '', _, _) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
if doneCh == nil {
- t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
+ t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr)
}
// Remove the waker before the neighbor cache has the opportunity to send a
@@ -991,15 +1029,19 @@ func TestNeighborCacheRemoveWaker(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -1028,10 +1070,9 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: "", // static entries don't need a local address
- LinkAddr: entry.LinkAddr,
- State: Static,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Static,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
@@ -1041,9 +1082,11 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -1058,10 +1101,9 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
startAtEntryIndex: 1,
wantStaticEntries: []NeighborEntry{
{
- Addr: entry.Addr,
- LocalAddr: "", // static entries don't need a local address
- LinkAddr: entry.LinkAddr,
- State: Static,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Static,
},
},
}
@@ -1089,9 +1131,8 @@ func TestNeighborCacheClear(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
@@ -1099,15 +1140,19 @@ func TestNeighborCacheClear(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -1126,9 +1171,11 @@ func TestNeighborCacheClear(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Static,
+ },
},
}
nudDisp.mu.Lock()
@@ -1149,16 +1196,20 @@ func TestNeighborCacheClear(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Static,
+ },
},
}
nudDisp.mu.Lock()
@@ -1185,24 +1236,27 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) {
if !ok {
t.Fatalf("c.store.entry(0) not found")
}
- _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -1220,9 +1274,11 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -1274,29 +1330,33 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if !ok {
t.Fatalf("store.entry(%d) not found", i)
}
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
select {
case <-doneCh:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr)
}
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -1312,9 +1372,8 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
for i := neighborCacheSize; i < store.size(); i++ {
// Periodically refresh the frequently used entry
if i%(neighborCacheSize/2) == 0 {
- _, _, err := neigh.entry(frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, linkRes, nil)
- if err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, err)
+ if _, _, err := neigh.entry(frequentlyUsedEntry.Addr, "", linkRes, nil); err != nil {
+ t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", frequentlyUsedEntry.Addr, err)
}
}
@@ -1322,15 +1381,15 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if !ok {
t.Fatalf("store.entry(%d) not found", i)
}
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
select {
case <-doneCh:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr)
}
// An entry should have been removed, as per the LRU eviction strategy
@@ -1342,22 +1401,28 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: removedEntry.Addr,
- LinkAddr: removedEntry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: removedEntry.Addr,
+ LinkAddr: removedEntry.LinkAddr,
+ State: Reachable,
+ },
},
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -1374,10 +1439,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
// have to be sorted before comparison.
wantUnsortedEntries := []NeighborEntry{
{
- Addr: frequentlyUsedEntry.Addr,
- LocalAddr: frequentlyUsedEntry.LocalAddr,
- LinkAddr: frequentlyUsedEntry.LinkAddr,
- State: Reachable,
+ Addr: frequentlyUsedEntry.Addr,
+ LinkAddr: frequentlyUsedEntry.LinkAddr,
+ State: Reachable,
},
}
@@ -1387,10 +1451,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
t.Fatalf("store.entry(%d) not found", i)
}
wantEntry := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
}
wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
}
@@ -1430,9 +1493,8 @@ func TestNeighborCacheConcurrent(t *testing.T) {
wg.Add(1)
go func(entry NeighborEntry) {
defer wg.Done()
- e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- if err != nil && err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, entry.LocalAddr, e, err, tcpip.ErrWouldBlock)
+ if e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil && err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, '', _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock)
}
}(entry)
}
@@ -1456,10 +1518,9 @@ func TestNeighborCacheConcurrent(t *testing.T) {
t.Errorf("store.entry(%d) not found", i)
}
wantEntry := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
}
wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
}
@@ -1488,37 +1549,36 @@ func TestNeighborCacheReplace(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
select {
case <-doneCh:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr)
}
// Verify the entry exists
{
- e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ e, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
}
if doneCh != nil {
- t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh)
+ t.Errorf("unexpected done channel from neigh.entry(%s, '', _, nil): %v", entry.Addr, doneCh)
}
if t.Failed() {
t.FailNow()
}
want := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
}
@@ -1542,37 +1602,35 @@ func TestNeighborCacheReplace(t *testing.T) {
//
// Verify the entry's new link address and the new state.
{
- e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ e, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != nil {
- t.Fatalf("neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ t.Fatalf("neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: updatedLinkAddr,
- State: Delay,
+ Addr: entry.Addr,
+ LinkAddr: updatedLinkAddr,
+ State: Delay,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
clock.Advance(config.DelayFirstProbeTime + typicalLatency)
}
// Verify that the neighbor is now reachable.
{
- e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ e, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
clock.Advance(typicalLatency)
if err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: updatedLinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: updatedLinkAddr,
+ State: Reachable,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
}
}
@@ -1601,35 +1659,34 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
- got, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ got, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != nil {
- t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
}
if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
// Verify that address resolution for an unknown address returns ErrNoLinkAddress
before := atomic.LoadUint32(&requestCount)
entry.Addr += "2"
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes)
clock.Advance(waitFor)
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress {
+ t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress)
}
maxAttempts := neigh.config().MaxUnicastProbes
@@ -1659,13 +1716,13 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
clock.Advance(waitFor)
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress {
+ t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress)
}
}
@@ -1683,18 +1740,17 @@ func TestNeighborCacheStaticResolution(t *testing.T) {
delay: typicalLatency,
}
- got, _, err := neigh.entry(testEntryBroadcastAddr, testEntryLocalAddr, linkRes, nil)
+ got, _, err := neigh.entry(testEntryBroadcastAddr, "", linkRes, nil)
if err != nil {
- t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", testEntryBroadcastAddr, testEntryLocalAddr, err)
+ t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", testEntryBroadcastAddr, err)
}
want := NeighborEntry{
- Addr: testEntryBroadcastAddr,
- LocalAddr: testEntryLocalAddr,
- LinkAddr: testEntryBroadcastLinkAddr,
- State: Static,
+ Addr: testEntryBroadcastAddr,
+ LinkAddr: testEntryBroadcastLinkAddr,
+ State: Static,
}
if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, testEntryLocalAddr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff)
}
}
@@ -1719,9 +1775,9 @@ func BenchmarkCacheClear(b *testing.B) {
if !ok {
b.Fatalf("store.entry(%d) not found", i)
}
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != tcpip.ErrWouldBlock {
- b.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ b.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
if doneCh != nil {
<-doneCh
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index be61a21af..32399b4f5 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -24,13 +24,18 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
)
+const (
+ // immediateDuration is a duration of zero for scheduling work that needs to
+ // be done immediately but asynchronously to avoid deadlock.
+ immediateDuration time.Duration = 0
+)
+
// NeighborEntry describes a neighboring device in the local network.
type NeighborEntry struct {
- Addr tcpip.Address
- LocalAddr tcpip.Address
- LinkAddr tcpip.LinkAddress
- State NeighborState
- UpdatedAt time.Time
+ Addr tcpip.Address
+ LinkAddr tcpip.LinkAddress
+ State NeighborState
+ UpdatedAtNanos int64
}
// NeighborState defines the state of a NeighborEntry within the Neighbor
@@ -106,35 +111,35 @@ type neighborEntry struct {
// state, Unknown. Transition out of Unknown by calling either
// `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created
// neighborEntry.
-func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, localAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry {
+func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry {
return &neighborEntry{
nic: nic,
linkRes: linkRes,
nudState: nudState,
neigh: NeighborEntry{
- Addr: remoteAddr,
- LocalAddr: localAddr,
- State: Unknown,
+ Addr: remoteAddr,
+ State: Unknown,
},
}
}
-// newStaticNeighborEntry creates a neighbor cache entry starting at the Static
-// state. The entry can only transition out of Static by directly calling
-// `setStateLocked`.
+// newStaticNeighborEntry creates a neighbor cache entry starting at the
+// Static state. The entry can only transition out of Static by directly
+// calling `setStateLocked`.
func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry {
+ entry := NeighborEntry{
+ Addr: addr,
+ LinkAddr: linkAddr,
+ State: Static,
+ UpdatedAtNanos: nic.stack.clock.NowNanoseconds(),
+ }
if nic.stack.nudDisp != nil {
- nic.stack.nudDisp.OnNeighborAdded(nic.id, addr, linkAddr, Static, time.Now())
+ nic.stack.nudDisp.OnNeighborAdded(nic.id, entry)
}
return &neighborEntry{
nic: nic,
nudState: state,
- neigh: NeighborEntry{
- Addr: addr,
- LinkAddr: linkAddr,
- State: Static,
- UpdatedAt: time.Now(),
- },
+ neigh: entry,
}
}
@@ -165,17 +170,17 @@ func (e *neighborEntry) notifyWakersLocked() {
// dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has
// been added.
-func (e *neighborEntry) dispatchAddEventLocked(nextState NeighborState) {
+func (e *neighborEntry) dispatchAddEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
- nudDisp.OnNeighborAdded(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now())
+ nudDisp.OnNeighborAdded(e.nic.id, e.neigh)
}
}
// dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry
// has changed state or link-layer address.
-func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) {
+func (e *neighborEntry) dispatchChangeEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
- nudDisp.OnNeighborChanged(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now())
+ nudDisp.OnNeighborChanged(e.nic.id, e.neigh)
}
}
@@ -183,7 +188,7 @@ func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) {
// has been removed.
func (e *neighborEntry) dispatchRemoveEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
- nudDisp.OnNeighborRemoved(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, e.neigh.State, time.Now())
+ nudDisp.OnNeighborRemoved(e.nic.id, e.neigh)
}
}
@@ -201,68 +206,24 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
prev := e.neigh.State
e.neigh.State = next
- e.neigh.UpdatedAt = time.Now()
+ e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds()
config := e.nudState.Config()
switch next {
case Incomplete:
- var retryCounter uint32
- var sendMulticastProbe func()
-
- sendMulticastProbe = func() {
- if retryCounter == config.MaxMulticastProbes {
- // "If no Neighbor Advertisement is received after
- // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed.
- // The sender MUST return ICMP destination unreachable indications with
- // code 3 (Address Unreachable) for each packet queued awaiting address
- // resolution." - RFC 4861 section 7.2.2
- //
- // There is no need to send an ICMP destination unreachable indication
- // since the failure to resolve the address is expected to only occur
- // on this node. Thus, redirecting traffic is currently not supported.
- //
- // "If the error occurs on a node other than the node originating the
- // packet, an ICMP error message is generated. If the error occurs on
- // the originating node, an implementation is not required to actually
- // create and send an ICMP error packet to the source, as long as the
- // upper-layer sender is notified through an appropriate mechanism
- // (e.g. return value from a procedure call). Note, however, that an
- // implementation may find it convenient in some cases to return errors
- // to the sender by taking the offending packet, generating an ICMP
- // error message, and then delivering it (locally) through the generic
- // error-handling routines.' - RFC 4861 section 2.1
- e.dispatchRemoveEventLocked()
- e.setStateLocked(Failed)
- return
- }
-
- if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.LinkEndpoint); err != nil {
- // There is no need to log the error here; the NUD implementation may
- // assume a working link. A valid link should be the responsibility of
- // the NIC/stack.LinkEndpoint.
- e.dispatchRemoveEventLocked()
- e.setStateLocked(Failed)
- return
- }
-
- retryCounter++
- e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe)
- e.job.Schedule(config.RetransmitTimer)
- }
-
- sendMulticastProbe()
+ panic(fmt.Sprintf("should never transition to Incomplete with setStateLocked; neigh = %#v, prev state = %s", e.neigh, prev))
case Reachable:
e.job = e.nic.stack.newJob(&e.mu, func() {
- e.dispatchChangeEventLocked(Stale)
e.setStateLocked(Stale)
+ e.dispatchChangeEventLocked()
})
e.job.Schedule(e.nudState.ReachableTime())
case Delay:
e.job = e.nic.stack.newJob(&e.mu, func() {
- e.dispatchChangeEventLocked(Probe)
e.setStateLocked(Probe)
+ e.dispatchChangeEventLocked()
})
e.job.Schedule(config.DelayFirstProbeTime)
@@ -277,28 +238,27 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
return
}
- if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.LinkEndpoint); err != nil {
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr, e.nic); err != nil {
e.dispatchRemoveEventLocked()
e.setStateLocked(Failed)
return
}
retryCounter++
- if retryCounter == config.MaxUnicastProbes {
- e.dispatchRemoveEventLocked()
- e.setStateLocked(Failed)
- return
- }
-
e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe)
e.job.Schedule(config.RetransmitTimer)
}
- sendUnicastProbe()
+ // Send a probe in another gorountine to free this thread of execution
+ // for finishing the state transition. This is necessary to avoid
+ // deadlock where sending and processing probes are done synchronously,
+ // such as loopback and integration tests.
+ e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe)
+ e.job.Schedule(immediateDuration)
case Failed:
e.notifyWakersLocked()
- e.job = e.nic.stack.newJob(&e.mu, func() {
+ e.job = e.nic.stack.newJob(&doubleLock{first: &e.nic.neigh.mu, second: &e.mu}, func() {
e.nic.neigh.removeEntryLocked(e)
})
e.job.Schedule(config.UnreachableTime)
@@ -315,19 +275,82 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
// being queued for outgoing transmission.
//
// Follows the logic defined in RFC 4861 section 7.3.3.
-func (e *neighborEntry) handlePacketQueuedLocked() {
+func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
switch e.neigh.State {
case Unknown:
- e.dispatchAddEventLocked(Incomplete)
- e.setStateLocked(Incomplete)
+ e.neigh.State = Incomplete
+ e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds()
+
+ e.dispatchAddEventLocked()
+
+ config := e.nudState.Config()
+
+ var retryCounter uint32
+ var sendMulticastProbe func()
+
+ sendMulticastProbe = func() {
+ if retryCounter == config.MaxMulticastProbes {
+ // "If no Neighbor Advertisement is received after
+ // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed.
+ // The sender MUST return ICMP destination unreachable indications with
+ // code 3 (Address Unreachable) for each packet queued awaiting address
+ // resolution." - RFC 4861 section 7.2.2
+ //
+ // There is no need to send an ICMP destination unreachable indication
+ // since the failure to resolve the address is expected to only occur
+ // on this node. Thus, redirecting traffic is currently not supported.
+ //
+ // "If the error occurs on a node other than the node originating the
+ // packet, an ICMP error message is generated. If the error occurs on
+ // the originating node, an implementation is not required to actually
+ // create and send an ICMP error packet to the source, as long as the
+ // upper-layer sender is notified through an appropriate mechanism
+ // (e.g. return value from a procedure call). Note, however, that an
+ // implementation may find it convenient in some cases to return errors
+ // to the sender by taking the offending packet, generating an ICMP
+ // error message, and then delivering it (locally) through the generic
+ // error-handling routines.' - RFC 4861 section 2.1
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ // As per RFC 4861 section 7.2.2:
+ //
+ // If the source address of the packet prompting the solicitation is the
+ // same as one of the addresses assigned to the outgoing interface, that
+ // address SHOULD be placed in the IP Source Address of the outgoing
+ // solicitation.
+ //
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, "", e.nic); err != nil {
+ // There is no need to log the error here; the NUD implementation may
+ // assume a working link. A valid link should be the responsibility of
+ // the NIC/stack.LinkEndpoint.
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ retryCounter++
+ e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe)
+ e.job.Schedule(config.RetransmitTimer)
+ }
+
+ // Send a probe in another gorountine to free this thread of execution
+ // for finishing the state transition. This is necessary to avoid
+ // deadlock where sending and processing probes are done synchronously,
+ // such as loopback and integration tests.
+ e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe)
+ e.job.Schedule(immediateDuration)
case Stale:
- e.dispatchChangeEventLocked(Delay)
e.setStateLocked(Delay)
+ e.dispatchChangeEventLocked()
- case Incomplete, Reachable, Delay, Probe, Static, Failed:
+ case Incomplete, Reachable, Delay, Probe, Static:
// Do nothing
-
+ case Failed:
+ e.nic.stats.Neighbor.FailedEntryLookups.Increment()
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
}
@@ -345,21 +368,21 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) {
switch e.neigh.State {
case Unknown, Incomplete, Failed:
e.neigh.LinkAddr = remoteLinkAddr
- e.dispatchAddEventLocked(Stale)
e.setStateLocked(Stale)
e.notifyWakersLocked()
+ e.dispatchAddEventLocked()
case Reachable, Delay, Probe:
if e.neigh.LinkAddr != remoteLinkAddr {
e.neigh.LinkAddr = remoteLinkAddr
- e.dispatchChangeEventLocked(Stale)
e.setStateLocked(Stale)
+ e.dispatchChangeEventLocked()
}
case Stale:
if e.neigh.LinkAddr != remoteLinkAddr {
e.neigh.LinkAddr = remoteLinkAddr
- e.dispatchChangeEventLocked(Stale)
+ e.dispatchChangeEventLocked()
}
case Static:
@@ -393,12 +416,11 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
e.neigh.LinkAddr = linkAddr
if flags.Solicited {
- e.dispatchChangeEventLocked(Reachable)
e.setStateLocked(Reachable)
} else {
- e.dispatchChangeEventLocked(Stale)
e.setStateLocked(Stale)
}
+ e.dispatchChangeEventLocked()
e.isRouter = flags.IsRouter
e.notifyWakersLocked()
@@ -411,8 +433,8 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
if isLinkAddrDifferent {
if !flags.Override {
if e.neigh.State == Reachable {
- e.dispatchChangeEventLocked(Stale)
e.setStateLocked(Stale)
+ e.dispatchChangeEventLocked()
}
break
}
@@ -421,23 +443,24 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
if !flags.Solicited {
if e.neigh.State != Stale {
- e.dispatchChangeEventLocked(Stale)
e.setStateLocked(Stale)
+ e.dispatchChangeEventLocked()
} else {
// Notify the LinkAddr change, even though NUD state hasn't changed.
- e.dispatchChangeEventLocked(e.neigh.State)
+ e.dispatchChangeEventLocked()
}
break
}
}
if flags.Solicited && (flags.Override || !isLinkAddrDifferent) {
- if e.neigh.State != Reachable {
- e.dispatchChangeEventLocked(Reachable)
- }
+ wasReachable := e.neigh.State == Reachable
// Set state to Reachable again to refresh timers.
e.setStateLocked(Reachable)
e.notifyWakersLocked()
+ if !wasReachable {
+ e.dispatchChangeEventLocked()
+ }
}
if e.isRouter && !flags.IsRouter && header.IsV6UnicastAddress(e.neigh.Addr) {
@@ -475,11 +498,12 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
func (e *neighborEntry) handleUpperLevelConfirmationLocked() {
switch e.neigh.State {
case Reachable, Stale, Delay, Probe:
- if e.neigh.State != Reachable {
- e.dispatchChangeEventLocked(Reachable)
- // Set state to Reachable again to refresh timers.
- }
+ wasReachable := e.neigh.State == Reachable
+ // Set state to Reachable again to refresh timers.
e.setStateLocked(Reachable)
+ if !wasReachable {
+ e.dispatchChangeEventLocked()
+ }
case Unknown, Incomplete, Failed, Static:
// Do nothing
@@ -488,3 +512,23 @@ func (e *neighborEntry) handleUpperLevelConfirmationLocked() {
panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
}
}
+
+// doubleLock combines two locks into one while maintaining lock ordering.
+//
+// TODO(gvisor.dev/issue/4796): Remove this once subsequent traffic to a Failed
+// neighbor is allowed.
+type doubleLock struct {
+ first, second sync.Locker
+}
+
+// Lock locks both locks in order: first then second.
+func (l *doubleLock) Lock() {
+ l.first.Lock()
+ l.second.Lock()
+}
+
+// Unlock unlocks both locks in reverse order: second then first.
+func (l *doubleLock) Unlock() {
+ l.second.Unlock()
+ l.first.Unlock()
+}
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index 3ee2a3b31..c497d3932 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -47,24 +47,27 @@ const (
entryTestNetDefaultMTU = 65536
)
+// runImmediatelyScheduledJobs runs all jobs scheduled to run at the current
+// time.
+func runImmediatelyScheduledJobs(clock *faketime.ManualClock) {
+ clock.Advance(immediateDuration)
+}
+
// eventDiffOpts are the options passed to cmp.Diff to compare entry events.
-// The UpdatedAt field is ignored due to a lack of a deterministic method to
-// predict the time that an event will be dispatched.
+// The UpdatedAtNanos field is ignored due to a lack of a deterministic method
+// to predict the time that an event will be dispatched.
func eventDiffOpts() []cmp.Option {
return []cmp.Option{
- cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"),
+ cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAtNanos"),
}
}
// eventDiffOptsWithSort is like eventDiffOpts but also includes an option to
// sort slices of events for cases where ordering must be ignored.
func eventDiffOptsWithSort() []cmp.Option {
- return []cmp.Option{
- cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"),
- cmpopts.SortSlices(func(a, b testEntryEventInfo) bool {
- return strings.Compare(string(a.Addr), string(b.Addr)) < 0
- }),
- }
+ return append(eventDiffOpts(), cmpopts.SortSlices(func(a, b testEntryEventInfo) bool {
+ return strings.Compare(string(a.Entry.Addr), string(b.Entry.Addr)) < 0
+ }))
}
// The following unit tests exercise every state transition and verify its
@@ -86,7 +89,7 @@ func eventDiffOptsWithSort() []cmp.Option {
// | Stale | Reachable | Solicited confirmation w/o address | Notify wakers | Changed |
// | Stale | Stale | Override confirmation | Update LinkAddr | Changed |
// | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed |
-// | Stale | Delay | Packet sent | | Changed |
+// | Stale | Delay | Packet queued | | Changed |
// | Delay | Reachable | Upper-layer confirmation | | Changed |
// | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
// | Delay | Reachable | Solicited confirmation w/o address | Notify wakers | Changed |
@@ -98,6 +101,7 @@ func eventDiffOptsWithSort() []cmp.Option {
// | Probe | Stale | Probe or confirmation w/ different address | | Changed |
// | Probe | Probe | Retransmit timer expired | Send probe | Changed |
// | Probe | Failed | Max probes sent without reply | Notify wakers | Removed |
+// | Failed | Failed | Packet queued | | |
// | Failed | | Unreachability timer expired | Delete entry | |
type testEntryEventType uint8
@@ -125,14 +129,11 @@ func (t testEntryEventType) String() string {
type testEntryEventInfo struct {
EventType testEntryEventType
NICID tcpip.NICID
- Addr tcpip.Address
- LinkAddr tcpip.LinkAddress
- State NeighborState
- UpdatedAt time.Time
+ Entry NeighborEntry
}
func (e testEntryEventInfo) String() string {
- return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.EventType, e.NICID, e.Addr, e.LinkAddr, e.State)
+ return fmt.Sprintf("%s event for NIC #%d, %#v", e.EventType, e.NICID, e.Entry)
}
// testNUDDispatcher implements NUDDispatcher to validate the dispatching of
@@ -150,36 +151,27 @@ func (d *testNUDDispatcher) queueEvent(e testEntryEventInfo) {
d.events = append(d.events, e)
}
-func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) {
+func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry NeighborEntry) {
d.queueEvent(testEntryEventInfo{
EventType: entryTestAdded,
NICID: nicID,
- Addr: addr,
- LinkAddr: linkAddr,
- State: state,
- UpdatedAt: updatedAt,
+ Entry: entry,
})
}
-func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) {
+func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry NeighborEntry) {
d.queueEvent(testEntryEventInfo{
EventType: entryTestChanged,
NICID: nicID,
- Addr: addr,
- LinkAddr: linkAddr,
- State: state,
- UpdatedAt: updatedAt,
+ Entry: entry,
})
}
-func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) {
+func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry NeighborEntry) {
d.queueEvent(testEntryEventInfo{
EventType: entryTestRemoved,
NICID: nicID,
- Addr: addr,
- LinkAddr: linkAddr,
- State: state,
- UpdatedAt: updatedAt,
+ Entry: entry,
})
}
@@ -202,9 +194,9 @@ func (p entryTestProbeInfo) String() string {
// LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts
// to the local network if linkAddr is the zero value.
-func (r *entryTestLinkResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
+func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
p := entryTestProbeInfo{
- RemoteAddress: addr,
+ RemoteAddress: targetAddr,
RemoteLinkAddress: linkAddr,
LocalAddress: localAddr,
}
@@ -237,6 +229,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e
clock: clock,
nudDisp: &disp,
},
+ stats: makeNICStats(),
}
nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{
header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil, nil, nil),
@@ -245,7 +238,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
nudState := NewNUDState(c, rng)
linkRes := entryTestLinkResolver{}
- entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, entryTestAddr2 /* localAddr */, nudState, &linkRes)
+ entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, nudState, &linkRes)
// Stub out the neighbor cache to verify deletion from the cache.
nic.neigh = &neighborCache{
@@ -323,15 +316,16 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) {
func TestEntryUnknownToIncomplete(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if got, want := e.neigh.State, Incomplete; got != want {
t.Errorf("got e.neigh.State = %q, want = %q", got, want)
}
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -350,9 +344,11 @@ func TestEntryUnknownToIncomplete(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
}
{
@@ -367,7 +363,7 @@ func TestEntryUnknownToIncomplete(t *testing.T) {
func TestEntryUnknownToStale(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handleProbeLocked(entryTestLinkAddr1)
@@ -377,6 +373,7 @@ func TestEntryUnknownToStale(t *testing.T) {
e.mu.Unlock()
// No probes should have been sent.
+ runImmediatelyScheduledJobs(clock)
linkRes.mu.Lock()
diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil))
linkRes.mu.Unlock()
@@ -388,9 +385,11 @@ func TestEntryUnknownToStale(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -406,11 +405,11 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if got, want := e.neigh.State, Incomplete; got != want {
t.Errorf("got e.neigh.State = %q, want = %q", got, want)
}
- updatedAt := e.neigh.UpdatedAt
+ updatedAtNanos := e.neigh.UpdatedAtNanos
e.mu.Unlock()
clock.Advance(c.RetransmitTimer)
@@ -437,7 +436,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
}
e.mu.Lock()
- if got, want := e.neigh.UpdatedAt, updatedAt; got != want {
+ if got, want := e.neigh.UpdatedAtNanos, updatedAtNanos; got != want {
t.Errorf("got e.neigh.UpdatedAt = %q, want = %q", got, want)
}
e.mu.Unlock()
@@ -468,16 +467,20 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestRemoved,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
}
nudDisp.mu.Lock()
@@ -487,7 +490,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
nudDisp.mu.Unlock()
e.mu.Lock()
- if got, notWant := e.neigh.UpdatedAt, updatedAt; got == notWant {
+ if got, notWant := e.neigh.UpdatedAtNanos, updatedAtNanos; got == notWant {
t.Errorf("expected e.neigh.UpdatedAt to change, got = %q", got)
}
e.mu.Unlock()
@@ -495,23 +498,16 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
func TestEntryIncompleteToReachable(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -526,20 +522,35 @@ func TestEntryIncompleteToReachable(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -555,7 +566,7 @@ func TestEntryIncompleteToReachable(t *testing.T) {
// to Reachable.
func TestEntryAddsAndClearsWakers(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
w := sleep.Waker{}
s := sleep.Sleeper{}
@@ -563,7 +574,25 @@ func TestEntryAddsAndClearsWakers(t *testing.T) {
defer s.Done()
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
if got := e.wakers; got != nil {
t.Errorf("got e.wakers = %v, want = nil", got)
}
@@ -587,34 +616,24 @@ func TestEntryAddsAndClearsWakers(t *testing.T) {
}
e.mu.Unlock()
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -626,26 +645,16 @@ func TestEntryAddsAndClearsWakers(t *testing.T) {
func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: true,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.isRouter, true; got != want {
- t.Errorf("got e.isRouter = %t, want = %t", got, want)
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -659,20 +668,38 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
}
linkRes.mu.Unlock()
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: true,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ if !e.isRouter {
+ t.Errorf("got e.isRouter = %t, want = true", e.isRouter)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -684,23 +711,16 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
func TestEntryIncompleteToStale(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -715,20 +735,35 @@ func TestEntryIncompleteToStale(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -744,7 +779,7 @@ func TestEntryIncompleteToFailed(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if got, want := e.neigh.State, Incomplete; got != want {
t.Errorf("got e.neigh.State = %q, want = %q", got, want)
}
@@ -783,16 +818,20 @@ func TestEntryIncompleteToFailed(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestRemoved,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
}
nudDisp.mu.Lock()
@@ -817,12 +856,30 @@ func (*testLocker) Unlock() {}
func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
ipv6EP := e.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: true,
Override: false,
@@ -848,34 +905,24 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
}
e.mu.Unlock()
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -893,27 +940,13 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleProbeLocked(entryTestLinkAddr1)
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -928,20 +961,42 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.handleProbeLocked(entryTestLinkAddr1)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr1 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -961,17 +1016,10 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -986,29 +1034,46 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.mu.Unlock()
+
clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1026,24 +1091,13 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleProbeLocked(entryTestLinkAddr2)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1058,27 +1112,48 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1086,38 +1161,17 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1132,27 +1186,52 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T)
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1160,38 +1239,17 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T)
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1206,27 +1264,52 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1234,37 +1317,17 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleProbeLocked(entryTestLinkAddr1)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1279,20 +1342,42 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handleProbeLocked(entryTestLinkAddr1)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr1 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1304,31 +1389,13 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) {
func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1343,27 +1410,55 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr2 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -1375,10 +1470,28 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
e.mu.Lock()
- e.handlePacketQueuedLocked()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -1400,41 +1513,33 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
}
e.mu.Unlock()
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -1446,31 +1551,13 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1485,27 +1572,55 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr2 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1517,27 +1632,13 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleProbeLocked(entryTestLinkAddr2)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1552,27 +1653,51 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr2 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1584,24 +1709,13 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
func TestEntryStaleToDelay(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1616,27 +1730,48 @@ func TestEntryStaleToDelay(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
}
nudDisp.mu.Lock()
@@ -1656,22 +1791,10 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleUpperLevelConfirmationLocked()
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1686,43 +1809,68 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
- clock.Advance(c.BaseReachableTime)
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleUpperLevelConfirmationLocked()
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.mu.Unlock()
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1743,29 +1891,10 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1780,43 +1909,75 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
- clock.Advance(c.BaseReachableTime)
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr2 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2)
+ }
+ e.mu.Unlock()
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1837,13 +1998,31 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if e.neigh.State != Delay {
t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
}
@@ -1860,57 +2039,52 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
}
e.mu.Unlock()
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
-
clock.Advance(c.BaseReachableTime)
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1922,32 +2096,13 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1962,27 +2117,56 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr1 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
}
nudDisp.mu.Lock()
@@ -1994,25 +2178,13 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleProbeLocked(entryTestLinkAddr2)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -2027,34 +2199,58 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2066,29 +2262,13 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -2103,34 +2283,62 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2145,69 +2353,91 @@ func TestEntryDelayToProbe(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if got, want := e.neigh.State, Delay; got != want {
t.Errorf("got e.neigh.State = %q, want = %q", got, want)
}
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
}
nudDisp.mu.Lock()
@@ -2228,36 +2458,50 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2274,37 +2518,47 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2312,12 +2566,6 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
@@ -2325,36 +2573,50 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2375,37 +2637,47 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2413,12 +2685,6 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
@@ -2426,36 +2692,51 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2479,30 +2760,38 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
}
nudDisp.mu.Lock()
@@ -2529,17 +2818,14 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
e.mu.Lock()
e.handleProbeLocked(entryTestLinkAddr1)
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
wantProbes := []entryTestProbeInfo{
- // Probe caused by the Delay-to-Probe transition
{
RemoteAddress: entryTestAddr1,
RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
},
}
linkRes.mu.Lock()
@@ -2567,42 +2853,51 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
e.mu.Unlock()
clock.Advance(c.BaseReachableTime)
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2622,36 +2917,50 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2672,49 +2981,60 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
e.mu.Unlock()
clock.Advance(c.BaseReachableTime)
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2734,36 +3054,50 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2781,49 +3115,60 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
e.mu.Unlock()
clock.Advance(c.BaseReachableTime)
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2843,36 +3188,50 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2890,49 +3249,60 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
e.mu.Unlock()
clock.Advance(c.BaseReachableTime)
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2946,87 +3316,238 @@ func TestEntryProbeToFailed(t *testing.T) {
c := DefaultNUDConfigurations()
c.MaxMulticastProbes = 3
c.MaxUnicastProbes = 3
+ c.DelayFirstProbeTime = c.RetransmitTimer
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
- waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes)
- clock.Advance(waitFor)
+ // Observe each probe sent while in the Probe state.
+ for i := uint32(0); i < c.MaxUnicastProbes; i++ {
+ clock.Advance(c.RetransmitTimer)
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probe #%d mismatch (-got, +want):\n%s", i+1, diff)
+ }
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
+ e.mu.Lock()
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
+ }
+ e.mu.Unlock()
+ }
+
+ // Wait for the last probe to expire, causing a transition to Failed.
+ clock.Advance(c.RetransmitTimer)
+ e.mu.Lock()
+ if e.neigh.State != Failed {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed)
+ }
+ e.mu.Unlock()
+
+ wantEvents := []testEntryEventInfo{
{
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
- // The next three probe are caused by the Delay-to-Probe transition.
{
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ },
+ {
+ EventType: entryTestRemoved,
+ NICID: entryTestNICID,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
}
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryFailedToFailed(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = 3
+ c.MaxUnicastProbes = 3
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ // Verify the cache contains the entry.
+ if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok {
+ t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1)
+ }
+
+ // TODO(gvisor.dev/issue/4872): Use helper functions to start entry tests in
+ // their expected state.
+ e.mu.Lock()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes)
+ clock.Advance(waitFor)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestRemoved,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
}
nudDisp.mu.Lock()
@@ -3035,11 +3556,23 @@ func TestEntryProbeToFailed(t *testing.T) {
}
nudDisp.mu.Unlock()
- e.mu.Lock()
- if got, want := e.neigh.State, Failed; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ failedLookups := e.nic.stats.Neighbor.FailedEntryLookups
+ if got := failedLookups.Value(); got != 0 {
+ t.Errorf("got Neighbor.FailedEntryLookups = %d, want = 0", got)
}
+
+ e.mu.Lock()
+ // Verify queuing a packet to the entry immediately fails.
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ state := e.neigh.State
e.mu.Unlock()
+ if state != Failed {
+ t.Errorf("got e.neigh.State = %q, want = %q", state, Failed)
+ }
+
+ if got := failedLookups.Value(); got != 1 {
+ t.Errorf("got Neighbor.FailedEntryLookups = %d, want = 1", got)
+ }
}
func TestEntryFailedGetsDeleted(t *testing.T) {
@@ -3054,84 +3587,106 @@ func TestEntryFailedGetsDeleted(t *testing.T) {
}
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime
clock.Advance(waitFor)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The next three probe are caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ // The next three probe are sent in Probe.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestRemoved,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
}
nudDisp.mu.Lock()
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index dcd4319bf..5d037a27e 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -54,18 +54,20 @@ type NIC struct {
sync.RWMutex
spoofing bool
promiscuous bool
- // packetEPs is protected by mu, but the contained PacketEndpoint
- // values are not.
- packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint
+ // packetEPs is protected by mu, but the contained packetEndpointList are
+ // not.
+ packetEPs map[tcpip.NetworkProtocolNumber]*packetEndpointList
}
}
-// NICStats includes transmitted and received stats.
+// NICStats hold statistics for a NIC.
type NICStats struct {
Tx DirectionStats
Rx DirectionStats
DisabledRx DirectionStats
+
+ Neighbor NeighborStats
}
func makeNICStats() NICStats {
@@ -80,6 +82,39 @@ type DirectionStats struct {
Bytes *tcpip.StatCounter
}
+type packetEndpointList struct {
+ mu sync.RWMutex
+
+ // eps is protected by mu, but the contained PacketEndpoint values are not.
+ eps []PacketEndpoint
+}
+
+func (p *packetEndpointList) add(ep PacketEndpoint) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.eps = append(p.eps, ep)
+}
+
+func (p *packetEndpointList) remove(ep PacketEndpoint) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ for i, epOther := range p.eps {
+ if epOther == ep {
+ p.eps = append(p.eps[:i], p.eps[i+1:]...)
+ break
+ }
+ }
+}
+
+// forEach calls fn with each endpoints in p while holding the read lock on p.
+func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ for _, ep := range p.eps {
+ fn(ep)
+ }
+}
+
// newNIC returns a new NIC using the default NDP configurations from stack.
func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC {
// TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For
@@ -100,7 +135,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
stats: makeNICStats(),
networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
}
- nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint)
+ nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList)
// Check for Neighbor Unreachability Detection support.
var nud NUDHandler
@@ -123,11 +158,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
// Register supported packet and network endpoint protocols.
for _, netProto := range header.Ethertypes {
- nic.mu.packetEPs[netProto] = []PacketEndpoint{}
+ nic.mu.packetEPs[netProto] = new(packetEndpointList)
}
for _, netProto := range stack.networkProtocols {
netNum := netProto.Number()
- nic.mu.packetEPs[netNum] = nil
+ nic.mu.packetEPs[netNum] = new(packetEndpointList)
nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic)
}
@@ -170,7 +205,7 @@ func (n *NIC) disable() {
//
// n MUST be locked.
func (n *NIC) disableLocked() {
- if !n.setEnabled(false) {
+ if !n.Enabled() {
return
}
@@ -182,6 +217,10 @@ func (n *NIC) disableLocked() {
for _, ep := range n.networkEndpoints {
ep.Disable()
}
+
+ if !n.setEnabled(false) {
+ panic("should have only done work to disable the NIC if it was enabled")
+ }
}
// enable enables n.
@@ -232,7 +271,8 @@ func (n *NIC) setPromiscuousMode(enable bool) {
n.mu.Unlock()
}
-func (n *NIC) isPromiscuousMode() bool {
+// Promiscuous implements NetworkInterface.
+func (n *NIC) Promiscuous() bool {
n.mu.RLock()
rv := n.mu.promiscuous
n.mu.RUnlock()
@@ -264,7 +304,7 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb
if ch, err := r.Resolve(nil); err != nil {
if err == tcpip.ErrWouldBlock {
r := r.Clone()
- n.stack.linkResQueue.enqueue(ch, &r, protocol, pkt)
+ n.stack.linkResQueue.enqueue(ch, r, protocol, pkt)
return nil
}
return err
@@ -273,6 +313,15 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb
return n.writePacket(r, gso, protocol, pkt)
}
+// WritePacketToRemote implements NetworkInterface.
+func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
+ r := Route{
+ NetProto: protocol,
+ }
+ r.ResolveWith(remoteLinkAddr)
+ return n.writePacket(&r, gso, protocol, pkt)
+}
+
func (n *NIC) writePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
// WritePacket takes ownership of pkt, calculate numBytes first.
numBytes := pkt.Size()
@@ -311,16 +360,21 @@ func (n *NIC) setSpoofing(enable bool) {
// primaryAddress returns an address that can be used to communicate with
// remoteAddr.
func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint {
- n.mu.RLock()
- spoofing := n.mu.spoofing
- n.mu.RUnlock()
-
ep, ok := n.networkEndpoints[protocol]
if !ok {
return nil
}
- return ep.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing)
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ return nil
+ }
+
+ n.mu.RLock()
+ spoofing := n.mu.spoofing
+ n.mu.RUnlock()
+
+ return addressableEndpoint.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing)
}
type getAddressBehaviour int
@@ -339,6 +393,16 @@ func (n *NIC) getAddress(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address
return n.getAddressOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous)
}
+func (n *NIC) hasAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ ep := n.getAddressOrCreateTempInner(protocol, addr, false, NeverPrimaryEndpoint)
+ if ep != nil {
+ ep.DecRef()
+ return true
+ }
+
+ return false
+}
+
// findEndpoint finds the endpoint, if any, with the given address.
func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) AssignableAddressEndpoint {
return n.getAddressOrCreateTemp(protocol, address, peb, spoofing)
@@ -369,11 +433,17 @@ func (n *NIC) getAddressOrCreateTemp(protocol tcpip.NetworkProtocolNumber, addre
// getAddressOrCreateTempInner is like getAddressEpOrCreateTemp except a boolean
// is passed to indicate whether or not we should generate temporary endpoints.
func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, createTemp bool, peb PrimaryEndpointBehavior) AssignableAddressEndpoint {
- if ep, ok := n.networkEndpoints[protocol]; ok {
- return ep.AcquireAssignedAddress(address, createTemp, peb)
+ ep, ok := n.networkEndpoints[protocol]
+ if !ok {
+ return nil
}
- return nil
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ return nil
+ }
+
+ return addressableEndpoint.AcquireAssignedAddress(address, createTemp, peb)
}
// addAddress adds a new address to n, so that it starts accepting packets
@@ -384,7 +454,12 @@ func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo
return tcpip.ErrUnknownProtocol
}
- addressEndpoint, err := ep.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */)
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ return tcpip.ErrNotSupported
+ }
+
+ addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */)
if err == nil {
// We have no need for the address endpoint.
addressEndpoint.DecRef()
@@ -397,7 +472,12 @@ func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo
func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress {
var addrs []tcpip.ProtocolAddress
for p, ep := range n.networkEndpoints {
- for _, a := range ep.PermanentAddresses() {
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ continue
+ }
+
+ for _, a := range addressableEndpoint.PermanentAddresses() {
addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a})
}
}
@@ -408,7 +488,12 @@ func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress {
func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress {
var addrs []tcpip.ProtocolAddress
for p, ep := range n.networkEndpoints {
- for _, a := range ep.PrimaryAddresses() {
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ continue
+ }
+
+ for _, a := range addressableEndpoint.PrimaryAddresses() {
addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a})
}
}
@@ -426,13 +511,23 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit
return tcpip.AddressWithPrefix{}
}
- return ep.MainAddress()
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ return tcpip.AddressWithPrefix{}
+ }
+
+ return addressableEndpoint.MainAddress()
}
// removeAddress removes an address from n.
func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error {
for _, ep := range n.networkEndpoints {
- if err := ep.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress {
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ continue
+ }
+
+ if err := addressableEndpoint.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress {
continue
} else {
return err
@@ -505,8 +600,7 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
return tcpip.ErrNotSupported
}
- _, err := gep.JoinGroup(addr)
- return err
+ return gep.JoinGroup(addr)
}
// leaveGroup decrements the count for the given multicast address, and when it
@@ -522,11 +616,7 @@ func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Addres
return tcpip.ErrNotSupported
}
- if _, err := gep.LeaveGroup(addr); err != nil {
- return err
- }
-
- return nil
+ return gep.LeaveGroup(addr)
}
// isInGroup returns true if n has joined the multicast group addr.
@@ -545,13 +635,6 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool {
return false
}
-func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) {
- r := makeRoute(protocol, dst, src, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */)
- defer r.Release()
- r.RemoteLinkAddress = remotelinkAddr
- n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt)
-}
-
// DeliverNetworkPacket finds the appropriate network protocol endpoint and
// hands the packet over for further processing. This function is called when
// the NIC receives a packet from the link endpoint.
@@ -573,7 +656,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
n.stats.Rx.Packets.Increment()
n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size()))
- netProto, ok := n.stack.networkProtocols[protocol]
+ networkEndpoint, ok := n.networkEndpoints[protocol]
if !ok {
n.mu.RUnlock()
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
@@ -585,23 +668,29 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
if local == "" {
local = n.LinkEndpoint.LinkAddress()
}
+ pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0
// Are any packet type sockets listening for this network protocol?
- packetEPs := n.mu.packetEPs[protocol]
- // Add any other packet type sockets that may be listening for all protocols.
- packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
+ protoEPs := n.mu.packetEPs[protocol]
+ // Other packet type sockets that are listening for all protocols.
+ anyEPs := n.mu.packetEPs[header.EthernetProtocolAll]
n.mu.RUnlock()
- for _, ep := range packetEPs {
+
+ // Deliver to interested packet endpoints without holding NIC lock.
+ deliverPacketEPs := func(ep PacketEndpoint) {
p := pkt.Clone()
p.PktType = tcpip.PacketHost
ep.HandlePacket(n.id, local, protocol, p)
}
-
- if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
- n.stack.stats.IP.PacketsReceived.Increment()
+ if protoEPs != nil {
+ protoEPs.forEach(deliverPacketEPs)
+ }
+ if anyEPs != nil {
+ anyEPs.forEach(deliverPacketEPs)
}
// Parse headers.
+ netProto := n.stack.NetworkProtocolInstance(protocol)
transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt)
if !ok {
// The packet is too small to contain a network header.
@@ -616,9 +705,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
}
- src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View())
-
if n.stack.handleLocal && !n.IsLoopback() {
+ src, _ := netProto.ParseAddresses(pkt.NetworkHeader().View())
if r := n.getAddress(protocol, src); r != nil {
r.DecRef()
@@ -631,78 +719,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
}
- // Loopback traffic skips the prerouting chain.
- if !n.IsLoopback() {
- // iptables filtering.
- ipt := n.stack.IPTables()
- address := n.primaryAddress(protocol)
- if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok {
- // iptables is telling us to drop the packet.
- n.stack.stats.IP.IPTablesPreroutingDropped.Increment()
- return
- }
- }
-
- if addressEndpoint := n.getAddress(protocol, dst); addressEndpoint != nil {
- n.handlePacket(protocol, dst, src, remote, addressEndpoint, pkt)
- return
- }
-
- // This NIC doesn't care about the packet. Find a NIC that cares about the
- // packet and forward it to the NIC.
- //
- // TODO: Should we be forwarding the packet even if promiscuous?
- if n.stack.Forwarding(protocol) {
- r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */)
- if err != nil {
- n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
- return
- }
-
- // Found a NIC.
- n := r.nic
- if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil {
- if n.isValidForOutgoing(addressEndpoint) {
- r.LocalLinkAddress = n.LinkEndpoint.LinkAddress()
- r.RemoteLinkAddress = remote
- r.RemoteAddress = src
- // TODO(b/123449044): Update the source NIC as well.
- n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt)
- addressEndpoint.DecRef()
- r.Release()
- return
- }
-
- addressEndpoint.DecRef()
- }
-
- // n doesn't have a destination endpoint.
- // Send the packet out of n.
- // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6.
-
- // pkt may have set its header and may not have enough headroom for
- // link-layer header for the other link to prepend. Here we create a new
- // packet to forward.
- fwdPkt := NewPacketBuffer(PacketBufferOptions{
- ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()),
- // We need to do a deep copy of the IP packet because WritePacket (and
- // friends) take ownership of the packet buffer, but we do not own it.
- Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(),
- })
-
- // TODO(b/143425874) Decrease the TTL field in forwarded packets.
- if err := n.WritePacket(&r, nil, protocol, fwdPkt); err != nil {
- n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
- }
-
- r.Release()
- return
- }
-
- // If a packet socket handled the packet, don't treat it as invalid.
- if len(packetEPs) == 0 {
- n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
- }
+ networkEndpoint.HandlePacket(pkt)
}
// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket.
@@ -711,21 +728,22 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc
// We do not deliver to protocol specific packet endpoints as on Linux
// only ETH_P_ALL endpoints get outbound packets.
// Add any other packet sockets that maybe listening for all protocols.
- packetEPs := n.mu.packetEPs[header.EthernetProtocolAll]
+ eps := n.mu.packetEPs[header.EthernetProtocolAll]
n.mu.RUnlock()
- for _, ep := range packetEPs {
+
+ eps.forEach(func(ep PacketEndpoint) {
p := pkt.Clone()
p.PktType = tcpip.PacketOutgoing
// Add the link layer header as outgoing packets are intercepted
// before the link layer header is created.
n.LinkEndpoint.AddHeader(local, remote, protocol, p)
ep.HandlePacket(n.id, local, protocol, p)
- }
+ })
}
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
-func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition {
+func (n *NIC) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
@@ -737,7 +755,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// Raw socket packets are delivered based solely on the transport
// protocol number. We do not inspect the payload to ensure it's
// validly formed.
- n.stack.demux.deliverRawPacket(r, protocol, pkt)
+ n.stack.demux.deliverRawPacket(protocol, pkt)
// TransportHeader is empty only when pkt is an ICMP packet or was reassembled
// from fragments.
@@ -766,14 +784,25 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
return TransportPacketHandled
}
- id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
- if n.stack.demux.deliverPacket(r, protocol, pkt, id) {
+ netProto, ok := n.stack.networkProtocols[pkt.NetworkProtocolNumber]
+ if !ok {
+ panic(fmt.Sprintf("expected network protocol = %d, have = %#v", pkt.NetworkProtocolNumber, n.stack.networkProtocolNumbers()))
+ }
+
+ src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View())
+ id := TransportEndpointID{
+ LocalPort: dstPort,
+ LocalAddress: dst,
+ RemotePort: srcPort,
+ RemoteAddress: src,
+ }
+ if n.stack.demux.deliverPacket(protocol, pkt, id) {
return TransportPacketHandled
}
// Try to deliver to per-stack default handler.
if state.defaultHandler != nil {
- if state.defaultHandler(r, id, pkt) {
+ if state.defaultHandler(id, pkt) {
return TransportPacketHandled
}
}
@@ -781,7 +810,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// We could not find an appropriate destination for this packet so
// give the protocol specific error handler a chance to handle it.
// If it doesn't handle it then we should do so.
- switch res := transProto.HandleUnknownDestinationPacket(r, id, pkt); res {
+ switch res := transProto.HandleUnknownDestinationPacket(id, pkt); res {
case UnknownDestinationPacketMalformed:
n.stack.stats.MalformedRcvdPackets.Increment()
return TransportPacketHandled
@@ -862,7 +891,7 @@ func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa
if !ok {
return tcpip.ErrNotSupported
}
- n.mu.packetEPs[netProto] = append(eps, ep)
+ eps.add(ep)
return nil
}
@@ -875,17 +904,11 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep
if !ok {
return
}
-
- for i, epOther := range eps {
- if epOther == ep {
- n.mu.packetEPs[netProto] = append(eps[:i], eps[i+1:]...)
- return
- }
- }
+ eps.remove(ep)
}
// isValidForOutgoing returns true if the endpoint can be used to send out a
-// packet. It requires the endpoint to not be marked expired (i.e., its address)
+// packet. It requires the endpoint to not be marked expired (i.e., its address
// has been removed) unless the NIC is in spoofing mode, or temporary.
func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool {
n.mu.RLock()
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 97a96af62..5b5c58afb 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -83,8 +83,7 @@ func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip
}
// HandlePacket implements NetworkEndpoint.HandlePacket.
-func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) {
-}
+func (*testIPv6Endpoint) HandlePacket(*PacketBuffer) {}
// Close implements NetworkEndpoint.Close.
func (e *testIPv6Endpoint) Close() {
@@ -169,7 +168,7 @@ func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements LinkAddressResolver.
-func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error {
+func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
return nil
}
diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go
index e1ec15487..ab629b3a4 100644
--- a/pkg/tcpip/stack/nud.go
+++ b/pkg/tcpip/stack/nud.go
@@ -129,7 +129,7 @@ type NUDDispatcher interface {
// the stack's operation.
//
// May be called concurrently.
- OnNeighborAdded(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time)
+ OnNeighborAdded(tcpip.NICID, NeighborEntry)
// OnNeighborChanged will be called when an entry in a NIC's (with ID nicID)
// neighbor table changes state and/or link address.
@@ -138,7 +138,7 @@ type NUDDispatcher interface {
// the stack's operation.
//
// May be called concurrently.
- OnNeighborChanged(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time)
+ OnNeighborChanged(tcpip.NICID, NeighborEntry)
// OnNeighborRemoved will be called when an entry is removed from a NIC's
// (with ID nicID) neighbor table.
@@ -147,7 +147,7 @@ type NUDDispatcher interface {
// the stack's operation.
//
// May be called concurrently.
- OnNeighborRemoved(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time)
+ OnNeighborRemoved(tcpip.NICID, NeighborEntry)
}
// ReachabilityConfirmationFlags describes the flags used within a reachability
@@ -177,7 +177,7 @@ type NUDHandler interface {
// Neighbor Solicitation for ARP or NDP, respectively). Validation of the
// probe needs to be performed before calling this function since the
// Neighbor Cache doesn't have access to view the NIC's assigned addresses.
- HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver)
+ HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver)
// HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP
// reply or Neighbor Advertisement for ARP or NDP, respectively).
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 7f54a6de8..664cc6fa0 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -112,6 +112,16 @@ type PacketBuffer struct {
// PktType indicates the SockAddrLink.PacketType of the packet as defined in
// https://www.man7.org/linux/man-pages/man7/packet.7.html.
PktType tcpip.PacketType
+
+ // NICID is the ID of the interface the network packet was received at.
+ NICID tcpip.NICID
+
+ // RXTransportChecksumValidated indicates that transport checksum verification
+ // may be safely skipped.
+ RXTransportChecksumValidated bool
+
+ // NetworkPacketInfo holds an incoming packet's network-layer information.
+ NetworkPacketInfo NetworkPacketInfo
}
// NewPacketBuffer creates a new PacketBuffer with opts.
@@ -240,20 +250,33 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum
// Clone should be called in such cases so that no modifications is done to
// underlying packet payload.
func (pk *PacketBuffer) Clone() *PacketBuffer {
- newPk := &PacketBuffer{
- PacketBufferEntry: pk.PacketBufferEntry,
- Data: pk.Data.Clone(nil),
- headers: pk.headers,
- header: pk.header,
- Hash: pk.Hash,
- Owner: pk.Owner,
- EgressRoute: pk.EgressRoute,
- GSOOptions: pk.GSOOptions,
- NetworkProtocolNumber: pk.NetworkProtocolNumber,
- NatDone: pk.NatDone,
- TransportProtocolNumber: pk.TransportProtocolNumber,
+ return &PacketBuffer{
+ PacketBufferEntry: pk.PacketBufferEntry,
+ Data: pk.Data.Clone(nil),
+ headers: pk.headers,
+ header: pk.header,
+ Hash: pk.Hash,
+ Owner: pk.Owner,
+ GSOOptions: pk.GSOOptions,
+ NetworkProtocolNumber: pk.NetworkProtocolNumber,
+ NatDone: pk.NatDone,
+ TransportProtocolNumber: pk.TransportProtocolNumber,
+ PktType: pk.PktType,
+ NICID: pk.NICID,
+ RXTransportChecksumValidated: pk.RXTransportChecksumValidated,
+ NetworkPacketInfo: pk.NetworkPacketInfo,
}
- return newPk
+}
+
+// SourceLinkAddress returns the source link address of the packet.
+func (pk *PacketBuffer) SourceLinkAddress() tcpip.LinkAddress {
+ link := pk.LinkHeader().View()
+
+ if link.IsEmpty() {
+ return ""
+ }
+
+ return header.Ethernet(link).SourceAddress()
}
// Network returns the network header as a header.Network.
@@ -270,6 +293,17 @@ func (pk *PacketBuffer) Network() header.Network {
}
}
+// CloneToInbound makes a shallow copy of the packet buffer to be used as an
+// inbound packet.
+//
+// See PacketBuffer.Data for details about how a packet buffer holds an inbound
+// packet.
+func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
+ return NewPacketBuffer(PacketBufferOptions{
+ Data: buffer.NewVectorisedView(pk.Size(), pk.Views()),
+ })
+}
+
// headerInfo stores metadata about a header in a packet.
type headerInfo struct {
// buf is the memorized slice for both prepended and consumed header.
diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go
index f838eda8d..5d364a2b0 100644
--- a/pkg/tcpip/stack/pending_packets.go
+++ b/pkg/tcpip/stack/pending_packets.go
@@ -106,7 +106,7 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro
} else if _, err := p.route.Resolve(nil); err != nil {
p.route.Stats().IP.OutgoingPacketErrors.Increment()
} else {
- p.route.nic.writePacket(p.route, nil /* gso */, p.proto, p.pkt)
+ p.route.outgoingNIC.writePacket(p.route, nil /* gso */, p.proto, p.pkt)
}
p.route.Release()
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index defb9129b..b334e27c4 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -63,17 +63,24 @@ const (
ControlUnknown
)
+// NetworkPacketInfo holds information about a network layer packet.
+type NetworkPacketInfo struct {
+ // LocalAddressBroadcast is true if the packet's local address is a broadcast
+ // address.
+ LocalAddressBroadcast bool
+}
+
// TransportEndpoint is the interface that needs to be implemented by transport
// protocol (e.g., tcp, udp) endpoints that can handle packets.
type TransportEndpoint interface {
// UniqueID returns an unique ID for this transport endpoint.
UniqueID() uint64
- // HandlePacket is called by the stack when new packets arrive to
- // this transport endpoint. It sets pkt.TransportHeader.
+ // HandlePacket is called by the stack when new packets arrive to this
+ // transport endpoint. It sets the packet buffer's transport header.
//
- // HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer)
+ // HandlePacket takes ownership of the packet.
+ HandlePacket(TransportEndpointID, *PacketBuffer)
// HandleControlPacket is called by the stack when new control (e.g.
// ICMP) packets arrive to this transport endpoint.
@@ -105,8 +112,8 @@ type RawTransportEndpoint interface {
// this transport endpoint. The packet contains all data from the link
// layer up.
//
- // HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, pkt *PacketBuffer)
+ // HandlePacket takes ownership of the packet.
+ HandlePacket(*PacketBuffer)
}
// PacketEndpoint is the interface that needs to be implemented by packet
@@ -127,7 +134,7 @@ type PacketEndpoint interface {
HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
-// UnknownDestinationPacketDisposition enumerates the possible return vaues from
+// UnknownDestinationPacketDisposition enumerates the possible return values from
// HandleUnknownDestinationPacket().
type UnknownDestinationPacketDisposition int
@@ -172,9 +179,9 @@ type TransportProtocol interface {
// protocol that don't match any existing endpoint. For example,
// it is targeted at a port that has no listeners.
//
- // HandleUnknownDestinationPacket takes ownership of pkt if it handles
+ // HandleUnknownDestinationPacket takes ownership of the packet if it handles
// the issue.
- HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) UnknownDestinationPacketDisposition
+ HandleUnknownDestinationPacket(TransportEndpointID, *PacketBuffer) UnknownDestinationPacketDisposition
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
@@ -227,8 +234,8 @@ type TransportDispatcher interface {
//
// pkt.NetworkHeader must be set before calling DeliverTransportPacket.
//
- // DeliverTransportPacket takes ownership of pkt.
- DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition
+ // DeliverTransportPacket takes ownership of the packet.
+ DeliverTransportPacket(tcpip.TransportProtocolNumber, *PacketBuffer) TransportPacketDisposition
// DeliverTransportControlPacket delivers control packets to the
// appropriate transport protocol endpoint.
@@ -270,15 +277,11 @@ type NetworkHeaderParams struct {
// An endpoint is considered to support group addressing when one or more
// endpoints may associate themselves with the same identifier (group address).
type GroupAddressableEndpoint interface {
- // JoinGroup joins the spcified group.
- //
- // Returns true if the group was newly joined.
- JoinGroup(group tcpip.Address) (bool, *tcpip.Error)
+ // JoinGroup joins the specified group.
+ JoinGroup(group tcpip.Address) *tcpip.Error
// LeaveGroup attempts to leave the specified group.
- //
- // Returns tcpip.ErrBadLocalAddress if the endpoint has not joined the group.
- LeaveGroup(group tcpip.Address) (bool, *tcpip.Error)
+ LeaveGroup(group tcpip.Address) *tcpip.Error
// IsInGroup returns true if the endpoint is a member of the specified group.
IsInGroup(group tcpip.Address) bool
@@ -329,6 +332,9 @@ type AssignableAddressEndpoint interface {
// AddressWithPrefix returns the endpoint's address.
AddressWithPrefix() tcpip.AddressWithPrefix
+ // Subnet returns the subnet of the endpoint's address.
+ Subnet() tcpip.Subnet
+
// IsAssigned returns whether or not the endpoint is considered bound
// to its NetworkEndpoint.
IsAssigned(allowExpired bool) bool
@@ -364,7 +370,7 @@ type AddressEndpoint interface {
SetDeprecated(bool)
}
-// AddressKind is the kind of of an address.
+// AddressKind is the kind of an address.
//
// See the values of AddressKind for more details.
type AddressKind int
@@ -490,13 +496,17 @@ type NetworkInterface interface {
// Enabled returns true if the interface is enabled.
Enabled() bool
+
+ // Promiscuous returns true if the interface is in promiscuous mode.
+ Promiscuous() bool
+
+ // WritePacketToRemote writes the packet to the given remote link address.
+ WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error
}
// NetworkEndpoint is the interface that needs to be implemented by endpoints
// of network layer protocols (e.g., ipv4, ipv6).
type NetworkEndpoint interface {
- AddressableEndpoint
-
// Enable enables the endpoint.
//
// Must only be called when the stack is in a state that allows the endpoint
@@ -544,7 +554,7 @@ type NetworkEndpoint interface {
// this network endpoint. It sets pkt.NetworkHeader.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, pkt *PacketBuffer)
+ HandlePacket(pkt *PacketBuffer)
// Close is called when the endpoint is reomved from a stack.
Close()
@@ -712,10 +722,6 @@ type LinkEndpoint interface {
// endpoint.
Capabilities() LinkEndpointCapabilities
- // WriteRawPacket writes a packet directly to the link. The packet
- // should already have an ethernet header. It takes ownership of vv.
- WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error
-
// Attach attaches the data link layer endpoint to the network-layer
// dispatcher of the stack.
//
@@ -764,13 +770,13 @@ type InjectableLinkEndpoint interface {
// A LinkAddressResolver is an extension to a NetworkProtocol that
// can resolve link addresses.
type LinkAddressResolver interface {
- // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts
- // the request on the local network if remoteLinkAddr is the zero value. The
- // request is sent on linkEP with localAddr as the source.
+ // LinkAddressRequest sends a request for the link address of the target
+ // address. The request is broadcasted on the local network if a remote link
+ // address is not provided.
//
- // A valid response will cause the discovery protocol's network
- // endpoint to call AddLinkAddress.
- LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error
+ // The request is sent from the passed network interface. If the interface
+ // local address is unspecified, any interface local address may be used.
+ LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic NetworkInterface) *tcpip.Error
// ResolveStaticAddress attempts to resolve address without sending
// requests. It either resolves the name immediately or returns the
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index b76e2d37b..de5fe6ffe 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -15,20 +15,25 @@
package stack
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
// Route represents a route through the networking stack to a given destination.
+//
+// It is safe to call Route's methods from multiple goroutines.
+//
+// The exported fields are immutable.
+//
+// TODO(gvisor.dev/issue/4902): Unexpose immutable fields.
type Route struct {
// RemoteAddress is the final destination of the route.
RemoteAddress tcpip.Address
- // RemoteLinkAddress is the link-layer (MAC) address of the
- // final destination of the route.
- RemoteLinkAddress tcpip.LinkAddress
-
// LocalAddress is the local address where the route starts.
LocalAddress tcpip.Address
@@ -45,11 +50,24 @@ type Route struct {
// Loop controls where WritePacket should send packets.
Loop PacketLooping
- // nic is the NIC the route goes through.
- nic *NIC
+ // localAddressNIC is the interface the address is associated with.
+ // TODO(gvisor.dev/issue/4548): Remove this field once we can query the
+ // address's assigned status without the NIC.
+ localAddressNIC *NIC
- // addressEndpoint is the local address this route is associated with.
- addressEndpoint AssignableAddressEndpoint
+ mu struct {
+ sync.RWMutex
+
+ // localAddressEndpoint is the local address this route is associated with.
+ localAddressEndpoint AssignableAddressEndpoint
+
+ // remoteLinkAddress is the link-layer (MAC) address of the next hop in the
+ // route.
+ remoteLinkAddress tcpip.LinkAddress
+ }
+
+ // outgoingNIC is the interface this route uses to write packets.
+ outgoingNIC *NIC
// linkCache is set if link address resolution is enabled for this protocol on
// the route's NIC.
@@ -60,51 +78,139 @@ type Route struct {
linkRes LinkAddressResolver
}
+// constructAndValidateRoute validates and initializes a route. It takes
+// ownership of the provided local address.
+//
+// Returns an empty route if validation fails.
+func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) *Route {
+ if len(localAddr) == 0 {
+ localAddr = addressEndpoint.AddressWithPrefix().Address
+ }
+
+ if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(localAddr) {
+ addressEndpoint.DecRef()
+ return nil
+ }
+
+ // If no remote address is provided, use the local address.
+ if len(remoteAddr) == 0 {
+ remoteAddr = localAddr
+ }
+
+ r := makeRoute(
+ netProto,
+ localAddr,
+ remoteAddr,
+ outgoingNIC,
+ localAddressNIC,
+ addressEndpoint,
+ handleLocal,
+ multicastLoop,
+ )
+
+ // If the route requires us to send a packet through some gateway, do not
+ // broadcast it.
+ if len(gateway) > 0 {
+ r.NextHop = gateway
+ } else if subnet := addressEndpoint.Subnet(); subnet.IsBroadcast(remoteAddr) {
+ r.ResolveWith(header.EthernetBroadcastAddress)
+ }
+
+ return r
+}
+
// makeRoute initializes a new route. It takes ownership of the provided
// AssignableAddressEndpoint.
-func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, nic *NIC, addressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route {
+func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) *Route {
+ if localAddressNIC.stack != outgoingNIC.stack {
+ panic(fmt.Sprintf("cannot create a route with NICs from different stacks"))
+ }
+
+ if len(localAddr) == 0 {
+ localAddr = localAddressEndpoint.AddressWithPrefix().Address
+ }
+
loop := PacketOut
- if handleLocal && localAddr != "" && remoteAddr == localAddr {
- loop = PacketLoop
- } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) {
- loop |= PacketLoop
- } else if remoteAddr == header.IPv4Broadcast {
- loop |= PacketLoop
+
+ // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the
+ // link endpoint level. We can remove this check once loopback interfaces
+ // loop back packets at the network layer.
+ if !outgoingNIC.IsLoopback() {
+ if handleLocal && localAddr != "" && remoteAddr == localAddr {
+ loop = PacketLoop
+ } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) {
+ loop |= PacketLoop
+ } else if remoteAddr == header.IPv4Broadcast {
+ loop |= PacketLoop
+ } else if subnet := localAddressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) {
+ loop |= PacketLoop
+ }
}
- r := Route{
+ return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop)
+}
+
+func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) *Route {
+ r := &Route{
NetProto: netProto,
LocalAddress: localAddr,
- LocalLinkAddress: nic.LinkEndpoint.LinkAddress(),
+ LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(),
RemoteAddress: remoteAddr,
- addressEndpoint: addressEndpoint,
- nic: nic,
+ localAddressNIC: localAddressNIC,
+ outgoingNIC: outgoingNIC,
Loop: loop,
}
- if r.nic.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 {
- if linkRes, ok := r.nic.stack.linkAddrResolvers[r.NetProto]; ok {
+ r.mu.Lock()
+ r.mu.localAddressEndpoint = localAddressEndpoint
+ r.mu.Unlock()
+
+ if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 {
+ if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok {
r.linkRes = linkRes
- r.linkCache = r.nic.stack
+ r.linkCache = r.outgoingNIC.stack
}
}
return r
}
+// makeLocalRoute initializes a new local route. It takes ownership of the
+// provided AssignableAddressEndpoint.
+//
+// A local route is a route to a destination that is local to the stack.
+func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) *Route {
+ loop := PacketLoop
+ // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the
+ // link endpoint level. We can remove this check once loopback interfaces
+ // loop back packets at the network layer.
+ if outgoingNIC.IsLoopback() {
+ loop = PacketOut
+ }
+ return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop)
+}
+
+// RemoteLinkAddress returns the link-layer (MAC) address of the next hop in
+// the route.
+func (r *Route) RemoteLinkAddress() tcpip.LinkAddress {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.mu.remoteLinkAddress
+}
+
// NICID returns the id of the NIC from which this route originates.
func (r *Route) NICID() tcpip.NICID {
- return r.nic.ID()
+ return r.outgoingNIC.ID()
}
// MaxHeaderLength forwards the call to the network endpoint's implementation.
func (r *Route) MaxHeaderLength() uint16 {
- return r.nic.getNetworkEndpoint(r.NetProto).MaxHeaderLength()
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MaxHeaderLength()
}
// Stats returns a mutable copy of current stats.
func (r *Route) Stats() tcpip.Stats {
- return r.nic.stack.Stats()
+ return r.outgoingNIC.stack.Stats()
}
// PseudoHeaderChecksum forwards the call to the network endpoint's
@@ -113,14 +219,38 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot
return header.PseudoHeaderChecksum(protocol, r.LocalAddress, r.RemoteAddress, totalLen)
}
-// Capabilities returns the link-layer capabilities of the route.
-func (r *Route) Capabilities() LinkEndpointCapabilities {
- return r.nic.LinkEndpoint.Capabilities()
+// RequiresTXTransportChecksum returns false if the route does not require
+// transport checksums to be populated.
+func (r *Route) RequiresTXTransportChecksum() bool {
+ if r.local() {
+ return false
+ }
+ return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityTXChecksumOffload == 0
+}
+
+// HasSoftwareGSOCapability returns true if the route supports software GSO.
+func (r *Route) HasSoftwareGSOCapability() bool {
+ return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySoftwareGSO != 0
+}
+
+// HasHardwareGSOCapability returns true if the route supports hardware GSO.
+func (r *Route) HasHardwareGSOCapability() bool {
+ return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityHardwareGSO != 0
+}
+
+// HasSaveRestoreCapability returns true if the route supports save/restore.
+func (r *Route) HasSaveRestoreCapability() bool {
+ return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySaveRestore != 0
+}
+
+// HasDisconncetOkCapability returns true if the route supports disconnecting.
+func (r *Route) HasDisconncetOkCapability() bool {
+ return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityDisconnectOk != 0
}
// GSOMaxSize returns the maximum GSO packet size.
func (r *Route) GSOMaxSize() uint32 {
- if gso, ok := r.nic.LinkEndpoint.(GSOEndpoint); ok {
+ if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok {
return gso.GSOMaxSize()
}
return 0
@@ -129,7 +259,9 @@ func (r *Route) GSOMaxSize() uint32 {
// ResolveWith immediately resolves a route with the specified remote link
// address.
func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
- r.RemoteLinkAddress = addr
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.mu.remoteLinkAddress = addr
}
// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in
@@ -142,7 +274,10 @@ func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
//
// The NIC r uses must not be locked.
func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
- if !r.IsResolutionRequired() {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if !r.isResolutionRequiredRLocked() {
// Nothing to do if there is no cache (which does the resolution on cache miss) or
// link address is already known.
return nil, nil
@@ -152,26 +287,33 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
if nextAddr == "" {
// Local link address is already known.
if r.RemoteAddress == r.LocalAddress {
- r.RemoteLinkAddress = r.LocalLinkAddress
+ r.mu.remoteLinkAddress = r.LocalLinkAddress
return nil, nil
}
nextAddr = r.RemoteAddress
}
- if neigh := r.nic.neigh; neigh != nil {
- entry, ch, err := neigh.entry(nextAddr, r.LocalAddress, r.linkRes, waker)
+ // If specified, the local address used for link address resolution must be an
+ // address on the outgoing interface.
+ var linkAddressResolutionRequestLocalAddr tcpip.Address
+ if r.localAddressNIC == r.outgoingNIC {
+ linkAddressResolutionRequestLocalAddr = r.LocalAddress
+ }
+
+ if neigh := r.outgoingNIC.neigh; neigh != nil {
+ entry, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, waker)
if err != nil {
return ch, err
}
- r.RemoteLinkAddress = entry.LinkAddr
+ r.mu.remoteLinkAddress = entry.LinkAddr
return nil, nil
}
- linkAddr, ch, err := r.linkCache.GetLinkAddress(r.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker)
+ linkAddr, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, waker)
if err != nil {
return ch, err
}
- r.RemoteLinkAddress = linkAddr
+ r.mu.remoteLinkAddress = linkAddr
return nil, nil
}
@@ -182,100 +324,146 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
nextAddr = r.RemoteAddress
}
- if neigh := r.nic.neigh; neigh != nil {
+ if neigh := r.outgoingNIC.neigh; neigh != nil {
neigh.removeWaker(nextAddr, waker)
return
}
- r.linkCache.RemoveWaker(r.nic.ID(), nextAddr, waker)
+ r.linkCache.RemoveWaker(r.outgoingNIC.ID(), nextAddr, waker)
+}
+
+// local returns true if the route is a local route.
+func (r *Route) local() bool {
+ return r.Loop == PacketLoop || r.outgoingNIC.IsLoopback()
}
// IsResolutionRequired returns true if Resolve() must be called to resolve
-// the link address before the this route can be written to.
+// the link address before the route can be written to.
//
-// The NIC r uses must not be locked.
+// The NICs the route is associated with must not be locked.
func (r *Route) IsResolutionRequired() bool {
- if r.nic.neigh != nil {
- return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkRes != nil && r.RemoteLinkAddress == ""
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.isResolutionRequiredRLocked()
+}
+
+func (r *Route) isResolutionRequiredRLocked() bool {
+ if !r.isValidForOutgoingRLocked() || r.mu.remoteLinkAddress != "" || r.local() {
+ return false
}
- return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkCache != nil && r.RemoteLinkAddress == ""
+
+ return (r.outgoingNIC.neigh != nil && r.linkRes != nil) || r.linkCache != nil
+}
+
+func (r *Route) isValidForOutgoing() bool {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.isValidForOutgoingRLocked()
+}
+
+func (r *Route) isValidForOutgoingRLocked() bool {
+ if !r.outgoingNIC.Enabled() {
+ return false
+ }
+
+ localAddressEndpoint := r.mu.localAddressEndpoint
+ if localAddressEndpoint == nil || !r.localAddressNIC.isValidForOutgoing(localAddressEndpoint) {
+ return false
+ }
+
+ // If the source NIC and outgoing NIC are different, make sure the stack has
+ // forwarding enabled, or the packet will be handled locally.
+ if r.outgoingNIC != r.localAddressNIC && !r.outgoingNIC.stack.Forwarding(r.NetProto) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto, r.RemoteAddress)) {
+ return false
+ }
+
+ return true
}
// WritePacket writes the packet through the given route.
func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error {
- if !r.nic.isValidForOutgoing(r.addressEndpoint) {
+ if !r.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
- return r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt)
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt)
}
// WritePackets writes a list of n packets through the given route and returns
// the number of packets written.
func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) {
- if !r.nic.isValidForOutgoing(r.addressEndpoint) {
+ if !r.isValidForOutgoing() {
return 0, tcpip.ErrInvalidEndpointState
}
- return r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params)
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params)
}
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error {
- if !r.nic.isValidForOutgoing(r.addressEndpoint) {
+ if !r.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
- return r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt)
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt)
}
// DefaultTTL returns the default TTL of the underlying network endpoint.
func (r *Route) DefaultTTL() uint8 {
- return r.nic.getNetworkEndpoint(r.NetProto).DefaultTTL()
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).DefaultTTL()
}
// MTU returns the MTU of the underlying network endpoint.
func (r *Route) MTU() uint32 {
- return r.nic.getNetworkEndpoint(r.NetProto).MTU()
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MTU()
}
// Release frees all resources associated with the route.
func (r *Route) Release() {
- if r.addressEndpoint != nil {
- r.addressEndpoint.DecRef()
- r.addressEndpoint = nil
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.mu.localAddressEndpoint != nil {
+ r.mu.localAddressEndpoint.DecRef()
+ r.mu.localAddressEndpoint = nil
}
}
// Clone clones the route.
-func (r *Route) Clone() Route {
- if r.addressEndpoint != nil {
- _ = r.addressEndpoint.IncRef()
+func (r *Route) Clone() *Route {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+
+ newRoute := &Route{
+ RemoteAddress: r.RemoteAddress,
+ LocalAddress: r.LocalAddress,
+ LocalLinkAddress: r.LocalLinkAddress,
+ NextHop: r.NextHop,
+ NetProto: r.NetProto,
+ Loop: r.Loop,
+ localAddressNIC: r.localAddressNIC,
+ outgoingNIC: r.outgoingNIC,
+ linkCache: r.linkCache,
+ linkRes: r.linkRes,
}
- return *r
-}
-// MakeLoopedRoute duplicates the given route with special handling for routes
-// used for sending multicast or broadcast packets. In those cases the
-// multicast/broadcast address is the remote address when sending out, but for
-// incoming (looped) packets it becomes the local address. Similarly, the local
-// interface address that was the local address going out becomes the remote
-// address coming in. This is different to unicast routes where local and
-// remote addresses remain the same as they identify location (local vs remote)
-// not direction (source vs destination).
-func (r *Route) MakeLoopedRoute() Route {
- l := r.Clone()
- if r.RemoteAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) {
- l.RemoteAddress, l.LocalAddress = l.LocalAddress, l.RemoteAddress
- l.RemoteLinkAddress = l.LocalLinkAddress
+ newRoute.mu.Lock()
+ defer newRoute.mu.Unlock()
+ newRoute.mu.localAddressEndpoint = r.mu.localAddressEndpoint
+ if newRoute.mu.localAddressEndpoint != nil {
+ if !newRoute.mu.localAddressEndpoint.IncRef() {
+ panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", newRoute.LocalAddress))
+ }
}
- return l
+ newRoute.mu.remoteLinkAddress = r.mu.remoteLinkAddress
+
+ return newRoute
}
// Stack returns the instance of the Stack that owns this route.
func (r *Route) Stack() *Stack {
- return r.nic.stack
+ return r.outgoingNIC.stack
}
func (r *Route) isV4Broadcast(addr tcpip.Address) bool {
@@ -283,7 +471,14 @@ func (r *Route) isV4Broadcast(addr tcpip.Address) bool {
return true
}
- subnet := r.addressEndpoint.AddressWithPrefix().Subnet()
+ r.mu.RLock()
+ localAddressEndpoint := r.mu.localAddressEndpoint
+ r.mu.RUnlock()
+ if localAddressEndpoint == nil {
+ return false
+ }
+
+ subnet := localAddressEndpoint.Subnet()
return subnet.IsBroadcast(addr)
}
@@ -293,26 +488,3 @@ func (r *Route) IsOutboundBroadcast() bool {
// Only IPv4 has a notion of broadcast.
return r.isV4Broadcast(r.RemoteAddress)
}
-
-// IsInboundBroadcast returns true if the route is for an inbound broadcast
-// packet.
-func (r *Route) IsInboundBroadcast() bool {
- // Only IPv4 has a notion of broadcast.
- return r.isV4Broadcast(r.LocalAddress)
-}
-
-// ReverseRoute returns new route with given source and destination address.
-func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route {
- return Route{
- NetProto: r.NetProto,
- LocalAddress: dst,
- LocalLinkAddress: r.RemoteLinkAddress,
- RemoteAddress: src,
- RemoteLinkAddress: r.LocalLinkAddress,
- Loop: r.Loop,
- addressEndpoint: r.addressEndpoint,
- nic: r.nic,
- linkCache: r.linkCache,
- linkRes: r.linkRes,
- }
-}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 3a07577c8..dc4f5b3e7 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -22,6 +22,7 @@ package stack
import (
"bytes"
"encoding/binary"
+ "fmt"
mathrand "math/rand"
"sync/atomic"
"time"
@@ -52,7 +53,7 @@ const (
type transportProtocolState struct {
proto TransportProtocol
- defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool
+ defaultHandler func(id TransportEndpointID, pkt *PacketBuffer) bool
}
// TCPProbeFunc is the expected function type for a TCP probe function to be
@@ -81,6 +82,7 @@ type TCPRACKState struct {
FACK seqnum.Value
RTT time.Duration
Reord bool
+ DSACKSeen bool
}
// TCPEndpointID is the unique 4 tuple that identifies a given endpoint.
@@ -518,6 +520,10 @@ type Options struct {
//
// RandSource must be thread-safe.
RandSource mathrand.Source
+
+ // IPTables are the initial iptables rules. If nil, iptables will allow
+ // all traffic.
+ IPTables *IPTables
}
// TransportEndpointInfo holds useful information about a transport endpoint
@@ -620,6 +626,10 @@ func New(opts Options) *Stack {
randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())}
}
+ if opts.IPTables == nil {
+ opts.IPTables = DefaultTables()
+ }
+
opts.NUDConfigs.resetInvalidFields()
s := &Stack{
@@ -633,7 +643,7 @@ func New(opts Options) *Stack {
clock: clock,
stats: opts.Stats.FillIn(),
handleLocal: opts.HandleLocal,
- tables: DefaultTables(),
+ tables: opts.IPTables,
icmpRateLimiter: NewICMPRateLimiter(),
seed: generateRandUint32(),
nudConfigs: opts.NUDConfigs,
@@ -751,7 +761,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber,
//
// It must be called only during initialization of the stack. Changing it as the
// stack is operating is not supported.
-func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *PacketBuffer) bool) {
+func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(TransportEndpointID, *PacketBuffer) bool) {
state := s.transportProtocols[p]
if state != nil {
state.defaultHandler = h
@@ -830,6 +840,20 @@ func (s *Stack) AddRoute(route tcpip.Route) {
s.routeTable = append(s.routeTable, route)
}
+// RemoveRoutes removes matching routes from the route table.
+func (s *Stack) RemoveRoutes(match func(tcpip.Route) bool) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ var filteredRoutes []tcpip.Route
+ for _, route := range s.routeTable {
+ if !match(route) {
+ filteredRoutes = append(filteredRoutes, route)
+ }
+ }
+ s.routeTable = filteredRoutes
+}
+
// NewEndpoint creates a new transport layer endpoint of the given protocol.
func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
t, ok := s.transportProtocols[transport]
@@ -1057,7 +1081,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
flags := NICStateFlags{
Up: true, // Netstack interfaces are always up.
Running: nic.Enabled(),
- Promiscuous: nic.isPromiscuousMode(),
+ Promiscuous: nic.Promiscuous(),
Loopback: nic.IsLoopback(),
}
nics[id] = NICInfo{
@@ -1094,6 +1118,16 @@ func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber,
return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint)
}
+// AddAddressWithPrefix is the same as AddAddress, but allows you to specify
+// the address prefix.
+func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) *tcpip.Error {
+ ap := tcpip.ProtocolAddress{
+ Protocol: protocol,
+ AddressWithPrefix: addr,
+ }
+ return s.AddProtocolAddressWithOptions(id, ap, CanBePrimaryEndpoint)
+}
+
// AddProtocolAddress adds a new network-layer protocol address to the
// specified NIC.
func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) *tcpip.Error {
@@ -1180,54 +1214,225 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP
return nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint)
}
+// findLocalRouteFromNICRLocked is like findLocalRouteRLocked but finds a route
+// from the specified NIC.
+//
+// Precondition: s.mu must be read locked.
+func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route {
+ localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint)
+ if localAddressEndpoint == nil {
+ return nil
+ }
+
+ var outgoingNIC *NIC
+ // Prefer a local route to the same interface as the local address.
+ if localAddressNIC.hasAddress(netProto, remoteAddr) {
+ outgoingNIC = localAddressNIC
+ }
+
+ // If the remote address isn't owned by the local address's NIC, check all
+ // NICs.
+ if outgoingNIC == nil {
+ for _, nic := range s.nics {
+ if nic.hasAddress(netProto, remoteAddr) {
+ outgoingNIC = nic
+ break
+ }
+ }
+ }
+
+ // If the remote address is not owned by the stack, we can't return a local
+ // route.
+ if outgoingNIC == nil {
+ localAddressEndpoint.DecRef()
+ return nil
+ }
+
+ r := makeLocalRoute(
+ netProto,
+ localAddr,
+ remoteAddr,
+ outgoingNIC,
+ localAddressNIC,
+ localAddressEndpoint,
+ )
+
+ if r.IsOutboundBroadcast() {
+ r.Release()
+ return nil
+ }
+
+ return r
+}
+
+// findLocalRouteRLocked returns a local route.
+//
+// A local route is a route to some remote address which the stack owns. That
+// is, a local route is a route where packets never have to leave the stack.
+//
+// Precondition: s.mu must be read locked.
+func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route {
+ if len(localAddr) == 0 {
+ localAddr = remoteAddr
+ }
+
+ if localAddressNICID == 0 {
+ for _, localAddressNIC := range s.nics {
+ if r := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); r != nil {
+ return r
+ }
+ }
+
+ return nil
+ }
+
+ if localAddressNIC, ok := s.nics[localAddressNICID]; ok {
+ return s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto)
+ }
+
+ return nil
+}
+
// FindRoute creates a route to the given destination address, leaving through
-// the given nic and local address (if provided).
-func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) {
+// the given NIC and local address (if provided).
+//
+// If a NIC is not specified, the returned route will leave through the same
+// NIC as the NIC that has the local address assigned when forwarding is
+// disabled. If forwarding is enabled and the NIC is unspecified, the route may
+// leave through any interface unless the route is link-local.
+//
+// If no local address is provided, the stack will select a local address. If no
+// remote address is provided, the stack wil use a remote address equal to the
+// local address.
+func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (*Route, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
+ isLinkLocal := header.IsV6LinkLocalAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr)
isLocalBroadcast := remoteAddr == header.IPv4Broadcast
isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)
- needRoute := !(isLocalBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr))
+ isLoopback := header.IsV4LoopbackAddress(remoteAddr) || header.IsV6LoopbackAddress(remoteAddr)
+ needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || isLoopback)
+
+ if s.handleLocal && !isMulticast && !isLocalBroadcast {
+ if r := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); r != nil {
+ return r, nil
+ }
+ }
+
+ // If the interface is specified and we do not need a route, return a route
+ // through the interface if the interface is valid and enabled.
if id != 0 && !needRoute {
if nic, ok := s.nics[id]; ok && nic.Enabled() {
if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
- return makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()), nil
+ return makeRoute(
+ netProto,
+ localAddr,
+ remoteAddr,
+ nic, /* outboundNIC */
+ nic, /* localAddressNIC*/
+ addressEndpoint,
+ s.handleLocal,
+ multicastLoop,
+ ), nil
}
}
- } else {
- for _, route := range s.routeTable {
- if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) {
- continue
+
+ if isLoopback {
+ return nil, tcpip.ErrBadLocalAddress
+ }
+ return nil, tcpip.ErrNetworkUnreachable
+ }
+
+ canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal
+
+ // Find a route to the remote with the route table.
+ var chosenRoute tcpip.Route
+ for _, route := range s.routeTable {
+ if len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr) {
+ continue
+ }
+
+ nic, ok := s.nics[route.NIC]
+ if !ok || !nic.Enabled() {
+ continue
+ }
+
+ if id == 0 || id == route.NIC {
+ if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
+ var gateway tcpip.Address
+ if needRoute {
+ gateway = route.Gateway
+ }
+ r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop)
+ if r == nil {
+ panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr))
+ }
+ return r, nil
}
- if nic, ok := s.nics[route.NIC]; ok && nic.Enabled() {
- if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
- if len(remoteAddr) == 0 {
- // If no remote address was provided, then the route
- // provided will refer to the link local address.
- remoteAddr = addressEndpoint.AddressWithPrefix().Address
- }
+ }
+
+ // If the stack has forwarding enabled and we haven't found a valid route to
+ // the remote address yet, keep track of the first valid route. We keep
+ // iterating because we prefer routes that let us use a local address that
+ // is assigned to the outgoing interface. There is no requirement to do this
+ // from any RFC but simply a choice made to better follow a strong host
+ // model which the netstack follows at the time of writing.
+ if canForward && chosenRoute == (tcpip.Route{}) {
+ chosenRoute = route
+ }
+ }
+
+ if chosenRoute != (tcpip.Route{}) {
+ // At this point we know the stack has forwarding enabled since chosenRoute is
+ // only set when forwarding is enabled.
+ nic, ok := s.nics[chosenRoute.NIC]
+ if !ok {
+ // If the route's NIC was invalid, we should not have chosen the route.
+ panic(fmt.Sprintf("chosen route must have a valid NIC with ID = %d", chosenRoute.NIC))
+ }
+
+ var gateway tcpip.Address
+ if needRoute {
+ gateway = chosenRoute.Gateway
+ }
- r := makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback())
- if len(route.Gateway) > 0 {
- if needRoute {
- r.NextHop = route.Gateway
- }
- } else if subnet := addressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) {
- r.RemoteLinkAddress = header.EthernetBroadcastAddress
+ // Use the specified NIC to get the local address endpoint.
+ if id != 0 {
+ if aNIC, ok := s.nics[id]; ok {
+ if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil {
+ if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil {
+ return r, nil
}
+ }
+ }
+
+ return nil, tcpip.ErrNoRoute
+ }
+
+ if id == 0 {
+ // If an interface is not specified, try to find a NIC that holds the local
+ // address endpoint to construct a route.
+ for _, aNIC := range s.nics {
+ addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto)
+ if addressEndpoint == nil {
+ continue
+ }
+ if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil {
return r, nil
}
}
}
}
- if !needRoute {
- return Route{}, tcpip.ErrNetworkUnreachable
+ if needRoute {
+ return nil, tcpip.ErrNoRoute
}
-
- return Route{}, tcpip.ErrNoRoute
+ if header.IsV6LoopbackAddress(remoteAddr) {
+ return nil, tcpip.ErrBadLocalAddress
+ }
+ return nil, tcpip.ErrNetworkUnreachable
}
// CheckNetworkProtocol checks if a given network protocol is enabled in the
@@ -1323,7 +1528,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address,
fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
linkRes := s.linkAddrResolvers[protocol]
- return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.LinkEndpoint, waker)
+ return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, waker)
}
// Neighbors returns all IP to MAC address associations.
@@ -1443,8 +1648,8 @@ func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) {
// FindTransportEndpoint finds an endpoint that most closely matches the provided
// id. If no endpoint is found it returns nil.
-func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint {
- return s.demux.findTransportEndpoint(netProto, transProto, id, r)
+func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint {
+ return s.demux.findTransportEndpoint(netProto, transProto, id, nicID)
}
// RegisterRawTransportEndpoint registers the given endpoint with the stack
@@ -1615,49 +1820,20 @@ func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip
nic.unregisterPacketEndpoint(netProto, ep)
}
-// WritePacket writes data directly to the specified NIC. It adds an ethernet
-// header based on the arguments.
-func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error {
- s.mu.Lock()
- nic, ok := s.nics[nicID]
- s.mu.Unlock()
- if !ok {
- return tcpip.ErrUnknownDevice
- }
-
- // Add our own fake ethernet header.
- ethFields := header.EthernetFields{
- SrcAddr: nic.LinkEndpoint.LinkAddress(),
- DstAddr: dst,
- Type: netProto,
- }
- fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
- fakeHeader.Encode(&ethFields)
- vv := buffer.View(fakeHeader).ToVectorisedView()
- vv.Append(payload)
-
- if err := nic.LinkEndpoint.WriteRawPacket(vv); err != nil {
- return err
- }
-
- return nil
-}
-
-// WriteRawPacket writes data directly to the specified NIC without adding any
-// headers.
-func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error {
+// WritePacketToRemote writes a payload on the specified NIC using the provided
+// network protocol and remote link address.
+func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error {
s.mu.Lock()
nic, ok := s.nics[nicID]
s.mu.Unlock()
if !ok {
return tcpip.ErrUnknownDevice
}
-
- if err := nic.LinkEndpoint.WriteRawPacket(payload); err != nil {
- return err
- }
-
- return nil
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: int(nic.MaxHeaderLength()),
+ Data: payload,
+ })
+ return nic.WritePacketToRemote(remote, nil, netProto, pkt)
}
// NetworkProtocolInstance returns the protocol instance in the stack for the
@@ -1717,7 +1893,6 @@ func (s *Stack) RemoveTCPProbe() {
// JoinGroup joins the given multicast group on the given NIC.
func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
- // TODO: notify network of subscription via igmp protocol.
s.mu.RLock()
defer s.mu.RUnlock()
@@ -1896,3 +2071,111 @@ func (s *Stack) FindNICNameFromID(id tcpip.NICID) string {
func (s *Stack) NewJob(l sync.Locker, f func()) *tcpip.Job {
return tcpip.NewJob(s.clock, l, f)
}
+
+// ParseResult indicates the result of a parsing attempt.
+type ParseResult int
+
+const (
+ // ParsedOK indicates that a packet was successfully parsed.
+ ParsedOK ParseResult = iota
+
+ // UnknownNetworkProtocol indicates that the network protocol is unknown.
+ UnknownNetworkProtocol
+
+ // NetworkLayerParseError indicates that the network packet was not
+ // successfully parsed.
+ NetworkLayerParseError
+
+ // UnknownTransportProtocol indicates that the transport protocol is unknown.
+ UnknownTransportProtocol
+
+ // TransportLayerParseError indicates that the transport packet was not
+ // successfully parsed.
+ TransportLayerParseError
+)
+
+// ParsePacketBuffer parses the provided packet buffer.
+func (s *Stack) ParsePacketBuffer(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) ParseResult {
+ netProto, ok := s.networkProtocols[protocol]
+ if !ok {
+ return UnknownNetworkProtocol
+ }
+
+ transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt)
+ if !ok {
+ return NetworkLayerParseError
+ }
+ if !hasTransportHdr {
+ return ParsedOK
+ }
+
+ // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
+ // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a
+ // full explanation.
+ if transProtoNum == header.ICMPv4ProtocolNumber || transProtoNum == header.ICMPv6ProtocolNumber {
+ return ParsedOK
+ }
+
+ pkt.TransportProtocolNumber = transProtoNum
+ // Parse the transport header if present.
+ state, ok := s.transportProtocols[transProtoNum]
+ if !ok {
+ return UnknownTransportProtocol
+ }
+
+ if !state.proto.Parse(pkt) {
+ return TransportLayerParseError
+ }
+
+ return ParsedOK
+}
+
+// networkProtocolNumbers returns the network protocol numbers the stack is
+// configured with.
+func (s *Stack) networkProtocolNumbers() []tcpip.NetworkProtocolNumber {
+ protos := make([]tcpip.NetworkProtocolNumber, 0, len(s.networkProtocols))
+ for p := range s.networkProtocols {
+ protos = append(protos, p)
+ }
+ return protos
+}
+
+func isSubnetBroadcastOnNIC(nic *NIC, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ addressEndpoint := nic.getAddressOrCreateTempInner(protocol, addr, false /* createTemp */, NeverPrimaryEndpoint)
+ if addressEndpoint == nil {
+ return false
+ }
+
+ subnet := addressEndpoint.Subnet()
+ addressEndpoint.DecRef()
+ return subnet.IsBroadcast(addr)
+}
+
+// IsSubnetBroadcast returns true if the provided address is a subnet-local
+// broadcast address on the specified NIC and protocol.
+//
+// Returns false if the NIC is unknown or if the protocol is unknown or does
+// not support addressing.
+//
+// If the NIC is not specified, the stack will check all NICs.
+func (s *Stack) IsSubnetBroadcast(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nicID != 0 {
+ nic, ok := s.nics[nicID]
+ if !ok {
+ return false
+ }
+
+ return isSubnetBroadcastOnNIC(nic, protocol, addr)
+ }
+
+ for _, nic := range s.nics {
+ if isSubnetBroadcastOnNIC(nic, protocol, addr) {
+ return true
+ }
+ }
+
+ return false
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index e75f58c64..457990945 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -21,12 +21,12 @@ import (
"bytes"
"fmt"
"math"
+ "net"
"sort"
"testing"
"time"
"github.com/google/go-cmp/cmp"
- "github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -108,12 +108,21 @@ func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
return 123
}
-func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Increment the received packet count in the protocol descriptor.
- f.proto.packetCount[int(r.LocalAddress[0])%len(f.proto.packetCount)]++
+ netHdr := pkt.NetworkHeader().View()
+
+ dst := tcpip.Address(netHdr[dstAddrOffset:][:1])
+ addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
+ if addressEndpoint == nil {
+ return
+ }
+ addressEndpoint.DecRef()
+
+ f.proto.packetCount[int(dst[0])%len(f.proto.packetCount)]++
// Handle control packets.
- if pkt.NetworkHeader().View()[protocolNumberOffset] == uint8(fakeControlProtocol) {
+ if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) {
nb, ok := pkt.Data.PullUp(fakeNetHeaderLen)
if !ok {
return
@@ -129,7 +138,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff
}
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
+ f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -151,12 +160,13 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
// Add the protocol's header to the packet and send it to the link
// endpoint.
hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen)
+ pkt.NetworkProtocolNumber = fakeNetNumber
hdr[dstAddrOffset] = r.RemoteAddress[0]
hdr[srcAddrOffset] = r.LocalAddress[0]
hdr[protocolNumberOffset] = byte(params.Protocol)
if r.Loop&stack.PacketLoop != 0 {
- f.HandlePacket(r, pkt)
+ f.HandlePacket(pkt.Clone())
}
if r.Loop&stack.PacketOut == 0 {
return nil
@@ -254,6 +264,7 @@ func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProto
if !ok {
return 0, false, false
}
+ pkt.NetworkProtocolNumber = fakeNetNumber
return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true
}
@@ -395,7 +406,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro
return send(r, payload)
}
-func send(r stack.Route, payload buffer.View) *tcpip.Error {
+func send(r *stack.Route, payload buffer.View) *tcpip.Error {
return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: payload.ToVectorisedView(),
@@ -413,7 +424,7 @@ func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.En
}
}
-func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) {
+func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View) {
t.Helper()
ep.Drain()
if err := send(r, payload); err != nil {
@@ -424,7 +435,7 @@ func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.
}
}
-func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
t.Helper()
if gotErr := send(r, payload); gotErr != wantErr {
t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr)
@@ -1334,6 +1345,106 @@ func TestPromiscuousMode(t *testing.T) {
testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
}
+// TestExternalSendWithHandleLocal tests that the stack creates a non-local
+// route when spoofing or promiscuous mode are enabled.
+//
+// This test makes sure that packets are transmitted from the stack.
+func TestExternalSendWithHandleLocal(t *testing.T) {
+ const (
+ unspecifiedNICID = 0
+ nicID = 1
+
+ localAddr = tcpip.Address("\x01")
+ dstAddr = tcpip.Address("\x03")
+ )
+
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ tests := []struct {
+ name string
+ configureStack func(*testing.T, *stack.Stack)
+ }{
+ {
+ name: "Default",
+ configureStack: func(*testing.T, *stack.Stack) {},
+ },
+ {
+ name: "Spoofing",
+ configureStack: func(t *testing.T, s *stack.Stack) {
+ if err := s.SetSpoofing(nicID, true); err != nil {
+ t.Fatalf("s.SetSpoofing(%d, true): %s", nicID, err)
+ }
+ },
+ },
+ {
+ name: "Promiscuous",
+ configureStack: func(t *testing.T, s *stack.Stack) {
+ if err := s.SetPromiscuousMode(nicID, true); err != nil {
+ t.Fatalf("s.SetPromiscuousMode(%d, true): %s", nicID, err)
+ }
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, handleLocal := range []bool{true, false} {
+ t.Run(fmt.Sprintf("HandleLocal=%t", handleLocal), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ HandleLocal: handleLocal,
+ })
+
+ ep := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID, ep); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: nicID}})
+
+ test.configureStack(t, s)
+
+ r, err := s.FindRoute(unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, err)
+ }
+ defer r.Release()
+
+ if r.LocalAddress != localAddr {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, localAddr)
+ }
+ if r.RemoteAddress != dstAddr {
+ t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr)
+ }
+
+ if n := ep.Drain(); n != 0 {
+ t.Fatalf("got ep.Drain() = %d, want = 0", n)
+ }
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: fakeTransNumber,
+ TTL: 123,
+ TOS: stack.DefaultTOS,
+ }, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: buffer.NewView(10).ToVectorisedView(),
+ })); err != nil {
+ t.Fatalf("r.WritePacket(nil, _, _): %s", err)
+ }
+ if n := ep.Drain(); n != 1 {
+ t.Fatalf("got ep.Drain() = %d, want = 1", n)
+ }
+ })
+ }
+ })
+ }
+}
+
func TestSpoofingWithAddress(t *testing.T) {
localAddr := tcpip.Address("\x01")
nonExistentLocalAddr := tcpip.Address("\x02")
@@ -1451,15 +1562,15 @@ func TestSpoofingNoAddress(t *testing.T) {
// testSendTo(t, s, remoteAddr, ep, nil)
}
-func verifyRoute(gotRoute, wantRoute stack.Route) error {
+func verifyRoute(gotRoute, wantRoute *stack.Route) error {
if gotRoute.LocalAddress != wantRoute.LocalAddress {
return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress)
}
if gotRoute.RemoteAddress != wantRoute.RemoteAddress {
return fmt.Errorf("bad remote address: got %s, want = %s", gotRoute.RemoteAddress, wantRoute.RemoteAddress)
}
- if gotRoute.RemoteLinkAddress != wantRoute.RemoteLinkAddress {
- return fmt.Errorf("bad remote link address: got %s, want = %s", gotRoute.RemoteLinkAddress, wantRoute.RemoteLinkAddress)
+ if got, want := gotRoute.RemoteLinkAddress(), wantRoute.RemoteLinkAddress(); got != want {
+ return fmt.Errorf("bad remote link address: got %s, want = %s", got, want)
}
if gotRoute.NextHop != wantRoute.NextHop {
return fmt.Errorf("bad next-hop address: got %s, want = %s", gotRoute.NextHop, wantRoute.NextHop)
@@ -1491,7 +1602,7 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ if err := verifyRoute(r, &stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil {
t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1545,7 +1656,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ if err := verifyRoute(r, &stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1555,7 +1666,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ if err := verifyRoute(r, &stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1571,7 +1682,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ if err := verifyRoute(r, &stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err)
}
}
@@ -2108,88 +2219,6 @@ func TestNICStats(t *testing.T) {
}
}
-func TestNICForwarding(t *testing.T) {
- const nicID1 = 1
- const nicID2 = 2
- const dstAddr = tcpip.Address("\x03")
-
- tests := []struct {
- name string
- headerLen uint16
- }{
- {
- name: "Zero header length",
- },
- {
- name: "Non-zero header length",
- headerLen: 16,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- s.SetForwarding(fakeNetNumber, true)
-
- ep1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID1, ep1); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
- }
- if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err)
- }
-
- ep2 := channelLinkWithHeaderLength{
- Endpoint: channel.New(10, defaultMTU, ""),
- headerLength: test.headerLen,
- }
- if err := s.CreateNIC(nicID2, &ep2); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
- }
- if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil {
- t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err)
- }
-
- // Route all packets to dstAddr to NIC 2.
- {
- subnet, err := tcpip.NewSubnet(dstAddr, "\xff")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}})
- }
-
- // Send a packet to dstAddr.
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = dstAddr[0]
- ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- pkt, ok := ep2.Read()
- if !ok {
- t.Fatal("packet not forwarded")
- }
-
- // Test that the link's MaxHeaderLength is honoured.
- if capacity, want := pkt.Pkt.AvailableHeaderBytes(), int(test.headerLen); capacity != want {
- t.Errorf("got LinkHeader.AvailableLength() = %d, want = %d", capacity, want)
- }
-
- // Test that forwarding increments Tx stats correctly.
- if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want {
- t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want {
- t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
- }
- })
- }
-}
-
// TestNICContextPreservation tests that you can read out via stack.NICInfo the
// Context data you pass via NICContext.Context in stack.CreateNICWithOptions.
func TestNICContextPreservation(t *testing.T) {
@@ -2377,9 +2406,9 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
}
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: test.autoGen,
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: test.iidOpts,
+ AutoGenLinkLocal: test.autoGen,
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: test.iidOpts,
})},
}
@@ -2472,8 +2501,8 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: true,
- OpaqueIIDOpts: test.opaqueIIDOpts,
+ AutoGenLinkLocal: true,
+ OpaqueIIDOpts: test.opaqueIIDOpts,
})},
}
@@ -2506,9 +2535,9 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
ndpConfigs := ipv6.DefaultNDPConfigurations()
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ndpConfigs,
- AutoGenIPv6LinkLocal: true,
- NDPDisp: &ndpDisp,
+ NDPConfigs: ndpConfigs,
+ AutoGenLinkLocal: true,
+ NDPDisp: &ndpDisp,
})},
}
@@ -3321,11 +3350,16 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
remNetSubnetBcast := remNetSubnet.Broadcast()
tests := []struct {
- name string
- nicAddr tcpip.ProtocolAddress
- routes []tcpip.Route
- remoteAddr tcpip.Address
- expectedRoute stack.Route
+ name string
+ nicAddr tcpip.ProtocolAddress
+ routes []tcpip.Route
+ remoteAddr tcpip.Address
+ expectedLocalAddress tcpip.Address
+ expectedRemoteAddress tcpip.Address
+ expectedRemoteLinkAddress tcpip.LinkAddress
+ expectedNextHop tcpip.Address
+ expectedNetProto tcpip.NetworkProtocolNumber
+ expectedLoop stack.PacketLooping
}{
// Broadcast to a locally attached subnet populates the broadcast MAC.
{
@@ -3340,14 +3374,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv4SubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4Addr.Address,
- RemoteAddress: ipv4SubnetBcast,
- RemoteLinkAddress: header.EthernetBroadcastAddress,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: ipv4SubnetBcast,
+ expectedLocalAddress: ipv4Addr.Address,
+ expectedRemoteAddress: ipv4SubnetBcast,
+ expectedRemoteLinkAddress: header.EthernetBroadcastAddress,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut | stack.PacketLoop,
},
// Broadcast to a locally attached /31 subnet does not populate the
// broadcast MAC.
@@ -3363,13 +3395,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv4Subnet31Bcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4AddrPrefix31.Address,
- RemoteAddress: ipv4Subnet31Bcast,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: ipv4Subnet31Bcast,
+ expectedLocalAddress: ipv4AddrPrefix31.Address,
+ expectedRemoteAddress: ipv4Subnet31Bcast,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// Broadcast to a locally attached /32 subnet does not populate the
// broadcast MAC.
@@ -3385,13 +3415,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv4Subnet32Bcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4AddrPrefix32.Address,
- RemoteAddress: ipv4Subnet32Bcast,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: ipv4Subnet32Bcast,
+ expectedLocalAddress: ipv4AddrPrefix32.Address,
+ expectedRemoteAddress: ipv4Subnet32Bcast,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// IPv6 has no notion of a broadcast.
{
@@ -3406,13 +3434,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv6SubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv6Addr.Address,
- RemoteAddress: ipv6SubnetBcast,
- NetProto: header.IPv6ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: ipv6SubnetBcast,
+ expectedLocalAddress: ipv6Addr.Address,
+ expectedRemoteAddress: ipv6SubnetBcast,
+ expectedNetProto: header.IPv6ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// Broadcast to a remote subnet in the route table is send to the next-hop
// gateway.
@@ -3429,14 +3455,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: remNetSubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4Addr.Address,
- RemoteAddress: remNetSubnetBcast,
- NextHop: ipv4Gateway,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: remNetSubnetBcast,
+ expectedLocalAddress: ipv4Addr.Address,
+ expectedRemoteAddress: remNetSubnetBcast,
+ expectedNextHop: ipv4Gateway,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// Broadcast to an unknown subnet follows the default route. Note that this
// is essentially just routing an unknown destination IP, because w/o any
@@ -3454,14 +3478,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: remNetSubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4Addr.Address,
- RemoteAddress: remNetSubnetBcast,
- NextHop: ipv4Gateway,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: remNetSubnetBcast,
+ expectedLocalAddress: ipv4Addr.Address,
+ expectedRemoteAddress: remNetSubnetBcast,
+ expectedNextHop: ipv4Gateway,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
}
@@ -3490,10 +3512,27 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
t.Fatalf("got unexpected address length = %d bytes", l)
}
- if r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */); err != nil {
+ r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */)
+ if err != nil {
t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, test.remoteAddr, netProto, err)
- } else if diff := cmp.Diff(r, test.expectedRoute, cmpopts.IgnoreUnexported(r)); diff != "" {
- t.Errorf("route mismatch (-want +got):\n%s", diff)
+ }
+ if r.LocalAddress != test.expectedLocalAddress {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, test.expectedLocalAddress)
+ }
+ if r.RemoteAddress != test.expectedRemoteAddress {
+ t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, test.expectedRemoteAddress)
+ }
+ if got := r.RemoteLinkAddress(); got != test.expectedRemoteLinkAddress {
+ t.Errorf("got r.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddress)
+ }
+ if r.NextHop != test.expectedNextHop {
+ t.Errorf("got r.NextHop = %s, want = %s", r.NextHop, test.expectedNextHop)
+ }
+ if r.NetProto != test.expectedNetProto {
+ t.Errorf("got r.NetProto = %d, want = %d", r.NetProto, test.expectedNetProto)
+ }
+ if r.Loop != test.expectedLoop {
+ t.Errorf("got r.Loop = %x, want = %x", r.Loop, test.expectedLoop)
}
})
}
@@ -3672,3 +3711,515 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) {
t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix)
}
}
+
+// TestAddRoute tests Stack.AddRoute
+func TestAddRoute(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{})
+
+ subnet1, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ subnet2, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ expected := []tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: 1},
+ {Destination: subnet2, Gateway: "\x00", NIC: 1},
+ }
+
+ // Initialize the route table with one route.
+ s.SetRouteTable([]tcpip.Route{expected[0]})
+
+ // Add another route.
+ s.AddRoute(expected[1])
+
+ rt := s.GetRouteTable()
+ if got, want := len(rt), len(expected); got != want {
+ t.Fatalf("Unexpected route table length got = %d, want = %d", got, want)
+ }
+ for i, route := range rt {
+ if got, want := route, expected[i]; got != want {
+ t.Fatalf("Unexpected route got = %#v, want = %#v", got, want)
+ }
+ }
+}
+
+// TestRemoveRoutes tests Stack.RemoveRoutes
+func TestRemoveRoutes(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{})
+
+ addressToRemove := tcpip.Address("\x01")
+ subnet1, err := tcpip.NewSubnet(addressToRemove, "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ subnet2, err := tcpip.NewSubnet(addressToRemove, "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ subnet3, err := tcpip.NewSubnet("\x02", "\x02")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Initialize the route table with three routes.
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: 1},
+ {Destination: subnet2, Gateway: "\x00", NIC: 1},
+ {Destination: subnet3, Gateway: "\x00", NIC: 1},
+ })
+
+ // Remove routes with the specific address.
+ s.RemoveRoutes(func(r tcpip.Route) bool {
+ return r.Destination.ID() == addressToRemove
+ })
+
+ expected := []tcpip.Route{{Destination: subnet3, Gateway: "\x00", NIC: 1}}
+ rt := s.GetRouteTable()
+ if got, want := len(rt), len(expected); got != want {
+ t.Fatalf("Unexpected route table length got = %d, want = %d", got, want)
+ }
+ for i, route := range rt {
+ if got, want := route, expected[i]; got != want {
+ t.Fatalf("Unexpected route got = %#v, want = %#v", got, want)
+ }
+ }
+}
+
+func TestFindRouteWithForwarding(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+
+ nic1Addr = tcpip.Address("\x01")
+ nic2Addr = tcpip.Address("\x02")
+ remoteAddr = tcpip.Address("\x03")
+ )
+
+ type netCfg struct {
+ proto tcpip.NetworkProtocolNumber
+ factory stack.NetworkProtocolFactory
+ nic1Addr tcpip.Address
+ nic2Addr tcpip.Address
+ remoteAddr tcpip.Address
+ }
+
+ fakeNetCfg := netCfg{
+ proto: fakeNetNumber,
+ factory: fakeNetFactory,
+ nic1Addr: nic1Addr,
+ nic2Addr: nic2Addr,
+ remoteAddr: remoteAddr,
+ }
+
+ globalIPv6Addr1 := tcpip.Address(net.ParseIP("a::1").To16())
+ globalIPv6Addr2 := tcpip.Address(net.ParseIP("a::2").To16())
+
+ ipv6LinkLocalNIC1WithGlobalRemote := netCfg{
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1Addr: llAddr1,
+ nic2Addr: globalIPv6Addr2,
+ remoteAddr: globalIPv6Addr1,
+ }
+ ipv6GlobalNIC1WithLinkLocalRemote := netCfg{
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1Addr: globalIPv6Addr1,
+ nic2Addr: llAddr1,
+ remoteAddr: llAddr2,
+ }
+ ipv6GlobalNIC1WithLinkLocalMulticastRemote := netCfg{
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1Addr: globalIPv6Addr1,
+ nic2Addr: globalIPv6Addr2,
+ remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ }
+
+ tests := []struct {
+ name string
+
+ netCfg netCfg
+ forwardingEnabled bool
+
+ addrNIC tcpip.NICID
+ localAddr tcpip.Address
+
+ findRouteErr *tcpip.Error
+ dependentOnForwarding bool
+ }{
+ {
+ name: "forwarding disabled and localAddr not on specified NIC but route from different NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ addrNIC: nicID1,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr not on specified NIC but route from different NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: true,
+ addrNIC: nicID1,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and localAddr on specified NIC but route from different NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ addrNIC: nicID1,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr on specified NIC but route from different NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: true,
+ addrNIC: nicID1,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: true,
+ },
+ {
+ name: "forwarding disabled and localAddr on specified NIC and route from same NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ addrNIC: nicID2,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr on specified NIC and route from same NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: true,
+ addrNIC: nicID2,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and localAddr not on specified NIC but route from same NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ addrNIC: nicID2,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr not on specified NIC but route from same NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: true,
+ addrNIC: nicID2,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and localAddr on same NIC as route",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr on same NIC as route",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and localAddr on different NIC as route",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr on different NIC as route",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: true,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: true,
+ },
+ {
+ name: "forwarding disabled and specified NIC only has link-local addr with route on different NIC",
+ netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
+ forwardingEnabled: false,
+ addrNIC: nicID1,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and specified NIC only has link-local addr with route on different NIC",
+ netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
+ forwardingEnabled: true,
+ addrNIC: nicID1,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and link-local local addr with route on different NIC",
+ netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
+ forwardingEnabled: false,
+ localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and link-local local addr with route on same NIC",
+ netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and global local addr with route on same NIC",
+ netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and link-local local addr with route on same NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
+ forwardingEnabled: false,
+ localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and link-local local addr with route on same NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and global local addr with link-local remote on different NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
+ forwardingEnabled: false,
+ localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNetworkUnreachable,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and global local addr with link-local remote on different NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNetworkUnreachable,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and global local addr with link-local multicast remote on different NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
+ forwardingEnabled: false,
+ localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNetworkUnreachable,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and global local addr with link-local multicast remote on different NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNetworkUnreachable,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and global local addr with link-local multicast remote on same NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
+ forwardingEnabled: false,
+ localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and global local addr with link-local multicast remote on same NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{test.netCfg.factory},
+ })
+
+ ep1 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s:", nicID1, err)
+ }
+
+ ep2 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID2, ep2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s:", nicID2, err)
+ }
+
+ if err := s.AddAddress(nicID1, test.netCfg.proto, test.netCfg.nic1Addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, test.netCfg.proto, test.netCfg.nic1Addr, err)
+ }
+
+ if err := s.AddAddress(nicID2, test.netCfg.proto, test.netCfg.nic2Addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err)
+ }
+
+ if err := s.SetForwarding(test.netCfg.proto, test.forwardingEnabled); err != nil {
+ t.Fatalf("SetForwarding(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}})
+
+ r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
+ if r != nil {
+ defer r.Release()
+ }
+ if err != test.findRouteErr {
+ t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr)
+ }
+
+ if test.findRouteErr != nil {
+ return
+ }
+
+ if r.LocalAddress != test.localAddr {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, test.localAddr)
+ }
+ if r.RemoteAddress != test.netCfg.remoteAddr {
+ t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, test.netCfg.remoteAddr)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Sending a packet should always go through NIC2 since we only install a
+ // route to test.netCfg.remoteAddr through NIC2.
+ data := buffer.View([]byte{1, 2, 3, 4})
+ if err := send(r, data); err != nil {
+ t.Fatalf("send(_, _): %s", err)
+ }
+ if n := ep1.Drain(); n != 0 {
+ t.Errorf("got %d unexpected packets from ep1", n)
+ }
+ pkt, ok := ep2.Read()
+ if !ok {
+ t.Fatal("packet not sent through ep2")
+ }
+ if pkt.Route.LocalAddress != test.localAddr {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddr)
+ }
+ if pkt.Route.RemoteAddress != test.netCfg.remoteAddr {
+ t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.netCfg.remoteAddr)
+ }
+
+ if !test.forwardingEnabled || !test.dependentOnForwarding {
+ return
+ }
+
+ // Disabling forwarding when the route is dependent on forwarding being
+ // enabled should make the route invalid.
+ if err := s.SetForwarding(test.netCfg.proto, false); err != nil {
+ t.Fatalf("SetForwarding(%d, false): %s", test.netCfg.proto, err)
+ }
+ if err := send(r, data); err != tcpip.ErrInvalidEndpointState {
+ t.Fatalf("got send(_, _) = %s, want = %s", err, tcpip.ErrInvalidEndpointState)
+ }
+ if n := ep1.Drain(); n != 0 {
+ t.Errorf("got %d unexpected packets from ep1", n)
+ }
+ if n := ep2.Drain(); n != 0 {
+ t.Errorf("got %d unexpected packets from ep2", n)
+ }
+ })
+ }
+}
+
+func TestWritePacketToRemote(t *testing.T) {
+ const nicID = 1
+ const MTU = 1280
+ e := channel.New(1, MTU, linkAddr1)
+ s := stack.New(stack.Options{})
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("CreateNIC(%d) = %s", nicID, err)
+ }
+ tests := []struct {
+ name string
+ protocol tcpip.NetworkProtocolNumber
+ payload []byte
+ }{
+ {
+ name: "SuccessIPv4",
+ protocol: header.IPv4ProtocolNumber,
+ payload: []byte{1, 2, 3, 4},
+ },
+ {
+ name: "SuccessIPv6",
+ protocol: header.IPv6ProtocolNumber,
+ payload: []byte{5, 6, 7, 8},
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if err := s.WritePacketToRemote(nicID, linkAddr2, test.protocol, buffer.View(test.payload).ToVectorisedView()); err != nil {
+ t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s", err)
+ }
+
+ pkt, ok := e.Read()
+ if got, want := ok, true; got != want {
+ t.Fatalf("e.Read() = %t, want %t", got, want)
+ }
+ if got, want := pkt.Proto, test.protocol; got != want {
+ t.Fatalf("pkt.Proto = %d, want %d", got, want)
+ }
+ if got, want := pkt.Route.RemoteLinkAddress(), linkAddr2; got != want {
+ t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", got, want)
+ }
+ if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" {
+ t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+
+ t.Run("InvalidNICID", func(t *testing.T) {
+ if got, want := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView()), tcpip.ErrUnknownDevice; got != want {
+ t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", got, want)
+ }
+ pkt, ok := e.Read()
+ if got, want := ok, false; got != want {
+ t.Fatalf("e.Read() = %t, %v; want %t", got, pkt, want)
+ }
+ })
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 35e5b1a2e..f183ec6e4 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -152,10 +152,10 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) {
+func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) {
epsByNIC.mu.RLock()
- mpep, ok := epsByNIC.endpoints[r.nic.ID()]
+ mpep, ok := epsByNIC.endpoints[pkt.NICID]
if !ok {
if mpep, ok = epsByNIC.endpoints[0]; !ok {
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
@@ -165,20 +165,20 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p
// If this is a broadcast or multicast datagram, deliver the datagram to all
// endpoints bound to the right device.
- if isInboundMulticastOrBroadcast(r) {
- mpep.handlePacketAll(r, id, pkt)
+ if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) {
+ mpep.handlePacketAll(id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return
}
// multiPortEndpoints are guaranteed to have at least one element.
transEP := selectEndpoint(id, mpep, epsByNIC.seed)
if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
- queuedProtocol.QueuePacket(r, transEP, id, pkt)
+ queuedProtocol.QueuePacket(transEP, id, pkt)
epsByNIC.mu.RUnlock()
return
}
- transEP.HandlePacket(r, id, pkt)
+ transEP.HandlePacket(id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
}
@@ -253,6 +253,8 @@ func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t T
// based on endpoints IDs. It should only be instantiated via
// newTransportDemuxer.
type transportDemuxer struct {
+ stack *Stack
+
// protocol is immutable.
protocol map[protocolIDs]*transportEndpoints
queuedProtocols map[protocolIDs]queuedTransportProtocol
@@ -262,11 +264,12 @@ type transportDemuxer struct {
// the dispatcher to delivery packets to the QueuePacket method instead of
// calling HandlePacket directly on the endpoint.
type queuedTransportProtocol interface {
- QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer)
+ QueuePacket(ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer)
}
func newTransportDemuxer(stack *Stack) *transportDemuxer {
d := &transportDemuxer{
+ stack: stack,
protocol: make(map[protocolIDs]*transportEndpoints),
queuedProtocols: make(map[protocolIDs]queuedTransportProtocol),
}
@@ -377,22 +380,22 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
return mpep.endpoints[idx]
}
-func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt *PacketBuffer) {
+func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) {
ep.mu.RLock()
queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}]
// HandlePacket takes ownership of pkt, so each endpoint needs
// its own copy except for the final one.
for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] {
if mustQueue {
- queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone())
+ queuedProtocol.QueuePacket(endpoint, id, pkt.Clone())
} else {
- endpoint.HandlePacket(r, id, pkt.Clone())
+ endpoint.HandlePacket(id, pkt.Clone())
}
}
if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue {
- queuedProtocol.QueuePacket(r, endpoint, id, pkt)
+ queuedProtocol.QueuePacket(endpoint, id, pkt)
} else {
- endpoint.HandlePacket(r, id, pkt)
+ endpoint.HandlePacket(id, pkt)
}
ep.mu.RUnlock() // Don't use defer for performance reasons.
}
@@ -518,29 +521,29 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN
// deliverPacket attempts to find one or more matching transport endpoints, and
// then, if matches are found, delivers the packet to them. Returns true if
// the packet no longer needs to be handled.
-func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool {
- eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
+func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool {
+ eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}]
if !ok {
return false
}
// If the packet is a UDP broadcast or multicast, then find all matching
// transport endpoints.
- if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) {
+ if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(pkt, id.LocalAddress) {
eps.mu.RLock()
destEPs := eps.findAllEndpointsLocked(id)
eps.mu.RUnlock()
// Fail if we didn't find at least one matching transport endpoint.
if len(destEPs) == 0 {
- r.Stats().UDP.UnknownPortErrors.Increment()
+ d.stack.stats.UDP.UnknownPortErrors.Increment()
return false
}
// handlePacket takes ownership of pkt, so each endpoint needs its own
// copy except for the final one.
for _, ep := range destEPs[:len(destEPs)-1] {
- ep.handlePacket(r, id, pkt.Clone())
+ ep.handlePacket(id, pkt.Clone())
}
- destEPs[len(destEPs)-1].handlePacket(r, id, pkt)
+ destEPs[len(destEPs)-1].handlePacket(id, pkt)
return true
}
@@ -548,10 +551,10 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// destination address, then do nothing further and instruct the caller to do
// the same. The network layer handles address validation for specified source
// addresses.
- if protocol == header.TCPProtocolNumber && (!isSpecified(r.LocalAddress) || !isSpecified(r.RemoteAddress) || isInboundMulticastOrBroadcast(r)) {
+ if protocol == header.TCPProtocolNumber && (!isSpecified(id.LocalAddress) || !isSpecified(id.RemoteAddress) || isInboundMulticastOrBroadcast(pkt, id.LocalAddress)) {
// TCP can only be used to communicate between a single source and a
- // single destination; the addresses must be unicast.
- r.Stats().TCP.InvalidSegmentsReceived.Increment()
+ // single destination; the addresses must be unicast.e
+ d.stack.stats.TCP.InvalidSegmentsReceived.Increment()
return true
}
@@ -560,18 +563,18 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
eps.mu.RUnlock()
if ep == nil {
if protocol == header.UDPProtocolNumber {
- r.Stats().UDP.UnknownPortErrors.Increment()
+ d.stack.stats.UDP.UnknownPortErrors.Increment()
}
return false
}
- ep.handlePacket(r, id, pkt)
+ ep.handlePacket(id, pkt)
return true
}
// deliverRawPacket attempts to deliver the given packet and returns whether it
// was delivered successfully.
-func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool {
- eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
+func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool {
+ eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}]
if !ok {
return false
}
@@ -584,7 +587,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr
for _, rawEP := range eps.rawEndpoints {
// Each endpoint gets its own copy of the packet for the sake
// of save/restore.
- rawEP.HandlePacket(r, pkt)
+ rawEP.HandlePacket(pkt.Clone())
foundRaw = true
}
eps.mu.RUnlock()
@@ -612,7 +615,7 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco
}
// findTransportEndpoint find a single endpoint that most closely matches the provided id.
-func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint {
+func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint {
eps, ok := d.protocol[protocolIDs{netProto, transProto}]
if !ok {
return nil
@@ -628,7 +631,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN
epsByNIC.mu.RLock()
eps.mu.RUnlock()
- mpep, ok := epsByNIC.endpoints[r.nic.ID()]
+ mpep, ok := epsByNIC.endpoints[nicID]
if !ok {
if mpep, ok = epsByNIC.endpoints[0]; !ok {
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
@@ -679,8 +682,8 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN
eps.mu.Unlock()
}
-func isInboundMulticastOrBroadcast(r *Route) bool {
- return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress)
+func isInboundMulticastOrBroadcast(pkt *PacketBuffer, localAddr tcpip.Address) bool {
+ return pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(localAddr) || header.IsV6MulticastAddress(localAddr)
}
func isSpecified(addr tcpip.Address) bool {
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 698c8609e..a692af20b 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -102,7 +102,6 @@ func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NI
// Initialize the IP header.
ip := header.IPv4(buf)
ip.Encode(&header.IPv4Fields{
- IHL: header.IPv4MinimumSize,
TOS: 0x80,
TotalLength: uint16(len(buf)),
TTL: 65,
@@ -142,11 +141,11 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: 65,
- SrcAddr: testSrcAddrV6,
- DstAddr: testDstAddrV6,
+ PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: 65,
+ SrcAddr: testSrcAddrV6,
+ DstAddr: testDstAddrV6,
})
// Initialize the UDP header.
@@ -308,9 +307,7 @@ func TestBindToDeviceDistribution(t *testing.T) {
}(ep)
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.ReusePortOption, endpoint.reuse); err != nil {
- t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err)
- }
+ ep.SocketOptions().SetReusePort(endpoint.reuse)
bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
if err := ep.SetSockOpt(&bindToDeviceOption); err != nil {
t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", bindToDeviceOption, bindToDeviceOption, i, err)
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 62ab6d92f..66eb562ba 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -20,7 +20,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
@@ -28,7 +27,7 @@ import (
const (
fakeTransNumber tcpip.TransportProtocolNumber = 1
- fakeTransHeaderLen = 3
+ fakeTransHeaderLen int = 3
)
// fakeTransportEndpoint is a transport-layer protocol endpoint. It counts
@@ -39,14 +38,18 @@ const (
// use it.
type fakeTransportEndpoint struct {
stack.TransportEndpointInfo
+ tcpip.DefaultSocketOptionsHandler
proto *fakeTransportProtocol
peerAddr tcpip.Address
- route stack.Route
+ route *stack.Route
uniqueID uint64
// acceptQueue is non-nil iff bound.
- acceptQueue []fakeTransportEndpoint
+ acceptQueue []*fakeTransportEndpoint
+
+ // ops is used to set and get socket options.
+ ops tcpip.SocketOptions
}
func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo {
@@ -59,8 +62,14 @@ func (*fakeTransportEndpoint) Stats() tcpip.EndpointStats {
func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {}
+func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions {
+ return &f.ops
+}
+
func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
- return &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
+ ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
+ ep.ops.InitHandler(ep)
+ return ep
}
func (f *fakeTransportEndpoint) Abort() {
@@ -100,8 +109,8 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
return int64(len(v)), nil, nil
}
-func (*fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
+func (*fakeTransportEndpoint) Peek([][]byte) (int64, *tcpip.Error) {
+ return 0, nil
}
// SetSockOpt sets a socket option. Currently not supported.
@@ -109,21 +118,11 @@ func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Erro
return tcpip.ErrInvalidEndpointState
}
-// SetSockOptBool sets a socket option. Currently not supported.
-func (*fakeTransportEndpoint) SetSockOptBool(tcpip.SockOptBool, bool) *tcpip.Error {
- return tcpip.ErrInvalidEndpointState
-}
-
// SetSockOptInt sets a socket option. Currently not supported.
func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (*fakeTransportEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- return false, tcpip.ErrUnknownProtocolOption
-}
-
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
return -1, tcpip.ErrUnknownProtocolOption
@@ -186,7 +185,7 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai
}
a := f.acceptQueue[0]
f.acceptQueue = f.acceptQueue[1:]
- return &a, nil, nil
+ return a, nil, nil
}
func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
@@ -201,7 +200,7 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
); err != nil {
return err
}
- f.acceptQueue = []fakeTransportEndpoint{}
+ f.acceptQueue = []*fakeTransportEndpoint{}
return nil
}
@@ -213,20 +212,31 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro
return tcpip.FullAddress{}, nil
}
-func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ *stack.PacketBuffer) {
+func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Increment the number of received packets.
f.proto.packetCount++
- if f.acceptQueue != nil {
- f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{
- TransportEndpointInfo: stack.TransportEndpointInfo{
- ID: f.ID,
- NetProto: f.NetProto,
- },
- proto: f.proto,
- peerAddr: r.RemoteAddress,
- route: r.Clone(),
- })
+ if f.acceptQueue == nil {
+ return
+ }
+
+ netHdr := pkt.NetworkHeader().View()
+ route, err := f.proto.stack.FindRoute(pkt.NICID, tcpip.Address(netHdr[dstAddrOffset]), tcpip.Address(netHdr[srcAddrOffset]), pkt.NetworkProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return
}
+ route.ResolveWith(pkt.SourceLinkAddress())
+
+ ep := &fakeTransportEndpoint{
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ ID: f.ID,
+ NetProto: f.NetProto,
+ },
+ proto: f.proto,
+ peerAddr: route.RemoteAddress,
+ route: route,
+ }
+ ep.ops.InitHandler(ep)
+ f.acceptQueue = append(f.acceptQueue, ep)
}
func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) {
@@ -288,7 +298,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp
return 0, 0, nil
}
-func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
+func (*fakeTransportProtocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
return stack.UnknownDestinationPacketHandled
}
@@ -544,87 +554,3 @@ func TestTransportOptions(t *testing.T) {
t.Fatalf("got tcpip.TCPModerateReceiveBufferOption = false, want = true")
}
}
-
-func TestTransportForwarding(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory},
- })
- s.SetForwarding(fakeNetNumber, true)
-
- // TODO(b/123449044): Change this to a channel NIC.
- ep1 := loopback.New()
- if err := s.CreateNIC(1, ep1); err != nil {
- t.Fatalf("CreateNIC #1 failed: %v", err)
- }
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress #1 failed: %v", err)
- }
-
- ep2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, ep2); err != nil {
- t.Fatalf("CreateNIC #2 failed: %v", err)
- }
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatalf("AddAddress #2 failed: %v", err)
- }
-
- // Route all packets to address 3 to NIC 2 and all packets to address
- // 1 to NIC 1.
- {
- subnet0, err := tcpip.NewSubnet("\x03", "\xff")
- if err != nil {
- t.Fatal(err)
- }
- subnet1, err := tcpip.NewSubnet("\x01", "\xff")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{
- {Destination: subnet0, Gateway: "\x00", NIC: 2},
- {Destination: subnet1, Gateway: "\x00", NIC: 1},
- })
- }
-
- wq := waiter.Queue{}
- ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if err := ep.Bind(tcpip.FullAddress{Addr: "\x01", NIC: 1}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Send a packet to address 1 from address 3.
- req := buffer.NewView(30)
- req[0] = 1
- req[1] = 3
- req[2] = byte(fakeTransNumber)
- ep2.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: req.ToVectorisedView(),
- }))
-
- aep, _, err := ep.Accept(nil)
- if err != nil || aep == nil {
- t.Fatalf("Accept failed: %v, %v", aep, err)
- }
-
- resp := buffer.NewView(30)
- if _, _, err := aep.Write(tcpip.SlicePayload(resp), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- p, ok := ep2.Read()
- if !ok {
- t.Fatal("Response packet not forwarded")
- }
-
- nh := stack.PayloadSince(p.Pkt.NetworkHeader())
- if dst := nh[0]; dst != 3 {
- t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst)
- }
- if src := nh[1]; src != 1 {
- t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src)
- }
-}