diff options
Diffstat (limited to 'pkg/tcpip/stack')
23 files changed, 3253 insertions, 1748 deletions
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index 4d3acab96..9478f3fb7 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -272,6 +272,9 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address addrState = &addressState{ addressableEndpointState: a, addr: addr, + // Cache the subnet in addrState to avoid calls to addr.Subnet() as that + // results in allocations on every call. + subnet: addr.Subnet(), } a.mu.endpoints[addr.Address] = addrState addrState.mu.Lock() @@ -361,6 +364,8 @@ func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) * return tcpip.ErrInvalidEndpointState } + a.mu.Lock() + defer a.mu.Unlock() return a.removePermanentEndpointLocked(addrState) } @@ -664,7 +669,7 @@ var _ AddressEndpoint = (*addressState)(nil) type addressState struct { addressableEndpointState *AddressableEndpointState addr tcpip.AddressWithPrefix - + subnet tcpip.Subnet // Lock ordering (from outer to inner lock ordering): // // AddressableEndpointState.mu @@ -684,6 +689,11 @@ func (a *addressState) AddressWithPrefix() tcpip.AddressWithPrefix { return a.addr } +// Subnet implements AddressEndpoint. +func (a *addressState) Subnet() tcpip.Subnet { + return a.subnet +} + // GetKind implements AddressEndpoint. func (a *addressState) GetKind() AddressKind { a.mu.RLock() diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 0cd1da11f..9a17efcba 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -269,7 +269,7 @@ func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) { return nil, dirOriginal } -func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *RedirectTarget) *conn { +func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn { tid, err := packetToTupleID(pkt) if err != nil { return nil @@ -282,8 +282,8 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *Redire // rule. This tuple will be used to manipulate the packet in // handlePacket. replyTID := tid.reply() - replyTID.srcAddr = rt.Addr - replyTID.srcPort = rt.Port + replyTID.srcAddr = address + replyTID.srcPort = port var manip manipType switch hook { case Prerouting: @@ -401,12 +401,12 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d // Calculate the TCP checksum and set it. tcpHeader.SetChecksum(0) - length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) - xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length) + length := uint16(len(tcpHeader) + pkt.Data.Size()) + xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) if gso != nil && gso.NeedsCsum { tcpHeader.SetChecksum(xsum) - } else if r.Capabilities()&CapabilityTXChecksumOffload == 0 { - xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, int(tcpHeader.DataOffset()), pkt.Data.Size()) + } else if r.RequiresTXTransportChecksum() { + xsum = header.ChecksumVV(pkt.Data, xsum) tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) } diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index cf042309e..7a501acdc 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -73,9 +73,9 @@ func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { return 123 } -func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { +func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) { // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) } func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { @@ -178,7 +178,7 @@ func (*fwdTestNetworkProtocol) Close() {} func (*fwdTestNetworkProtocol) Wait() {} -func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { +func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { if f.onLinkAddressResolved != nil { time.AfterFunc(f.addrResolveDelay, func() { f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr) diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 8d6d9a7f1..2d8c883cd 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -22,30 +22,17 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) -// tableID is an index into IPTables.tables. -type tableID int +// TableID identifies a specific table. +type TableID int +// Each value identifies a specific table. const ( - natID tableID = iota - mangleID - filterID - numTables + NATID TableID = iota + MangleID + FilterID + NumTables ) -// Table names. -const ( - NATTable = "nat" - MangleTable = "mangle" - FilterTable = "filter" -) - -// nameToID is immutable. -var nameToID = map[string]tableID{ - NATTable: natID, - MangleTable: mangleID, - FilterTable: filterID, -} - // HookUnset indicates that there is no hook set for an entrypoint or // underflow. const HookUnset = -1 @@ -57,8 +44,8 @@ const reaperDelay = 5 * time.Second // all packets. func DefaultTables() *IPTables { return &IPTables{ - v4Tables: [numTables]Table{ - natID: Table{ + v4Tables: [NumTables]Table{ + NATID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, @@ -81,7 +68,7 @@ func DefaultTables() *IPTables { Postrouting: 3, }, }, - mangleID: Table{ + MangleID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, @@ -99,7 +86,7 @@ func DefaultTables() *IPTables { Postrouting: HookUnset, }, }, - filterID: Table{ + FilterID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, @@ -122,8 +109,8 @@ func DefaultTables() *IPTables { }, }, }, - v6Tables: [numTables]Table{ - natID: Table{ + v6Tables: [NumTables]Table{ + NATID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, @@ -146,7 +133,7 @@ func DefaultTables() *IPTables { Postrouting: 3, }, }, - mangleID: Table{ + MangleID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, @@ -164,7 +151,7 @@ func DefaultTables() *IPTables { Postrouting: HookUnset, }, }, - filterID: Table{ + FilterID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, @@ -187,10 +174,10 @@ func DefaultTables() *IPTables { }, }, }, - priorities: [NumHooks][]tableID{ - Prerouting: []tableID{mangleID, natID}, - Input: []tableID{natID, filterID}, - Output: []tableID{mangleID, natID, filterID}, + priorities: [NumHooks][]TableID{ + Prerouting: []TableID{MangleID, NATID}, + Input: []TableID{NATID, FilterID}, + Output: []TableID{MangleID, NATID, FilterID}, }, connections: ConnTrack{ seed: generateRandUint32(), @@ -229,26 +216,20 @@ func EmptyNATTable() Table { } } -// GetTable returns a table by name. -func (it *IPTables) GetTable(name string, ipv6 bool) (Table, bool) { - id, ok := nameToID[name] - if !ok { - return Table{}, false - } +// GetTable returns a table with the given id and IP version. It panics when an +// invalid id is provided. +func (it *IPTables) GetTable(id TableID, ipv6 bool) Table { it.mu.RLock() defer it.mu.RUnlock() if ipv6 { - return it.v6Tables[id], true + return it.v6Tables[id] } - return it.v4Tables[id], true + return it.v4Tables[id] } -// ReplaceTable replaces or inserts table by name. -func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Error { - id, ok := nameToID[name] - if !ok { - return tcpip.ErrInvalidOptionValue - } +// ReplaceTable replaces or inserts table by name. It panics when an invalid id +// is provided. +func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) *tcpip.Error { it.mu.Lock() defer it.mu.Unlock() // If iptables is being enabled, initialize the conntrack table and @@ -311,7 +292,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer for _, tableID := range priorities { // If handlePacket already NATed the packet, we don't need to // check the NAT table. - if tableID == natID && pkt.NatDone { + if tableID == NATID && pkt.NatDone { continue } var table Table diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 538c4625d..d63e9757c 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -15,6 +15,8 @@ package stack import ( + "fmt" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -26,13 +28,6 @@ type AcceptTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (at *AcceptTarget) ID() TargetID { - return TargetID{ - NetworkProtocol: at.NetworkProtocol, - } -} - // Action implements Target.Action. func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleAccept, 0 @@ -44,22 +39,11 @@ type DropTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (dt *DropTarget) ID() TargetID { - return TargetID{ - NetworkProtocol: dt.NetworkProtocol, - } -} - // Action implements Target.Action. func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleDrop, 0 } -// ErrorTargetName is used to mark targets as error targets. Error targets -// shouldn't be reached - an error has occurred if we fall through to one. -const ErrorTargetName = "ERROR" - // ErrorTarget logs an error and drops the packet. It represents a target that // should be unreachable. type ErrorTarget struct { @@ -67,14 +51,6 @@ type ErrorTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (et *ErrorTarget) ID() TargetID { - return TargetID{ - Name: ErrorTargetName, - NetworkProtocol: et.NetworkProtocol, - } -} - // Action implements Target.Action. func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") @@ -90,14 +66,6 @@ type UserChainTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (uc *UserChainTarget) ID() TargetID { - return TargetID{ - Name: ErrorTargetName, - NetworkProtocol: uc.NetworkProtocol, - } -} - // Action implements Target.Action. func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { panic("UserChainTarget should never be called.") @@ -110,50 +78,39 @@ type ReturnTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (rt *ReturnTarget) ID() TargetID { - return TargetID{ - NetworkProtocol: rt.NetworkProtocol, - } -} - // Action implements Target.Action. func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleReturn, 0 } -// RedirectTargetName is used to mark targets as redirect targets. Redirect -// targets should be reached for only NAT and Mangle tables. These targets will -// change the destination port/destination IP for packets. -const RedirectTargetName = "REDIRECT" - -// RedirectTarget redirects the packet by modifying the destination port/IP. +// RedirectTarget redirects the packet to this machine by modifying the +// destination port/IP. Outgoing packets are redirected to the loopback device, +// and incoming packets are redirected to the incoming interface (rather than +// forwarded). +// // TODO(gvisor.dev/issue/170): Other flags need to be added after we support // them. type RedirectTarget struct { - // Addr indicates address used to redirect. - Addr tcpip.Address - - // Port indicates port used to redirect. + // Port indicates port used to redirect. It is immutable. Port uint16 - // NetworkProtocol is the network protocol the target is used with. + // NetworkProtocol is the network protocol the target is used with. It + // is immutable. NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (rt *RedirectTarget) ID() TargetID { - return TargetID{ - Name: RedirectTargetName, - NetworkProtocol: rt.NetworkProtocol, - } -} - // Action implements Target.Action. // TODO(gvisor.dev/issue/170): Parse headers without copying. The current -// implementation only works for PREROUTING and calls pkt.Clone(), neither +// implementation only works for Prerouting and calls pkt.Clone(), neither // of which should be the case. func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { + // Sanity check. + if rt.NetworkProtocol != pkt.NetworkProtocolNumber { + panic(fmt.Sprintf( + "RedirectTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d", + rt.NetworkProtocol, pkt.NetworkProtocolNumber)) + } + // Packet is already manipulated. if pkt.NatDone { return RuleAccept, 0 @@ -164,17 +121,17 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs return RuleDrop, 0 } - // Change the address to localhost (127.0.0.1 or ::1) in Output and to + // Change the address to loopback (127.0.0.1 or ::1) in Output and to // the primary address of the incoming interface in Prerouting. switch hook { case Output: if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - rt.Addr = tcpip.Address([]byte{127, 0, 0, 1}) + address = tcpip.Address([]byte{127, 0, 0, 1}) } else { - rt.Addr = header.IPv6Loopback + address = header.IPv6Loopback } case Prerouting: - rt.Addr = address + // No-op, as address is already set correctly. default: panic("redirect target is supported only on output and prerouting hooks") } @@ -189,21 +146,18 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs // Calculate UDP checksum and set it. if hook == Output { udpHeader.SetChecksum(0) + netHeader := pkt.Network() + netHeader.SetDestinationAddress(address) // Only calculate the checksum if offloading isn't supported. - if r.Capabilities()&CapabilityTXChecksumOffload == 0 { + if r.RequiresTXTransportChecksum() { length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) - xsum := r.PseudoHeaderChecksum(protocol, length) - for _, v := range pkt.Data.Views() { - xsum = header.Checksum(v, xsum) - } - udpHeader.SetChecksum(0) + xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) + xsum = header.ChecksumVV(pkt.Data, xsum) udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) } } - pkt.Network().SetDestinationAddress(rt.Addr) - // After modification, IPv4 packets need a valid checksum. if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { netHeader := header.IPv4(pkt.NetworkHeader().View()) @@ -219,7 +173,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs // Set up conection for matching NAT rule. Only the first // packet of the connection comes here. Other packets will be // manipulated in connection tracking. - if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil { + if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil { ct.handlePacket(pkt, hook, gso, r) } default: diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 7b3f3e88b..4b86c1be9 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -37,7 +37,6 @@ import ( // ----->[Prerouting]----->routing----->[Forward]---------[Postrouting]-----> type Hook uint -// These values correspond to values in include/uapi/linux/netfilter.h. const ( // Prerouting happens before a packet is routed to applications or to // be forwarded. @@ -86,8 +85,8 @@ type IPTables struct { mu sync.RWMutex // v4Tables and v6tables map tableIDs to tables. They hold builtin // tables only, not user tables. mu must be locked for accessing. - v4Tables [numTables]Table - v6Tables [numTables]Table + v4Tables [NumTables]Table + v6Tables [NumTables]Table // modified is whether tables have been modified at least once. It is // used to elide the iptables performance overhead for workloads that // don't utilize iptables. @@ -96,7 +95,7 @@ type IPTables struct { // priorities maps each hook to a list of table names. The order of the // list is the order in which each table should be visited for that // hook. It is immutable. - priorities [NumHooks][]tableID + priorities [NumHooks][]TableID connections ConnTrack @@ -104,6 +103,24 @@ type IPTables struct { reaperDone chan struct{} } +// VisitTargets traverses all the targets of all tables and replaces each with +// transform(target). +func (it *IPTables) VisitTargets(transform func(Target) Target) { + it.mu.Lock() + defer it.mu.Unlock() + + for tid := range it.v4Tables { + for i, rule := range it.v4Tables[tid].Rules { + it.v4Tables[tid].Rules[i].Target = transform(rule.Target) + } + } + for tid := range it.v6Tables { + for i, rule := range it.v6Tables[tid].Rules { + it.v6Tables[tid].Rules[i].Target = transform(rule.Target) + } + } +} + // A Table defines a set of chains and hooks into the network stack. // // It is a list of Rules, entry points (BuiltinChains), and error handlers @@ -169,7 +186,6 @@ type IPHeaderFilter struct { // CheckProtocol determines whether the Protocol field should be // checked during matching. - // TODO(gvisor.dev/issue/3549): Check this field during matching. CheckProtocol bool // Dst matches the destination IP address. @@ -309,23 +325,8 @@ type Matcher interface { Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool) } -// A TargetID uniquely identifies a target. -type TargetID struct { - // Name is the target name as stored in the xt_entry_target struct. - Name string - - // NetworkProtocol is the protocol to which the target applies. - NetworkProtocol tcpip.NetworkProtocolNumber - - // Revision is the version of the target. - Revision uint8 -} - // A Target is the interface for taking an action for a packet. type Target interface { - // ID uniquely identifies the Target. - ID() TargetID - // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 6f73a0ce4..c9b13cd0e 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -180,7 +180,7 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt } // get reports any known link address for k. -func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { if linkRes != nil { if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok { return addr, nil, nil @@ -221,7 +221,7 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo } entry.done = make(chan struct{}) - go c.startAddressResolution(k, linkRes, localAddr, linkEP, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + go c.startAddressResolution(k, linkRes, localAddr, nic, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. } return entry.linkAddr, entry.done, tcpip.ErrWouldBlock @@ -240,11 +240,11 @@ func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) { } } -func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, done <-chan struct{}) { +func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) { for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check // whether the request succeeded. - linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, linkEP) + linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, nic) select { case now := <-time.After(c.resolutionTimeout): diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 33806340e..d2e37f38d 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -49,8 +49,8 @@ type testLinkAddressResolver struct { onLinkAddressRequest func() } -func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error { - time.AfterFunc(r.delay, func() { r.fakeRequest(addr) }) +func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { + time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) }) if f := r.onLinkAddressRequest; f != nil { f() } diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 4df288798..177bf5516 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -16,7 +16,6 @@ package stack import ( "fmt" - "time" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" @@ -68,7 +67,7 @@ var _ NUDHandler = (*neighborCache)(nil) // reset to state incomplete, and returned. If no matching entry exists and the // cache is not full, a new entry with state incomplete is allocated and // returned. -func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry { +func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry { n.mu.Lock() defer n.mu.Unlock() @@ -84,7 +83,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, li // The entry that needs to be created must be dynamic since all static // entries are directly added to the cache via addStaticEntry. - entry := newNeighborEntry(n.nic, remoteAddr, localAddr, n.state, linkRes) + entry := newNeighborEntry(n.nic, remoteAddr, n.state, linkRes) if n.dynamic.count == neighborCacheSize { e := n.dynamic.lru.Back() e.mu.Lock() @@ -111,28 +110,31 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, li // provided, it will be notified when address resolution is complete (success // or not). // +// If specified, the local address must be an address local to the interface the +// neighbor cache belongs to. The local address is the source address of a +// packet prompting NUD/link address resolution. +// // If address resolution is required, ErrNoLinkAddress and a notification // channel is returned for the top level caller to block. Channel is closed // once address resolution is complete (success or not). func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) { if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok { e := NeighborEntry{ - Addr: remoteAddr, - LocalAddr: localAddr, - LinkAddr: linkAddr, - State: Static, - UpdatedAt: time.Now(), + Addr: remoteAddr, + LinkAddr: linkAddr, + State: Static, + UpdatedAtNanos: 0, } return e, nil, nil } - entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes) + entry := n.getOrCreateEntry(remoteAddr, linkRes) entry.mu.Lock() defer entry.mu.Unlock() switch s := entry.neigh.State; s { case Stale: - entry.handlePacketQueuedLocked() + entry.handlePacketQueuedLocked(localAddr) fallthrough case Reachable, Static, Delay, Probe: // As per RFC 4861 section 7.3.3: @@ -152,7 +154,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA entry.done = make(chan struct{}) } - entry.handlePacketQueuedLocked() + entry.handlePacketQueuedLocked(localAddr) return entry.neigh, entry.done, tcpip.ErrWouldBlock case Failed: return entry.neigh, nil, tcpip.ErrNoLinkAddress @@ -207,7 +209,7 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd } else { // Static entry found with the same address but different link address. entry.neigh.LinkAddr = linkAddr - entry.dispatchChangeEventLocked(entry.neigh.State) + entry.dispatchChangeEventLocked() entry.mu.Unlock() return } @@ -220,8 +222,7 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd entry.mu.Unlock() } - entry := newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) - n.cache[addr] = entry + n.cache[addr] = newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) } // removeEntryLocked removes the specified entry from the neighbor cache. @@ -292,8 +293,8 @@ func (n *neighborCache) setConfig(config NUDConfigurations) { // HandleProbe implements NUDHandler.HandleProbe by following the logic defined // in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled // by the caller. -func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { - entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes) +func (n *neighborCache) HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { + entry := n.getOrCreateEntry(remoteAddr, linkRes) entry.mu.Lock() entry.handleProbeLocked(remoteLinkAddr) entry.mu.Unlock() diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index fcd54ed83..ed33418f3 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -61,23 +61,20 @@ const ( ) // entryDiffOpts returns the options passed to cmp.Diff to compare neighbor -// entries. The UpdatedAt field is ignored due to a lack of a deterministic -// method to predict the time that an event will be dispatched. +// entries. The UpdatedAtNanos field is ignored due to a lack of a +// deterministic method to predict the time that an event will be dispatched. func entryDiffOpts() []cmp.Option { return []cmp.Option{ - cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"), + cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAtNanos"), } } // entryDiffOptsWithSort is like entryDiffOpts but also includes an option to // sort slices of entries for cases where ordering must be ignored. func entryDiffOptsWithSort() []cmp.Option { - return []cmp.Option{ - cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"), - cmpopts.SortSlices(func(a, b NeighborEntry) bool { - return strings.Compare(string(a.Addr), string(b.Addr)) < 0 - }), - } + return append(entryDiffOpts(), cmpopts.SortSlices(func(a, b NeighborEntry) bool { + return strings.Compare(string(a.Addr), string(b.Addr)) < 0 + })) } func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache { @@ -128,9 +125,8 @@ func newTestEntryStore() *testEntryStore { linkAddr := toLinkAddress(i) store.entriesMap[addr] = NeighborEntry{ - Addr: addr, - LocalAddr: testEntryLocalAddr, - LinkAddr: linkAddr, + Addr: addr, + LinkAddr: linkAddr, } } return store @@ -195,10 +191,10 @@ type testNeighborResolver struct { var _ LinkAddressResolver = (*testNeighborResolver)(nil) -func (r *testNeighborResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { +func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { // Delay handling the request to emulate network latency. r.clock.AfterFunc(r.delay, func() { - r.fakeRequest(addr) + r.fakeRequest(targetAddr) }) // Execute post address resolution action, if available. @@ -294,9 +290,8 @@ func TestNeighborCacheEntry(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -305,15 +300,19 @@ func TestNeighborCacheEntry(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -324,8 +323,8 @@ func TestNeighborCacheEntry(t *testing.T) { t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil { + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } // No more events should have been dispatched. @@ -354,9 +353,9 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -365,15 +364,19 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -391,9 +394,11 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -404,8 +409,8 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { } } - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } } @@ -452,8 +457,8 @@ func (c *testContext) overflowCache(opts overflowOptions) error { if !ok { return fmt.Errorf("c.store.entry(%d) not found", i) } - if _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil); err != tcpip.ErrWouldBlock { - return fmt.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { + return fmt.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(c.neigh.config().RetransmitTimer) @@ -470,23 +475,29 @@ func (c *testContext) overflowCache(opts overflowOptions) error { wantEvents = append(wantEvents, testEntryEventInfo{ EventType: entryTestRemoved, NICID: 1, - Addr: removedEntry.Addr, - LinkAddr: removedEntry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + }, }) } wantEvents = append(wantEvents, testEntryEventInfo{ EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, testEntryEventInfo{ EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }) c.nudDisp.mu.Lock() @@ -508,10 +519,9 @@ func (c *testContext) overflowCache(opts overflowOptions) error { return fmt.Errorf("c.store.entry(%d) not found", i) } wantEntry := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } @@ -564,24 +574,27 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { if !ok { t.Fatalf("c.store.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(c.neigh.config().RetransmitTimer) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -600,9 +613,11 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -640,9 +655,11 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -682,9 +699,11 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -703,9 +722,11 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -740,9 +761,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -760,9 +783,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -800,24 +825,27 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { if !ok { t.Fatalf("c.store.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -836,16 +864,20 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -861,10 +893,9 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { startAtEntryIndex: 1, wantStaticEntries: []NeighborEntry{ { - Addr: entry.Addr, - LocalAddr: "", // static entries don't need a local address - LinkAddr: staticLinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, }, }, } @@ -896,12 +927,12 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w) if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, _ = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Fatalf("got neigh.entry(%s, '', _, _ = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } if doneCh == nil { - t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr) } clock.Advance(typicalLatency) @@ -913,7 +944,7 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) { id, ok := s.Fetch(false /* block */) if !ok { - t.Errorf("expected waker to be notified after neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + t.Errorf("expected waker to be notified after neigh.entry(%s, '', _, _)", entry.Addr) } if id != wakerID { t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID) @@ -923,15 +954,19 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -964,12 +999,12 @@ func TestNeighborCacheRemoveWaker(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w) if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, _) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Fatalf("got neigh.entry(%s, '', _, _) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } if doneCh == nil { - t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr) } // Remove the waker before the neighbor cache has the opportunity to send a @@ -991,15 +1026,19 @@ func TestNeighborCacheRemoveWaker(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1028,10 +1067,9 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: "", // static entries don't need a local address - LinkAddr: entry.LinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) @@ -1041,9 +1079,11 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -1058,10 +1098,9 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { startAtEntryIndex: 1, wantStaticEntries: []NeighborEntry{ { - Addr: entry.Addr, - LocalAddr: "", // static entries don't need a local address - LinkAddr: entry.LinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, }, }, } @@ -1089,9 +1128,8 @@ func TestNeighborCacheClear(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -1099,15 +1137,19 @@ func TestNeighborCacheClear(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1126,9 +1168,11 @@ func TestNeighborCacheClear(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + }, }, } nudDisp.mu.Lock() @@ -1149,16 +1193,20 @@ func TestNeighborCacheClear(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, { EventType: entryTestRemoved, NICID: 1, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + }, }, } nudDisp.mu.Lock() @@ -1185,24 +1233,27 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { if !ok { t.Fatalf("c.store.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -1220,9 +1271,11 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -1274,29 +1327,33 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { case <-doneCh: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) } wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1312,9 +1369,8 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { for i := neighborCacheSize; i < store.size(); i++ { // Periodically refresh the frequently used entry if i%(neighborCacheSize/2) == 0 { - _, _, err := neigh.entry(frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, linkRes, nil) - if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, err) + if _, _, err := neigh.entry(frequentlyUsedEntry.Addr, "", linkRes, nil); err != nil { + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", frequentlyUsedEntry.Addr, err) } } @@ -1322,15 +1378,15 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { case <-doneCh: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) } // An entry should have been removed, as per the LRU eviction strategy @@ -1342,22 +1398,28 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: removedEntry.Addr, - LinkAddr: removedEntry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + }, }, { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1374,10 +1436,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { // have to be sorted before comparison. wantUnsortedEntries := []NeighborEntry{ { - Addr: frequentlyUsedEntry.Addr, - LocalAddr: frequentlyUsedEntry.LocalAddr, - LinkAddr: frequentlyUsedEntry.LinkAddr, - State: Reachable, + Addr: frequentlyUsedEntry.Addr, + LinkAddr: frequentlyUsedEntry.LinkAddr, + State: Reachable, }, } @@ -1387,10 +1448,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { t.Fatalf("store.entry(%d) not found", i) } wantEntry := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } @@ -1430,9 +1490,8 @@ func TestNeighborCacheConcurrent(t *testing.T) { wg.Add(1) go func(entry NeighborEntry) { defer wg.Done() - e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != nil && err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, entry.LocalAddr, e, err, tcpip.ErrWouldBlock) + if e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil && err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock) } }(entry) } @@ -1456,10 +1515,9 @@ func TestNeighborCacheConcurrent(t *testing.T) { t.Errorf("store.entry(%d) not found", i) } wantEntry := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } @@ -1488,37 +1546,36 @@ func TestNeighborCacheReplace(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { case <-doneCh: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) } // Verify the entry exists { - e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + e, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } if doneCh != nil { - t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh) + t.Errorf("unexpected done channel from neigh.entry(%s, '', _, nil): %v", entry.Addr, doneCh) } if t.Failed() { t.FailNow() } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } } @@ -1542,37 +1599,35 @@ func TestNeighborCacheReplace(t *testing.T) { // // Verify the entry's new link address and the new state. { - e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Fatalf("neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + t.Fatalf("neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: updatedLinkAddr, - State: Delay, + Addr: entry.Addr, + LinkAddr: updatedLinkAddr, + State: Delay, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } clock.Advance(config.DelayFirstProbeTime + typicalLatency) } // Verify that the neighbor is now reachable. { - e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) clock.Advance(typicalLatency) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: updatedLinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: updatedLinkAddr, + State: Reachable, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } } } @@ -1601,35 +1656,34 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) - got, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + got, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } // Verify that address resolution for an unknown address returns ErrNoLinkAddress before := atomic.LoadUint32(&requestCount) entry.Addr += "2" - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress) } maxAttempts := neigh.config().MaxUnicastProbes @@ -1659,13 +1713,13 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress) } } @@ -1683,18 +1737,17 @@ func TestNeighborCacheStaticResolution(t *testing.T) { delay: typicalLatency, } - got, _, err := neigh.entry(testEntryBroadcastAddr, testEntryLocalAddr, linkRes, nil) + got, _, err := neigh.entry(testEntryBroadcastAddr, "", linkRes, nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", testEntryBroadcastAddr, testEntryLocalAddr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", testEntryBroadcastAddr, err) } want := NeighborEntry{ - Addr: testEntryBroadcastAddr, - LocalAddr: testEntryLocalAddr, - LinkAddr: testEntryBroadcastLinkAddr, - State: Static, + Addr: testEntryBroadcastAddr, + LinkAddr: testEntryBroadcastLinkAddr, + State: Static, } if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, testEntryLocalAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff) } } @@ -1719,9 +1772,9 @@ func BenchmarkCacheClear(b *testing.B) { if !ok { b.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != tcpip.ErrWouldBlock { - b.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + b.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } if doneCh != nil { <-doneCh diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index be61a21af..493e48031 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -24,13 +24,18 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) +const ( + // immediateDuration is a duration of zero for scheduling work that needs to + // be done immediately but asynchronously to avoid deadlock. + immediateDuration time.Duration = 0 +) + // NeighborEntry describes a neighboring device in the local network. type NeighborEntry struct { - Addr tcpip.Address - LocalAddr tcpip.Address - LinkAddr tcpip.LinkAddress - State NeighborState - UpdatedAt time.Time + Addr tcpip.Address + LinkAddr tcpip.LinkAddress + State NeighborState + UpdatedAtNanos int64 } // NeighborState defines the state of a NeighborEntry within the Neighbor @@ -106,35 +111,35 @@ type neighborEntry struct { // state, Unknown. Transition out of Unknown by calling either // `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created // neighborEntry. -func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, localAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry { +func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry { return &neighborEntry{ nic: nic, linkRes: linkRes, nudState: nudState, neigh: NeighborEntry{ - Addr: remoteAddr, - LocalAddr: localAddr, - State: Unknown, + Addr: remoteAddr, + State: Unknown, }, } } -// newStaticNeighborEntry creates a neighbor cache entry starting at the Static -// state. The entry can only transition out of Static by directly calling -// `setStateLocked`. +// newStaticNeighborEntry creates a neighbor cache entry starting at the +// Static state. The entry can only transition out of Static by directly +// calling `setStateLocked`. func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry { + entry := NeighborEntry{ + Addr: addr, + LinkAddr: linkAddr, + State: Static, + UpdatedAtNanos: nic.stack.clock.NowNanoseconds(), + } if nic.stack.nudDisp != nil { - nic.stack.nudDisp.OnNeighborAdded(nic.id, addr, linkAddr, Static, time.Now()) + nic.stack.nudDisp.OnNeighborAdded(nic.id, entry) } return &neighborEntry{ nic: nic, nudState: state, - neigh: NeighborEntry{ - Addr: addr, - LinkAddr: linkAddr, - State: Static, - UpdatedAt: time.Now(), - }, + neigh: entry, } } @@ -165,17 +170,17 @@ func (e *neighborEntry) notifyWakersLocked() { // dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has // been added. -func (e *neighborEntry) dispatchAddEventLocked(nextState NeighborState) { +func (e *neighborEntry) dispatchAddEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborAdded(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now()) + nudDisp.OnNeighborAdded(e.nic.id, e.neigh) } } // dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry // has changed state or link-layer address. -func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) { +func (e *neighborEntry) dispatchChangeEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborChanged(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now()) + nudDisp.OnNeighborChanged(e.nic.id, e.neigh) } } @@ -183,7 +188,7 @@ func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) { // has been removed. func (e *neighborEntry) dispatchRemoveEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborRemoved(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, e.neigh.State, time.Now()) + nudDisp.OnNeighborRemoved(e.nic.id, e.neigh) } } @@ -201,68 +206,24 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { prev := e.neigh.State e.neigh.State = next - e.neigh.UpdatedAt = time.Now() + e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() config := e.nudState.Config() switch next { case Incomplete: - var retryCounter uint32 - var sendMulticastProbe func() - - sendMulticastProbe = func() { - if retryCounter == config.MaxMulticastProbes { - // "If no Neighbor Advertisement is received after - // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed. - // The sender MUST return ICMP destination unreachable indications with - // code 3 (Address Unreachable) for each packet queued awaiting address - // resolution." - RFC 4861 section 7.2.2 - // - // There is no need to send an ICMP destination unreachable indication - // since the failure to resolve the address is expected to only occur - // on this node. Thus, redirecting traffic is currently not supported. - // - // "If the error occurs on a node other than the node originating the - // packet, an ICMP error message is generated. If the error occurs on - // the originating node, an implementation is not required to actually - // create and send an ICMP error packet to the source, as long as the - // upper-layer sender is notified through an appropriate mechanism - // (e.g. return value from a procedure call). Note, however, that an - // implementation may find it convenient in some cases to return errors - // to the sender by taking the offending packet, generating an ICMP - // error message, and then delivering it (locally) through the generic - // error-handling routines.' - RFC 4861 section 2.1 - e.dispatchRemoveEventLocked() - e.setStateLocked(Failed) - return - } - - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.LinkEndpoint); err != nil { - // There is no need to log the error here; the NUD implementation may - // assume a working link. A valid link should be the responsibility of - // the NIC/stack.LinkEndpoint. - e.dispatchRemoveEventLocked() - e.setStateLocked(Failed) - return - } - - retryCounter++ - e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) - e.job.Schedule(config.RetransmitTimer) - } - - sendMulticastProbe() + panic(fmt.Sprintf("should never transition to Incomplete with setStateLocked; neigh = %#v, prev state = %s", e.neigh, prev)) case Reachable: e.job = e.nic.stack.newJob(&e.mu, func() { - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) + e.dispatchChangeEventLocked() }) e.job.Schedule(e.nudState.ReachableTime()) case Delay: e.job = e.nic.stack.newJob(&e.mu, func() { - e.dispatchChangeEventLocked(Probe) e.setStateLocked(Probe) + e.dispatchChangeEventLocked() }) e.job.Schedule(config.DelayFirstProbeTime) @@ -277,24 +238,23 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { return } - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.LinkEndpoint); err != nil { + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr, e.nic); err != nil { e.dispatchRemoveEventLocked() e.setStateLocked(Failed) return } retryCounter++ - if retryCounter == config.MaxUnicastProbes { - e.dispatchRemoveEventLocked() - e.setStateLocked(Failed) - return - } - e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe) e.job.Schedule(config.RetransmitTimer) } - sendUnicastProbe() + // Send a probe in another gorountine to free this thread of execution + // for finishing the state transition. This is necessary to avoid + // deadlock where sending and processing probes are done synchronously, + // such as loopback and integration tests. + e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe) + e.job.Schedule(immediateDuration) case Failed: e.notifyWakersLocked() @@ -315,15 +275,77 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // being queued for outgoing transmission. // // Follows the logic defined in RFC 4861 section 7.3.3. -func (e *neighborEntry) handlePacketQueuedLocked() { +func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { switch e.neigh.State { case Unknown: - e.dispatchAddEventLocked(Incomplete) - e.setStateLocked(Incomplete) + e.neigh.State = Incomplete + e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() + + e.dispatchAddEventLocked() + + config := e.nudState.Config() + + var retryCounter uint32 + var sendMulticastProbe func() + + sendMulticastProbe = func() { + if retryCounter == config.MaxMulticastProbes { + // "If no Neighbor Advertisement is received after + // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed. + // The sender MUST return ICMP destination unreachable indications with + // code 3 (Address Unreachable) for each packet queued awaiting address + // resolution." - RFC 4861 section 7.2.2 + // + // There is no need to send an ICMP destination unreachable indication + // since the failure to resolve the address is expected to only occur + // on this node. Thus, redirecting traffic is currently not supported. + // + // "If the error occurs on a node other than the node originating the + // packet, an ICMP error message is generated. If the error occurs on + // the originating node, an implementation is not required to actually + // create and send an ICMP error packet to the source, as long as the + // upper-layer sender is notified through an appropriate mechanism + // (e.g. return value from a procedure call). Note, however, that an + // implementation may find it convenient in some cases to return errors + // to the sender by taking the offending packet, generating an ICMP + // error message, and then delivering it (locally) through the generic + // error-handling routines.' - RFC 4861 section 2.1 + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + // As per RFC 4861 section 7.2.2: + // + // If the source address of the packet prompting the solicitation is the + // same as one of the addresses assigned to the outgoing interface, that + // address SHOULD be placed in the IP Source Address of the outgoing + // solicitation. + // + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, "", e.nic); err != nil { + // There is no need to log the error here; the NUD implementation may + // assume a working link. A valid link should be the responsibility of + // the NIC/stack.LinkEndpoint. + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + retryCounter++ + e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) + e.job.Schedule(config.RetransmitTimer) + } + + // Send a probe in another gorountine to free this thread of execution + // for finishing the state transition. This is necessary to avoid + // deadlock where sending and processing probes are done synchronously, + // such as loopback and integration tests. + e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) + e.job.Schedule(immediateDuration) case Stale: - e.dispatchChangeEventLocked(Delay) e.setStateLocked(Delay) + e.dispatchChangeEventLocked() case Incomplete, Reachable, Delay, Probe, Static, Failed: // Do nothing @@ -345,21 +367,21 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { switch e.neigh.State { case Unknown, Incomplete, Failed: e.neigh.LinkAddr = remoteLinkAddr - e.dispatchAddEventLocked(Stale) e.setStateLocked(Stale) e.notifyWakersLocked() + e.dispatchAddEventLocked() case Reachable, Delay, Probe: if e.neigh.LinkAddr != remoteLinkAddr { e.neigh.LinkAddr = remoteLinkAddr - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) + e.dispatchChangeEventLocked() } case Stale: if e.neigh.LinkAddr != remoteLinkAddr { e.neigh.LinkAddr = remoteLinkAddr - e.dispatchChangeEventLocked(Stale) + e.dispatchChangeEventLocked() } case Static: @@ -393,12 +415,11 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla e.neigh.LinkAddr = linkAddr if flags.Solicited { - e.dispatchChangeEventLocked(Reachable) e.setStateLocked(Reachable) } else { - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) } + e.dispatchChangeEventLocked() e.isRouter = flags.IsRouter e.notifyWakersLocked() @@ -411,8 +432,8 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla if isLinkAddrDifferent { if !flags.Override { if e.neigh.State == Reachable { - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) + e.dispatchChangeEventLocked() } break } @@ -421,23 +442,24 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla if !flags.Solicited { if e.neigh.State != Stale { - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) + e.dispatchChangeEventLocked() } else { // Notify the LinkAddr change, even though NUD state hasn't changed. - e.dispatchChangeEventLocked(e.neigh.State) + e.dispatchChangeEventLocked() } break } } if flags.Solicited && (flags.Override || !isLinkAddrDifferent) { - if e.neigh.State != Reachable { - e.dispatchChangeEventLocked(Reachable) - } + wasReachable := e.neigh.State == Reachable // Set state to Reachable again to refresh timers. e.setStateLocked(Reachable) e.notifyWakersLocked() + if !wasReachable { + e.dispatchChangeEventLocked() + } } if e.isRouter && !flags.IsRouter && header.IsV6UnicastAddress(e.neigh.Addr) { @@ -475,11 +497,12 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla func (e *neighborEntry) handleUpperLevelConfirmationLocked() { switch e.neigh.State { case Reachable, Stale, Delay, Probe: - if e.neigh.State != Reachable { - e.dispatchChangeEventLocked(Reachable) - // Set state to Reachable again to refresh timers. - } + wasReachable := e.neigh.State == Reachable + // Set state to Reachable again to refresh timers. e.setStateLocked(Reachable) + if !wasReachable { + e.dispatchChangeEventLocked() + } case Unknown, Incomplete, Failed, Static: // Do nothing diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index 3ee2a3b31..c2b763325 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -47,24 +47,27 @@ const ( entryTestNetDefaultMTU = 65536 ) +// runImmediatelyScheduledJobs runs all jobs scheduled to run at the current +// time. +func runImmediatelyScheduledJobs(clock *faketime.ManualClock) { + clock.Advance(immediateDuration) +} + // eventDiffOpts are the options passed to cmp.Diff to compare entry events. -// The UpdatedAt field is ignored due to a lack of a deterministic method to -// predict the time that an event will be dispatched. +// The UpdatedAtNanos field is ignored due to a lack of a deterministic method +// to predict the time that an event will be dispatched. func eventDiffOpts() []cmp.Option { return []cmp.Option{ - cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"), + cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAtNanos"), } } // eventDiffOptsWithSort is like eventDiffOpts but also includes an option to // sort slices of events for cases where ordering must be ignored. func eventDiffOptsWithSort() []cmp.Option { - return []cmp.Option{ - cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"), - cmpopts.SortSlices(func(a, b testEntryEventInfo) bool { - return strings.Compare(string(a.Addr), string(b.Addr)) < 0 - }), - } + return append(eventDiffOpts(), cmpopts.SortSlices(func(a, b testEntryEventInfo) bool { + return strings.Compare(string(a.Entry.Addr), string(b.Entry.Addr)) < 0 + })) } // The following unit tests exercise every state transition and verify its @@ -125,14 +128,11 @@ func (t testEntryEventType) String() string { type testEntryEventInfo struct { EventType testEntryEventType NICID tcpip.NICID - Addr tcpip.Address - LinkAddr tcpip.LinkAddress - State NeighborState - UpdatedAt time.Time + Entry NeighborEntry } func (e testEntryEventInfo) String() string { - return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.EventType, e.NICID, e.Addr, e.LinkAddr, e.State) + return fmt.Sprintf("%s event for NIC #%d, %#v", e.EventType, e.NICID, e.Entry) } // testNUDDispatcher implements NUDDispatcher to validate the dispatching of @@ -150,36 +150,27 @@ func (d *testNUDDispatcher) queueEvent(e testEntryEventInfo) { d.events = append(d.events, e) } -func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { +func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry NeighborEntry) { d.queueEvent(testEntryEventInfo{ EventType: entryTestAdded, NICID: nicID, - Addr: addr, - LinkAddr: linkAddr, - State: state, - UpdatedAt: updatedAt, + Entry: entry, }) } -func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { +func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry NeighborEntry) { d.queueEvent(testEntryEventInfo{ EventType: entryTestChanged, NICID: nicID, - Addr: addr, - LinkAddr: linkAddr, - State: state, - UpdatedAt: updatedAt, + Entry: entry, }) } -func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { +func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry NeighborEntry) { d.queueEvent(testEntryEventInfo{ EventType: entryTestRemoved, NICID: nicID, - Addr: addr, - LinkAddr: linkAddr, - State: state, - UpdatedAt: updatedAt, + Entry: entry, }) } @@ -202,9 +193,9 @@ func (p entryTestProbeInfo) String() string { // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts // to the local network if linkAddr is the zero value. -func (r *entryTestLinkResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { +func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { p := entryTestProbeInfo{ - RemoteAddress: addr, + RemoteAddress: targetAddr, RemoteLinkAddress: linkAddr, LocalAddress: localAddr, } @@ -245,7 +236,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e rng := rand.New(rand.NewSource(time.Now().UnixNano())) nudState := NewNUDState(c, rng) linkRes := entryTestLinkResolver{} - entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, entryTestAddr2 /* localAddr */, nudState, &linkRes) + entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, nudState, &linkRes) // Stub out the neighbor cache to verify deletion from the cache. nic.neigh = &neighborCache{ @@ -323,15 +314,16 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { func TestEntryUnknownToIncomplete(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Incomplete; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -350,9 +342,11 @@ func TestEntryUnknownToIncomplete(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, } { @@ -367,7 +361,7 @@ func TestEntryUnknownToIncomplete(t *testing.T) { func TestEntryUnknownToStale(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() e.handleProbeLocked(entryTestLinkAddr1) @@ -377,6 +371,7 @@ func TestEntryUnknownToStale(t *testing.T) { e.mu.Unlock() // No probes should have been sent. + runImmediatelyScheduledJobs(clock) linkRes.mu.Lock() diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) linkRes.mu.Unlock() @@ -388,9 +383,11 @@ func TestEntryUnknownToStale(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -406,11 +403,11 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Incomplete; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } - updatedAt := e.neigh.UpdatedAt + updatedAtNanos := e.neigh.UpdatedAtNanos e.mu.Unlock() clock.Advance(c.RetransmitTimer) @@ -437,7 +434,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.UpdatedAt, updatedAt; got != want { + if got, want := e.neigh.UpdatedAtNanos, updatedAtNanos; got != want { t.Errorf("got e.neigh.UpdatedAt = %q, want = %q", got, want) } e.mu.Unlock() @@ -468,16 +465,20 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestRemoved, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, } nudDisp.mu.Lock() @@ -487,7 +488,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if got, notWant := e.neigh.UpdatedAt, updatedAt; got == notWant { + if got, notWant := e.neigh.UpdatedAtNanos, updatedAtNanos; got == notWant { t.Errorf("expected e.neigh.UpdatedAt to change, got = %q", got) } e.mu.Unlock() @@ -495,23 +496,16 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { func TestEntryIncompleteToReachable(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -526,20 +520,35 @@ func TestEntryIncompleteToReachable(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -555,7 +564,7 @@ func TestEntryIncompleteToReachable(t *testing.T) { // to Reachable. func TestEntryAddsAndClearsWakers(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) w := sleep.Waker{} s := sleep.Sleeper{} @@ -563,7 +572,25 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { defer s.Done() e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() if got := e.wakers; got != nil { t.Errorf("got e.wakers = %v, want = nil", got) } @@ -587,34 +614,24 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { } e.mu.Unlock() - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -626,26 +643,16 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: true, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.isRouter, true; got != want { - t.Errorf("got e.isRouter = %t, want = %t", got, want) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -659,20 +666,38 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { } linkRes.mu.Unlock() + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: true, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + if !e.isRouter { + t.Errorf("got e.isRouter = %t, want = true", e.isRouter) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -684,23 +709,16 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { func TestEntryIncompleteToStale(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -715,20 +733,35 @@ func TestEntryIncompleteToStale(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -744,7 +777,7 @@ func TestEntryIncompleteToFailed(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Incomplete; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -783,16 +816,20 @@ func TestEntryIncompleteToFailed(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestRemoved, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, } nudDisp.mu.Lock() @@ -817,12 +854,30 @@ func (*testLocker) Unlock() {} func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) ipv6EP := e.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, @@ -848,34 +903,24 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { } e.mu.Unlock() - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -893,27 +938,13 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleProbeLocked(entryTestLinkAddr1) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -928,20 +959,42 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.handleProbeLocked(entryTestLinkAddr1) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + if e.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -961,17 +1014,10 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -986,29 +1032,46 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.mu.Unlock() + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1026,24 +1089,13 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleProbeLocked(entryTestLinkAddr2) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1058,27 +1110,48 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.handleProbeLocked(entryTestLinkAddr2) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1086,38 +1159,17 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1132,27 +1184,52 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1160,38 +1237,17 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1206,27 +1262,52 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1234,37 +1315,17 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleProbeLocked(entryTestLinkAddr1) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1279,20 +1340,42 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handleProbeLocked(entryTestLinkAddr1) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + if e.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1304,31 +1387,13 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: true, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1343,27 +1408,55 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + if e.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1375,10 +1468,28 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } e.mu.Lock() - e.handlePacketQueuedLocked() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -1400,41 +1511,33 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing } e.mu.Unlock() - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1446,31 +1549,13 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1485,27 +1570,55 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + if e.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1517,27 +1630,13 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleProbeLocked(entryTestLinkAddr2) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1552,27 +1651,51 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handleProbeLocked(entryTestLinkAddr2) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + if e.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1584,24 +1707,13 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { func TestEntryStaleToDelay(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1616,27 +1728,48 @@ func TestEntryStaleToDelay(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, } nudDisp.mu.Lock() @@ -1656,22 +1789,10 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleUpperLevelConfirmationLocked() - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1686,43 +1807,68 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - clock.Advance(c.BaseReachableTime) + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleUpperLevelConfirmationLocked() + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.mu.Unlock() + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1743,29 +1889,10 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: true, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1780,43 +1907,75 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - clock.Advance(c.BaseReachableTime) + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + if e.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + } + e.mu.Unlock() + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1837,13 +1996,31 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if e.neigh.State != Delay { t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) } @@ -1860,57 +2037,52 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing } e.mu.Unlock() - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } - clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1922,32 +2094,13 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1962,27 +2115,56 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + if e.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, } nudDisp.mu.Lock() @@ -1994,25 +2176,13 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleProbeLocked(entryTestLinkAddr2) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -2027,34 +2197,58 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleProbeLocked(entryTestLinkAddr2) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2066,29 +2260,13 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -2103,34 +2281,62 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2145,69 +2351,91 @@ func TestEntryDelayToProbe(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Delay; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, } nudDisp.mu.Lock() @@ -2228,36 +2456,50 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2274,37 +2516,47 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2312,12 +2564,6 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { @@ -2325,36 +2571,50 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2375,37 +2635,47 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2413,12 +2683,6 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { @@ -2426,36 +2690,51 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2479,30 +2758,38 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, } nudDisp.mu.Lock() @@ -2529,17 +2816,14 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { e.mu.Lock() e.handleProbeLocked(entryTestLinkAddr1) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - wantProbes := []entryTestProbeInfo{ - // Probe caused by the Delay-to-Probe transition { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, }, } linkRes.mu.Lock() @@ -2567,42 +2851,51 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { e.mu.Unlock() clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2622,36 +2915,50 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2672,49 +2979,60 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { e.mu.Unlock() clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2734,36 +3052,50 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2781,49 +3113,60 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin e.mu.Unlock() clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2843,36 +3186,50 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2890,49 +3247,60 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing e.mu.Unlock() clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2946,87 +3314,116 @@ func TestEntryProbeToFailed(t *testing.T) { c := DefaultNUDConfigurations() c.MaxMulticastProbes = 3 c.MaxUnicastProbes = 3 + c.DelayFirstProbeTime = c.RetransmitTimer e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() - waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) - clock.Advance(waitFor) + // Observe each probe sent while in the Probe state. + for i := uint32(0); i < c.MaxUnicastProbes; i++ { + clock.Advance(c.RetransmitTimer) + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probe #%d mismatch (-got, +want):\n%s", i+1, diff) + } - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The next three probe are caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, + e.mu.Lock() + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + } + e.mu.Unlock() } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + + // Wait for the last probe to expire, causing a transition to Failed. + clock.Advance(c.RetransmitTimer) + e.mu.Lock() + if e.neigh.State != Failed { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed) } + e.mu.Unlock() wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestRemoved, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, } nudDisp.mu.Lock() @@ -3034,12 +3431,6 @@ func TestEntryProbeToFailed(t *testing.T) { t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Failed; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryFailedGetsDeleted(t *testing.T) { @@ -3054,84 +3445,106 @@ func TestEntryFailedGetsDeleted(t *testing.T) { } e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime clock.Advance(waitFor) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The next three probe are caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + // The next three probe are sent in Probe. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestRemoved, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, } nudDisp.mu.Lock() diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index dcd4319bf..60c81a3aa 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -273,6 +273,15 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb return n.writePacket(r, gso, protocol, pkt) } +// WritePacketToRemote implements NetworkInterface. +func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { + r := Route{ + NetProto: protocol, + RemoteLinkAddress: remoteLinkAddr, + } + return n.writePacket(&r, gso, protocol, pkt) +} + func (n *NIC) writePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { // WritePacket takes ownership of pkt, calculate numBytes first. numBytes := pkt.Size() @@ -339,6 +348,16 @@ func (n *NIC) getAddress(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address return n.getAddressOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous) } +func (n *NIC) hasAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + ep := n.getAddressOrCreateTempInner(protocol, addr, false, NeverPrimaryEndpoint) + if ep != nil { + ep.DecRef() + return true + } + + return false +} + // findEndpoint finds the endpoint, if any, with the given address. func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { return n.getAddressOrCreateTemp(protocol, address, peb, spoofing) @@ -546,10 +565,10 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { } func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) { - r := makeRoute(protocol, dst, src, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */) + r := makeRoute(protocol, dst, src, n, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */) defer r.Release() - r.RemoteLinkAddress = remotelinkAddr - n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + n.getNetworkEndpoint(protocol).HandlePacket(pkt) } // DeliverNetworkPacket finds the appropriate network protocol endpoint and @@ -585,6 +604,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp if local == "" { local = n.LinkEndpoint.LinkAddress() } + pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0 // Are any packet type sockets listening for this network protocol? packetEPs := n.mu.packetEPs[protocol] @@ -660,14 +680,13 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } // Found a NIC. - n := r.nic + n := r.localAddressNIC if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil { if n.isValidForOutgoing(addressEndpoint) { - r.LocalLinkAddress = n.LinkEndpoint.LinkAddress() - r.RemoteLinkAddress = remote + pkt.NICID = n.ID() r.RemoteAddress = src - // TODO(b/123449044): Update the source NIC as well. - n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt) + pkt.NetworkPacketInfo = r.networkPacketInfo() + n.getNetworkEndpoint(protocol).HandlePacket(pkt) addressEndpoint.DecRef() r.Release() return @@ -678,7 +697,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // n doesn't have a destination endpoint. // Send the packet out of n. - // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6. + // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease + // the TTL field for ipv4/ipv6. // pkt may have set its header and may not have enough headroom for // link-layer header for the other link to prepend. Here we create a new @@ -725,7 +745,7 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { +func (n *NIC) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -737,7 +757,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // Raw socket packets are delivered based solely on the transport // protocol number. We do not inspect the payload to ensure it's // validly formed. - n.stack.demux.deliverRawPacket(r, protocol, pkt) + n.stack.demux.deliverRawPacket(protocol, pkt) // TransportHeader is empty only when pkt is an ICMP packet or was reassembled // from fragments. @@ -766,14 +786,25 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN return TransportPacketHandled } - id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} - if n.stack.demux.deliverPacket(r, protocol, pkt, id) { + netProto, ok := n.stack.networkProtocols[pkt.NetworkProtocolNumber] + if !ok { + panic(fmt.Sprintf("expected network protocol = %d, have = %#v", pkt.NetworkProtocolNumber, n.stack.networkProtocolNumbers())) + } + + src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View()) + id := TransportEndpointID{ + LocalPort: dstPort, + LocalAddress: dst, + RemotePort: srcPort, + RemoteAddress: src, + } + if n.stack.demux.deliverPacket(protocol, pkt, id) { return TransportPacketHandled } // Try to deliver to per-stack default handler. if state.defaultHandler != nil { - if state.defaultHandler(r, id, pkt) { + if state.defaultHandler(id, pkt) { return TransportPacketHandled } } @@ -781,7 +812,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // We could not find an appropriate destination for this packet so // give the protocol specific error handler a chance to handle it. // If it doesn't handle it then we should do so. - switch res := transProto.HandleUnknownDestinationPacket(r, id, pkt); res { + switch res := transProto.HandleUnknownDestinationPacket(id, pkt); res { case UnknownDestinationPacketMalformed: n.stack.stats.MalformedRcvdPackets.Increment() return TransportPacketHandled @@ -885,7 +916,7 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep } // isValidForOutgoing returns true if the endpoint can be used to send out a -// packet. It requires the endpoint to not be marked expired (i.e., its address) +// packet. It requires the endpoint to not be marked expired (i.e., its address // has been removed) unless the NIC is in spoofing mode, or temporary. func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool { n.mu.RLock() diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 97a96af62..5b5c58afb 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -83,8 +83,7 @@ func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip } // HandlePacket implements NetworkEndpoint.HandlePacket. -func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) { -} +func (*testIPv6Endpoint) HandlePacket(*PacketBuffer) {} // Close implements NetworkEndpoint.Close. func (e *testIPv6Endpoint) Close() { @@ -169,7 +168,7 @@ func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements LinkAddressResolver. -func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error { +func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { return nil } diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index e1ec15487..ab629b3a4 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -129,7 +129,7 @@ type NUDDispatcher interface { // the stack's operation. // // May be called concurrently. - OnNeighborAdded(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + OnNeighborAdded(tcpip.NICID, NeighborEntry) // OnNeighborChanged will be called when an entry in a NIC's (with ID nicID) // neighbor table changes state and/or link address. @@ -138,7 +138,7 @@ type NUDDispatcher interface { // the stack's operation. // // May be called concurrently. - OnNeighborChanged(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + OnNeighborChanged(tcpip.NICID, NeighborEntry) // OnNeighborRemoved will be called when an entry is removed from a NIC's // (with ID nicID) neighbor table. @@ -147,7 +147,7 @@ type NUDDispatcher interface { // the stack's operation. // // May be called concurrently. - OnNeighborRemoved(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + OnNeighborRemoved(tcpip.NICID, NeighborEntry) } // ReachabilityConfirmationFlags describes the flags used within a reachability @@ -177,7 +177,7 @@ type NUDHandler interface { // Neighbor Solicitation for ARP or NDP, respectively). Validation of the // probe needs to be performed before calling this function since the // Neighbor Cache doesn't have access to view the NIC's assigned addresses. - HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) + HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) // HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP // reply or Neighbor Advertisement for ARP or NDP, respectively). diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 7f54a6de8..664cc6fa0 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -112,6 +112,16 @@ type PacketBuffer struct { // PktType indicates the SockAddrLink.PacketType of the packet as defined in // https://www.man7.org/linux/man-pages/man7/packet.7.html. PktType tcpip.PacketType + + // NICID is the ID of the interface the network packet was received at. + NICID tcpip.NICID + + // RXTransportChecksumValidated indicates that transport checksum verification + // may be safely skipped. + RXTransportChecksumValidated bool + + // NetworkPacketInfo holds an incoming packet's network-layer information. + NetworkPacketInfo NetworkPacketInfo } // NewPacketBuffer creates a new PacketBuffer with opts. @@ -240,20 +250,33 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum // Clone should be called in such cases so that no modifications is done to // underlying packet payload. func (pk *PacketBuffer) Clone() *PacketBuffer { - newPk := &PacketBuffer{ - PacketBufferEntry: pk.PacketBufferEntry, - Data: pk.Data.Clone(nil), - headers: pk.headers, - header: pk.header, - Hash: pk.Hash, - Owner: pk.Owner, - EgressRoute: pk.EgressRoute, - GSOOptions: pk.GSOOptions, - NetworkProtocolNumber: pk.NetworkProtocolNumber, - NatDone: pk.NatDone, - TransportProtocolNumber: pk.TransportProtocolNumber, + return &PacketBuffer{ + PacketBufferEntry: pk.PacketBufferEntry, + Data: pk.Data.Clone(nil), + headers: pk.headers, + header: pk.header, + Hash: pk.Hash, + Owner: pk.Owner, + GSOOptions: pk.GSOOptions, + NetworkProtocolNumber: pk.NetworkProtocolNumber, + NatDone: pk.NatDone, + TransportProtocolNumber: pk.TransportProtocolNumber, + PktType: pk.PktType, + NICID: pk.NICID, + RXTransportChecksumValidated: pk.RXTransportChecksumValidated, + NetworkPacketInfo: pk.NetworkPacketInfo, } - return newPk +} + +// SourceLinkAddress returns the source link address of the packet. +func (pk *PacketBuffer) SourceLinkAddress() tcpip.LinkAddress { + link := pk.LinkHeader().View() + + if link.IsEmpty() { + return "" + } + + return header.Ethernet(link).SourceAddress() } // Network returns the network header as a header.Network. @@ -270,6 +293,17 @@ func (pk *PacketBuffer) Network() header.Network { } } +// CloneToInbound makes a shallow copy of the packet buffer to be used as an +// inbound packet. +// +// See PacketBuffer.Data for details about how a packet buffer holds an inbound +// packet. +func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { + return NewPacketBuffer(PacketBufferOptions{ + Data: buffer.NewVectorisedView(pk.Size(), pk.Views()), + }) +} + // headerInfo stores metadata about a header in a packet. type headerInfo struct { // buf is the memorized slice for both prepended and consumed header. diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index f838eda8d..5d364a2b0 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -106,7 +106,7 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro } else if _, err := p.route.Resolve(nil); err != nil { p.route.Stats().IP.OutgoingPacketErrors.Increment() } else { - p.route.nic.writePacket(p.route, nil /* gso */, p.proto, p.pkt) + p.route.outgoingNIC.writePacket(p.route, nil /* gso */, p.proto, p.pkt) } p.route.Release() } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index defb9129b..b8f333057 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -63,17 +63,28 @@ const ( ControlUnknown ) +// NetworkPacketInfo holds information about a network layer packet. +type NetworkPacketInfo struct { + // RemoteAddressBroadcast is true if the packet's remote address is a + // broadcast address. + RemoteAddressBroadcast bool + + // LocalAddressBroadcast is true if the packet's local address is a broadcast + // address. + LocalAddressBroadcast bool +} + // TransportEndpoint is the interface that needs to be implemented by transport // protocol (e.g., tcp, udp) endpoints that can handle packets. type TransportEndpoint interface { // UniqueID returns an unique ID for this transport endpoint. UniqueID() uint64 - // HandlePacket is called by the stack when new packets arrive to - // this transport endpoint. It sets pkt.TransportHeader. + // HandlePacket is called by the stack when new packets arrive to this + // transport endpoint. It sets the packet buffer's transport header. // - // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) + // HandlePacket takes ownership of the packet. + HandlePacket(TransportEndpointID, *PacketBuffer) // HandleControlPacket is called by the stack when new control (e.g. // ICMP) packets arrive to this transport endpoint. @@ -105,8 +116,8 @@ type RawTransportEndpoint interface { // this transport endpoint. The packet contains all data from the link // layer up. // - // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt *PacketBuffer) + // HandlePacket takes ownership of the packet. + HandlePacket(*PacketBuffer) } // PacketEndpoint is the interface that needs to be implemented by packet @@ -172,9 +183,9 @@ type TransportProtocol interface { // protocol that don't match any existing endpoint. For example, // it is targeted at a port that has no listeners. // - // HandleUnknownDestinationPacket takes ownership of pkt if it handles + // HandleUnknownDestinationPacket takes ownership of the packet if it handles // the issue. - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) UnknownDestinationPacketDisposition + HandleUnknownDestinationPacket(TransportEndpointID, *PacketBuffer) UnknownDestinationPacketDisposition // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -227,8 +238,8 @@ type TransportDispatcher interface { // // pkt.NetworkHeader must be set before calling DeliverTransportPacket. // - // DeliverTransportPacket takes ownership of pkt. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition + // DeliverTransportPacket takes ownership of the packet. + DeliverTransportPacket(tcpip.TransportProtocolNumber, *PacketBuffer) TransportPacketDisposition // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. @@ -329,6 +340,9 @@ type AssignableAddressEndpoint interface { // AddressWithPrefix returns the endpoint's address. AddressWithPrefix() tcpip.AddressWithPrefix + // Subnet returns the subnet of the endpoint's address. + Subnet() tcpip.Subnet + // IsAssigned returns whether or not the endpoint is considered bound // to its NetworkEndpoint. IsAssigned(allowExpired bool) bool @@ -490,6 +504,9 @@ type NetworkInterface interface { // Enabled returns true if the interface is enabled. Enabled() bool + + // WritePacketToRemote writes the packet to the given remote link address. + WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error } // NetworkEndpoint is the interface that needs to be implemented by endpoints @@ -544,7 +561,7 @@ type NetworkEndpoint interface { // this network endpoint. It sets pkt.NetworkHeader. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt *PacketBuffer) + HandlePacket(pkt *PacketBuffer) // Close is called when the endpoint is reomved from a stack. Close() @@ -764,13 +781,13 @@ type InjectableLinkEndpoint interface { // A LinkAddressResolver is an extension to a NetworkProtocol that // can resolve link addresses. type LinkAddressResolver interface { - // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts - // the request on the local network if remoteLinkAddr is the zero value. The - // request is sent on linkEP with localAddr as the source. + // LinkAddressRequest sends a request for the link address of the target + // address. The request is broadcasted on the local network if a remote link + // address is not provided. // - // A valid response will cause the discovery protocol's network - // endpoint to call AddLinkAddress. - LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error + // The request is sent from the passed network interface. If the interface + // local address is unspecified, any interface local address may be used. + LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic NetworkInterface) *tcpip.Error // ResolveStaticAddress attempts to resolve address without sending // requests. It either resolves the name immediately or returns the diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index b76e2d37b..15ff437c7 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -15,6 +15,8 @@ package stack import ( + "fmt" + "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -45,11 +47,16 @@ type Route struct { // Loop controls where WritePacket should send packets. Loop PacketLooping - // nic is the NIC the route goes through. - nic *NIC + // localAddressNIC is the interface the address is associated with. + // TODO(gvisor.dev/issue/4548): Remove this field once we can query the + // address's assigned status without the NIC. + localAddressNIC *NIC + + // localAddressEndpoint is the local address this route is associated with. + localAddressEndpoint AssignableAddressEndpoint - // addressEndpoint is the local address this route is associated with. - addressEndpoint AssignableAddressEndpoint + // outgoingNIC is the interface this route uses to write packets. + outgoingNIC *NIC // linkCache is set if link address resolution is enabled for this protocol on // the route's NIC. @@ -60,51 +67,144 @@ type Route struct { linkRes LinkAddressResolver } +// constructAndValidateRoute validates and initializes a route. It takes +// ownership of the provided local address. +// +// Returns an empty route if validation fails. +func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) Route { + addrWithPrefix := addressEndpoint.AddressWithPrefix() + + if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(addrWithPrefix.Address) { + addressEndpoint.DecRef() + return Route{} + } + + // If no remote address is provided, use the local address. + if len(remoteAddr) == 0 { + remoteAddr = addrWithPrefix.Address + } + + r := makeRoute( + netProto, + addrWithPrefix.Address, + remoteAddr, + outgoingNIC, + localAddressNIC, + addressEndpoint, + handleLocal, + multicastLoop, + ) + + // If the route requires us to send a packet through some gateway, do not + // broadcast it. + if len(gateway) > 0 { + r.NextHop = gateway + } else if subnet := addrWithPrefix.Subnet(); subnet.IsBroadcast(remoteAddr) { + r.RemoteLinkAddress = header.EthernetBroadcastAddress + } + + return r +} + // makeRoute initializes a new route. It takes ownership of the provided // AssignableAddressEndpoint. -func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, nic *NIC, addressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route { +func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route { + if localAddressNIC.stack != outgoingNIC.stack { + panic(fmt.Sprintf("cannot create a route with NICs from different stacks")) + } + loop := PacketOut - if handleLocal && localAddr != "" && remoteAddr == localAddr { - loop = PacketLoop - } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) { - loop |= PacketLoop - } else if remoteAddr == header.IPv4Broadcast { - loop |= PacketLoop + + // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the + // link endpoint level. We can remove this check once loopback interfaces + // loop back packets at the network layer. + if !outgoingNIC.IsLoopback() { + if handleLocal && localAddr != "" && remoteAddr == localAddr { + loop = PacketLoop + } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) { + loop |= PacketLoop + } else if remoteAddr == header.IPv4Broadcast { + loop |= PacketLoop + } else if subnet := localAddressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) { + loop |= PacketLoop + } } + return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop) +} + +func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) Route { r := Route{ - NetProto: netProto, - LocalAddress: localAddr, - LocalLinkAddress: nic.LinkEndpoint.LinkAddress(), - RemoteAddress: remoteAddr, - addressEndpoint: addressEndpoint, - nic: nic, - Loop: loop, + NetProto: netProto, + LocalAddress: localAddr, + LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(), + RemoteAddress: remoteAddr, + localAddressNIC: localAddressNIC, + localAddressEndpoint: localAddressEndpoint, + outgoingNIC: outgoingNIC, + Loop: loop, } - if r.nic.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { - if linkRes, ok := r.nic.stack.linkAddrResolvers[r.NetProto]; ok { + if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { + if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok { r.linkRes = linkRes - r.linkCache = r.nic.stack + r.linkCache = r.outgoingNIC.stack } } return r } +// makeLocalRoute initializes a new local route. It takes ownership of the +// provided AssignableAddressEndpoint. +// +// A local route is a route to a destination that is local to the stack. +func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) Route { + loop := PacketLoop + // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the + // link endpoint level. We can remove this check once loopback interfaces + // loop back packets at the network layer. + if outgoingNIC.IsLoopback() { + loop = PacketOut + } + return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop) +} + +// PopulatePacketInfo populates a packet buffer's packet information fields. +// +// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by +// the network layer. +func (r *Route) PopulatePacketInfo(pkt *PacketBuffer) { + if r.local() { + pkt.RXTransportChecksumValidated = true + } + pkt.NetworkPacketInfo = r.networkPacketInfo() +} + +// networkPacketInfo returns the network packet information of the route. +// +// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by +// the network layer. +func (r *Route) networkPacketInfo() NetworkPacketInfo { + return NetworkPacketInfo{ + RemoteAddressBroadcast: r.IsOutboundBroadcast(), + LocalAddressBroadcast: r.isInboundBroadcast(), + } +} + // NICID returns the id of the NIC from which this route originates. func (r *Route) NICID() tcpip.NICID { - return r.nic.ID() + return r.outgoingNIC.ID() } // MaxHeaderLength forwards the call to the network endpoint's implementation. func (r *Route) MaxHeaderLength() uint16 { - return r.nic.getNetworkEndpoint(r.NetProto).MaxHeaderLength() + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MaxHeaderLength() } // Stats returns a mutable copy of current stats. func (r *Route) Stats() tcpip.Stats { - return r.nic.stack.Stats() + return r.outgoingNIC.stack.Stats() } // PseudoHeaderChecksum forwards the call to the network endpoint's @@ -113,14 +213,38 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot return header.PseudoHeaderChecksum(protocol, r.LocalAddress, r.RemoteAddress, totalLen) } -// Capabilities returns the link-layer capabilities of the route. -func (r *Route) Capabilities() LinkEndpointCapabilities { - return r.nic.LinkEndpoint.Capabilities() +// RequiresTXTransportChecksum returns false if the route does not require +// transport checksums to be populated. +func (r *Route) RequiresTXTransportChecksum() bool { + if r.local() { + return false + } + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityTXChecksumOffload == 0 +} + +// HasSoftwareGSOCapability returns true if the route supports software GSO. +func (r *Route) HasSoftwareGSOCapability() bool { + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySoftwareGSO != 0 +} + +// HasHardwareGSOCapability returns true if the route supports hardware GSO. +func (r *Route) HasHardwareGSOCapability() bool { + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityHardwareGSO != 0 +} + +// HasSaveRestoreCapability returns true if the route supports save/restore. +func (r *Route) HasSaveRestoreCapability() bool { + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySaveRestore != 0 +} + +// HasDisconncetOkCapability returns true if the route supports disconnecting. +func (r *Route) HasDisconncetOkCapability() bool { + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityDisconnectOk != 0 } // GSOMaxSize returns the maximum GSO packet size. func (r *Route) GSOMaxSize() uint32 { - if gso, ok := r.nic.LinkEndpoint.(GSOEndpoint); ok { + if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok { return gso.GSOMaxSize() } return 0 @@ -158,8 +282,15 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { nextAddr = r.RemoteAddress } - if neigh := r.nic.neigh; neigh != nil { - entry, ch, err := neigh.entry(nextAddr, r.LocalAddress, r.linkRes, waker) + // If specified, the local address used for link address resolution must be an + // address on the outgoing interface. + var linkAddressResolutionRequestLocalAddr tcpip.Address + if r.localAddressNIC == r.outgoingNIC { + linkAddressResolutionRequestLocalAddr = r.LocalAddress + } + + if neigh := r.outgoingNIC.neigh; neigh != nil { + entry, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, waker) if err != nil { return ch, err } @@ -167,7 +298,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { return nil, nil } - linkAddr, ch, err := r.linkCache.GetLinkAddress(r.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker) + linkAddr, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, waker) if err != nil { return ch, err } @@ -182,76 +313,102 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { nextAddr = r.RemoteAddress } - if neigh := r.nic.neigh; neigh != nil { + if neigh := r.outgoingNIC.neigh; neigh != nil { neigh.removeWaker(nextAddr, waker) return } - r.linkCache.RemoveWaker(r.nic.ID(), nextAddr, waker) + r.linkCache.RemoveWaker(r.outgoingNIC.ID(), nextAddr, waker) +} + +// local returns true if the route is a local route. +func (r *Route) local() bool { + return r.Loop == PacketLoop || r.outgoingNIC.IsLoopback() } // IsResolutionRequired returns true if Resolve() must be called to resolve -// the link address before the this route can be written to. +// the link address before the route can be written to. // -// The NIC r uses must not be locked. +// The NICs the route is associated with must not be locked. func (r *Route) IsResolutionRequired() bool { - if r.nic.neigh != nil { - return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkRes != nil && r.RemoteLinkAddress == "" + if !r.isValidForOutgoing() || r.RemoteLinkAddress != "" || r.local() { + return false } - return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkCache != nil && r.RemoteLinkAddress == "" + + return (r.outgoingNIC.neigh != nil && r.linkRes != nil) || r.linkCache != nil +} + +func (r *Route) isValidForOutgoing() bool { + if !r.outgoingNIC.Enabled() { + return false + } + + if !r.localAddressNIC.isValidForOutgoing(r.localAddressEndpoint) { + return false + } + + // If the source NIC and outgoing NIC are different, make sure the stack has + // forwarding enabled, or the packet will be handled locally. + if r.outgoingNIC != r.localAddressNIC && !r.outgoingNIC.stack.Forwarding(r.NetProto) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto, r.RemoteAddress)) { + return false + } + + return true } // WritePacket writes the packet through the given route. func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { - if !r.nic.isValidForOutgoing(r.addressEndpoint) { + if !r.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } - return r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt) + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt) } // WritePackets writes a list of n packets through the given route and returns // the number of packets written. func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { - if !r.nic.isValidForOutgoing(r.addressEndpoint) { + if !r.isValidForOutgoing() { return 0, tcpip.ErrInvalidEndpointState } - return r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) } // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { - if !r.nic.isValidForOutgoing(r.addressEndpoint) { + if !r.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } - return r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt) + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt) } // DefaultTTL returns the default TTL of the underlying network endpoint. func (r *Route) DefaultTTL() uint8 { - return r.nic.getNetworkEndpoint(r.NetProto).DefaultTTL() + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).DefaultTTL() } // MTU returns the MTU of the underlying network endpoint. func (r *Route) MTU() uint32 { - return r.nic.getNetworkEndpoint(r.NetProto).MTU() + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MTU() } // Release frees all resources associated with the route. func (r *Route) Release() { - if r.addressEndpoint != nil { - r.addressEndpoint.DecRef() - r.addressEndpoint = nil + if r.localAddressEndpoint != nil { + r.localAddressEndpoint.DecRef() + r.localAddressEndpoint = nil } } // Clone clones the route. func (r *Route) Clone() Route { - if r.addressEndpoint != nil { - _ = r.addressEndpoint.IncRef() + if r.localAddressEndpoint != nil { + if !r.localAddressEndpoint.IncRef() { + panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress)) + } } return *r } @@ -275,7 +432,7 @@ func (r *Route) MakeLoopedRoute() Route { // Stack returns the instance of the Stack that owns this route. func (r *Route) Stack() *Stack { - return r.nic.stack + return r.outgoingNIC.stack } func (r *Route) isV4Broadcast(addr tcpip.Address) bool { @@ -283,7 +440,7 @@ func (r *Route) isV4Broadcast(addr tcpip.Address) bool { return true } - subnet := r.addressEndpoint.AddressWithPrefix().Subnet() + subnet := r.localAddressEndpoint.Subnet() return subnet.IsBroadcast(addr) } @@ -294,9 +451,9 @@ func (r *Route) IsOutboundBroadcast() bool { return r.isV4Broadcast(r.RemoteAddress) } -// IsInboundBroadcast returns true if the route is for an inbound broadcast +// isInboundBroadcast returns true if the route is for an inbound broadcast // packet. -func (r *Route) IsInboundBroadcast() bool { +func (r *Route) isInboundBroadcast() bool { // Only IPv4 has a notion of broadcast. return r.isV4Broadcast(r.LocalAddress) } @@ -304,15 +461,16 @@ func (r *Route) IsInboundBroadcast() bool { // ReverseRoute returns new route with given source and destination address. func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route { return Route{ - NetProto: r.NetProto, - LocalAddress: dst, - LocalLinkAddress: r.RemoteLinkAddress, - RemoteAddress: src, - RemoteLinkAddress: r.LocalLinkAddress, - Loop: r.Loop, - addressEndpoint: r.addressEndpoint, - nic: r.nic, - linkCache: r.linkCache, - linkRes: r.linkRes, + NetProto: r.NetProto, + LocalAddress: dst, + LocalLinkAddress: r.RemoteLinkAddress, + RemoteAddress: src, + RemoteLinkAddress: r.LocalLinkAddress, + Loop: r.Loop, + localAddressNIC: r.localAddressNIC, + localAddressEndpoint: r.localAddressEndpoint, + outgoingNIC: r.outgoingNIC, + linkCache: r.linkCache, + linkRes: r.linkRes, } } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 3a07577c8..a23fb97ff 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -22,6 +22,7 @@ package stack import ( "bytes" "encoding/binary" + "fmt" mathrand "math/rand" "sync/atomic" "time" @@ -52,7 +53,7 @@ const ( type transportProtocolState struct { proto TransportProtocol - defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool + defaultHandler func(id TransportEndpointID, pkt *PacketBuffer) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -518,6 +519,10 @@ type Options struct { // // RandSource must be thread-safe. RandSource mathrand.Source + + // IPTables are the initial iptables rules. If nil, iptables will allow + // all traffic. + IPTables *IPTables } // TransportEndpointInfo holds useful information about a transport endpoint @@ -620,6 +625,10 @@ func New(opts Options) *Stack { randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())} } + if opts.IPTables == nil { + opts.IPTables = DefaultTables() + } + opts.NUDConfigs.resetInvalidFields() s := &Stack{ @@ -633,7 +642,7 @@ func New(opts Options) *Stack { clock: clock, stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, - tables: DefaultTables(), + tables: opts.IPTables, icmpRateLimiter: NewICMPRateLimiter(), seed: generateRandUint32(), nudConfigs: opts.NUDConfigs, @@ -751,7 +760,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *PacketBuffer) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(TransportEndpointID, *PacketBuffer) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h @@ -830,6 +839,20 @@ func (s *Stack) AddRoute(route tcpip.Route) { s.routeTable = append(s.routeTable, route) } +// RemoveRoutes removes matching routes from the route table. +func (s *Stack) RemoveRoutes(match func(tcpip.Route) bool) { + s.mu.Lock() + defer s.mu.Unlock() + + var filteredRoutes []tcpip.Route + for _, route := range s.routeTable { + if !match(route) { + filteredRoutes = append(filteredRoutes, route) + } + } + s.routeTable = filteredRoutes +} + // NewEndpoint creates a new transport layer endpoint of the given protocol. func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { t, ok := s.transportProtocols[transport] @@ -1180,54 +1203,225 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP return nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint) } +// findLocalRouteFromNICRLocked is like findLocalRouteRLocked but finds a route +// from the specified NIC. +// +// Precondition: s.mu must be read locked. +func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) { + localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint) + if localAddressEndpoint == nil { + return Route{}, false + } + + var outgoingNIC *NIC + // Prefer a local route to the same interface as the local address. + if localAddressNIC.hasAddress(netProto, remoteAddr) { + outgoingNIC = localAddressNIC + } + + // If the remote address isn't owned by the local address's NIC, check all + // NICs. + if outgoingNIC == nil { + for _, nic := range s.nics { + if nic.hasAddress(netProto, remoteAddr) { + outgoingNIC = nic + break + } + } + } + + // If the remote address is not owned by the stack, we can't return a local + // route. + if outgoingNIC == nil { + localAddressEndpoint.DecRef() + return Route{}, false + } + + r := makeLocalRoute( + netProto, + localAddressEndpoint.AddressWithPrefix().Address, + remoteAddr, + outgoingNIC, + localAddressNIC, + localAddressEndpoint, + ) + + if r.IsOutboundBroadcast() { + r.Release() + return Route{}, false + } + + return r, true +} + +// findLocalRouteRLocked returns a local route. +// +// A local route is a route to some remote address which the stack owns. That +// is, a local route is a route where packets never have to leave the stack. +// +// Precondition: s.mu must be read locked. +func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) { + if len(localAddr) == 0 { + localAddr = remoteAddr + } + + if localAddressNICID == 0 { + for _, localAddressNIC := range s.nics { + if r, ok := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); ok { + return r, true + } + } + + return Route{}, false + } + + if localAddressNIC, ok := s.nics[localAddressNICID]; ok { + return s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto) + } + + return Route{}, false +} + // FindRoute creates a route to the given destination address, leaving through -// the given nic and local address (if provided). +// the given NIC and local address (if provided). +// +// If a NIC is not specified, the returned route will leave through the same +// NIC as the NIC that has the local address assigned when forwarding is +// disabled. If forwarding is enabled and the NIC is unspecified, the route may +// leave through any interface unless the route is link-local. +// +// If no local address is provided, the stack will select a local address. If no +// remote address is provided, the stack wil use a remote address equal to the +// local address. func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() + isLinkLocal := header.IsV6LinkLocalAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr) isLocalBroadcast := remoteAddr == header.IPv4Broadcast isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr) - needRoute := !(isLocalBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr)) + isLoopback := header.IsV4LoopbackAddress(remoteAddr) || header.IsV6LoopbackAddress(remoteAddr) + needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || isLoopback) + + if s.handleLocal && !isMulticast && !isLocalBroadcast { + if r, ok := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); ok { + return r, nil + } + } + + // If the interface is specified and we do not need a route, return a route + // through the interface if the interface is valid and enabled. if id != 0 && !needRoute { if nic, ok := s.nics[id]; ok && nic.Enabled() { if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { - return makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()), nil + return makeRoute( + netProto, + addressEndpoint.AddressWithPrefix().Address, + remoteAddr, + nic, /* outboundNIC */ + nic, /* localAddressNIC*/ + addressEndpoint, + s.handleLocal, + multicastLoop, + ), nil } } - } else { - for _, route := range s.routeTable { - if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) { - continue + + if isLoopback { + return Route{}, tcpip.ErrBadLocalAddress + } + return Route{}, tcpip.ErrNetworkUnreachable + } + + canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal + + // Find a route to the remote with the route table. + var chosenRoute tcpip.Route + for _, route := range s.routeTable { + if len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr) { + continue + } + + nic, ok := s.nics[route.NIC] + if !ok || !nic.Enabled() { + continue + } + + if id == 0 || id == route.NIC { + if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { + var gateway tcpip.Address + if needRoute { + gateway = route.Gateway + } + r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop) + if r == (Route{}) { + panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr)) + } + return r, nil } - if nic, ok := s.nics[route.NIC]; ok && nic.Enabled() { - if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { - if len(remoteAddr) == 0 { - // If no remote address was provided, then the route - // provided will refer to the link local address. - remoteAddr = addressEndpoint.AddressWithPrefix().Address - } + } - r := makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()) - if len(route.Gateway) > 0 { - if needRoute { - r.NextHop = route.Gateway - } - } else if subnet := addressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) { - r.RemoteLinkAddress = header.EthernetBroadcastAddress + // If the stack has forwarding enabled and we haven't found a valid route to + // the remote address yet, keep track of the first valid route. We keep + // iterating because we prefer routes that let us use a local address that + // is assigned to the outgoing interface. There is no requirement to do this + // from any RFC but simply a choice made to better follow a strong host + // model which the netstack follows at the time of writing. + if canForward && chosenRoute == (tcpip.Route{}) { + chosenRoute = route + } + } + + if chosenRoute != (tcpip.Route{}) { + // At this point we know the stack has forwarding enabled since chosenRoute is + // only set when forwarding is enabled. + nic, ok := s.nics[chosenRoute.NIC] + if !ok { + // If the route's NIC was invalid, we should not have chosen the route. + panic(fmt.Sprintf("chosen route must have a valid NIC with ID = %d", chosenRoute.NIC)) + } + + var gateway tcpip.Address + if needRoute { + gateway = chosenRoute.Gateway + } + + // Use the specified NIC to get the local address endpoint. + if id != 0 { + if aNIC, ok := s.nics[id]; ok { + if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil { + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { + return r, nil } + } + } + + return Route{}, tcpip.ErrNoRoute + } + if id == 0 { + // If an interface is not specified, try to find a NIC that holds the local + // address endpoint to construct a route. + for _, aNIC := range s.nics { + addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto) + if addressEndpoint == nil { + continue + } + + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { return r, nil } } } } - if !needRoute { - return Route{}, tcpip.ErrNetworkUnreachable + if needRoute { + return Route{}, tcpip.ErrNoRoute } - - return Route{}, tcpip.ErrNoRoute + if isLoopback { + return Route{}, tcpip.ErrBadLocalAddress + } + return Route{}, tcpip.ErrNetworkUnreachable } // CheckNetworkProtocol checks if a given network protocol is enabled in the @@ -1323,7 +1517,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} linkRes := s.linkAddrResolvers[protocol] - return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.LinkEndpoint, waker) + return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, waker) } // Neighbors returns all IP to MAC address associations. @@ -1443,8 +1637,8 @@ func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) { // FindTransportEndpoint finds an endpoint that most closely matches the provided // id. If no endpoint is found it returns nil. -func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { - return s.demux.findTransportEndpoint(netProto, transProto, id, r) +func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint { + return s.demux.findTransportEndpoint(netProto, transProto, id, nicID) } // RegisterRawTransportEndpoint registers the given endpoint with the stack @@ -1896,3 +2090,71 @@ func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { func (s *Stack) NewJob(l sync.Locker, f func()) *tcpip.Job { return tcpip.NewJob(s.clock, l, f) } + +// ParseResult indicates the result of a parsing attempt. +type ParseResult int + +const ( + // ParsedOK indicates that a packet was successfully parsed. + ParsedOK ParseResult = iota + + // UnknownNetworkProtocol indicates that the network protocol is unknown. + UnknownNetworkProtocol + + // NetworkLayerParseError indicates that the network packet was not + // successfully parsed. + NetworkLayerParseError + + // UnknownTransportProtocol indicates that the transport protocol is unknown. + UnknownTransportProtocol + + // TransportLayerParseError indicates that the transport packet was not + // successfully parsed. + TransportLayerParseError +) + +// ParsePacketBuffer parses the provided packet buffer. +func (s *Stack) ParsePacketBuffer(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) ParseResult { + netProto, ok := s.networkProtocols[protocol] + if !ok { + return UnknownNetworkProtocol + } + + transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) + if !ok { + return NetworkLayerParseError + } + if !hasTransportHdr { + return ParsedOK + } + + // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader + // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a + // full explanation. + if transProtoNum == header.ICMPv4ProtocolNumber || transProtoNum == header.ICMPv6ProtocolNumber { + return ParsedOK + } + + pkt.TransportProtocolNumber = transProtoNum + // Parse the transport header if present. + state, ok := s.transportProtocols[transProtoNum] + if !ok { + return UnknownTransportProtocol + } + + if !state.proto.Parse(pkt) { + return TransportLayerParseError + } + + return ParsedOK +} + +// networkProtocolNumbers returns the network protocol numbers the stack is +// configured with. +func (s *Stack) networkProtocolNumbers() []tcpip.NetworkProtocolNumber { + protos := make([]tcpip.NetworkProtocolNumber, 0, len(s.networkProtocols)) + for p := range s.networkProtocols { + protos = append(protos, p) + } + return protos +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index e75f58c64..dedfdd435 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -21,6 +21,7 @@ import ( "bytes" "fmt" "math" + "net" "sort" "testing" "time" @@ -108,12 +109,13 @@ func (*fakeNetworkEndpoint) DefaultTTL() uint8 { return 123 } -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { +func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Increment the received packet count in the protocol descriptor. - f.proto.packetCount[int(r.LocalAddress[0])%len(f.proto.packetCount)]++ + netHdr := pkt.NetworkHeader().View() + f.proto.packetCount[int(netHdr[dstAddrOffset])%len(f.proto.packetCount)]++ // Handle control packets. - if pkt.NetworkHeader().View()[protocolNumberOffset] == uint8(fakeControlProtocol) { + if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) if !ok { return @@ -129,7 +131,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -151,12 +153,15 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params // Add the protocol's header to the packet and send it to the link // endpoint. hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen) + pkt.NetworkProtocolNumber = fakeNetNumber hdr[dstAddrOffset] = r.RemoteAddress[0] hdr[srcAddrOffset] = r.LocalAddress[0] hdr[protocolNumberOffset] = byte(params.Protocol) if r.Loop&stack.PacketLoop != 0 { - f.HandlePacket(r, pkt) + pkt := pkt.Clone() + r.PopulatePacketInfo(pkt) + f.HandlePacket(pkt) } if r.Loop&stack.PacketOut == 0 { return nil @@ -254,6 +259,7 @@ func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProto if !ok { return 0, false, false } + pkt.NetworkProtocolNumber = fakeNetNumber return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true } @@ -1334,6 +1340,106 @@ func TestPromiscuousMode(t *testing.T) { testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } +// TestExternalSendWithHandleLocal tests that the stack creates a non-local +// route when spoofing or promiscuous mode are enabled. +// +// This test makes sure that packets are transmitted from the stack. +func TestExternalSendWithHandleLocal(t *testing.T) { + const ( + unspecifiedNICID = 0 + nicID = 1 + + localAddr = tcpip.Address("\x01") + dstAddr = tcpip.Address("\x03") + ) + + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + configureStack func(*testing.T, *stack.Stack) + }{ + { + name: "Default", + configureStack: func(*testing.T, *stack.Stack) {}, + }, + { + name: "Spoofing", + configureStack: func(t *testing.T, s *stack.Stack) { + if err := s.SetSpoofing(nicID, true); err != nil { + t.Fatalf("s.SetSpoofing(%d, true): %s", nicID, err) + } + }, + }, + { + name: "Promiscuous", + configureStack: func(t *testing.T, s *stack.Stack) { + if err := s.SetPromiscuousMode(nicID, true); err != nil { + t.Fatalf("s.SetPromiscuousMode(%d, true): %s", nicID, err) + } + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, handleLocal := range []bool{true, false} { + t.Run(fmt.Sprintf("HandleLocal=%t", handleLocal), func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + HandleLocal: handleLocal, + }) + + ep := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID, ep); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err) + } + + s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: nicID}}) + + test.configureStack(t, s) + + r, err := s.FindRoute(unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, err) + } + defer r.Release() + + if r.LocalAddress != localAddr { + t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, localAddr) + } + if r.RemoteAddress != dstAddr { + t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) + } + + if n := ep.Drain(); n != 0 { + t.Fatalf("got ep.Drain() = %d, want = 0", n) + } + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: fakeTransNumber, + TTL: 123, + TOS: stack.DefaultTOS, + }, stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.NewView(10).ToVectorisedView(), + })); err != nil { + t.Fatalf("r.WritePacket(nil, _, _): %s", err) + } + if n := ep.Drain(); n != 1 { + t.Fatalf("got ep.Drain() = %d, want = 1", n) + } + }) + } + }) + } +} + func TestSpoofingWithAddress(t *testing.T) { localAddr := tcpip.Address("\x01") nonExistentLocalAddr := tcpip.Address("\x02") @@ -3346,7 +3452,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { RemoteAddress: ipv4SubnetBcast, RemoteLinkAddress: header.EthernetBroadcastAddress, NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut, + Loop: stack.PacketOut | stack.PacketLoop, }, }, // Broadcast to a locally attached /31 subnet does not populate the @@ -3672,3 +3778,453 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix) } } + +// TestAddRoute tests Stack.AddRoute +func TestAddRoute(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{}) + + subnet1, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + + subnet2, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + + expected := []tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: 1}, + {Destination: subnet2, Gateway: "\x00", NIC: 1}, + } + + // Initialize the route table with one route. + s.SetRouteTable([]tcpip.Route{expected[0]}) + + // Add another route. + s.AddRoute(expected[1]) + + rt := s.GetRouteTable() + if got, want := len(rt), len(expected); got != want { + t.Fatalf("Unexpected route table length got = %d, want = %d", got, want) + } + for i, route := range rt { + if got, want := route, expected[i]; got != want { + t.Fatalf("Unexpected route got = %#v, want = %#v", got, want) + } + } +} + +// TestRemoveRoutes tests Stack.RemoveRoutes +func TestRemoveRoutes(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{}) + + addressToRemove := tcpip.Address("\x01") + subnet1, err := tcpip.NewSubnet(addressToRemove, "\x01") + if err != nil { + t.Fatal(err) + } + + subnet2, err := tcpip.NewSubnet(addressToRemove, "\x01") + if err != nil { + t.Fatal(err) + } + + subnet3, err := tcpip.NewSubnet("\x02", "\x02") + if err != nil { + t.Fatal(err) + } + + // Initialize the route table with three routes. + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: 1}, + {Destination: subnet2, Gateway: "\x00", NIC: 1}, + {Destination: subnet3, Gateway: "\x00", NIC: 1}, + }) + + // Remove routes with the specific address. + s.RemoveRoutes(func(r tcpip.Route) bool { + return r.Destination.ID() == addressToRemove + }) + + expected := []tcpip.Route{{Destination: subnet3, Gateway: "\x00", NIC: 1}} + rt := s.GetRouteTable() + if got, want := len(rt), len(expected); got != want { + t.Fatalf("Unexpected route table length got = %d, want = %d", got, want) + } + for i, route := range rt { + if got, want := route, expected[i]; got != want { + t.Fatalf("Unexpected route got = %#v, want = %#v", got, want) + } + } +} + +func TestFindRouteWithForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + + nic1Addr = tcpip.Address("\x01") + nic2Addr = tcpip.Address("\x02") + remoteAddr = tcpip.Address("\x03") + ) + + type netCfg struct { + proto tcpip.NetworkProtocolNumber + factory stack.NetworkProtocolFactory + nic1Addr tcpip.Address + nic2Addr tcpip.Address + remoteAddr tcpip.Address + } + + fakeNetCfg := netCfg{ + proto: fakeNetNumber, + factory: fakeNetFactory, + nic1Addr: nic1Addr, + nic2Addr: nic2Addr, + remoteAddr: remoteAddr, + } + + globalIPv6Addr1 := tcpip.Address(net.ParseIP("a::1").To16()) + globalIPv6Addr2 := tcpip.Address(net.ParseIP("a::2").To16()) + + ipv6LinkLocalNIC1WithGlobalRemote := netCfg{ + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1Addr: llAddr1, + nic2Addr: globalIPv6Addr2, + remoteAddr: globalIPv6Addr1, + } + ipv6GlobalNIC1WithLinkLocalRemote := netCfg{ + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1Addr: globalIPv6Addr1, + nic2Addr: llAddr1, + remoteAddr: llAddr2, + } + ipv6GlobalNIC1WithLinkLocalMulticastRemote := netCfg{ + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1Addr: globalIPv6Addr1, + nic2Addr: globalIPv6Addr2, + remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + } + + tests := []struct { + name string + + netCfg netCfg + forwardingEnabled bool + + addrNIC tcpip.NICID + localAddr tcpip.Address + + findRouteErr *tcpip.Error + dependentOnForwarding bool + }{ + { + name: "forwarding disabled and localAddr not on specified NIC but route from different NIC", + netCfg: fakeNetCfg, + forwardingEnabled: false, + addrNIC: nicID1, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr not on specified NIC but route from different NIC", + netCfg: fakeNetCfg, + forwardingEnabled: true, + addrNIC: nicID1, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and localAddr on specified NIC but route from different NIC", + netCfg: fakeNetCfg, + forwardingEnabled: false, + addrNIC: nicID1, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr on specified NIC but route from different NIC", + netCfg: fakeNetCfg, + forwardingEnabled: true, + addrNIC: nicID1, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: nil, + dependentOnForwarding: true, + }, + { + name: "forwarding disabled and localAddr on specified NIC and route from same NIC", + netCfg: fakeNetCfg, + forwardingEnabled: false, + addrNIC: nicID2, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr on specified NIC and route from same NIC", + netCfg: fakeNetCfg, + forwardingEnabled: true, + addrNIC: nicID2, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and localAddr not on specified NIC but route from same NIC", + netCfg: fakeNetCfg, + forwardingEnabled: false, + addrNIC: nicID2, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr not on specified NIC but route from same NIC", + netCfg: fakeNetCfg, + forwardingEnabled: true, + addrNIC: nicID2, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and localAddr on same NIC as route", + netCfg: fakeNetCfg, + forwardingEnabled: false, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr on same NIC as route", + netCfg: fakeNetCfg, + forwardingEnabled: false, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and localAddr on different NIC as route", + netCfg: fakeNetCfg, + forwardingEnabled: false, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr on different NIC as route", + netCfg: fakeNetCfg, + forwardingEnabled: true, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: nil, + dependentOnForwarding: true, + }, + { + name: "forwarding disabled and specified NIC only has link-local addr with route on different NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: false, + addrNIC: nicID1, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and specified NIC only has link-local addr with route on different NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: true, + addrNIC: nicID1, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and link-local local addr with route on different NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: false, + localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and link-local local addr with route on same NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: true, + localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and global local addr with route on same NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: true, + localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and link-local local addr with route on same NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalRemote, + forwardingEnabled: false, + localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and link-local local addr with route on same NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalRemote, + forwardingEnabled: true, + localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and global local addr with link-local remote on different NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalRemote, + forwardingEnabled: false, + localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, + findRouteErr: tcpip.ErrNetworkUnreachable, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and global local addr with link-local remote on different NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalRemote, + forwardingEnabled: true, + localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, + findRouteErr: tcpip.ErrNetworkUnreachable, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and global local addr with link-local multicast remote on different NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, + forwardingEnabled: false, + localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, + findRouteErr: tcpip.ErrNetworkUnreachable, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and global local addr with link-local multicast remote on different NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, + forwardingEnabled: true, + localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, + findRouteErr: tcpip.ErrNetworkUnreachable, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and global local addr with link-local multicast remote on same NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, + forwardingEnabled: false, + localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and global local addr with link-local multicast remote on same NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, + forwardingEnabled: true, + localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{test.netCfg.factory}, + }) + + ep1 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s:", nicID1, err) + } + + ep2 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID2, ep2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s:", nicID2, err) + } + + if err := s.AddAddress(nicID1, test.netCfg.proto, test.netCfg.nic1Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, test.netCfg.proto, test.netCfg.nic1Addr, err) + } + + if err := s.AddAddress(nicID2, test.netCfg.proto, test.netCfg.nic2Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err) + } + + if err := s.SetForwarding(test.netCfg.proto, test.forwardingEnabled); err != nil { + t.Fatalf("SetForwarding(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err) + } + + s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) + + r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) + if err != test.findRouteErr { + t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr) + } + defer r.Release() + + if test.findRouteErr != nil { + return + } + + if r.LocalAddress != test.localAddr { + t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, test.localAddr) + } + if r.RemoteAddress != test.netCfg.remoteAddr { + t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, test.netCfg.remoteAddr) + } + + if t.Failed() { + t.FailNow() + } + + // Sending a packet should always go through NIC2 since we only install a + // route to test.netCfg.remoteAddr through NIC2. + data := buffer.View([]byte{1, 2, 3, 4}) + if err := send(r, data); err != nil { + t.Fatalf("send(_, _): %s", err) + } + if n := ep1.Drain(); n != 0 { + t.Errorf("got %d unexpected packets from ep1", n) + } + pkt, ok := ep2.Read() + if !ok { + t.Fatal("packet not sent through ep2") + } + if pkt.Route.LocalAddress != test.localAddr { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddr) + } + if pkt.Route.RemoteAddress != test.netCfg.remoteAddr { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.netCfg.remoteAddr) + } + + if !test.forwardingEnabled || !test.dependentOnForwarding { + return + } + + // Disabling forwarding when the route is dependent on forwarding being + // enabled should make the route invalid. + if err := s.SetForwarding(test.netCfg.proto, false); err != nil { + t.Fatalf("SetForwarding(%d, false): %s", test.netCfg.proto, err) + } + if err := send(r, data); err != tcpip.ErrInvalidEndpointState { + t.Fatalf("got send(_, _) = %s, want = %s", err, tcpip.ErrInvalidEndpointState) + } + if n := ep1.Drain(); n != 0 { + t.Errorf("got %d unexpected packets from ep1", n) + } + if n := ep2.Drain(); n != 0 { + t.Errorf("got %d unexpected packets from ep2", n) + } + }) + } +} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 35e5b1a2e..f183ec6e4 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -152,10 +152,10 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) { +func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) { epsByNIC.mu.RLock() - mpep, ok := epsByNIC.endpoints[r.nic.ID()] + mpep, ok := epsByNIC.endpoints[pkt.NICID] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. @@ -165,20 +165,20 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p // If this is a broadcast or multicast datagram, deliver the datagram to all // endpoints bound to the right device. - if isInboundMulticastOrBroadcast(r) { - mpep.handlePacketAll(r, id, pkt) + if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) { + mpep.handlePacketAll(id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return } // multiPortEndpoints are guaranteed to have at least one element. transEP := selectEndpoint(id, mpep, epsByNIC.seed) if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { - queuedProtocol.QueuePacket(r, transEP, id, pkt) + queuedProtocol.QueuePacket(transEP, id, pkt) epsByNIC.mu.RUnlock() return } - transEP.HandlePacket(r, id, pkt) + transEP.HandlePacket(id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. } @@ -253,6 +253,8 @@ func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t T // based on endpoints IDs. It should only be instantiated via // newTransportDemuxer. type transportDemuxer struct { + stack *Stack + // protocol is immutable. protocol map[protocolIDs]*transportEndpoints queuedProtocols map[protocolIDs]queuedTransportProtocol @@ -262,11 +264,12 @@ type transportDemuxer struct { // the dispatcher to delivery packets to the QueuePacket method instead of // calling HandlePacket directly on the endpoint. type queuedTransportProtocol interface { - QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer) + QueuePacket(ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer) } func newTransportDemuxer(stack *Stack) *transportDemuxer { d := &transportDemuxer{ + stack: stack, protocol: make(map[protocolIDs]*transportEndpoints), queuedProtocols: make(map[protocolIDs]queuedTransportProtocol), } @@ -377,22 +380,22 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 return mpep.endpoints[idx] } -func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt *PacketBuffer) { +func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) { ep.mu.RLock() queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] // HandlePacket takes ownership of pkt, so each endpoint needs // its own copy except for the final one. for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] { if mustQueue { - queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone()) + queuedProtocol.QueuePacket(endpoint, id, pkt.Clone()) } else { - endpoint.HandlePacket(r, id, pkt.Clone()) + endpoint.HandlePacket(id, pkt.Clone()) } } if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue { - queuedProtocol.QueuePacket(r, endpoint, id, pkt) + queuedProtocol.QueuePacket(endpoint, id, pkt) } else { - endpoint.HandlePacket(r, id, pkt) + endpoint.HandlePacket(id, pkt) } ep.mu.RUnlock() // Don't use defer for performance reasons. } @@ -518,29 +521,29 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN // deliverPacket attempts to find one or more matching transport endpoints, and // then, if matches are found, delivers the packet to them. Returns true if // the packet no longer needs to be handled. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool { - eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] +func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool { + eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}] if !ok { return false } // If the packet is a UDP broadcast or multicast, then find all matching // transport endpoints. - if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) { + if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(pkt, id.LocalAddress) { eps.mu.RLock() destEPs := eps.findAllEndpointsLocked(id) eps.mu.RUnlock() // Fail if we didn't find at least one matching transport endpoint. if len(destEPs) == 0 { - r.Stats().UDP.UnknownPortErrors.Increment() + d.stack.stats.UDP.UnknownPortErrors.Increment() return false } // handlePacket takes ownership of pkt, so each endpoint needs its own // copy except for the final one. for _, ep := range destEPs[:len(destEPs)-1] { - ep.handlePacket(r, id, pkt.Clone()) + ep.handlePacket(id, pkt.Clone()) } - destEPs[len(destEPs)-1].handlePacket(r, id, pkt) + destEPs[len(destEPs)-1].handlePacket(id, pkt) return true } @@ -548,10 +551,10 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // destination address, then do nothing further and instruct the caller to do // the same. The network layer handles address validation for specified source // addresses. - if protocol == header.TCPProtocolNumber && (!isSpecified(r.LocalAddress) || !isSpecified(r.RemoteAddress) || isInboundMulticastOrBroadcast(r)) { + if protocol == header.TCPProtocolNumber && (!isSpecified(id.LocalAddress) || !isSpecified(id.RemoteAddress) || isInboundMulticastOrBroadcast(pkt, id.LocalAddress)) { // TCP can only be used to communicate between a single source and a - // single destination; the addresses must be unicast. - r.Stats().TCP.InvalidSegmentsReceived.Increment() + // single destination; the addresses must be unicast.e + d.stack.stats.TCP.InvalidSegmentsReceived.Increment() return true } @@ -560,18 +563,18 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto eps.mu.RUnlock() if ep == nil { if protocol == header.UDPProtocolNumber { - r.Stats().UDP.UnknownPortErrors.Increment() + d.stack.stats.UDP.UnknownPortErrors.Increment() } return false } - ep.handlePacket(r, id, pkt) + ep.handlePacket(id, pkt) return true } // deliverRawPacket attempts to deliver the given packet and returns whether it // was delivered successfully. -func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool { - eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] +func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool { + eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}] if !ok { return false } @@ -584,7 +587,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr for _, rawEP := range eps.rawEndpoints { // Each endpoint gets its own copy of the packet for the sake // of save/restore. - rawEP.HandlePacket(r, pkt) + rawEP.HandlePacket(pkt.Clone()) foundRaw = true } eps.mu.RUnlock() @@ -612,7 +615,7 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco } // findTransportEndpoint find a single endpoint that most closely matches the provided id. -func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { +func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint { eps, ok := d.protocol[protocolIDs{netProto, transProto}] if !ok { return nil @@ -628,7 +631,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN epsByNIC.mu.RLock() eps.mu.RUnlock() - mpep, ok := epsByNIC.endpoints[r.nic.ID()] + mpep, ok := epsByNIC.endpoints[nicID] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. @@ -679,8 +682,8 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN eps.mu.Unlock() } -func isInboundMulticastOrBroadcast(r *Route) bool { - return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress) +func isInboundMulticastOrBroadcast(pkt *PacketBuffer, localAddr tcpip.Address) bool { + return pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(localAddr) || header.IsV6MulticastAddress(localAddr) } func isSpecified(addr tcpip.Address) bool { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 62ab6d92f..c457b67a2 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -28,7 +28,7 @@ import ( const ( fakeTransNumber tcpip.TransportProtocolNumber = 1 - fakeTransHeaderLen = 3 + fakeTransHeaderLen int = 3 ) // fakeTransportEndpoint is a transport-layer protocol endpoint. It counts @@ -213,20 +213,29 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ *stack.PacketBuffer) { +func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Increment the number of received packets. f.proto.packetCount++ - if f.acceptQueue != nil { - f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ - TransportEndpointInfo: stack.TransportEndpointInfo{ - ID: f.ID, - NetProto: f.NetProto, - }, - proto: f.proto, - peerAddr: r.RemoteAddress, - route: r.Clone(), - }) + if f.acceptQueue == nil { + return } + + netHdr := pkt.NetworkHeader().View() + route, err := f.proto.stack.FindRoute(pkt.NICID, tcpip.Address(netHdr[dstAddrOffset]), tcpip.Address(netHdr[srcAddrOffset]), pkt.NetworkProtocolNumber, false /* multicastLoop */) + if err != nil { + return + } + route.ResolveWith(pkt.SourceLinkAddress()) + + f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ + TransportEndpointInfo: stack.TransportEndpointInfo{ + ID: f.ID, + NetProto: f.NetProto, + }, + proto: f.proto, + peerAddr: route.RemoteAddress, + route: route, + }) } func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) { @@ -288,7 +297,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { return stack.UnknownDestinationPacketHandled } |