diff options
Diffstat (limited to 'pkg/tcpip/stack')
25 files changed, 1568 insertions, 1287 deletions
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index cd423bf71..e5590ecc0 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -117,7 +117,7 @@ func (a *AddressableEndpointState) releaseAddressStateLocked(addrState *addressS } // AddAndAcquirePermanentAddress implements AddressableEndpoint. -func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, *tcpip.Error) { +func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error) { a.mu.Lock() defer a.mu.Unlock() ep, err := a.addAndAcquireAddressLocked(addr, peb, configType, deprecated, true /* permanent */) @@ -143,10 +143,10 @@ func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.Addr // AddAndAcquireTemporaryAddress adds a temporary address. // -// Returns tcpip.ErrDuplicateAddress if the address exists. +// Returns *tcpip.ErrDuplicateAddress if the address exists. // // The temporary address's endpoint is acquired and returned. -func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, *tcpip.Error) { +func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, tcpip.Error) { a.mu.Lock() defer a.mu.Unlock() ep, err := a.addAndAcquireAddressLocked(addr, peb, AddressConfigStatic, false /* deprecated */, false /* permanent */) @@ -176,11 +176,11 @@ func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.Addr // If the addressable endpoint already has the address in a non-permanent state, // and addAndAcquireAddressLocked is adding a permanent address, that address is // promoted in place and its properties set to the properties provided. If the -// address already exists in any other state, then tcpip.ErrDuplicateAddress is +// address already exists in any other state, then *tcpip.ErrDuplicateAddress is // returned, regardless the kind of address that is being added. // // Precondition: a.mu must be write locked. -func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, *tcpip.Error) { +func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, tcpip.Error) { // attemptAddToPrimary is false when the address is already in the primary // address list. attemptAddToPrimary := true @@ -190,7 +190,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address // We are adding a non-permanent address but the address exists. No need // to go any further since we can only promote existing temporary/expired // addresses to permanent. - return nil, tcpip.ErrDuplicateAddress + return nil, &tcpip.ErrDuplicateAddress{} } addrState.mu.Lock() @@ -198,7 +198,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address addrState.mu.Unlock() // We are adding a permanent address but a permanent address already // exists. - return nil, tcpip.ErrDuplicateAddress + return nil, &tcpip.ErrDuplicateAddress{} } if addrState.mu.refs == 0 { @@ -293,7 +293,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address } // RemovePermanentAddress implements AddressableEndpoint. -func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { +func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { a.mu.Lock() defer a.mu.Unlock() return a.removePermanentAddressLocked(addr) @@ -303,10 +303,10 @@ func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *t // requirements. // // Precondition: a.mu must be write locked. -func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { +func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Address) tcpip.Error { addrState, ok := a.mu.endpoints[addr] if !ok { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } return a.removePermanentEndpointLocked(addrState) @@ -314,10 +314,10 @@ func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Addre // RemovePermanentEndpoint removes the passed endpoint if it is associated with // a and permanent. -func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) *tcpip.Error { +func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) tcpip.Error { addrState, ok := ep.(*addressState) if !ok || addrState.addressableEndpointState != a { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } a.mu.Lock() @@ -329,9 +329,9 @@ func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) * // requirements. // // Precondition: a.mu must be write locked. -func (a *AddressableEndpointState) removePermanentEndpointLocked(addrState *addressState) *tcpip.Error { +func (a *AddressableEndpointState) removePermanentEndpointLocked(addrState *addressState) tcpip.Error { if !addrState.GetKind().IsPermanent() { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } addrState.SetKind(PermanentExpired) @@ -574,9 +574,11 @@ func (a *AddressableEndpointState) Cleanup() { defer a.mu.Unlock() for _, ep := range a.mu.endpoints { - // removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is + // removePermanentEndpointLocked returns *tcpip.ErrBadLocalAddress if ep is // not a permanent address. - if err := a.removePermanentEndpointLocked(ep); err != nil && err != tcpip.ErrBadLocalAddress { + switch err := a.removePermanentEndpointLocked(ep); err.(type) { + case nil, *tcpip.ErrBadLocalAddress: + default: panic(fmt.Sprintf("unexpected error from removePermanentEndpointLocked(%s): %s", ep.addr, err)) } } diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 5e649cca6..54617f2e6 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -198,15 +198,15 @@ type bucket struct { // TCP header. // // Preconditions: pkt.NetworkHeader() is valid. -func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { +func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) { netHeader := pkt.Network() if netHeader.TransportProtocol() != header.TCPProtocolNumber { - return tupleID{}, tcpip.ErrUnknownProtocol + return tupleID{}, &tcpip.ErrUnknownProtocol{} } tcpHeader := header.TCP(pkt.TransportHeader().View()) if len(tcpHeader) < header.TCPMinimumSize { - return tupleID{}, tcpip.ErrUnknownProtocol + return tupleID{}, &tcpip.ErrUnknownProtocol{} } return tupleID{ @@ -617,7 +617,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo return true } -func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) { +func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { // Lookup the connection. The reply's original destination // describes the original address. tid := tupleID{ @@ -631,10 +631,10 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ conn, _ := ct.connForTID(tid) if conn == nil { // Not a tracked connection. - return "", 0, tcpip.ErrNotConnected + return "", 0, &tcpip.ErrNotConnected{} } else if conn.manip == manipNone { // Unmanipulated connection. - return "", 0, tcpip.ErrInvalidOptionValue + return "", 0, &tcpip.ErrInvalidOptionValue{} } return conn.original.dstAddr, conn.original.dstPort, nil diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 4908848e9..63a42a2ea 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -41,6 +41,8 @@ const ( protocolNumberOffset = 2 ) +var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil) + // fwdTestNetworkEndpoint is a network-layer protocol endpoint. // Headers of this protocol are fwdTestNetHeaderLen bytes, but we currently only // use the first three: destination address, source address, and transport @@ -53,9 +55,7 @@ type fwdTestNetworkEndpoint struct { dispatcher TransportDispatcher } -var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil) - -func (*fwdTestNetworkEndpoint) Enable() *tcpip.Error { +func (*fwdTestNetworkEndpoint) Enable() tcpip.Error { return nil } @@ -104,7 +104,7 @@ func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen } -func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { +func (*fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { return 0 } @@ -112,7 +112,7 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu return f.proto.Number() } -func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { +func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error { // Add the protocol's header to the packet and send it to the link // endpoint. b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen) @@ -124,14 +124,14 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH } // WritePackets implements LinkEndpoint.WritePackets. -func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { +func (*fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } -func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error { +func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) tcpip.Error { // The network header should not already be populated. if _, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen); !ok { - return tcpip.ErrMalformedHeader + return &tcpip.ErrMalformedHeader{} } return f.nic.WritePacket(r, nil /* gso */, fwdTestNetNumber, pkt) @@ -141,6 +141,21 @@ func (f *fwdTestNetworkEndpoint) Close() { f.AddressableEndpointState.Cleanup() } +// Stats implements stack.NetworkEndpoint. +func (*fwdTestNetworkEndpoint) Stats() NetworkEndpointStats { + return &fwdTestNetworkEndpointStats{} +} + +var _ NetworkEndpointStats = (*fwdTestNetworkEndpointStats)(nil) + +type fwdTestNetworkEndpointStats struct{} + +// IsNetworkEndpointStats implements stack.NetworkEndpointStats. +func (*fwdTestNetworkEndpointStats) IsNetworkEndpointStats() {} + +var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil) +var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) + // fwdTestNetworkProtocol is a network-layer protocol that implements Address // resolution. type fwdTestNetworkProtocol struct { @@ -158,18 +173,15 @@ type fwdTestNetworkProtocol struct { } } -var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) -var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil) - -func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { +func (*fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { return fwdTestNetNumber } -func (f *fwdTestNetworkProtocol) MinimumPacketSize() int { +func (*fwdTestNetworkProtocol) MinimumPacketSize() int { return fwdTestNetHeaderLen } -func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int { +func (*fwdTestNetworkProtocol) DefaultPrefixLen() int { return fwdTestNetDefaultPrefixLen } @@ -195,19 +207,19 @@ func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddress return e } -func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } -func (*fwdTestNetworkProtocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*fwdTestNetworkProtocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } func (*fwdTestNetworkProtocol) Close() {} func (*fwdTestNetworkProtocol) Wait() {} -func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { +func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { if f.onLinkAddressResolved != nil { time.AfterFunc(f.addrResolveDelay, func() { f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr) @@ -307,7 +319,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { +func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { p := fwdTestPacketInfo{ RemoteLinkAddress: r.RemoteLinkAddress, LocalLinkAddress: r.LocalLinkAddress, @@ -323,7 +335,7 @@ func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.N } // WritePackets stores outbound packets into the channel. -func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.WritePacket(r, gso, protocol, pkt) @@ -356,10 +368,6 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC UseNeighborCache: useNeighborCache, }) - if !useNeighborCache { - proto.addrCache = s.linkAddrCache - } - // Enable forwarding. s.SetForwarding(proto.Number(), true) @@ -389,13 +397,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC t.Fatal("AddAddress #2 failed:", err) } + nic, ok := s.nics[2] + if !ok { + t.Fatal("NIC 2 does not exist") + } if useNeighborCache { // Control the neighbor cache for NIC 2. - nic, ok := s.nics[2] - if !ok { - t.Fatal("failed to get the neighbor cache for NIC 2") - } proto.neigh = nic.neigh + } else { + proto.addrCache = nic.linkAddrCache } // Route all packets to NIC 2. @@ -481,7 +491,7 @@ func TestForwardingWithFakeResolver(t *testing.T) { addrResolveDelay: 500 * time.Millisecond, onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any address will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + cache.AddLinkAddress(addr, "c") }, }, }, @@ -607,7 +617,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // Only packets to address 3 will be resolved to the // link address "c". if addr == "\x03" { - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + cache.AddLinkAddress(addr, "c") } }, }, @@ -692,7 +702,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { addrResolveDelay: 500 * time.Millisecond, onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + cache.AddLinkAddress(addr, "c") }, }, }, @@ -768,7 +778,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { addrResolveDelay: 500 * time.Millisecond, onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + cache.AddLinkAddress(addr, "c") }, }, }, @@ -858,7 +868,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { addrResolveDelay: 500 * time.Millisecond, onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + cache.AddLinkAddress(addr, "c") }, }, }, diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 09c7811fa..63832c200 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -229,7 +229,7 @@ func (it *IPTables) GetTable(id TableID, ipv6 bool) Table { // ReplaceTable replaces or inserts table by name. It panics when an invalid id // is provided. -func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) *tcpip.Error { +func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) tcpip.Error { it.mu.Lock() defer it.mu.Unlock() // If iptables is being enabled, initialize the conntrack table and @@ -267,11 +267,11 @@ const ( // dropped. // // TODO(gvisor.dev/issue/170): PacketBuffer should hold the GSO and route, from -// which address and nicName can be gathered. Currently, address is only -// needed for prerouting and nicName is only needed for output. +// which address can be gathered. Currently, address is only needed for +// prerouting. // // Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) bool { +func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool { if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { return true } @@ -302,7 +302,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer table = it.v4Tables[tableID] } ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -385,10 +385,10 @@ func (it *IPTables) startReaper(interval time.Duration) { // // NOTE: unlike the Check API the returned map contains packets that should be // dropped. -func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, nicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { +func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if !pkt.NatDone { - if ok := it.Check(hook, pkt, gso, r, "", nicName); !ok { + if ok := it.Check(hook, pkt, gso, r, "", inNicName, outNicName); !ok { if drop == nil { drop = make(map[*PacketBuffer]struct{}) } @@ -408,11 +408,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r * // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) chainVerdict { +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict { case RuleAccept: return chainAccept @@ -429,7 +429,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, nicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, inNicName, outNicName); verdict { case chainAccept: return chainAccept case chainDrop: @@ -455,11 +455,11 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) (RuleVerdict, int) { +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // Check whether the packet matches the IP header filter. - if !rule.Filter.match(pkt, hook, nicName) { + if !rule.Filter.match(pkt, hook, inNicName, outNicName) { // Continue on to the next rule. return RuleJump, ruleIdx + 1 } @@ -467,7 +467,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // Go through each rule matcher. If they all match, run // the rule target. for _, matcher := range rule.Matchers { - matches, hotdrop := matcher.Match(hook, pkt, "") + matches, hotdrop := matcher.Match(hook, pkt, inNicName, outNicName) if hotdrop { return RuleDrop, 0 } @@ -483,11 +483,11 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // OriginalDst returns the original destination of redirected connections. It // returns an error if the connection doesn't exist or isn't redirected. -func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) { +func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { it.mu.RLock() defer it.mu.RUnlock() if !it.modified { - return "", 0, tcpip.ErrNotConnected + return "", 0, &tcpip.ErrNotConnected{} } return it.connections.originalDst(epID, netProto) } diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 56a3e7861..fd9d61e39 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -210,8 +210,19 @@ type IPHeaderFilter struct { // filter will match packets that fail the source comparison. SrcInvert bool - // OutputInterface matches the name of the outgoing interface for the - // packet. + // InputInterface matches the name of the incoming interface for the packet. + InputInterface string + + // InputInterfaceMask masks the characters of the interface name when + // comparing with InputInterface. + InputInterfaceMask string + + // InputInterfaceInvert inverts the meaning of incoming interface check, + // i.e. when true the filter will match packets that fail the incoming + // interface comparison. + InputInterfaceInvert bool + + // OutputInterface matches the name of the outgoing interface for the packet. OutputInterface string // OutputInterfaceMask masks the characters of the interface name when @@ -228,7 +239,7 @@ type IPHeaderFilter struct { // // Preconditions: pkt.NetworkHeader is set and is at least of the minimal IPv4 // or IPv6 header length. -func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) bool { +func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicName string) bool { // Extract header fields. var ( // TODO(gvisor.dev/issue/170): Support other filter fields. @@ -264,26 +275,35 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) boo return false } - // Check the output interface. - // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING - // hooks after supported. - if hook == Output { - n := len(fl.OutputInterface) - if n == 0 { - return true - } - - // If the interface name ends with '+', any interface which - // begins with the name should be matched. - ifName := fl.OutputInterface - matches := nicName == ifName - if strings.HasSuffix(ifName, "+") { - matches = strings.HasPrefix(nicName, ifName[:n-1]) - } - return fl.OutputInterfaceInvert != matches + switch hook { + case Prerouting, Input: + return matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert) + case Output: + return matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert) + case Forward, Postrouting: + // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING + // hooks after supported. + return true + default: + panic(fmt.Sprintf("unknown hook: %d", hook)) } +} - return true +func matchIfName(nicName string, ifName string, invert bool) bool { + n := len(ifName) + if n == 0 { + // If the interface name is omitted in the filter, any interface will match. + return true + } + // If the interface name ends with '+', any interface which begins with the + // name should be matched. + var matches bool + if strings.HasSuffix(ifName, "+") { + matches = strings.HasPrefix(nicName, ifName[:n-1]) + } else { + matches = nicName == ifName + } + return matches != invert } // NetworkProtocol returns the protocol (IPv4 or IPv6) on to which the header @@ -320,7 +340,7 @@ type Matcher interface { // used for suspicious packets. // // Precondition: packet.NetworkHeader is set. - Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool) + Match(hook Hook, packet *PacketBuffer, inputInterfaceName, outputInterfaceName string) (matches bool, hotdrop bool) } // A Target is the interface for taking an action for a packet. diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index b600a1cab..930b8f795 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -24,12 +24,16 @@ import ( const linkAddrCacheSize = 512 // max cache entries +var _ LinkAddressCache = (*linkAddrCache)(nil) + // linkAddrCache is a fixed-sized cache mapping IP addresses to link addresses. // // The entries are stored in a ring buffer, oldest entry replaced first. // // This struct is safe for concurrent use. type linkAddrCache struct { + nic *NIC + // ageLimit is how long a cache entry is valid for. ageLimit time.Duration @@ -41,9 +45,9 @@ type linkAddrCache struct { // resolved before failing. resolutionAttempts int - cache struct { + mu struct { sync.Mutex - table map[tcpip.FullAddress]*linkAddrEntry + table map[tcpip.Address]*linkAddrEntry lru linkAddrEntryList } } @@ -77,31 +81,42 @@ type linkAddrEntry struct { // linkAddrEntryEntry access is synchronized by the linkAddrCache lock. linkAddrEntryEntry - // TODO(gvisor.dev/issue/5150): move these fields under mu. - // mu protects the fields below. - mu sync.RWMutex + cache *linkAddrCache - addr tcpip.FullAddress - linkAddr tcpip.LinkAddress - expiration time.Time - s entryState + mu struct { + sync.RWMutex - // done is closed when address resolution is complete. It is nil iff s is - // incomplete and resolution is not yet in progress. - done chan struct{} + addr tcpip.Address + linkAddr tcpip.LinkAddress + expiration time.Time + s entryState - // onResolve is called with the result of address resolution. - onResolve []func(tcpip.LinkAddress, bool) + // done is closed when address resolution is complete. It is nil iff s is + // incomplete and resolution is not yet in progress. + done chan struct{} + + // onResolve is called with the result of address resolution. + onResolve []func(LinkResolutionResult) + } } func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) { - for _, callback := range e.onResolve { - callback(linkAddr, len(linkAddr) != 0) + res := LinkResolutionResult{LinkAddress: linkAddr, Success: len(linkAddr) != 0} + for _, callback := range e.mu.onResolve { + callback(res) } - e.onResolve = nil - if ch := e.done; ch != nil { + e.mu.onResolve = nil + if ch := e.mu.done; ch != nil { close(ch) - e.done = nil + e.mu.done = nil + // Dequeue the pending packets in a new goroutine to not hold up the current + // goroutine as writing packets may be a costly operation. + // + // At the time of writing, when writing packets, a neighbor's link address + // is resolved (which ends up obtaining the entry's lock) while holding the + // link resolution queue's lock. Dequeuing packets in a new goroutine avoids + // a lock ordering violation. + go e.cache.nic.linkResQueue.dequeue(ch, linkAddr, len(linkAddr) != 0) } } @@ -114,30 +129,30 @@ func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) { // // Precondition: e.mu must be locked func (e *linkAddrEntry) changeStateLocked(ns entryState, expiration time.Time) { - if e.s == incomplete && ns == ready { - e.notifyCompletionLocked(e.linkAddr) + if e.mu.s == incomplete && ns == ready { + e.notifyCompletionLocked(e.mu.linkAddr) } - if expiration.IsZero() || expiration.After(e.expiration) { - e.expiration = expiration + if expiration.IsZero() || expiration.After(e.mu.expiration) { + e.mu.expiration = expiration } - e.s = ns + e.mu.s = ns } // add adds a k -> v mapping to the cache. -func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) { +func (c *linkAddrCache) AddLinkAddress(k tcpip.Address, v tcpip.LinkAddress) { // Calculate expiration time before acquiring the lock, since expiration is // relative to the time when information was learned, rather than when it // happened to be inserted into the cache. expiration := time.Now().Add(c.ageLimit) - c.cache.Lock() + c.mu.Lock() entry := c.getOrCreateEntryLocked(k) - c.cache.Unlock() - entry.mu.Lock() defer entry.mu.Unlock() - entry.linkAddr = v + c.mu.Unlock() + + entry.mu.linkAddr = v entry.changeStateLocked(ready, expiration) } @@ -150,19 +165,19 @@ func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) { // reset to state incomplete, and returned. If no matching entry exists and the // cache is not full, a new entry with state incomplete is allocated and // returned. -func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEntry { - if entry, ok := c.cache.table[k]; ok { - c.cache.lru.Remove(entry) - c.cache.lru.PushFront(entry) +func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry { + if entry, ok := c.mu.table[k]; ok { + c.mu.lru.Remove(entry) + c.mu.lru.PushFront(entry) return entry } var entry *linkAddrEntry - if len(c.cache.table) == linkAddrCacheSize { - entry = c.cache.lru.Back() + if len(c.mu.table) == linkAddrCacheSize { + entry = c.mu.lru.Back() entry.mu.Lock() - delete(c.cache.table, entry.addr) - c.cache.lru.Remove(entry) + delete(c.mu.table, entry.mu.addr) + c.mu.lru.Remove(entry) // Wake waiters and mark the soon-to-be-reused entry as expired. entry.notifyCompletionLocked("" /* linkAddr */) @@ -172,53 +187,56 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt } *entry = linkAddrEntry{ - addr: k, - s: incomplete, + cache: c, } - c.cache.table[k] = entry - c.cache.lru.PushFront(entry) + entry.mu.Lock() + entry.mu.addr = k + entry.mu.s = incomplete + entry.mu.Unlock() + c.mu.table[k] = entry + c.mu.lru.PushFront(entry) return entry } // get reports any known link address for k. -func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { - c.cache.Lock() - defer c.cache.Unlock() +func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { + c.mu.Lock() entry := c.getOrCreateEntryLocked(k) entry.mu.Lock() defer entry.mu.Unlock() + c.mu.Unlock() - switch s := entry.s; s { + switch s := entry.mu.s; s { case ready: - if !time.Now().After(entry.expiration) { + if !time.Now().After(entry.mu.expiration) { // Not expired. if onResolve != nil { - onResolve(entry.linkAddr, true) + onResolve(LinkResolutionResult{LinkAddress: entry.mu.linkAddr, Success: true}) } - return entry.linkAddr, nil, nil + return entry.mu.linkAddr, nil, nil } entry.changeStateLocked(incomplete, time.Time{}) fallthrough case incomplete: if onResolve != nil { - entry.onResolve = append(entry.onResolve, onResolve) + entry.mu.onResolve = append(entry.mu.onResolve, onResolve) } - if entry.done == nil { - entry.done = make(chan struct{}) - go c.startAddressResolution(k, linkRes, localAddr, nic, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + if entry.mu.done == nil { + entry.mu.done = make(chan struct{}) + go c.startAddressResolution(k, linkRes, localAddr, nic, entry.mu.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. } - return entry.linkAddr, entry.done, tcpip.ErrWouldBlock + return entry.mu.linkAddr, entry.mu.done, &tcpip.ErrWouldBlock{} default: panic(fmt.Sprintf("invalid cache entry state: %s", s)) } } -func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) { +func (c *linkAddrCache) startAddressResolution(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) { for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check // whether the request succeeded. - linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, nic) + linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */, nic) select { case now := <-time.After(c.resolutionTimeout): @@ -234,10 +252,10 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link // checkLinkRequest checks whether previous attempt to resolve address has // succeeded and mark the entry accordingly. Returns true if request can stop, // false if another request should be sent. -func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool { - c.cache.Lock() - defer c.cache.Unlock() - entry, ok := c.cache.table[k] +func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.Address, attempt int) bool { + c.mu.Lock() + defer c.mu.Unlock() + entry, ok := c.mu.table[k] if !ok { // Entry was evicted from the cache. return true @@ -245,7 +263,7 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, att entry.mu.Lock() defer entry.mu.Unlock() - switch s := entry.s; s { + switch s := entry.mu.s; s { case ready: // Entry was made ready by resolver. case incomplete: @@ -255,19 +273,20 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, att } // Max number of retries reached, delete entry. entry.notifyCompletionLocked("" /* linkAddr */) - delete(c.cache.table, k) + delete(c.mu.table, k) default: panic(fmt.Sprintf("invalid cache entry state: %s", s)) } return true } -func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache { +func newLinkAddrCache(nic *NIC, ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache { c := &linkAddrCache{ + nic: nic, ageLimit: ageLimit, resolutionTimeout: resolutionTimeout, resolutionAttempts: resolutionAttempts, } - c.cache.table = make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize) + c.mu.table = make(map[tcpip.Address]*linkAddrEntry, linkAddrCacheSize) return c } diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index d7ac6cf5f..466a5e8d9 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -26,7 +26,7 @@ import ( ) type testaddr struct { - addr tcpip.FullAddress + addr tcpip.Address linkAddr tcpip.LinkAddress } @@ -35,7 +35,7 @@ var testAddrs = func() []testaddr { for i := 0; i < 4*linkAddrCacheSize; i++ { addr := fmt.Sprintf("Addr%06d", i) addrs = append(addrs, testaddr{ - addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)}, + addr: tcpip.Address(addr), linkAddr: tcpip.LinkAddress("Link" + addr), }) } @@ -48,7 +48,7 @@ type testLinkAddressResolver struct { onLinkAddressRequest func() } -func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { +func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { // TODO(gvisor.dev/issue/5141): Use a fake clock. time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) }) if f := r.onLinkAddressRequest; f != nil { @@ -59,8 +59,8 @@ func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) { for _, ta := range testAddrs { - if ta.addr.Addr == addr { - r.cache.add(ta.addr, ta.linkAddr) + if ta.addr == addr { + r.cache.AddLinkAddress(ta.addr, ta.linkAddr) break } } @@ -77,13 +77,13 @@ func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe return 1 } -func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) { +func getBlocking(c *linkAddrCache, addr tcpip.Address, linkRes LinkAddressResolver) (tcpip.LinkAddress, tcpip.Error) { var attemptedResolution bool for { got, ch, err := c.get(addr, linkRes, "", nil, nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { if attemptedResolution { - return got, tcpip.ErrTimeout + return got, &tcpip.ErrTimeout{} } attemptedResolution = true <-ch @@ -93,17 +93,23 @@ func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressRe } } +func newEmptyNIC() *NIC { + n := &NIC{} + n.linkResQueue.init(n) + return n +} + func TestCacheOverflow(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) + c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) for i := len(testAddrs) - 1; i >= 0; i-- { e := testAddrs[i] - c.add(e.addr, e.linkAddr) + c.AddLinkAddress(e.addr, e.linkAddr) got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { - t.Errorf("insert %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) + t.Errorf("insert %d, c.get(%s, nil, '', nil, nil): %s", i, e.addr, err) } if got != e.linkAddr { - t.Errorf("insert %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr) + t.Errorf("insert %d, got c.get(%s, nil, '', nil, nil) = %s, want = %s", i, e.addr, got, e.linkAddr) } } // Expect to find at least half of the most recent entries. @@ -111,25 +117,25 @@ func TestCacheOverflow(t *testing.T) { e := testAddrs[i] got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { - t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) + t.Errorf("check %d, c.get(%s, nil, '', nil, nil): %s", i, e.addr, err) } if got != e.linkAddr { - t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr) + t.Errorf("check %d, got c.get(%s, nil, '', nil, nil) = %s, want = %s", i, e.addr, got, e.linkAddr) } } // The earliest entries should no longer be in the cache. - c.cache.Lock() - defer c.cache.Unlock() + c.mu.Lock() + defer c.mu.Unlock() for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- { e := testAddrs[i] - if entry, ok := c.cache.table[e.addr]; ok { - t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry) + if entry, ok := c.mu.table[e.addr]; ok { + t.Errorf("unexpected entry at c.mu.table[%s]: %#v", e.addr, entry) } } } func TestCacheConcurrent(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) + c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) linkRes := &testLinkAddressResolver{cache: c} var wg sync.WaitGroup @@ -137,7 +143,7 @@ func TestCacheConcurrent(t *testing.T) { wg.Add(1) go func() { for _, e := range testAddrs { - c.add(e.addr, e.linkAddr) + c.AddLinkAddress(e.addr, e.linkAddr) } wg.Done() }() @@ -150,52 +156,53 @@ func TestCacheConcurrent(t *testing.T) { e := testAddrs[len(testAddrs)-1] got, _, err := c.get(e.addr, linkRes, "", nil, nil) if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + t.Errorf("c.get(%s, _, '', nil, nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) + t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr) } e = testAddrs[0] - c.cache.Lock() - defer c.cache.Unlock() - if entry, ok := c.cache.table[e.addr]; ok { - t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry) + c.mu.Lock() + defer c.mu.Unlock() + if entry, ok := c.mu.table[e.addr]; ok { + t.Errorf("unexpected entry at c.mu.table[%s]: %#v", e.addr, entry) } } func TestCacheAgeLimit(t *testing.T) { - c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3) + c := newLinkAddrCache(newEmptyNIC(), 1*time.Millisecond, 1*time.Second, 3) linkRes := &testLinkAddressResolver{cache: c} e := testAddrs[0] - c.add(e.addr, e.linkAddr) + c.AddLinkAddress(e.addr, e.linkAddr) time.Sleep(50 * time.Millisecond) - if _, _, err := c.get(e.addr, linkRes, "", nil, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.get(%q) = %s, want = ErrWouldBlock", string(e.addr.Addr), err) + _, _, err := c.get(e.addr, linkRes, "", nil, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = ErrWouldBlock", e.addr, err) } } func TestCacheReplace(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) + c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) e := testAddrs[0] l2 := e.linkAddr + "2" - c.add(e.addr, e.linkAddr) + c.AddLinkAddress(e.addr, e.linkAddr) got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + t.Errorf("c.get(%s, nil, '', nil, nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) + t.Errorf("got c.get(%s, nil, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr) } - c.add(e.addr, l2) + c.AddLinkAddress(e.addr, l2) got, _, err = c.get(e.addr, nil, "", nil, nil) if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + t.Errorf("c.get(%s, nil, '', nil, nil): %s", e.addr, err) } if got != l2 { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, l2) + t.Errorf("got c.get(%s, nil, '', nil, nil) = %s, want = %s", e.addr, got, l2) } } @@ -206,15 +213,15 @@ func TestCacheResolution(t *testing.T) { // // Using a large resolution timeout decreases the probability of experiencing // this race condition and does not affect how long this test takes to run. - c := newLinkAddrCache(1<<63-1, math.MaxInt64, 1) + c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, math.MaxInt64, 1) linkRes := &testLinkAddressResolver{cache: c} for i, ta := range testAddrs { got, err := getBlocking(c, ta.addr, linkRes) if err != nil { - t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err) + t.Errorf("check %d, getBlocking(_, %s, _): %s", i, ta.addr, err) } if got != ta.linkAddr { - t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(ta.addr.Addr), got, ta.linkAddr) + t.Errorf("check %d, got getBlocking(_, %s, _) = %s, want = %s", i, ta.addr, got, ta.linkAddr) } } @@ -223,16 +230,16 @@ func TestCacheResolution(t *testing.T) { e := testAddrs[len(testAddrs)-1] got, _, err := c.get(e.addr, linkRes, "", nil, nil) if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + t.Errorf("c.get(%s, _, '', nil, nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) + t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr) } } } func TestCacheResolutionFailed(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5) + c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 10*time.Millisecond, 5) linkRes := &testLinkAddressResolver{cache: c} var requestCount uint32 @@ -244,17 +251,18 @@ func TestCacheResolutionFailed(t *testing.T) { e := testAddrs[0] got, err := getBlocking(c, e.addr, linkRes) if err != nil { - t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + t.Errorf("getBlocking(_, %s, _): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) + t.Errorf("got getBlocking(_, %s, _) = %s, want = %s", e.addr, got, e.linkAddr) } before := atomic.LoadUint32(&requestCount) - e.addr.Addr += "2" - if a, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrTimeout { - t.Errorf("got getBlocking(_, %#v, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout) + e.addr += "2" + a, err := getBlocking(c, e.addr, linkRes) + if _, ok := err.(*tcpip.ErrTimeout); !ok { + t.Errorf("got getBlocking(_, %s, _) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) } if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want { @@ -265,11 +273,12 @@ func TestCacheResolutionFailed(t *testing.T) { func TestCacheResolutionTimeout(t *testing.T) { resolverDelay := 500 * time.Millisecond expiration := resolverDelay / 10 - c := newLinkAddrCache(expiration, 1*time.Millisecond, 3) + c := newLinkAddrCache(newEmptyNIC(), expiration, 1*time.Millisecond, 3) linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay} e := testAddrs[0] - if a, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrTimeout { - t.Errorf("got getBlocking(_, %#v, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout) + a, err := getBlocking(c, e.addr, linkRes) + if _, ok := err.(*tcpip.ErrTimeout); !ok { + t.Errorf("got getBlocking(_, %s, _) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) } } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 61636cae5..64383bc7c 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -45,6 +45,8 @@ const ( linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09") + defaultPrefixLen = 128 + // Extra time to use when waiting for an async event to occur. defaultAsyncPositiveEventTimeout = 10 * time.Second @@ -102,7 +104,7 @@ type ndpDADEvent struct { nicID tcpip.NICID addr tcpip.Address resolved bool - err *tcpip.Error + err tcpip.Error } type ndpRouterEvent struct { @@ -172,7 +174,7 @@ type ndpDispatcher struct { } // Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionStatus. -func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) { +func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) { if n.dadC != nil { n.dadC <- ndpDADEvent{ nicID, @@ -309,7 +311,7 @@ func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 { // Check e to make sure that the event is for addr on nic with ID 1, and the // resolved flag set to resolved with the specified err. -func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) string { +func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) string { return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, resolved: resolved, err: err}, e, cmp.AllowUnexported(e)) } @@ -330,8 +332,12 @@ func TestDADDisabled(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: addr1, + PrefixLen: defaultPrefixLen, + } + if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) } // Should get the address immediately since we should not have performed @@ -344,12 +350,8 @@ func TestDADDisabled(t *testing.T) { default: t.Fatal("expected DAD event") } - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(%d, %d) err = %s", nicID, header.IPv6ProtocolNumber, err) - } - if addr.Address != addr1 { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { + t.Fatal(err) } // We should not have sent any NDP NS messages. @@ -440,31 +442,31 @@ func TestDADResolve(t *testing.T) { NIC: nicID, }}) - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: addr1, + PrefixLen: defaultPrefixLen, + } + if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) } // Address should not be considered bound to the NIC yet (DAD ongoing). - if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } else if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Make sure the address does not resolve before the resolution time has // passed. time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncNegativeEventTimeout) - if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } else if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Error(err) } // Should not get a route even if we specify the local address as the // tentative address. { r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false) - if err != tcpip.ErrNoRoute { - t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, &tcpip.ErrNoRoute{}) } if r != nil { r.Release() @@ -472,8 +474,8 @@ func TestDADResolve(t *testing.T) { } { r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false) - if err != tcpip.ErrNoRoute { - t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, &tcpip.ErrNoRoute{}) } if r != nil { r.Release() @@ -493,10 +495,8 @@ func TestDADResolve(t *testing.T) { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } else if addr.Address != addr1 { - t.Errorf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { + t.Error(err) } // Should get a route using the address now that it is resolved. { @@ -662,12 +662,8 @@ func TestDADFail(t *testing.T) { // Address should not be considered bound to the NIC yet // (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Receive a packet to simulate an address conflict. @@ -691,12 +687,8 @@ func TestDADFail(t *testing.T) { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Attempting to add the address again should not fail if the address's @@ -777,12 +769,8 @@ func TestDADStop(t *testing.T) { } // Address should not be considered bound to the NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } test.stopFn(t, s) @@ -800,12 +788,8 @@ func TestDADStop(t *testing.T) { } if !test.skipFinalAddrCheck { - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } } @@ -901,26 +885,25 @@ func TestSetNDPConfigurations(t *testing.T) { } // Add addresses for each NIC. - if err := s.AddAddress(nicID1, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addr1, err) + addrWithPrefix1 := tcpip.AddressWithPrefix{Address: addr1, PrefixLen: defaultPrefixLen} + if err := s.AddAddressWithPrefix(nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addrWithPrefix1, err) } - if err := s.AddAddress(nicID2, header.IPv6ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addr2, err) + addrWithPrefix2 := tcpip.AddressWithPrefix{Address: addr2, PrefixLen: defaultPrefixLen} + if err := s.AddAddressWithPrefix(nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addrWithPrefix2, err) } expectDADEvent(nicID2, addr2) - if err := s.AddAddress(nicID3, header.IPv6ProtocolNumber, addr3); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addr3, err) + addrWithPrefix3 := tcpip.AddressWithPrefix{Address: addr3, PrefixLen: defaultPrefixLen} + if err := s.AddAddressWithPrefix(nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addrWithPrefix3, err) } expectDADEvent(nicID3, addr3) // Address should not be considered bound to NIC(1) yet // (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Should get the address on NIC(2) and NIC(3) @@ -928,31 +911,19 @@ func TestSetNDPConfigurations(t *testing.T) { // it as the stack was configured to not do DAD by // default and we only updated the NDP configurations on // NIC(1). - addr, err = s.GetMainNICAddress(nicID2, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID2, header.IPv6ProtocolNumber, err) - } - if addr.Address != addr2 { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID2, header.IPv6ProtocolNumber, addr, addr2) + if err := checkGetMainNICAddress(s, nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil { + t.Fatal(err) } - addr, err = s.GetMainNICAddress(nicID3, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID3, header.IPv6ProtocolNumber, err) - } - if addr.Address != addr3 { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID3, header.IPv6ProtocolNumber, addr, addr3) + if err := checkGetMainNICAddress(s, nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil { + t.Fatal(err) } // Sleep until right (500ms before) before resolution to // make sure the address didn't resolve on NIC(1) yet. const delta = 500 * time.Millisecond time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) - addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Wait for DAD to resolve. @@ -970,12 +941,8 @@ func TestSetNDPConfigurations(t *testing.T) { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err) - } - if addr.Address != addr1 { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID1, header.IPv6ProtocolNumber, addr, addr1) + if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { + t.Fatal(err) } }) } @@ -2808,6 +2775,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -2827,10 +2795,15 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN Gateway: llAddr3, NIC: nicID, }}) + if useNeighborCache { - s.AddStaticNeighbor(nicID, llAddr3, linkAddr3) + if err := s.AddStaticNeighbor(nicID, llAddr3, linkAddr3); err != nil { + t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, llAddr3, linkAddr3, err) + } } else { - s.AddLinkAddress(nicID, llAddr3, linkAddr3) + if err := s.AddLinkAddress(nicID, llAddr3, linkAddr3); err != nil { + t.Fatalf("s.AddLinkAddress(%d, %s, %s): %s", nicID, llAddr3, linkAddr3, err) + } } return ndpDisp, e, s } @@ -2940,10 +2913,8 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { t.Helper() - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if got != addr { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatal(err) } if got := addrForNewConnection(t, s); got != addr.Address { @@ -3088,10 +3059,8 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { t.Helper() - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if got != addr { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatal(err) } if got := addrForNewConnection(t, s); got != addr.Address { @@ -3238,10 +3207,8 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { t.Fatalf("should not have %s in the list of addresses", addr2) } // Should not have any primary endpoints. - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if want := (tcpip.AddressWithPrefix{}); got != want { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } wq := waiter.Queue{} we, ch := waiter.NewChannelEntry(nil) @@ -3255,8 +3222,11 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { defer ep.Close() ep.SocketOptions().SetV6Only(true) - if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute { - t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, tcpip.ErrNoRoute) + { + err := ep.Connect(dstAddr) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, &tcpip.ErrNoRoute{}) + } } }) } @@ -3615,10 +3585,8 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) { expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { t.Helper() - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if got != addr { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatal(err) } if got := addrForNewConnection(t, s); got != addr.Address { diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index acee72572..88a3ff776 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -126,7 +126,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA // packet prompting NUD/link address resolution. // // TODO(gvisor.dev/issue/5151): Don't return the neighbor entry. -func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (NeighborEntry, <-chan struct{}, *tcpip.Error) { +func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (NeighborEntry, <-chan struct{}, tcpip.Error) { entry := n.getOrCreateEntry(remoteAddr, linkRes) entry.mu.Lock() defer entry.mu.Unlock() @@ -142,7 +142,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA // a node continues sending packets to that neighbor using the cached // link-layer address." if onResolve != nil { - onResolve(entry.neigh.LinkAddr, true) + onResolve(LinkResolutionResult{LinkAddress: entry.neigh.LinkAddr, Success: true}) } return entry.neigh, nil, nil case Unknown, Incomplete, Failed: @@ -154,7 +154,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA entry.done = make(chan struct{}) } entry.handlePacketQueuedLocked(localAddr) - return entry.neigh, entry.done, tcpip.ErrWouldBlock + return entry.neigh, entry.done, &tcpip.ErrWouldBlock{} default: panic(fmt.Sprintf("Invalid cache entry state: %s", s)) } @@ -297,10 +297,9 @@ func (n *neighborCache) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.Li // no matching entry for the remote address. } -// HandleUpperLevelConfirmation implements -// NUDHandler.HandleUpperLevelConfirmation by following the logic defined in -// RFC 4861 section 7.3.1. -func (n *neighborCache) HandleUpperLevelConfirmation(addr tcpip.Address) { +// handleUpperLevelConfirmation processes a confirmation of reachablity from +// some protocol that operates at a layer above the IP/link layer. +func (n *neighborCache) handleUpperLevelConfirmation(addr tcpip.Address) { n.mu.RLock() entry, ok := n.cache[addr] n.mu.RUnlock() diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index b96a56612..2870e4f66 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -194,7 +194,7 @@ type testNeighborResolver struct { var _ LinkAddressResolver = (*testNeighborResolver)(nil) -func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { +func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { if !r.dropReplies { // Delay handling the request to emulate network latency. r.clock.AfterFunc(r.delay, func() { @@ -251,8 +251,8 @@ func TestNeighborCacheGetConfig(t *testing.T) { // No events should have been dispatched. nudDisp.mu.Lock() defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -273,8 +273,8 @@ func TestNeighborCacheSetConfig(t *testing.T) { // No events should have been dispatched. nudDisp.mu.Lock() defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -295,8 +295,9 @@ func TestNeighborCacheEntry(t *testing.T) { if !ok { t.Fatal("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) @@ -321,11 +322,11 @@ func TestNeighborCacheEntry(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.events = nil nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil { @@ -335,8 +336,8 @@ func TestNeighborCacheEntry(t *testing.T) { // No more events should have been dispatched. nudDisp.mu.Lock() defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -359,8 +360,9 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { t.Fatal("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) @@ -385,11 +387,11 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.events = nil nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } neigh.removeEntry(entry.Addr) @@ -407,15 +409,18 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + { + _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + } } } @@ -462,8 +467,9 @@ func (c *testContext) overflowCache(opts overflowOptions) error { if !ok { return fmt.Errorf("c.store.entry(%d) not found", i) } - if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - return fmt.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + return fmt.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } c.clock.Advance(c.neigh.config().RetransmitTimer) @@ -506,11 +512,11 @@ func (c *testContext) overflowCache(opts overflowOptions) error { }) c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -531,15 +537,15 @@ func (c *testContext) overflowCache(opts overflowOptions) error { wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } - if diff := cmp.Diff(c.neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { - return fmt.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantUnsortedEntries, c.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + return fmt.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } // No more events should have been dispatched. c.nudDisp.mu.Lock() defer c.nudDisp.mu.Unlock() - if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), c.nudDisp.events); diff != "" { + return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } return nil @@ -579,8 +585,9 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { if !ok { t.Fatal("c.store.entry(0) not found") } - if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } c.clock.Advance(c.neigh.config().RetransmitTimer) wantEvents := []testEntryEventInfo{ @@ -603,11 +610,11 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Remove the entry @@ -626,11 +633,11 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -668,11 +675,11 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Remove the static entry that was just added @@ -681,8 +688,8 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { // No more events should have been dispatched. c.nudDisp.mu.Lock() defer c.nudDisp.mu.Unlock() - if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), c.nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -712,11 +719,11 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Add a duplicate entry with a different link address @@ -736,8 +743,8 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) } c.nudDisp.mu.Lock() defer c.nudDisp.mu.Unlock() - if diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } } @@ -774,11 +781,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Remove the static entry that was just added @@ -796,11 +803,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -830,8 +837,9 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { if !ok { t.Fatal("c.store.entry(0) not found") } - if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ @@ -854,11 +862,11 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Override the entry with a static one using the same address @@ -886,11 +894,11 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -932,8 +940,8 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { LinkAddr: entry.LinkAddr, State: Static, } - if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("c.neigh.entry(%s, \"\", _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { + t.Errorf("c.neigh.entry(%s, \"\", _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } wantEvents := []testEntryEventInfo{ @@ -948,11 +956,11 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } opts := overflowOptions{ @@ -989,8 +997,9 @@ func TestNeighborCacheClear(t *testing.T) { if !ok { t.Fatal("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) @@ -1014,11 +1023,11 @@ func TestNeighborCacheClear(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.events = nil nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Add a static entry. @@ -1037,11 +1046,11 @@ func TestNeighborCacheClear(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.events = nil nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -1072,8 +1081,8 @@ func TestNeighborCacheClear(t *testing.T) { } nudDisp.mu.Lock() defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, wantUnsortedEvents, eventDiffOptsWithSort()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantUnsortedEvents, nudDisp.events, eventDiffOptsWithSort()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -1094,8 +1103,9 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { if !ok { t.Fatal("c.store.entry(0) not found") } - if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ @@ -1118,11 +1128,11 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Clear the cache. @@ -1140,11 +1150,11 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -1188,16 +1198,13 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if !ok { - t.Fatal("expected successful address resolution") - } - if linkAddr != entry.LinkAddr { - t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { @@ -1225,11 +1232,11 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.events = nil nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -1247,16 +1254,13 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { t.Fatalf("store.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if !ok { - t.Fatal("expected successful address resolution") - } - if linkAddr != entry.LinkAddr { - t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { @@ -1299,11 +1303,11 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.events = nil nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -1331,15 +1335,15 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } - if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { - t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantUnsortedEntries, neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } // No more events should have been dispatched. nudDisp.mu.Lock() defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -1366,8 +1370,10 @@ func TestNeighborCacheConcurrent(t *testing.T) { wg.Add(1) go func(entry NeighborEntry) { defer wg.Done() - if e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil && err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock) + switch e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err.(type) { + case nil, *tcpip.ErrWouldBlock: + default: + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, &tcpip.ErrWouldBlock{}) } }(entry) } @@ -1398,8 +1404,8 @@ func TestNeighborCacheConcurrent(t *testing.T) { wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } - if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { - t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantUnsortedEntries, neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } } @@ -1423,16 +1429,13 @@ func TestNeighborCacheReplace(t *testing.T) { t.Fatal("store.entry(0) not found") } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if !ok { - t.Fatal("expected successful address resolution") - } - if linkAddr != entry.LinkAddr { - t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { @@ -1455,8 +1458,8 @@ func TestNeighborCacheReplace(t *testing.T) { LinkAddr: entry.LinkAddr, State: Reachable, } - if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, '', _, _, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } } @@ -1489,8 +1492,8 @@ func TestNeighborCacheReplace(t *testing.T) { LinkAddr: updatedLinkAddr, State: Delay, } - if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } clock.Advance(config.DelayFirstProbeTime + typicalLatency) } @@ -1507,8 +1510,8 @@ func TestNeighborCacheReplace(t *testing.T) { LinkAddr: updatedLinkAddr, State: Reachable, } - if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } } } @@ -1539,16 +1542,13 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { // First, sanity check that resolution is working { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if !ok { - t.Fatal("expected successful address resolution") - } - if linkAddr != entry.LinkAddr { - t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { @@ -1567,8 +1567,8 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { LinkAddr: entry.LinkAddr, State: Reachable, } - if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + if diff := cmp.Diff(want, got, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } // Verify address resolution fails for an unknown address. @@ -1576,19 +1576,13 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { entry.Addr += "2" { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if ok { - t.Error("expected unsuccessful address resolution") - } - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) - } - if t.Failed() { - t.FailNow() + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) @@ -1627,19 +1621,13 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { t.Fatal("store.entry(0) not found") } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if ok { - t.Error("expected unsuccessful address resolution") - } - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) - } - if t.Failed() { - t.FailNow() + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) @@ -1674,19 +1662,13 @@ func TestNeighborCacheRetryResolution(t *testing.T) { // Perform address resolution with a faulty link, which will fail. { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if ok { - t.Error("expected unsuccessful address resolution") - } - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) - } - if t.Failed() { - t.FailNow() + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) @@ -1713,13 +1695,13 @@ func TestNeighborCacheRetryResolution(t *testing.T) { // Retry address resolution with a working link. linkRes.dropReplies = false { - incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if linkAddr != entry.LinkAddr { - t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } if incompleteEntry.State != Incomplete { t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete) @@ -1772,16 +1754,13 @@ func BenchmarkCacheClear(b *testing.B) { b.Fatalf("store.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if !ok { - b.Fatal("expected successful address resolution") - } - if linkAddr != entry.LinkAddr { - b.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + b.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - b.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + b.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } select { diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 75afb3001..53ac9bb6e 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -96,7 +96,7 @@ type neighborEntry struct { done chan struct{} // onResolve is called with the result of address resolution. - onResolve []func(tcpip.LinkAddress, bool) + onResolve []func(LinkResolutionResult) isRouter bool job *tcpip.Job @@ -143,13 +143,22 @@ func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAdd // // Precondition: e.mu MUST be locked. func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { + res := LinkResolutionResult{LinkAddress: e.neigh.LinkAddr, Success: succeeded} for _, callback := range e.onResolve { - callback(e.neigh.LinkAddr, succeeded) + callback(res) } e.onResolve = nil if ch := e.done; ch != nil { close(ch) e.done = nil + // Dequeue the pending packets in a new goroutine to not hold up the current + // goroutine as writing packets may be a costly operation. + // + // At the time of writing, when writing packets, a neighbor's link address + // is resolved (which ends up obtaining the entry's lock) while holding the + // link resolution queue's lock. Dequeuing packets in a new goroutine avoids + // a lock ordering violation. + go e.nic.linkResQueue.dequeue(ch, e.neigh.LinkAddr, succeeded) } } diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index ec34ffa5a..140b8ca00 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -193,7 +193,7 @@ func (p entryTestProbeInfo) String() string { // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts // to the local network if linkAddr is the zero value. -func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { +func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { p := entryTestProbeInfo{ RemoteAddress: targetAddr, RemoteLinkAddress: linkAddr, @@ -266,16 +266,16 @@ func TestEntryInitiallyUnknown(t *testing.T) { // No probes should have been sent. linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } // No events should have been dispatched. nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -299,16 +299,16 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { // No probes should have been sent. linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } // No events should have been dispatched. nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -333,10 +333,10 @@ func TestEntryUnknownToIncomplete(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } wantEvents := []testEntryEventInfo{ @@ -352,10 +352,10 @@ func TestEntryUnknownToIncomplete(t *testing.T) { } { nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } } @@ -374,10 +374,10 @@ func TestEntryUnknownToStale(t *testing.T) { // No probes should have been sent. runImmediatelyScheduledJobs(clock) linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } wantEvents := []testEntryEventInfo{ @@ -392,8 +392,8 @@ func TestEntryUnknownToStale(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -427,11 +427,11 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -453,10 +453,10 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -483,8 +483,8 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() @@ -515,10 +515,10 @@ func TestEntryIncompleteToReachable(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -553,8 +553,8 @@ func TestEntryIncompleteToReachable(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -579,10 +579,10 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -620,8 +620,8 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -646,10 +646,10 @@ func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -684,8 +684,8 @@ func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -710,10 +710,10 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -744,8 +744,8 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -785,10 +785,10 @@ func TestEntryIncompleteToFailed(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } wantEvents := []testEntryEventInfo{ @@ -812,8 +812,8 @@ func TestEntryIncompleteToFailed(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() @@ -850,10 +850,10 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -903,8 +903,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() @@ -932,10 +932,10 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -977,8 +977,8 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1005,10 +1005,10 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1054,8 +1054,8 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() @@ -1083,10 +1083,10 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1134,8 +1134,8 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1157,10 +1157,10 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1212,8 +1212,8 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1235,10 +1235,10 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1290,8 +1290,8 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1313,10 +1313,10 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1358,8 +1358,8 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1381,10 +1381,10 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1439,8 +1439,8 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1462,10 +1462,10 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1520,8 +1520,8 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1543,10 +1543,10 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1601,8 +1601,8 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1624,10 +1624,10 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1678,8 +1678,8 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1701,10 +1701,10 @@ func TestEntryStaleToDelay(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1752,8 +1752,8 @@ func TestEntryStaleToDelay(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1780,10 +1780,10 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1851,8 +1851,8 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1880,10 +1880,10 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1958,8 +1958,8 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1987,10 +1987,10 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -2065,8 +2065,8 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2088,10 +2088,10 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -2147,8 +2147,8 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2170,10 +2170,10 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -2231,8 +2231,8 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2254,10 +2254,10 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -2319,8 +2319,8 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2343,11 +2343,11 @@ func TestEntryDelayToProbe(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2372,10 +2372,10 @@ func TestEntryDelayToProbe(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2418,8 +2418,8 @@ func TestEntryDelayToProbe(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() @@ -2448,11 +2448,11 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2474,10 +2474,10 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2539,8 +2539,8 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2563,11 +2563,11 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2589,10 +2589,10 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2658,8 +2658,8 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2682,11 +2682,11 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2709,10 +2709,10 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2772,8 +2772,8 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2806,10 +2806,10 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -2878,8 +2878,8 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2907,11 +2907,11 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2933,10 +2933,10 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -3015,8 +3015,8 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -3044,11 +3044,11 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -3070,10 +3070,10 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -3149,8 +3149,8 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -3178,11 +3178,11 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -3204,10 +3204,10 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -3283,8 +3283,8 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -3309,11 +3309,11 @@ func TestEntryProbeToFailed(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -3336,11 +3336,11 @@ func TestEntryProbeToFailed(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probe #%d mismatch (-got, +want):\n%s", i+1, diff) + t.Fatalf("link address resolver probe #%d mismatch (-want, +got):\n%s", i+1, diff) } e.mu.Lock() @@ -3406,8 +3406,8 @@ func TestEntryProbeToFailed(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -3449,10 +3449,10 @@ func TestEntryFailedToIncomplete(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -3498,8 +3498,8 @@ func TestEntryFailedToIncomplete(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 0f545f255..e56a624fe 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -53,6 +53,8 @@ type NIC struct { // complete. linkResQueue packetsPendingLinkResolution + linkAddrCache *linkAddrCache + mu struct { sync.RWMutex spoofing bool @@ -138,7 +140,8 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC stats: makeNICStats(), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), } - nic.linkResQueue.init() + nic.linkResQueue.init(nic) + nic.linkAddrCache = newLinkAddrCache(nic, ageLimit, resolutionTimeout, resolutionAttempts) nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) // Check for Neighbor Unreachability Detection support. @@ -167,7 +170,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC for _, netProto := range stack.networkProtocols { netNum := netProto.Number() nic.mu.packetEPs[netNum] = new(packetEndpointList) - nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic) + nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, nic.linkAddrCache, nud, nic) } nic.LinkEndpoint.Attach(nic) @@ -228,7 +231,9 @@ func (n *NIC) disableLocked() { // // This matches linux's behaviour at the time of writing: // https://github.com/torvalds/linux/blob/71c061d2443814de15e177489d5cc00a4a253ef3/net/core/neighbour.c#L371 - if err := n.clearNeighbors(); err != nil && err != tcpip.ErrNotSupported { + switch err := n.clearNeighbors(); err.(type) { + case nil, *tcpip.ErrNotSupported: + default: panic(fmt.Sprintf("n.clearNeighbors(): %s", err)) } @@ -243,7 +248,7 @@ func (n *NIC) disableLocked() { // address (ff02::1), start DAD for permanent addresses, and start soliciting // routers if the stack is not operating as a router. If the stack is also // configured to auto-generate a link-local address, one will be generated. -func (n *NIC) enable() *tcpip.Error { +func (n *NIC) enable() tcpip.Error { n.mu.Lock() defer n.mu.Unlock() @@ -263,7 +268,7 @@ func (n *NIC) enable() *tcpip.Error { // remove detaches NIC from the link endpoint and releases network endpoint // resources. This guarantees no packets between this NIC and the network // stack. -func (n *NIC) remove() *tcpip.Error { +func (n *NIC) remove() tcpip.Error { n.mu.Lock() defer n.mu.Unlock() @@ -299,48 +304,69 @@ func (n *NIC) IsLoopback() bool { } // WritePacket implements NetworkLinkEndpoint. -func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { - // As per relevant RFCs, we should queue packets while we wait for link - // resolution to complete. - // - // RFC 1122 section 2.3.2.2 (for IPv4): - // The link layer SHOULD save (rather than discard) at least - // one (the latest) packet of each set of packets destined to - // the same unresolved IP address, and transmit the saved - // packet when the address has been resolved. - // - // RFC 4861 section 7.2.2 (for IPv6): - // While waiting for address resolution to complete, the sender MUST, for - // each neighbor, retain a small queue of packets waiting for address - // resolution to complete. The queue MUST hold at least one packet, and MAY - // contain more. However, the number of queued packets per neighbor SHOULD - // be limited to some small value. When a queue overflows, the new arrival - // SHOULD replace the oldest entry. Once address resolution completes, the - // node transmits any queued packets. - if ch, err := r.Resolve(nil); err != nil { - if err == tcpip.ErrWouldBlock { - r.Acquire() - n.linkResQueue.enqueue(ch, r, protocol, pkt) - return nil +func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { + _, err := n.enqueuePacketBuffer(r, gso, protocol, pkt) + return err +} + +func (n *NIC) writePacketBuffer(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { + switch pkt := pkt.(type) { + case *PacketBuffer: + if err := n.writePacket(r, gso, protocol, pkt); err != nil { + return 0, err } - return err + return 1, nil + case *PacketBufferList: + return n.writePackets(r, gso, protocol, *pkt) + default: + panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", pkt)) } +} - return n.writePacket(r.Fields(), gso, protocol, pkt) +func (n *NIC) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { + routeInfo, _, err := r.resolvedFields(nil) + switch err.(type) { + case nil: + return n.writePacketBuffer(routeInfo, gso, protocol, pkt) + case *tcpip.ErrWouldBlock: + // As per relevant RFCs, we should queue packets while we wait for link + // resolution to complete. + // + // RFC 1122 section 2.3.2.2 (for IPv4): + // The link layer SHOULD save (rather than discard) at least + // one (the latest) packet of each set of packets destined to + // the same unresolved IP address, and transmit the saved + // packet when the address has been resolved. + // + // RFC 4861 section 7.2.2 (for IPv6): + // While waiting for address resolution to complete, the sender MUST, for + // each neighbor, retain a small queue of packets waiting for address + // resolution to complete. The queue MUST hold at least one packet, and + // MAY contain more. However, the number of queued packets per neighbor + // SHOULD be limited to some small value. When a queue overflows, the new + // arrival SHOULD replace the oldest entry. Once address resolution + // completes, the node transmits any queued packets. + return n.linkResQueue.enqueue(r, gso, protocol, pkt) + default: + return 0, err + } } // WritePacketToRemote implements NetworkInterface. -func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { +func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { var r RouteInfo r.NetProto = protocol r.RemoteLinkAddress = remoteLinkAddr return n.writePacket(r, gso, protocol, pkt) } -func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { +func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { // WritePacket takes ownership of pkt, calculate numBytes first. numBytes := pkt.Size() + pkt.EgressRoute = r + pkt.GSOOptions = gso + pkt.NetworkProtocolNumber = protocol if err := n.LinkEndpoint.WritePacket(r, gso, protocol, pkt); err != nil { return err } @@ -351,10 +377,18 @@ func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolN } // WritePackets implements NetworkLinkEndpoint. -func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - // TODO(gvisor.dev/issue/4458): Queue packets whie link address resolution - // is being peformed like WritePacket. - writtenPackets, err := n.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol) +func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { + return n.enqueuePacketBuffer(r, gso, protocol, &pkts) +} + +func (n *NIC) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, tcpip.Error) { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + pkt.EgressRoute = r + pkt.GSOOptions = gso + pkt.NetworkProtocolNumber = protocol + } + + writtenPackets, err := n.LinkEndpoint.WritePackets(r, gso, pkts, protocol) n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets)) writtenBytes := 0 for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() { @@ -463,15 +497,15 @@ func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, // addAddress adds a new address to n, so that it starts accepting packets // targeted at the given address (and network protocol). -func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { +func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error { ep, ok := n.networkEndpoints[protocolAddress.Protocol] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } addressableEndpoint, ok := ep.(AddressableEndpoint) if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */) @@ -535,63 +569,70 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit } // removeAddress removes an address from n. -func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error { +func (n *NIC) removeAddress(addr tcpip.Address) tcpip.Error { for _, ep := range n.networkEndpoints { addressableEndpoint, ok := ep.(AddressableEndpoint) if !ok { continue } - if err := addressableEndpoint.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress { + switch err := addressableEndpoint.RemovePermanentAddress(addr); err.(type) { + case *tcpip.ErrBadLocalAddress: continue - } else { + default: return err } } - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} +} + +func (n *NIC) confirmReachable(addr tcpip.Address) { + if n := n.neigh; n != nil { + n.handleUpperLevelConfirmation(addr) + } } -func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { if n.neigh != nil { entry, ch, err := n.neigh.entry(addr, localAddr, linkRes, onResolve) return entry.LinkAddr, ch, err } - return n.stack.linkAddrCache.get(tcpip.FullAddress{NIC: n.ID(), Addr: addr}, linkRes, localAddr, n, onResolve) + return n.linkAddrCache.get(addr, linkRes, localAddr, n, onResolve) } -func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) { +func (n *NIC) neighbors() ([]NeighborEntry, tcpip.Error) { if n.neigh == nil { - return nil, tcpip.ErrNotSupported + return nil, &tcpip.ErrNotSupported{} } return n.neigh.entries(), nil } -func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) *tcpip.Error { +func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) tcpip.Error { if n.neigh == nil { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } n.neigh.addStaticEntry(addr, linkAddress) return nil } -func (n *NIC) removeNeighbor(addr tcpip.Address) *tcpip.Error { +func (n *NIC) removeNeighbor(addr tcpip.Address) tcpip.Error { if n.neigh == nil { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } if !n.neigh.removeEntry(addr) { - return tcpip.ErrBadAddress + return &tcpip.ErrBadAddress{} } return nil } -func (n *NIC) clearNeighbors() *tcpip.Error { +func (n *NIC) clearNeighbors() tcpip.Error { if n.neigh == nil { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } n.neigh.clear() @@ -600,7 +641,7 @@ func (n *NIC) clearNeighbors() *tcpip.Error { // joinGroup adds a new endpoint for the given multicast address, if none // exists yet. Otherwise it just increments its count. -func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { +func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { // TODO(b/143102137): When implementing MLD, make sure MLD packets are // not sent unless a valid link-local address is available for use on n // as an MLD packet's source address must be a link-local address as @@ -608,12 +649,12 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address ep, ok := n.networkEndpoints[protocol] if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } gep, ok := ep.(GroupAddressableEndpoint) if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } return gep.JoinGroup(addr) @@ -621,15 +662,15 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address // leaveGroup decrements the count for the given multicast address, and when it // reaches zero removes the endpoint for this address. -func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { +func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { ep, ok := n.networkEndpoints[protocol] if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } gep, ok := ep.(GroupAddressableEndpoint) if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } return gep.LeaveGroup(addr) @@ -879,9 +920,9 @@ func (n *NIC) Name() string { } // nudConfigs gets the NUD configurations for n. -func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) { +func (n *NIC) nudConfigs() (NUDConfigurations, tcpip.Error) { if n.neigh == nil { - return NUDConfigurations{}, tcpip.ErrNotSupported + return NUDConfigurations{}, &tcpip.ErrNotSupported{} } return n.neigh.config(), nil } @@ -890,22 +931,22 @@ func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) { // // Note, if c contains invalid NUD configuration values, it will be fixed to // use default values for the erroneous values. -func (n *NIC) setNUDConfigs(c NUDConfigurations) *tcpip.Error { +func (n *NIC) setNUDConfigs(c NUDConfigurations) tcpip.Error { if n.neigh == nil { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } c.resetInvalidFields() n.neigh.setConfig(c) return nil } -func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error { +func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error { n.mu.Lock() defer n.mu.Unlock() eps, ok := n.mu.packetEPs[netProto] if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } eps.add(ep) diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 5b5c58afb..2f719fbe5 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -39,7 +39,7 @@ type testIPv6Endpoint struct { invalidatedRtr tcpip.Address } -func (*testIPv6Endpoint) Enable() *tcpip.Error { +func (*testIPv6Endpoint) Enable() tcpip.Error { return nil } @@ -65,21 +65,21 @@ func (e *testIPv6Endpoint) MaxHeaderLength() uint16 { } // WritePacket implements NetworkEndpoint.WritePacket. -func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) *tcpip.Error { +func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) tcpip.Error { return nil } // WritePackets implements NetworkEndpoint.WritePackets. -func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, *tcpip.Error) { +func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) { // Our tests don't use this so we don't support it. - return 0, tcpip.ErrNotSupported + return 0, &tcpip.ErrNotSupported{} } // WriteHeaderIncludedPacket implements // NetworkEndpoint.WriteHeaderIncludedPacket. -func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip.Error { +func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) tcpip.Error { // Our tests don't use this so we don't support it. - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } // HandlePacket implements NetworkEndpoint.HandlePacket. @@ -99,11 +99,20 @@ func (e *testIPv6Endpoint) InvalidateDefaultRouter(rtr tcpip.Address) { e.invalidatedRtr = rtr } -var _ NetworkProtocol = (*testIPv6Protocol)(nil) +// Stats implements NetworkEndpoint. +func (*testIPv6Endpoint) Stats() NetworkEndpointStats { + return &testIPv6EndpointStats{} +} + +var _ NetworkEndpointStats = (*testIPv6EndpointStats)(nil) + +type testIPv6EndpointStats struct{} + +// IsNetworkEndpointStats implements stack.NetworkEndpointStats. +func (*testIPv6EndpointStats) IsNetworkEndpointStats() {} + +var _ LinkAddressResolver = (*testIPv6Protocol)(nil) -// An IPv6 NetworkProtocol that supports the bare minimum to make a stack -// believe it supports IPv6. -// // We use this instead of ipv6.protocol because the ipv6 package depends on // the stack package which this test lives in, causing a cyclic dependency. type testIPv6Protocol struct{} @@ -140,12 +149,12 @@ func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, } // SetOption implements NetworkProtocol.SetOption. -func (*testIPv6Protocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error { +func (*testIPv6Protocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error { return nil } // Option implements NetworkProtocol.Option. -func (*testIPv6Protocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error { +func (*testIPv6Protocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error { return nil } @@ -160,15 +169,13 @@ func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bo return 0, false, false } -var _ LinkAddressResolver = (*testIPv6Protocol)(nil) - // LinkAddressProtocol implements LinkAddressResolver. func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return header.IPv6ProtocolNumber } // LinkAddressRequest implements LinkAddressResolver. -func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { +func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error { return nil } diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index 12d67409a..77926e289 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -174,10 +174,6 @@ type NUDHandler interface { // HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP // reply or Neighbor Advertisement for ARP or NDP, respectively). HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) - - // HandleUpperLevelConfirmation processes an incoming upper-level protocol - // (e.g. TCP acknowledgements) reachability confirmation. - HandleUpperLevelConfirmation(addr tcpip.Address) } // NUDConfigurations is the NUD configurations for the netstack. This is used diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index 7bca1373e..ebfd5eb45 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -65,8 +65,9 @@ func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) { // No NIC with ID 1 yet. config := stack.NUDConfigurations{} - if err := s.SetNUDConfigurations(1, config); err != tcpip.ErrUnknownNICID { - t.Fatalf("got s.SetNDPConfigurations(1, %+v) = %v, want = %s", config, err, tcpip.ErrUnknownNICID) + err := s.SetNUDConfigurations(1, config) + if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { + t.Fatalf("got s.SetNDPConfigurations(1, %+v) = %v, want = %s", config, err, &tcpip.ErrUnknownNICID{}) } } @@ -90,8 +91,9 @@ func TestNUDConfigurationFailsForNotSupported(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if _, err := s.NUDConfigurations(nicID); err != tcpip.ErrNotSupported { - t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, tcpip.ErrNotSupported) + _, err := s.NUDConfigurations(nicID) + if _, ok := err.(*tcpip.ErrNotSupported); !ok { + t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, &tcpip.ErrNotSupported{}) } } @@ -117,8 +119,9 @@ func TestSetNUDConfigurationFailsForNotSupported(t *testing.T) { } config := stack.NUDConfigurations{} - if err := s.SetNUDConfigurations(nicID, config); err != tcpip.ErrNotSupported { - t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, tcpip.ErrNotSupported) + err := s.SetNUDConfigurations(nicID, config) + if _, ok := err.(*tcpip.ErrNotSupported); !ok { + t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, &tcpip.ErrNotSupported{}) } } diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index 41529ffd5..1c651e216 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -28,108 +28,205 @@ const ( maxPendingPacketsPerResolution = 256 ) +// pendingPacketBuffer is a pending packet buffer. +// +// TODO(gvisor.dev/issue/5331): Drop this when we drop WritePacket and only use +// WritePackets so we can use a PacketBufferList everywhere. +type pendingPacketBuffer interface { + len() int +} + +func (*PacketBuffer) len() int { + return 1 +} + +func (p *PacketBufferList) len() int { + return p.Len() +} + type pendingPacket struct { - route *Route - proto tcpip.NetworkProtocolNumber - pkt *PacketBuffer + routeInfo RouteInfo + gso *GSO + proto tcpip.NetworkProtocolNumber + pkt pendingPacketBuffer } // packetsPendingLinkResolution is a queue of packets pending link resolution. // // Once link resolution completes successfully, the packets will be written. type packetsPendingLinkResolution struct { - sync.Mutex + nic *NIC - // The packets to send once the resolver completes. - packets map[<-chan struct{}][]pendingPacket + mu struct { + sync.Mutex - // FIFO of channels used to cancel the oldest goroutine waiting for - // link-address resolution. - cancelChans []chan struct{} -} + // The packets to send once the resolver completes. + // + // The link resolution channel is used as the key for this map. + packets map[<-chan struct{}][]pendingPacket -func (f *packetsPendingLinkResolution) init() { - f.Lock() - defer f.Unlock() - f.packets = make(map[<-chan struct{}][]pendingPacket) + // FIFO of channels used to cancel the oldest goroutine waiting for + // link-address resolution. + // + // cancelChans holds the same channels that are used as keys to packets. + cancelChans []<-chan struct{} + } } -func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, proto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - f.Lock() - defer f.Unlock() +func (f *packetsPendingLinkResolution) incrementOutgoingPacketErrors(proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) { + n := uint64(pkt.len()) + f.nic.stack.stats.IP.OutgoingPacketErrors.IncrementBy(n) - packets, ok := f.packets[ch] - if len(packets) == maxPendingPacketsPerResolution { - p := packets[0] - packets[0] = pendingPacket{} - packets = packets[1:] - p.route.Stats().IP.OutgoingPacketErrors.Increment() - p.route.Release() + if ipEndpointStats, ok := f.nic.getNetworkEndpoint(proto).Stats().(IPNetworkEndpointStats); ok { + ipEndpointStats.IPStats().OutgoingPacketErrors.IncrementBy(n) } +} - if l := len(packets); l >= maxPendingPacketsPerResolution { - panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution)) +func (f *packetsPendingLinkResolution) init(nic *NIC) { + f.mu.Lock() + defer f.mu.Unlock() + f.nic = nic + f.mu.packets = make(map[<-chan struct{}][]pendingPacket) +} + +// dequeue any pending packets associated with ch. +// +// If success is true, packets will be written and sent to the given remote link +// address. +func (f *packetsPendingLinkResolution) dequeue(ch <-chan struct{}, linkAddr tcpip.LinkAddress, success bool) { + f.mu.Lock() + packets, ok := f.mu.packets[ch] + delete(f.mu.packets, ch) + + if ok { + for i, cancelChan := range f.mu.cancelChans { + if cancelChan == ch { + f.mu.cancelChans = append(f.mu.cancelChans[:i], f.mu.cancelChans[i+1:]...) + break + } + } } - f.packets[ch] = append(packets, pendingPacket{ - route: r, - proto: proto, - pkt: pkt, - }) + f.mu.Unlock() if ok { - return + f.dequeuePackets(packets, linkAddr, success) } +} - // Wait for the link-address resolution to complete. - cancel := f.newCancelChannelLocked() - go func() { - cancelled := false - select { - case <-ch: - case <-cancel: - cancelled = true - } +// enqueue a packet to be sent once link resolution completes. +// +// If the maximum number of pending resolutions is reached, the packets +// associated with the oldest link resolution will be dequeued as if they failed +// link resolution. +func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { + f.mu.Lock() + // Make sure we attempt resolution while holding f's lock so that we avoid + // a race where link resolution completes before we enqueue the packets. + // + // A @ T1: Call ResolvedFields (get link resolution channel) + // B @ T2: Complete link resolution, dequeue pending packets + // C @ T1: Enqueue packet that already completed link resolution (which will + // never dequeue) + // + // To make sure B does not interleave with A and C, we make sure A and C are + // done while holding the lock. + routeInfo, ch, err := r.resolvedFields(nil) + switch err.(type) { + case nil: + // The route resolved immediately, so we don't need to wait for link + // resolution to send the packet. + f.mu.Unlock() + return f.nic.writePacketBuffer(routeInfo, gso, proto, pkt) + case *tcpip.ErrWouldBlock: + // We need to wait for link resolution to complete. + default: + f.mu.Unlock() + return 0, err + } - f.Lock() - packets, ok := f.packets[ch] - delete(f.packets, ch) - f.Unlock() + defer f.mu.Unlock() - if !ok { - panic(fmt.Sprintf("link-resolution goroutine woke up but no entry exists in the queue of packets")) - } + packets, ok := f.mu.packets[ch] + packets = append(packets, pendingPacket{ + routeInfo: routeInfo, + gso: gso, + proto: proto, + pkt: pkt, + }) - for _, p := range packets { - if cancelled || p.route.IsResolutionRequired() { - p.route.Stats().IP.OutgoingPacketErrors.Increment() + if len(packets) > maxPendingPacketsPerResolution { + f.incrementOutgoingPacketErrors(packets[0].proto, packets[0].pkt) + packets[0] = pendingPacket{} + packets = packets[1:] - if linkResolvableEP, ok := p.route.outgoingNIC.getNetworkEndpoint(p.route.NetProto).(LinkResolvableNetworkEndpoint); ok { - linkResolvableEP.HandleLinkResolutionFailure(pkt) - } - } else { - p.route.outgoingNIC.writePacket(p.route.Fields(), nil /* gso */, p.proto, p.pkt) - } - p.route.Release() + if numPackets := len(packets); numPackets != maxPendingPacketsPerResolution { + panic(fmt.Sprintf("holding more queued packets than expected; got = %d, want <= %d", numPackets, maxPendingPacketsPerResolution)) } - }() + } + + f.mu.packets[ch] = packets + + if ok { + return pkt.len(), nil + } + + cancelledPackets := f.newCancelChannelLocked(ch) + + if len(cancelledPackets) != 0 { + // Dequeue the pending packets in a new goroutine to not hold up the current + // goroutine as handing link resolution failures may be a costly operation. + go f.dequeuePackets(cancelledPackets, "" /* linkAddr */, false /* success */) + } + + return pkt.len(), nil } -// newCancelChannel creates a channel that can cancel a pending forwarding -// activity. The oldest channel is closed if the number of open channels would -// exceed maxPendingResolutions. -func (f *packetsPendingLinkResolution) newCancelChannelLocked() chan struct{} { - if len(f.cancelChans) == maxPendingResolutions { - ch := f.cancelChans[0] - f.cancelChans[0] = nil - f.cancelChans = f.cancelChans[1:] - close(ch) +// newCancelChannelLocked appends the link resolution channel to a FIFO. If the +// maximum number of pending resolutions is reached, the oldest channel will be +// removed and its associated pending packets will be returned. +func (f *packetsPendingLinkResolution) newCancelChannelLocked(newCH <-chan struct{}) []pendingPacket { + f.mu.cancelChans = append(f.mu.cancelChans, newCH) + if len(f.mu.cancelChans) <= maxPendingResolutions { + return nil } - if l := len(f.cancelChans); l >= maxPendingResolutions { + + ch := f.mu.cancelChans[0] + f.mu.cancelChans[0] = nil + f.mu.cancelChans = f.mu.cancelChans[1:] + if l := len(f.mu.cancelChans); l > maxPendingResolutions { panic(fmt.Sprintf("max pending resolutions reached; got %d active resolutions, max = %d", l, maxPendingResolutions)) } - ch := make(chan struct{}) - f.cancelChans = append(f.cancelChans, ch) - return ch + packets, ok := f.mu.packets[ch] + if !ok { + panic("must have a packet queue for an uncancelled channel") + } + delete(f.mu.packets, ch) + + return packets +} + +func (f *packetsPendingLinkResolution) dequeuePackets(packets []pendingPacket, linkAddr tcpip.LinkAddress, success bool) { + for _, p := range packets { + if success { + p.routeInfo.RemoteLinkAddress = linkAddr + _, _ = f.nic.writePacketBuffer(p.routeInfo, p.gso, p.proto, p.pkt) + } else { + f.incrementOutgoingPacketErrors(p.proto, p.pkt) + + if linkResolvableEP, ok := f.nic.getNetworkEndpoint(p.proto).(LinkResolvableNetworkEndpoint); ok { + switch pkt := p.pkt.(type) { + case *PacketBuffer: + linkResolvableEP.HandleLinkResolutionFailure(pkt) + case *PacketBufferList: + for pb := pkt.Front(); pb != nil; pb = pb.Next() { + linkResolvableEP.HandleLinkResolutionFailure(pb) + } + default: + panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", p.pkt)) + } + } + } + } } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 68c113b6a..510da8689 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -172,10 +172,10 @@ type TransportProtocol interface { Number() tcpip.TransportProtocolNumber // NewEndpoint creates a new endpoint of the transport protocol. - NewEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + NewEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) // NewRawEndpoint creates a new raw endpoint of the transport protocol. - NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) // MinimumPacketSize returns the minimum valid packet size of this // transport protocol. The stack automatically drops any packets smaller @@ -184,7 +184,7 @@ type TransportProtocol interface { // ParsePorts returns the source and destination ports stored in a // packet of this protocol. - ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) + ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this // protocol that don't match any existing endpoint. For example, @@ -197,12 +197,12 @@ type TransportProtocol interface { // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the // provided option value is invalid. - SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error + SetOption(option tcpip.SettableTransportProtocolOption) tcpip.Error // Option allows retrieving protocol specific option values. // Option returns an error if the option is not supported or the // provided option value is invalid. - Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error + Option(option tcpip.GettableTransportProtocolOption) tcpip.Error // Close requests that any worker goroutines owned by the protocol // stop. @@ -289,10 +289,10 @@ type NetworkHeaderParams struct { // endpoints may associate themselves with the same identifier (group address). type GroupAddressableEndpoint interface { // JoinGroup joins the specified group. - JoinGroup(group tcpip.Address) *tcpip.Error + JoinGroup(group tcpip.Address) tcpip.Error // LeaveGroup attempts to leave the specified group. - LeaveGroup(group tcpip.Address) *tcpip.Error + LeaveGroup(group tcpip.Address) tcpip.Error // IsInGroup returns true if the endpoint is a member of the specified group. IsInGroup(group tcpip.Address) bool @@ -440,17 +440,17 @@ func (k AddressKind) IsPermanent() bool { type AddressableEndpoint interface { // AddAndAcquirePermanentAddress adds the passed permanent address. // - // Returns tcpip.ErrDuplicateAddress if the address exists. + // Returns *tcpip.ErrDuplicateAddress if the address exists. // // Acquires and returns the AddressEndpoint for the added address. - AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, *tcpip.Error) + AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error) // RemovePermanentAddress removes the passed address if it is a permanent // address. // - // Returns tcpip.ErrBadLocalAddress if the endpoint does not have the passed + // Returns *tcpip.ErrBadLocalAddress if the endpoint does not have the passed // permanent address. - RemovePermanentAddress(addr tcpip.Address) *tcpip.Error + RemovePermanentAddress(addr tcpip.Address) tcpip.Error // MainAddress returns the endpoint's primary permanent address. MainAddress() tcpip.AddressWithPrefix @@ -512,14 +512,14 @@ type NetworkInterface interface { Promiscuous() bool // WritePacketToRemote writes the packet to the given remote link address. - WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error + WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error // WritePacket writes a packet with the given protocol through the given // route. // // WritePacket takes ownership of the packet buffer. The packet buffer's // network and transport header must be set. - WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error + WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error // WritePackets writes packets with the given protocol through the given // route. Must not be called with an empty list of packet buffers. @@ -529,7 +529,7 @@ type NetworkInterface interface { // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, syscall filters // may need to be updated. - WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) + WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) } // LinkResolvableNetworkEndpoint handles link resolution events. @@ -547,8 +547,8 @@ type NetworkEndpoint interface { // Must only be called when the stack is in a state that allows the endpoint // to send and receive packets. // - // Returns tcpip.ErrNotPermitted if the endpoint cannot be enabled. - Enable() *tcpip.Error + // Returns *tcpip.ErrNotPermitted if the endpoint cannot be enabled. + Enable() tcpip.Error // Enabled returns true if the endpoint is enabled. Enabled() bool @@ -574,16 +574,16 @@ type NetworkEndpoint interface { // WritePacket writes a packet to the given destination address and // protocol. It takes ownership of pkt. pkt.TransportHeader must have // already been set. - WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error + WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error // WritePackets writes packets to the given destination address and // protocol. pkts must not be zero length. It takes ownership of pkts and // underlying packets. - WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) + WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. It takes ownership of pkt. - WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error + WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) tcpip.Error // HandlePacket is called by the link layer when new packets arrive to // this network endpoint. It sets pkt.NetworkHeader. @@ -597,6 +597,26 @@ type NetworkEndpoint interface { // NetworkProtocolNumber returns the tcpip.NetworkProtocolNumber for // this endpoint. NetworkProtocolNumber() tcpip.NetworkProtocolNumber + + // Stats returns a reference to the network endpoint stats. + Stats() NetworkEndpointStats +} + +// NetworkEndpointStats is the interface implemented by each network endpoint +// stats struct. +type NetworkEndpointStats interface { + // IsNetworkEndpointStats is an empty method to implement the + // NetworkEndpointStats marker interface. + IsNetworkEndpointStats() +} + +// IPNetworkEndpointStats is a NetworkEndpointStats that tracks IP-related +// statistics. +type IPNetworkEndpointStats interface { + NetworkEndpointStats + + // IPStats returns the IP statistics of a network endpoint. + IPStats() *tcpip.IPStats } // ForwardingNetworkProtocol is a NetworkProtocol that may forward packets. @@ -634,12 +654,12 @@ type NetworkProtocol interface { // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the // provided option value is invalid. - SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error + SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error // Option allows retrieving protocol specific option values. // Option returns an error if the option is not supported or the // provided option value is invalid. - Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error + Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error // Close requests that any worker goroutines owned by the protocol // stop. @@ -776,7 +796,7 @@ type LinkEndpoint interface { // To participate in transparent bridging, a LinkEndpoint implementation // should call eth.Encode with header.EthernetFields.SrcAddr set to // r.LocalLinkAddress if it is provided. - WritePacket(RouteInfo, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error + WritePacket(RouteInfo, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error // WritePackets writes packets with the given protocol and route. Must not be // called with an empty list of packet buffers. @@ -786,7 +806,7 @@ type LinkEndpoint interface { // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, syscall filters // may need to be updated. - WritePackets(RouteInfo, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) + WritePackets(RouteInfo, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) } // InjectableLinkEndpoint is a LinkEndpoint where inbound packets are @@ -801,7 +821,7 @@ type InjectableLinkEndpoint interface { // link. // // dest is used by endpoints with multiple raw destinations. - InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error + InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error } // A LinkAddressResolver is an extension to a NetworkProtocol that @@ -813,7 +833,7 @@ type LinkAddressResolver interface { // // The request is sent from the passed network interface. If the interface // local address is unspecified, any interface local address may be used. - LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic NetworkInterface) *tcpip.Error + LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic NetworkInterface) tcpip.Error // ResolveStaticAddress attempts to resolve address without sending // requests. It either resolves the name immediately or returns the @@ -829,12 +849,8 @@ type LinkAddressResolver interface { // A LinkAddressCache caches link addresses. type LinkAddressCache interface { - // CheckLocalAddress determines if the given local address exists, and if it - // does not exist. - CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID - // AddLinkAddress adds a link address to the cache. - AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) + AddLinkAddress(addr tcpip.Address, linkAddr tcpip.LinkAddress) } // RawFactory produces endpoints for writing various types of raw packets. @@ -842,11 +858,11 @@ type RawFactory interface { // NewUnassociatedEndpoint produces endpoints for writing packets not // associated with a particular transport protocol. Such endpoints can // be used to write arbitrary packets that include the network header. - NewUnassociatedEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + NewUnassociatedEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) // NewPacketEndpoint produces endpoints for reading and writing packets // that include network and (when cooked is false) link layer headers. - NewPacketEndpoint(stack *Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + NewPacketEndpoint(stack *Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) } // GSOType is the type of GSO segments. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 1ff7b3a37..4ae0f2a1a 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -86,12 +86,21 @@ type RouteInfo struct { RemoteLinkAddress tcpip.LinkAddress } -// Fields returns a RouteInfo with all of r's exported fields. This allows -// callers to store the route's fields without retaining a reference to it. +// Fields returns a RouteInfo with all of the known values for the route's +// fields. +// +// If any fields are unknown (e.g. remote link address when it is waiting for +// link address resolution), they will be unset. func (r *Route) Fields() RouteInfo { + r.mu.RLock() + defer r.mu.RUnlock() + return r.fieldsLocked() +} + +func (r *Route) fieldsLocked() RouteInfo { return RouteInfo{ routeInfo: r.routeInfo, - RemoteLinkAddress: r.RemoteLinkAddress(), + RemoteLinkAddress: r.mu.remoteLinkAddress, } } @@ -306,32 +315,43 @@ func (r *Route) ResolveWith(addr tcpip.LinkAddress) { r.mu.remoteLinkAddress = addr } -// Resolve attempts to resolve the link address if necessary. +// ResolvedFieldsResult is the result of a route resolution attempt. +type ResolvedFieldsResult struct { + RouteInfo RouteInfo + Success bool +} + +// ResolvedFields attempts to resolve the remote link address if it is not +// known. // -// Returns tcpip.ErrWouldBlock if address resolution requires blocking (e.g. -// waiting for ARP reply). If address resolution is required, a notification -// channel is also returned for the caller to block on. The channel is closed -// once address resolution is complete (successful or not). If a callback is -// provided, it will be called when address resolution is complete, regardless +// If a callback is provided, it will be called before ResolvedFields returns +// when address resolution is not required. If address resolution is required, +// the callback will be called once address resolution is complete, regardless // of success or failure. -func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) { - r.mu.Lock() - - if !r.isResolutionRequiredRLocked() { - // Nothing to do if there is no cache (which does the resolution on cache miss) or - // link address is already known. - r.mu.Unlock() - return nil, nil - } - - // Increment the route's reference count because finishResolution retains a - // reference to the route and releases it when called. - r.acquireLocked() - r.mu.Unlock() +// +// Note, the route will not cache the remote link address when address +// resolution completes. +func (r *Route) ResolvedFields(afterResolve func(ResolvedFieldsResult)) tcpip.Error { + _, _, err := r.resolvedFields(afterResolve) + return err +} - nextAddr := r.NextHop - if nextAddr == "" { - nextAddr = r.RemoteAddress +// resolvedFields is like ResolvedFields but also returns a notification channel +// when address resolution is required. This channel will become readable once +// address resolution is complete. +// +// The route's fields will also be returned, regardless of whether address +// resolution is required or not. +func (r *Route) resolvedFields(afterResolve func(ResolvedFieldsResult)) (RouteInfo, <-chan struct{}, tcpip.Error) { + r.mu.RLock() + fields := r.fieldsLocked() + resolutionRequired := r.isResolutionRequiredRLocked() + r.mu.RUnlock() + if !resolutionRequired { + if afterResolve != nil { + afterResolve(ResolvedFieldsResult{RouteInfo: fields, Success: true}) + } + return fields, nil, nil } // If specified, the local address used for link address resolution must be an @@ -341,18 +361,27 @@ func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) { linkAddressResolutionRequestLocalAddr = r.LocalAddress } - finishResolution := func(linkAddress tcpip.LinkAddress, ok bool) { - if ok { - r.ResolveWith(linkAddress) - } + afterResolveFields := fields + linkAddr, ch, err := r.outgoingNIC.getNeighborLinkAddress(r.nextHop(), linkAddressResolutionRequestLocalAddr, r.linkRes, func(r LinkResolutionResult) { if afterResolve != nil { - afterResolve() + if r.Success { + afterResolveFields.RemoteLinkAddress = r.LinkAddress + } + + afterResolve(ResolvedFieldsResult{RouteInfo: afterResolveFields, Success: r.Success}) } - r.Release() + }) + if err == nil { + fields.RemoteLinkAddress = linkAddr } + return fields, ch, err +} - _, ch, err := r.outgoingNIC.getNeighborLinkAddress(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, finishResolution) - return ch, err +func (r *Route) nextHop() tcpip.Address { + if len(r.NextHop) == 0 { + return r.RemoteAddress + } + return r.NextHop } // local returns true if the route is a local route. @@ -371,11 +400,7 @@ func (r *Route) IsResolutionRequired() bool { } func (r *Route) isResolutionRequiredRLocked() bool { - if !r.isValidForOutgoingRLocked() || r.mu.remoteLinkAddress != "" || r.local() { - return false - } - - return r.linkRes != nil + return len(r.mu.remoteLinkAddress) == 0 && r.linkRes != nil && r.isValidForOutgoingRLocked() && !r.local() } func (r *Route) isValidForOutgoing() bool { @@ -404,9 +429,9 @@ func (r *Route) isValidForOutgoingRLocked() bool { } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { +func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error { if !r.isValidForOutgoing() { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt) @@ -414,9 +439,9 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuf // WritePackets writes a list of n packets through the given route and returns // the number of packets written. -func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { +func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { if !r.isValidForOutgoing() { - return 0, tcpip.ErrInvalidEndpointState + return 0, &tcpip.ErrInvalidEndpointState{} } return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) @@ -424,9 +449,9 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { +func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) tcpip.Error { if !r.isValidForOutgoing() { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt) @@ -496,3 +521,12 @@ func (r *Route) IsOutboundBroadcast() bool { // Only IPv4 has a notion of broadcast. return r.isV4Broadcast(r.RemoteAddress) } + +// ConfirmReachable informs the network/link layer that the neighbour used for +// the route is reachable. +// +// "Reachable" is defined as having full-duplex communication between the +// local and remote ends of the route. +func (r *Route) ConfirmReachable() { + r.outgoingNIC.confirmReachable(r.nextHop()) +} diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index c0aec61a6..119c4c505 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -76,12 +76,16 @@ type TCPCubicState struct { // TCPRACKState is used to hold a copy of the internal RACK state when the // TCPProbeFunc is invoked. type TCPRACKState struct { - XmitTime time.Time - EndSequence seqnum.Value - FACK seqnum.Value - RTT time.Duration - Reord bool - DSACKSeen bool + XmitTime time.Time + EndSequence seqnum.Value + FACK seqnum.Value + RTT time.Duration + Reord bool + DSACKSeen bool + ReoWnd time.Duration + ReoWndIncr uint8 + ReoWndPersist int8 + RTTSeq seqnum.Value } // TCPEndpointID is the unique 4 tuple that identifies a given endpoint. @@ -382,8 +386,6 @@ type Stack struct { stats tcpip.Stats - linkAddrCache *linkAddrCache - mu sync.RWMutex nics map[tcpip.NICID]*NIC @@ -446,7 +448,7 @@ type Stack struct { // sendBufferSize holds the min/default/max send buffer sizes for // endpoints other than TCP. - sendBufferSize SendBufferSizeOption + sendBufferSize tcpip.SendBufferSizeOption // receiveBufferSize holds the min/default/max receive buffer sizes for // endpoints other than TCP. @@ -554,7 +556,7 @@ type TransportEndpointInfo struct { // incompatible with the receiver. // // Preconditon: the parent endpoint mu must be held while calling this method. -func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { +func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { netProto := t.NetProto switch len(addr.Addr) { case header.IPv4AddressSize: @@ -572,11 +574,11 @@ func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl switch len(t.ID.LocalAddress) { case header.IPv4AddressSize: if len(addr.Addr) == header.IPv6AddressSize { - return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState + return tcpip.FullAddress{}, 0, &tcpip.ErrInvalidEndpointState{} } case header.IPv6AddressSize: if len(addr.Addr) == header.IPv4AddressSize { - return tcpip.FullAddress{}, 0, tcpip.ErrNetworkUnreachable + return tcpip.FullAddress{}, 0, &tcpip.ErrNetworkUnreachable{} } } @@ -584,10 +586,10 @@ func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl case netProto == t.NetProto: case netProto == header.IPv4ProtocolNumber && t.NetProto == header.IPv6ProtocolNumber: if v6only { - return tcpip.FullAddress{}, 0, tcpip.ErrNoRoute + return tcpip.FullAddress{}, 0, &tcpip.ErrNoRoute{} } default: - return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState + return tcpip.FullAddress{}, 0, &tcpip.ErrInvalidEndpointState{} } return addr, netProto, nil @@ -636,7 +638,6 @@ func New(opts Options) *Stack { linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), nics: make(map[tcpip.NICID]*NIC), cleanupEndpoints: make(map[TransportEndpoint]struct{}), - linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), PortManager: ports.NewPortManager(), clock: clock, stats: opts.Stats.FillIn(), @@ -649,7 +650,7 @@ func New(opts Options) *Stack { uniqueIDGenerator: opts.UniqueID, nudDisp: opts.NUDDisp, randomGenerator: mathrand.New(randSrc), - sendBufferSize: SendBufferSizeOption{ + sendBufferSize: tcpip.SendBufferSizeOption{ Min: MinBufferSize, Default: DefaultBufferSize, Max: DefaultMaxBufferSize, @@ -701,10 +702,10 @@ func (s *Stack) UniqueID() uint64 { // options. This method returns an error if the protocol is not supported or // option is not supported by the protocol implementation or the provided value // is incorrect. -func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.SettableNetworkProtocolOption) *tcpip.Error { +func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.SettableNetworkProtocolOption) tcpip.Error { netProto, ok := s.networkProtocols[network] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } return netProto.SetOption(option) } @@ -718,10 +719,10 @@ func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, op // if err != nil { // ... // } -func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.GettableNetworkProtocolOption) *tcpip.Error { +func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.GettableNetworkProtocolOption) tcpip.Error { netProto, ok := s.networkProtocols[network] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } return netProto.Option(option) } @@ -730,10 +731,10 @@ func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, optio // options. This method returns an error if the protocol is not supported or // option is not supported by the protocol implementation or the provided value // is incorrect. -func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.SettableTransportProtocolOption) *tcpip.Error { +func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.SettableTransportProtocolOption) tcpip.Error { transProtoState, ok := s.transportProtocols[transport] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } return transProtoState.proto.SetOption(option) } @@ -745,10 +746,10 @@ func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumb // if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil { // ... // } -func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.GettableTransportProtocolOption) *tcpip.Error { +func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.GettableTransportProtocolOption) tcpip.Error { transProtoState, ok := s.transportProtocols[transport] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } return transProtoState.proto.Option(option) } @@ -781,15 +782,15 @@ func (s *Stack) Stats() tcpip.Stats { // SetForwarding enables or disables packet forwarding between NICs for the // passed protocol. -func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) *tcpip.Error { +func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { protocol, ok := s.networkProtocols[protocolNum] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol) if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } forwardingProtocol.SetForwarding(enable) @@ -852,10 +853,10 @@ func (s *Stack) RemoveRoutes(match func(tcpip.Route) bool) { } // NewEndpoint creates a new transport layer endpoint of the given protocol. -func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { t, ok := s.transportProtocols[transport] if !ok { - return nil, tcpip.ErrUnknownProtocol + return nil, &tcpip.ErrUnknownProtocol{} } return t.proto.NewEndpoint(network, waiterQueue) @@ -864,9 +865,9 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp // NewRawEndpoint creates a new raw transport layer endpoint of the given // protocol. Raw endpoints receive all traffic for a given protocol regardless // of address. -func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { +func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, tcpip.Error) { if s.rawFactory == nil { - return nil, tcpip.ErrNotPermitted + return nil, &tcpip.ErrNotPermitted{} } if !associated { @@ -875,7 +876,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network t, ok := s.transportProtocols[transport] if !ok { - return nil, tcpip.ErrUnknownProtocol + return nil, &tcpip.ErrUnknownProtocol{} } return t.proto.NewRawEndpoint(network, waiterQueue) @@ -883,9 +884,9 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network // NewPacketEndpoint creates a new packet endpoint listening for the given // netProto. -func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { if s.rawFactory == nil { - return nil, tcpip.ErrNotPermitted + return nil, &tcpip.ErrNotPermitted{} } return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue) @@ -916,20 +917,20 @@ type NICOptions struct { // NICs can be configured. // // LinkEndpoint.Attach will be called to bind ep with a NetworkDispatcher. -func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) *tcpip.Error { +func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) tcpip.Error { s.mu.Lock() defer s.mu.Unlock() // Make sure id is unique. if _, ok := s.nics[id]; ok { - return tcpip.ErrDuplicateNICID + return &tcpip.ErrDuplicateNICID{} } // Make sure name is unique, unless unnamed. if opts.Name != "" { for _, n := range s.nics { if n.Name() == opts.Name { - return tcpip.ErrDuplicateNICID + return &tcpip.ErrDuplicateNICID{} } } } @@ -945,7 +946,7 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp // CreateNIC creates a NIC with the provided id and LinkEndpoint and calls // LinkEndpoint.Attach to bind ep with a NetworkDispatcher. -func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { +func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) tcpip.Error { return s.CreateNICWithOptions(id, ep, NICOptions{}) } @@ -963,26 +964,26 @@ func (s *Stack) GetLinkEndpointByName(name string) LinkEndpoint { // EnableNIC enables the given NIC so that the link-layer endpoint can start // delivering packets to it. -func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error { +func (s *Stack) EnableNIC(id tcpip.NICID) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() nic, ok := s.nics[id] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } return nic.enable() } // DisableNIC disables the given NIC. -func (s *Stack) DisableNIC(id tcpip.NICID) *tcpip.Error { +func (s *Stack) DisableNIC(id tcpip.NICID) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() nic, ok := s.nics[id] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } nic.disable() @@ -1003,7 +1004,7 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool { } // RemoveNIC removes NIC and all related routes from the network stack. -func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error { +func (s *Stack) RemoveNIC(id tcpip.NICID) tcpip.Error { s.mu.Lock() defer s.mu.Unlock() @@ -1013,10 +1014,10 @@ func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error { // removeNICLocked removes NIC and all related routes from the network stack. // // s.mu must be locked. -func (s *Stack) removeNICLocked(id tcpip.NICID) *tcpip.Error { +func (s *Stack) removeNICLocked(id tcpip.NICID) tcpip.Error { nic, ok := s.nics[id] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } delete(s.nics, id) @@ -1050,6 +1051,9 @@ type NICInfo struct { Stats NICStats + // NetworkStats holds the stats of each NetworkEndpoint bound to the NIC. + NetworkStats map[tcpip.NetworkProtocolNumber]NetworkEndpointStats + // Context is user-supplied data optionally supplied in CreateNICWithOptions. // See type NICOptions for more details. Context NICContext @@ -1081,6 +1085,12 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { Promiscuous: nic.Promiscuous(), Loopback: nic.IsLoopback(), } + + netStats := make(map[tcpip.NetworkProtocolNumber]NetworkEndpointStats) + for proto, netEP := range nic.networkEndpoints { + netStats[proto] = netEP.Stats() + } + nics[id] = NICInfo{ Name: nic.name, LinkAddress: nic.LinkEndpoint.LinkAddress(), @@ -1088,6 +1098,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { Flags: flags, MTU: nic.LinkEndpoint.MTU(), Stats: nic.stats, + NetworkStats: netStats, Context: nic.context, ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(), } @@ -1111,13 +1122,13 @@ type NICStateFlags struct { } // AddAddress adds a new network-layer address to the specified NIC. -func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { +func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint) } // AddAddressWithPrefix is the same as AddAddress, but allows you to specify // the address prefix. -func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) *tcpip.Error { +func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) tcpip.Error { ap := tcpip.ProtocolAddress{ Protocol: protocol, AddressWithPrefix: addr, @@ -1127,16 +1138,16 @@ func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProto // AddProtocolAddress adds a new network-layer protocol address to the // specified NIC. -func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) *tcpip.Error { +func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) tcpip.Error { return s.AddProtocolAddressWithOptions(id, protocolAddress, CanBePrimaryEndpoint) } // AddAddressWithOptions is the same as AddAddress, but allows you to specify // whether the new endpoint can be primary or not. -func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) *tcpip.Error { +func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) tcpip.Error { netProto, ok := s.networkProtocols[protocol] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } return s.AddProtocolAddressWithOptions(id, tcpip.ProtocolAddress{ Protocol: protocol, @@ -1149,13 +1160,13 @@ func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProt // AddProtocolAddressWithOptions is the same as AddProtocolAddress, but allows // you to specify whether the new endpoint can be primary or not. -func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { +func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() nic, ok := s.nics[id] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } return nic.addAddress(protocolAddress, peb) @@ -1163,7 +1174,7 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc // RemoveAddress removes an existing network-layer address from the specified // NIC. -func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { +func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() @@ -1171,7 +1182,7 @@ func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { return nic.removeAddress(addr) } - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } // AllAddresses returns a map of NICIDs to their protocol addresses (primary @@ -1189,19 +1200,19 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress { // GetMainNICAddress returns the first non-deprecated primary address and prefix // for the given NIC and protocol. If no non-deprecated primary address exists, -// a deprecated primary address and prefix will be returned. Returns an error if +// a deprecated primary address and prefix will be returned. Returns false if // the NIC doesn't exist and an empty value if the NIC doesn't have a primary // address for the given protocol. -func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) { +func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, bool) { s.mu.RLock() defer s.mu.RUnlock() nic, ok := s.nics[id] if !ok { - return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID + return tcpip.AddressWithPrefix{}, false } - return nic.primaryAddress(protocol), nil + return nic.primaryAddress(protocol), true } func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) AssignableAddressEndpoint { @@ -1301,7 +1312,7 @@ func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, // If no local address is provided, the stack will select a local address. If no // remote address is provided, the stack wil use a remote address equal to the // local address. -func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (*Route, *tcpip.Error) { +func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (*Route, tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() @@ -1337,9 +1348,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } if isLoopback { - return nil, tcpip.ErrBadLocalAddress + return nil, &tcpip.ErrBadLocalAddress{} } - return nil, tcpip.ErrNetworkUnreachable + return nil, &tcpip.ErrNetworkUnreachable{} } canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal @@ -1405,7 +1416,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } } - return nil, tcpip.ErrNoRoute + return nil, &tcpip.ErrNoRoute{} } if id == 0 { @@ -1425,12 +1436,12 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } if needRoute { - return nil, tcpip.ErrNoRoute + return nil, &tcpip.ErrNoRoute{} } if header.IsV6LoopbackAddress(remoteAddr) { - return nil, tcpip.ErrBadLocalAddress + return nil, &tcpip.ErrBadLocalAddress{} } - return nil, tcpip.ErrNetworkUnreachable + return nil, &tcpip.ErrNetworkUnreachable{} } // CheckNetworkProtocol checks if a given network protocol is enabled in the @@ -1476,13 +1487,13 @@ func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProto } // SetPromiscuousMode enables or disables promiscuous mode in the given NIC. -func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error { +func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() nic, ok := s.nics[nicID] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } nic.setPromiscuousMode(enable) @@ -1492,13 +1503,13 @@ func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error // SetSpoofing enables or disables address spoofing in the given NIC, allowing // endpoints to bind to any address in the NIC. -func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error { +func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() nic, ok := s.nics[nicID] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } nic.setSpoofing(enable) @@ -1506,17 +1517,27 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error { return nil } -// AddLinkAddress adds a link address to the stack link cache. -func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} - s.linkAddrCache.add(fullAddr, linkAddr) - // TODO: provide a way for a transport endpoint to receive a signal - // that AddLinkAddress for a particular address has been called. +// AddLinkAddress adds a link address for the neighbor on the specified NIC. +func (s *Stack) AddLinkAddress(nicID tcpip.NICID, neighbor tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + nic, ok := s.nics[nicID] + if !ok { + return &tcpip.ErrUnknownNICID{} + } + + nic.linkAddrCache.AddLinkAddress(neighbor, linkAddr) + return nil +} + +// LinkResolutionResult is the result of a link address resolution attempt. +type LinkResolutionResult struct { + LinkAddress tcpip.LinkAddress + Success bool } -// GetLinkAddress finds the link address corresponding to a neighbor's address. -// -// Returns a link address for the remote address, if readily available. +// GetLinkAddress finds the link address corresponding to a network address. // // Returns ErrNotSupported if the stack is not configured with a link address // resolver for the specified network protocol. @@ -1525,53 +1546,56 @@ func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr t // with a notification channel for the caller to block on. Triggers address // resolution asynchronously. // -// If onResolve is provided, it will be called either immediately, if -// resolution is not required, or when address resolution is complete, with -// the resolved link address and whether resolution succeeded. After any -// callbacks have been called, the returned notification channel is closed. +// onResolve will be called either immediately, if resolution is not required, +// or when address resolution is complete, with the resolved link address and +// whether resolution succeeded. // // If specified, the local address must be an address local to the interface // the neighbor cache belongs to. The local address is the source address of // a packet prompting NUD/link address resolution. -// -// TODO(gvisor.dev/issue/5151): Don't return the link address. -func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(LinkResolutionResult)) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() if !ok { - return "", nil, tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } linkRes, ok := s.linkAddrResolvers[protocol] if !ok { - return "", nil, tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} + } + + if linkAddr, ok := linkRes.ResolveStaticAddress(addr); ok { + onResolve(LinkResolutionResult{LinkAddress: linkAddr, Success: true}) + return nil } - return nic.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve) + _, _, err := nic.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve) + return err } // Neighbors returns all IP to MAC address associations. -func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, *tcpip.Error) { +func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, tcpip.Error) { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() if !ok { - return nil, tcpip.ErrUnknownNICID + return nil, &tcpip.ErrUnknownNICID{} } return nic.neighbors() } // AddStaticNeighbor statically associates an IP address to a MAC address. -func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) *tcpip.Error { +func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } return nic.addStaticNeighbor(addr, linkAddr) @@ -1580,26 +1604,26 @@ func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAdd // RemoveNeighbor removes an IP to MAC address association previously created // either automically or by AddStaticNeighbor. Returns ErrBadAddress if there // is no association with the provided address. -func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) *tcpip.Error { +func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } return nic.removeNeighbor(addr) } // ClearNeighbors removes all IP to MAC address associations. -func (s *Stack) ClearNeighbors(nicID tcpip.NICID) *tcpip.Error { +func (s *Stack) ClearNeighbors(nicID tcpip.NICID) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } return nic.clearNeighbors() @@ -1609,25 +1633,25 @@ func (s *Stack) ClearNeighbors(nicID tcpip.NICID) *tcpip.Error { // transport dispatcher. Received packets that match the provided id will be // delivered to the given endpoint; specifying a nic is optional, but // nic-specific IDs have precedence over global ones. -func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (s *Stack) RegisterTransportEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { return s.demux.registerEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } // CheckRegisterTransportEndpoint checks if an endpoint can be registered with // the stack transport dispatcher. -func (s *Stack) CheckRegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (s *Stack) CheckRegisterTransportEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { return s.demux.checkEndpoint(netProtos, protocol, id, flags, bindToDevice) } // UnregisterTransportEndpoint removes the endpoint with the given id from the // stack transport dispatcher. -func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { +func (s *Stack) UnregisterTransportEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } // StartTransportEndpointCleanup removes the endpoint with the given id from // the stack transport dispatcher. It also transitions it to the cleanup stage. -func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { +func (s *Stack) StartTransportEndpointCleanup(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { s.cleanupEndpointsMu.Lock() s.cleanupEndpoints[ep] = struct{}{} s.cleanupEndpointsMu.Unlock() @@ -1652,13 +1676,13 @@ func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, tran // RegisterRawTransportEndpoint registers the given endpoint with the stack // transport dispatcher. Received packets that match the provided transport // protocol will be delivered to the given endpoint. -func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { +func (s *Stack) RegisterRawTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) tcpip.Error { return s.demux.registerRawEndpoint(netProto, transProto, ep) } // UnregisterRawTransportEndpoint removes the endpoint for the transport // protocol from the stack transport dispatcher. -func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) { +func (s *Stack) UnregisterRawTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) { s.demux.unregisterRawEndpoint(netProto, transProto, ep) } @@ -1762,7 +1786,7 @@ func (s *Stack) Resume() { // RegisterPacketEndpoint registers ep with the stack, causing it to receive // all traffic of the specified netProto on the given NIC. If nicID is 0, it // receives traffic from every NIC. -func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error { +func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error { s.mu.Lock() defer s.mu.Unlock() @@ -1781,7 +1805,7 @@ func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.Network // Capture on a specific device. nic, ok := s.nics[nicID] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } if err := nic.registerPacketEndpoint(netProto, ep); err != nil { return err @@ -1819,12 +1843,12 @@ func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip // WritePacketToRemote writes a payload on the specified NIC using the provided // network protocol and remote link address. -func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error { +func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) tcpip.Error { s.mu.Lock() nic, ok := s.nics[nicID] s.mu.Unlock() if !ok { - return tcpip.ErrUnknownDevice + return &tcpip.ErrUnknownDevice{} } pkt := NewPacketBuffer(PacketBufferOptions{ ReserveHeaderBytes: int(nic.MaxHeaderLength()), @@ -1889,37 +1913,37 @@ func (s *Stack) RemoveTCPProbe() { } // JoinGroup joins the given multicast group on the given NIC. -func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error { +func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() if nic, ok := s.nics[nicID]; ok { return nic.joinGroup(protocol, multicastAddr) } - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } // LeaveGroup leaves the given multicast group on the given NIC. -func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error { +func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() if nic, ok := s.nics[nicID]; ok { return nic.leaveGroup(protocol, multicastAddr) } - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } // IsInGroup returns true if the NIC with ID nicID has joined the multicast // group multicastAddr. -func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, *tcpip.Error) { +func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() if nic, ok := s.nics[nicID]; ok { return nic.isInGroup(multicastAddr), nil } - return false, tcpip.ErrUnknownNICID + return false, &tcpip.ErrUnknownNICID{} } // IPTables returns the stack's iptables. @@ -1959,26 +1983,26 @@ func (s *Stack) AllowICMPMessage() bool { // GetNetworkEndpoint returns the NetworkEndpoint with the specified protocol // number installed on the specified NIC. -func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber) (NetworkEndpoint, *tcpip.Error) { +func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber) (NetworkEndpoint, tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() nic, ok := s.nics[nicID] if !ok { - return nil, tcpip.ErrUnknownNICID + return nil, &tcpip.ErrUnknownNICID{} } return nic.getNetworkEndpoint(proto), nil } // NUDConfigurations gets the per-interface NUD configurations. -func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, *tcpip.Error) { +func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, tcpip.Error) { s.mu.RLock() nic, ok := s.nics[id] s.mu.RUnlock() if !ok { - return NUDConfigurations{}, tcpip.ErrUnknownNICID + return NUDConfigurations{}, &tcpip.ErrUnknownNICID{} } return nic.nudConfigs() @@ -1988,13 +2012,13 @@ func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, *tcpip.Err // // Note, if c contains invalid NUD configuration values, it will be fixed to // use default values for the erroneous values. -func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) *tcpip.Error { +func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) tcpip.Error { s.mu.RLock() nic, ok := s.nics[id] s.mu.RUnlock() if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } return nic.setNUDConfigs(c) @@ -2036,7 +2060,7 @@ func generateRandInt64() int64 { } // FindNetworkEndpoint returns the network endpoint for the given address. -func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, *tcpip.Error) { +func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() @@ -2048,7 +2072,7 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres addressEndpoint.DecRef() return nic.getNetworkEndpoint(netProto), nil } - return nil, tcpip.ErrBadAddress + return nil, &tcpip.ErrBadAddress{} } // FindNICNameFromID returns the name of the NIC for the given NICID. diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go index 0b093e6c5..8d9b20b7e 100644 --- a/pkg/tcpip/stack/stack_options.go +++ b/pkg/tcpip/stack/stack_options.go @@ -14,7 +14,9 @@ package stack -import "gvisor.dev/gvisor/pkg/tcpip" +import ( + "gvisor.dev/gvisor/pkg/tcpip" +) const ( // MinBufferSize is the smallest size of a receive or send buffer. @@ -29,14 +31,6 @@ const ( DefaultMaxBufferSize = 4 << 20 // 4 MiB ) -// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to -// get/set the default, min and max send buffer sizes. -type SendBufferSizeOption struct { - Min int - Default int - Max int -} - // ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to // get/set the default, min and max receive buffer sizes. type ReceiveBufferSizeOption struct { @@ -46,17 +40,17 @@ type ReceiveBufferSizeOption struct { } // SetOption allows setting stack wide options. -func (s *Stack) SetOption(option interface{}) *tcpip.Error { +func (s *Stack) SetOption(option interface{}) tcpip.Error { switch v := option.(type) { - case SendBufferSizeOption: + case tcpip.SendBufferSizeOption: // Make sure we don't allow lowering the buffer below minimum // required for stack to work. if v.Min < MinBufferSize { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } if v.Default < v.Min || v.Default > v.Max { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } s.mu.Lock() @@ -68,11 +62,11 @@ func (s *Stack) SetOption(option interface{}) *tcpip.Error { // Make sure we don't allow lowering the buffer below minimum // required for stack to work. if v.Min < MinBufferSize { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } if v.Default < v.Min || v.Default > v.Max { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } s.mu.Lock() @@ -81,14 +75,14 @@ func (s *Stack) SetOption(option interface{}) *tcpip.Error { return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } // Option allows retrieving stack wide options. -func (s *Stack) Option(option interface{}) *tcpip.Error { +func (s *Stack) Option(option interface{}) tcpip.Error { switch v := option.(type) { - case *SendBufferSizeOption: + case *tcpip.SendBufferSizeOption: s.mu.RLock() *v = s.sendBufferSize s.mu.RUnlock() @@ -101,6 +95,6 @@ func (s *Stack) Option(option interface{}) *tcpip.Error { return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 7e935ddff..41f95811f 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -60,6 +60,15 @@ const ( protocolNumberOffset = 2 ) +func checkGetMainNICAddress(s *stack.Stack, nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber, want tcpip.AddressWithPrefix) error { + if addr, ok := s.GetMainNICAddress(nicID, proto); !ok { + return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, false), want = (_, true)", nicID, proto) + } else if addr != want { + return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, true), want = (%s, true)", nicID, proto, addr, want) + } + return nil +} + // fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and // received packets; the counts of all endpoints are aggregated in the protocol // descriptor. @@ -81,7 +90,7 @@ type fakeNetworkEndpoint struct { dispatcher stack.TransportDispatcher } -func (f *fakeNetworkEndpoint) Enable() *tcpip.Error { +func (f *fakeNetworkEndpoint) Enable() tcpip.Error { f.mu.Lock() defer f.mu.Unlock() f.mu.enabled = true @@ -145,7 +154,7 @@ func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { return f.nic.MaxHeaderLength() + fakeNetHeaderLen } -func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { +func (*fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { return 0 } @@ -153,7 +162,7 @@ func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumbe return f.proto.Number() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ @@ -176,18 +185,30 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { - return tcpip.ErrNotSupported +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error { + return &tcpip.ErrNotSupported{} } func (f *fakeNetworkEndpoint) Close() { f.AddressableEndpointState.Cleanup() } +// Stats implements NetworkEndpoint. +func (*fakeNetworkEndpoint) Stats() stack.NetworkEndpointStats { + return &fakeNetworkEndpointStats{} +} + +var _ stack.NetworkEndpointStats = (*fakeNetworkEndpointStats)(nil) + +type fakeNetworkEndpointStats struct{} + +// IsNetworkEndpointStats implements stack.NetworkEndpointStats. +func (*fakeNetworkEndpointStats) IsNetworkEndpointStats() {} + // fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the // number of packets sent and received via endpoints of this protocol. The index // where packets are added is given by the packet's destination address MOD 10. @@ -202,15 +223,15 @@ type fakeNetworkProtocol struct { } } -func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { +func (*fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { return fakeNetNumber } -func (f *fakeNetworkProtocol) MinimumPacketSize() int { +func (*fakeNetworkProtocol) MinimumPacketSize() int { return fakeNetHeaderLen } -func (f *fakeNetworkProtocol) DefaultPrefixLen() int { +func (*fakeNetworkProtocol) DefaultPrefixLen() int { return fakeDefaultPrefixLen } @@ -232,23 +253,23 @@ func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.Li return e } -func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error { +func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: f.defaultTTL = uint8(*v) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } -func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error { +func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: *v = tcpip.DefaultTTLOption(f.defaultTTL) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } @@ -397,7 +418,7 @@ func TestNetworkReceive(t *testing.T) { } } -func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Error { +func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) tcpip.Error { r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */) if err != nil { return err @@ -406,7 +427,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro return send(r, payload) } -func send(r *stack.Route, payload buffer.View) *tcpip.Error { +func send(r *stack.Route, payload buffer.View) tcpip.Error { return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: payload.ToVectorisedView(), @@ -435,14 +456,14 @@ func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer } } -func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { +func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { t.Helper() if gotErr := send(r, payload); gotErr != wantErr { t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) } } -func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { +func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { t.Helper() if gotErr := sendTo(s, addr, payload); gotErr != wantErr { t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) @@ -579,8 +600,8 @@ func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) { _, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != tcpip.ErrNoRoute { - t.Fatalf("FindRoute returned unexpected error, got = %v, want = %s", err, tcpip.ErrNoRoute) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Fatalf("FindRoute returned unexpected error, got = %v, want = %s", err, &tcpip.ErrNoRoute{}) } } @@ -628,8 +649,9 @@ func TestDisableUnknownNIC(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) - if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID { - t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID) + err := s.DisableNIC(1) + if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { + t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, &tcpip.ErrUnknownNICID{}) } } @@ -687,8 +709,9 @@ func TestRemoveUnknownNIC(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) - if err := s.RemoveNIC(1); err != tcpip.ErrUnknownNICID { - t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID) + err := s.RemoveNIC(1) + if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { + t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, &tcpip.ErrUnknownNICID{}) } } @@ -731,8 +754,8 @@ func TestRemoveNIC(t *testing.T) { func TestRouteWithDownNIC(t *testing.T) { tests := []struct { name string - downFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error - upFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error + downFn func(s *stack.Stack, nicID tcpip.NICID) tcpip.Error + upFn func(s *stack.Stack, nicID tcpip.NICID) tcpip.Error }{ { name: "Disabled NIC", @@ -890,15 +913,15 @@ func TestRouteWithDownNIC(t *testing.T) { if err := test.downFn(s, nicID1); err != nil { t.Fatalf("test.downFn(_, %d): %s", nicID1, err) } - testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) + testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) testSend(t, r2, ep2, buf) // Writes with Routes that use NIC2 after being brought down should fail. if err := test.downFn(s, nicID2); err != nil { t.Fatalf("test.downFn(_, %d): %s", nicID2, err) } - testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) - testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) if upFn := test.upFn; upFn != nil { // Writes with Routes that use NIC1 after being brought up should @@ -911,7 +934,7 @@ func TestRouteWithDownNIC(t *testing.T) { t.Fatalf("test.upFn(_, %d): %s", nicID1, err) } testSend(t, r1, ep1, buf) - testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) } }) } @@ -1036,11 +1059,12 @@ func TestAddressRemoval(t *testing.T) { t.Fatal("RemoveAddress failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) // Check that removing the same address fails. - if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { - t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) + err := s.RemoveAddress(1, localAddr) + if _, ok := err.(*tcpip.ErrBadLocalAddress); !ok { + t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, &tcpip.ErrBadLocalAddress{}) } } @@ -1087,12 +1111,15 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { t.Fatal("RemoveAddress failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) + testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) // Check that removing the same address fails. - if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { - t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) + { + err := s.RemoveAddress(1, localAddr) + if _, ok := err.(*tcpip.ErrBadLocalAddress); !ok { + t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, &tcpip.ErrBadLocalAddress{}) + } } } @@ -1186,7 +1213,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) } // 2. Add Address, everything should work. @@ -1214,7 +1241,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) } // 4. Add Address back, everything should work again. @@ -1253,8 +1280,8 @@ func TestEndpointExpiration(t *testing.T) { testSend(t, r, ep, nil) testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) + testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) } // 7. Add Address back, everything should work again. @@ -1290,7 +1317,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) } }) } @@ -1333,8 +1360,8 @@ func TestPromiscuousMode(t *testing.T) { // Check that we can't get a route as there is no local address. _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) - if err != tcpip.ErrNoRoute { - t.Fatalf("FindRoute returned unexpected error: got = %v, want = %s", err, tcpip.ErrNoRoute) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Fatalf("FindRoute returned unexpected error: got = %v, want = %s", err, &tcpip.ErrNoRoute{}) } // Set promiscuous mode to false, then check that packet can't be @@ -1540,7 +1567,7 @@ func TestSpoofingNoAddress(t *testing.T) { t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) } // Sending a packet fails. - testFailingSendTo(t, s, dstAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, dstAddr, ep, nil, &tcpip.ErrNoRoute{}) // With address spoofing enabled, FindRoute permits any address to be used // as the source. @@ -1590,8 +1617,11 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { s.SetRouteTable([]tcpip.Route{}) // If there is no endpoint, it won't work. - if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + { + _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if _, ok := err.(*tcpip.ErrNetworkUnreachable); !ok { + t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, &tcpip.ErrNetworkUnreachable{}) + } } protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} @@ -1610,8 +1640,11 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } // If the NIC doesn't exist, it won't work. - if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(2, %v, %v, %d) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + { + _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if _, ok := err.(*tcpip.ErrNetworkUnreachable); !ok { + t.Fatalf("got FindRoute(2, %v, %v, %d) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, &tcpip.ErrNetworkUnreachable{}) + } } } @@ -1753,9 +1786,9 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { anyAddr = header.IPv6Any } - want := tcpip.ErrNetworkUnreachable + var want tcpip.Error = &tcpip.ErrNetworkUnreachable{} if tc.routeNeeded { - want = tcpip.ErrNoRoute + want = &tcpip.ErrNoRoute{} } // If there is no endpoint, it won't work. @@ -1769,8 +1802,8 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded { // Route table is empty but we need a route, this should cause an error. - if err != tcpip.ErrNoRoute { - t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, tcpip.ErrNoRoute) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, &tcpip.ErrNoRoute{}) } } else { if err != nil { @@ -1861,20 +1894,20 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { // Check that GetMainNICAddress returns an address if at least // one primary address was added. In that case make sure the // address/prefixLen matches what we added. - gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("GetMainNICAddress failed:", err) + gotAddr, ok := s.GetMainNICAddress(1, fakeNetNumber) + if !ok { + t.Fatalf("got GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber) } if len(primaryAddrAdded) == 0 { // No primary addresses present. if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { - t.Fatalf("GetMainNICAddress: got addr = %s, want = %s", gotAddr, wantAddr) + t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, gotAddr, wantAddr) } } else { // At least one primary address was added, verify the returned // address is in the list of primary addresses we added. if _, ok := primaryAddrAdded[gotAddr]; !ok { - t.Fatalf("GetMainNICAddress: got = %s, want any in {%v}", gotAddr, primaryAddrAdded) + t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, gotAddr, primaryAddrAdded) } } }) @@ -1915,12 +1948,8 @@ func TestGetMainNICAddressAddRemove(t *testing.T) { } // Check that we get the right initial address and prefix length. - gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("GetMainNICAddress failed:", err) - } - if wantAddr := protocolAddress.AddressWithPrefix; gotAddr != wantAddr { - t.Fatalf("got s.GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr) + if err := checkGetMainNICAddress(s, 1, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil { + t.Fatal(err) } if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil { @@ -1928,12 +1957,8 @@ func TestGetMainNICAddressAddRemove(t *testing.T) { } // Check that we get no address after removal. - gotAddr, err = s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("GetMainNICAddress failed:", err) - } - if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { - t.Fatalf("got GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr) + if err := checkGetMainNICAddress(s, 1, fakeNetNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } }) } @@ -2102,7 +2127,7 @@ func TestCreateNICWithOptions(t *testing.T) { type callArgsAndExpect struct { nicID tcpip.NICID opts stack.NICOptions - err *tcpip.Error + err tcpip.Error } tests := []struct { @@ -2120,7 +2145,7 @@ func TestCreateNICWithOptions(t *testing.T) { { nicID: tcpip.NICID(1), opts: stack.NICOptions{Name: "eth2"}, - err: tcpip.ErrDuplicateNICID, + err: &tcpip.ErrDuplicateNICID{}, }, }, }, @@ -2135,7 +2160,7 @@ func TestCreateNICWithOptions(t *testing.T) { { nicID: tcpip.NICID(2), opts: stack.NICOptions{Name: "lo"}, - err: tcpip.ErrDuplicateNICID, + err: &tcpip.ErrDuplicateNICID{}, }, }, }, @@ -2165,7 +2190,7 @@ func TestCreateNICWithOptions(t *testing.T) { { nicID: tcpip.NICID(1), opts: stack.NICOptions{}, - err: tcpip.ErrDuplicateNICID, + err: &tcpip.ErrDuplicateNICID{}, }, }, }, @@ -2474,12 +2499,12 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { } } - gotMainAddr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + // Check that we get no address after removal. + if err := checkGetMainNICAddress(s, 1, fakeNetNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } - if gotMainAddr != expectedMainAddr { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", gotMainAddr, expectedMainAddr) + if err := checkGetMainNICAddress(s, 1, header.IPv6ProtocolNumber, expectedMainAddr); err != nil { + t.Fatal(err) } }) } @@ -2525,12 +2550,8 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err) } - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Errorf("got stack.GetMainNICAddress(%d, _) = %s, want = %s", nicID, addr, want) + if err := checkGetMainNICAddress(s, 1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } }) } @@ -2561,12 +2582,8 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { // Address should not be considered bound to the // NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } linkLocalAddr := header.LinkLocalAddr(linkAddr1) @@ -2584,12 +2601,8 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); err != nil { + t.Fatal(err) } } @@ -2621,17 +2634,17 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", pi); err != nil { t.Fatal("AddAddressWithOptions failed:", err) } - addr, err := s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("s.GetMainNICAddress failed:", err) + addr, ok := s.GetMainNICAddress(1, fakeNetNumber) + if !ok { + t.Fatalf("GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber) } if pi == stack.NeverPrimaryEndpoint { if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want) + t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, addr, want) } } else if addr.Address != "\x01" { - t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address) + t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (1, true)", fakeNetNumber, addr.Address) } { @@ -2710,18 +2723,17 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { if err := s.RemoveAddress(1, "\x03"); err != nil { t.Fatalf("RemoveAddress failed: %v", err) } - addr, err = s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatalf("s.GetMainNICAddress failed: %v", err) + addr, ok = s.GetMainNICAddress(1, fakeNetNumber) + if !ok { + t.Fatalf("got GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber) } if ps == stack.NeverPrimaryEndpoint { if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want) - + t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, addr, want) } } else { if addr.Address != "\x01" { - t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address) + t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (1, true)", fakeNetNumber, addr.Address) } } }) @@ -3247,12 +3259,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) { } // Address should be tentative so it should not be a main address. - got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); got != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Enabling the NIC should start DAD for the address. @@ -3264,12 +3272,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) { } // Address should not be considered bound to the NIC yet (DAD ongoing). - got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); got != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) } // Wait for DAD to resolve. @@ -3284,12 +3288,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) { if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) } - got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if got != addr.AddressWithPrefix { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr.AddressWithPrefix); err != nil { + t.Fatal(err) } // Enabling the NIC again should be a no-op. @@ -3299,12 +3299,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) { if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) } - got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if got != addr.AddressWithPrefix { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix) + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr.AddressWithPrefix); err != nil { + t.Fatal(err) } } @@ -3313,14 +3309,14 @@ func TestStackReceiveBufferSizeOption(t *testing.T) { testCases := []struct { name string rs stack.ReceiveBufferSizeOption - err *tcpip.Error + err tcpip.Error }{ // Invalid configurations. - {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, - {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, + {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, // Valid Configurations {"in_ascending_order", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, @@ -3352,31 +3348,32 @@ func TestStackSendBufferSizeOption(t *testing.T) { const sMin = stack.MinBufferSize testCases := []struct { name string - ss stack.SendBufferSizeOption - err *tcpip.Error + ss tcpip.SendBufferSizeOption + err tcpip.Error }{ // Invalid configurations. - {"min_below_zero", stack.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"min_zero", stack.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"default_below_min", stack.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, - {"default_above_max", stack.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"max_below_min", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + {"min_below_zero", tcpip.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"min_zero", tcpip.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"default_below_min", tcpip.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, + {"default_above_max", tcpip.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"max_below_min", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, // Valid Configurations - {"in_ascending_order", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, - {"all_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, - {"min_default_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, - {"default_max_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, + {"in_ascending_order", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, + {"all_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, + {"min_default_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, + {"default_max_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { s := stack.New(stack.Options{}) defer s.Close() - if err := s.SetOption(tc.ss); err != tc.err { - t.Fatalf("s.SetOption(%+v) = %v, want: %v", tc.ss, err, tc.err) + err := s.SetOption(tc.ss) + if diff := cmp.Diff(tc.err, err); diff != "" { + t.Fatalf("unexpected error from s.SetOption(%+v), (-want, +got):\n%s", tc.ss, diff) } - var ss stack.SendBufferSizeOption if tc.err == nil { + var ss tcpip.SendBufferSizeOption if err := s.Option(&ss); err != nil { t.Fatalf("s.Option(%+v) = %v, want: nil", ss, err) } @@ -3778,20 +3775,16 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { } // Check that we get the right initial address and prefix length. - if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil { - t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) - } else if gotAddr != protocolAddress.AddressWithPrefix { - t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix) + if err := checkGetMainNICAddress(s, nicID, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil { + t.Fatal(err) } // Should still get the address when the NIC is diabled. if err := s.DisableNIC(nicID); err != nil { t.Fatalf("DisableNIC(%d): %s", nicID, err) } - if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil { - t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) - } else if gotAddr != protocolAddress.AddressWithPrefix { - t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix) + if err := checkGetMainNICAddress(s, nicID, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil { + t.Fatal(err) } } @@ -3939,7 +3932,7 @@ func TestFindRouteWithForwarding(t *testing.T) { addrNIC tcpip.NICID localAddr tcpip.Address - findRouteErr *tcpip.Error + findRouteErr tcpip.Error dependentOnForwarding bool }{ { @@ -3948,7 +3941,7 @@ func TestFindRouteWithForwarding(t *testing.T) { forwardingEnabled: false, addrNIC: nicID1, localAddr: fakeNetCfg.nic2Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -3957,7 +3950,7 @@ func TestFindRouteWithForwarding(t *testing.T) { forwardingEnabled: true, addrNIC: nicID1, localAddr: fakeNetCfg.nic2Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -3966,7 +3959,7 @@ func TestFindRouteWithForwarding(t *testing.T) { forwardingEnabled: false, addrNIC: nicID1, localAddr: fakeNetCfg.nic1Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4002,7 +3995,7 @@ func TestFindRouteWithForwarding(t *testing.T) { forwardingEnabled: false, addrNIC: nicID2, localAddr: fakeNetCfg.nic1Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4011,7 +4004,7 @@ func TestFindRouteWithForwarding(t *testing.T) { forwardingEnabled: true, addrNIC: nicID2, localAddr: fakeNetCfg.nic1Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4035,7 +4028,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: false, localAddr: fakeNetCfg.nic1Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4051,7 +4044,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: false, addrNIC: nicID1, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4059,7 +4052,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: true, addrNIC: nicID1, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4067,7 +4060,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: false, localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4075,7 +4068,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: true, localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4107,7 +4100,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: false, localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, - findRouteErr: tcpip.ErrNetworkUnreachable, + findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, { @@ -4115,7 +4108,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: true, localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, - findRouteErr: tcpip.ErrNetworkUnreachable, + findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, { @@ -4123,7 +4116,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: false, localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, - findRouteErr: tcpip.ErrNetworkUnreachable, + findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, { @@ -4131,7 +4124,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: true, localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, - findRouteErr: tcpip.ErrNetworkUnreachable, + findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, { @@ -4186,8 +4179,8 @@ func TestFindRouteWithForwarding(t *testing.T) { if r != nil { defer r.Release() } - if err != test.findRouteErr { - t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr) + if diff := cmp.Diff(test.findRouteErr, err); diff != "" { + t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, diff) } if test.findRouteErr != nil { @@ -4234,8 +4227,11 @@ func TestFindRouteWithForwarding(t *testing.T) { if err := s.SetForwarding(test.netCfg.proto, false); err != nil { t.Fatalf("SetForwarding(%d, false): %s", test.netCfg.proto, err) } - if err := send(r, data); err != tcpip.ErrInvalidEndpointState { - t.Fatalf("got send(_, _) = %s, want = %s", err, tcpip.ErrInvalidEndpointState) + { + err := send(r, data) + if _, ok := err.(*tcpip.ErrInvalidEndpointState); !ok { + t.Fatalf("got send(_, _) = %s, want = %s", err, &tcpip.ErrInvalidEndpointState{}) + } } if n := ep1.Drain(); n != 0 { t.Errorf("got %d unexpected packets from ep1", n) @@ -4297,8 +4293,9 @@ func TestWritePacketToRemote(t *testing.T) { } t.Run("InvalidNICID", func(t *testing.T) { - if got, want := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView()), tcpip.ErrUnknownDevice; got != want { - t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", got, want) + err := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView()) + if _, ok := err.(*tcpip.ErrUnknownDevice); !ok { + t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", err, &tcpip.ErrUnknownDevice{}) } pkt, ok := e.Read() if got, want := ok, false; got != want { @@ -4372,10 +4369,64 @@ func TestGetLinkAddressErrors(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if addr, _, err := s.GetLinkAddress(unknownNICID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrUnknownNICID { - t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = (%s, _, %s), want = (_, _, %s)", unknownNICID, ipv4.ProtocolNumber, addr, err, tcpip.ErrUnknownNICID) + { + err := s.GetLinkAddress(unknownNICID, "", "", ipv4.ProtocolNumber, nil) + if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { + t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, &tcpip.ErrUnknownNICID{}) + } + } + { + err := s.GetLinkAddress(nicID, "", "", ipv4.ProtocolNumber, nil) + if _, ok := err.(*tcpip.ErrNotSupported); !ok { + t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, &tcpip.ErrNotSupported{}) + } + } +} + +func TestStaticGetLinkAddress(t *testing.T) { + const ( + nicID = 1 + ) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + }) + if err := s.CreateNIC(nicID, channel.New(0, 0, "")); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + addr tcpip.Address + expectedLinkAddr tcpip.LinkAddress + }{ + { + name: "IPv4", + proto: ipv4.ProtocolNumber, + addr: header.IPv4Broadcast, + expectedLinkAddr: header.EthernetBroadcastAddress, + }, + { + name: "IPv6", + proto: ipv6.ProtocolNumber, + addr: header.IPv6AllNodesMulticastAddress, + expectedLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress), + }, } - if addr, _, err := s.GetLinkAddress(nicID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrNotSupported { - t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = (%s, _, %s), want = (_, _, %s)", unknownNICID, ipv4.ProtocolNumber, addr, err, tcpip.ErrNotSupported) + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ch := make(chan stack.LinkResolutionResult, 1) + if err := s.GetLinkAddress(nicID, test.addr, "", test.proto, func(r stack.LinkResolutionResult) { + ch <- r + }); err != nil { + t.Fatalf("s.GetLinkAddress(%d, %s, '', %d, _): %s", nicID, test.addr, test.proto, err) + } + + if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: true}, <-ch); diff != "" { + t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) + } + }) } } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 07b2818d2..26eceb804 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -205,7 +205,7 @@ func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpoint // registerEndpoint returns true if it succeeds. It fails and returns // false if ep already has an element with the same key. -func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { epsByNIC.mu.Lock() defer epsByNIC.mu.Unlock() @@ -222,7 +222,7 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t return multiPortEp.singleRegisterEndpoint(t, flags) } -func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { epsByNIC.mu.RLock() defer epsByNIC.mu.RUnlock() @@ -294,7 +294,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { // registerEndpoint registers the given endpoint with the dispatcher such that // packets that match the endpoint ID are delivered to it. -func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { for i, n := range netProtos { if err := d.singleRegisterEndpoint(n, protocol, id, ep, flags, bindToDevice); err != nil { d.unregisterEndpoint(netProtos[:i], protocol, id, ep, flags, bindToDevice) @@ -306,7 +306,7 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum } // checkEndpoint checks if an endpoint can be registered with the dispatcher. -func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { for _, n := range netProtos { if err := d.singleCheckEndpoint(n, protocol, id, flags, bindToDevice); err != nil { return err @@ -403,7 +403,7 @@ func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *Packet // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint // list. The list might be empty already. -func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) *tcpip.Error { +func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() @@ -412,7 +412,7 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p if len(ep.endpoints) != 0 { // If it was previously bound, we need to check if we can bind again. if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 { - return tcpip.ErrPortInUse + return &tcpip.ErrPortInUse{} } } @@ -422,7 +422,7 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p return nil } -func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) *tcpip.Error { +func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) tcpip.Error { ep.mu.RLock() defer ep.mu.RUnlock() @@ -431,7 +431,7 @@ func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) *tcpip.Error if len(ep.endpoints) != 0 { // If it was previously bound, we need to check if we can bind again. if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 { - return tcpip.ErrPortInUse + return &tcpip.ErrPortInUse{} } } @@ -456,7 +456,7 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports return len(ep.endpoints) == 0 } -func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { if id.RemotePort != 0 { // SO_REUSEPORT only applies to bound/listening endpoints. flags.LoadBalanced = false @@ -464,7 +464,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol eps, ok := d.protocol[protocolIDs{netProto, protocol}] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } eps.mu.Lock() @@ -482,7 +482,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice) } -func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { if id.RemotePort != 0 { // SO_REUSEPORT only applies to bound/listening endpoints. flags.LoadBalanced = false @@ -490,7 +490,7 @@ func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNum eps, ok := d.protocol[protocolIDs{netProto, protocol}] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } eps.mu.RLock() @@ -649,10 +649,10 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN // that packets of the appropriate protocol are delivered to it. A single // packet can be sent to one or more raw endpoints along with a non-raw // endpoint. -func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { +func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) tcpip.Error { eps, ok := d.protocol[protocolIDs{netProto, transProto}] if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } eps.mu.Lock() diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 57e1f8354..10cbbe589 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -175,9 +175,9 @@ func TestTransportDemuxerRegister(t *testing.T) { for _, test := range []struct { name string proto tcpip.NetworkProtocolNumber - want *tcpip.Error + want tcpip.Error }{ - {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol}, + {"failure", ipv6.ProtocolNumber, &tcpip.ErrUnknownProtocol{}}, {"success", ipv4.ProtocolNumber, nil}, } { t.Run(test.name, func(t *testing.T) { @@ -194,7 +194,7 @@ func TestTransportDemuxerRegister(t *testing.T) { if !ok { t.Fatalf("%T does not implement stack.TransportEndpoint", ep) } - if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want { + if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want { t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want) } }) @@ -294,7 +294,7 @@ func TestBindToDeviceDistribution(t *testing.T) { defer wq.EventUnregister(&we) defer close(ch) - var err *tcpip.Error + var err tcpip.Error ep, err := c.s.NewEndpoint(udp.ProtocolNumber, netProtoNum, &wq) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 9d39533a1..cf5de747b 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -15,6 +15,7 @@ package stack_test import ( + "bytes" "io" "testing" @@ -67,9 +68,9 @@ func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions { return &f.ops } -func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { - ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} - ep.ops.InitHandler(ep) +func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, s *stack.Stack) tcpip.Endpoint { + ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: s.UniqueID()} + ep.ops.InitHandler(ep, s, tcpip.GetStackSendBufferLimits) return ep } @@ -86,19 +87,20 @@ func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask return mask } -func (*fakeTransportEndpoint) Read(io.Writer, tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (*fakeTransportEndpoint) Read(io.Writer, tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { return tcpip.ReadResult{}, nil } -func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { if len(f.route.RemoteAddress) == 0 { - return 0, tcpip.ErrNoRoute + return 0, &tcpip.ErrNoRoute{} } - v, err := p.FullPayload() - if err != nil { - return 0, err + v := make([]byte, p.Len()) + if _, err := io.ReadFull(p, v); err != nil { + return 0, &tcpip.ErrBadBuffer{} } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen, Data: buffer.View(v).ToVectorisedView(), @@ -112,42 +114,42 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions } // SetSockOpt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error { - return tcpip.ErrInvalidEndpointState +func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error { + return &tcpip.ErrInvalidEndpointState{} } // SetSockOptInt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error { - return tcpip.ErrInvalidEndpointState +func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error { + return &tcpip.ErrInvalidEndpointState{} } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { - return -1, tcpip.ErrUnknownProtocolOption +func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { + return -1, &tcpip.ErrUnknownProtocolOption{} } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*fakeTransportEndpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error { - return tcpip.ErrInvalidEndpointState +func (*fakeTransportEndpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { + return &tcpip.ErrInvalidEndpointState{} } // Disconnect implements tcpip.Endpoint.Disconnect. -func (*fakeTransportEndpoint) Disconnect() *tcpip.Error { - return tcpip.ErrNotSupported +func (*fakeTransportEndpoint) Disconnect() tcpip.Error { + return &tcpip.ErrNotSupported{} } -func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { +func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) tcpip.Error { f.peerAddr = addr.Addr // Find the route. r, err := f.proto.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */) if err != nil { - return tcpip.ErrNoRoute + return &tcpip.ErrNoRoute{} } // Try to register so that we can start receiving packets. f.ID.RemoteAddress = addr.Addr - err = f.proto.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) + err = f.proto.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) if err != nil { r.Release() return err @@ -162,22 +164,22 @@ func (f *fakeTransportEndpoint) UniqueID() uint64 { return f.uniqueID } -func (*fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error { +func (*fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) tcpip.Error { return nil } -func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) *tcpip.Error { +func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error { return nil } func (*fakeTransportEndpoint) Reset() { } -func (*fakeTransportEndpoint) Listen(int) *tcpip.Error { +func (*fakeTransportEndpoint) Listen(int) tcpip.Error { return nil } -func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { if len(f.acceptQueue) == 0 { return nil, nil, nil } @@ -186,9 +188,8 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai return a, nil, nil } -func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { +func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) tcpip.Error { if err := f.proto.stack.RegisterTransportEndpoint( - a.NIC, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, stack.TransportEndpointID{LocalAddress: a.Addr}, @@ -202,11 +203,11 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { return nil } -func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { return tcpip.FullAddress{}, nil } -func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { return tcpip.FullAddress{}, nil } @@ -232,7 +233,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt * peerAddr: route.RemoteAddress, route: route, } - ep.ops.InitHandler(ep) + ep.ops.InitHandler(ep, f.proto.stack, tcpip.GetStackSendBufferLimits) f.acceptQueue = append(f.acceptQueue, ep) } @@ -251,7 +252,7 @@ func (*fakeTransportEndpoint) Resume(*stack.Stack) {} func (*fakeTransportEndpoint) Wait() {} -func (*fakeTransportEndpoint) LastError() *tcpip.Error { +func (*fakeTransportEndpoint) LastError() tcpip.Error { return nil } @@ -279,19 +280,19 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { return fakeTransNumber } -func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newFakeTransportEndpoint(f, netProto, f.stack.UniqueID()), nil +func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { + return newFakeTransportEndpoint(f, netProto, f.stack), nil } -func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return nil, tcpip.ErrUnknownProtocol +func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { + return nil, &tcpip.ErrUnknownProtocol{} } func (*fakeTransportProtocol) MinimumPacketSize() int { return fakeTransHeaderLen } -func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcpip.Error) { +func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err tcpip.Error) { return 0, 0, nil } @@ -299,23 +300,23 @@ func (*fakeTransportProtocol) HandleUnknownDestinationPacket(stack.TransportEndp return stack.UnknownDestinationPacketHandled } -func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error { +func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.TCPModerateReceiveBufferOption: f.opts.good = bool(*v) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } -func (f *fakeTransportProtocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error { +func (f *fakeTransportProtocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.TCPModerateReceiveBufferOption: *v = tcpip.TCPModerateReceiveBufferOption(f.opts.good) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } @@ -520,8 +521,10 @@ func TestTransportSend(t *testing.T) { } // Create buffer that will hold the payload. - view := buffer.NewView(30) - if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + b := make([]byte, 30) + var r bytes.Reader + r.Reset(b) + if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("write failed: %v", err) } |