diff options
Diffstat (limited to 'pkg/tcpip/stack')
25 files changed, 1647 insertions, 1879 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index d09ebe7fa..bb30556cf 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test", "most_shards") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -112,7 +112,7 @@ go_test( "transport_demuxer_test.go", "transport_test.go", ], - shard_count = 20, + shard_count = most_shards, deps = [ ":stack", "//pkg/rand", @@ -120,6 +120,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", @@ -131,7 +132,6 @@ go_test( "//pkg/tcpip/transport/udp", "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", - "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) @@ -148,7 +148,6 @@ go_test( ], library = ":stack", deps = [ - "//pkg/sleep", "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index 9478f3fb7..cd423bf71 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" ) -var _ GroupAddressableEndpoint = (*AddressableEndpointState)(nil) var _ AddressableEndpoint = (*AddressableEndpointState)(nil) // AddressableEndpointState is an implementation of an AddressableEndpoint. @@ -37,10 +36,6 @@ type AddressableEndpointState struct { endpoints map[tcpip.Address]*addressState primary []*addressState - - // groups holds the mapping between group addresses and the number of times - // they have been joined. - groups map[tcpip.Address]uint32 } } @@ -53,65 +48,33 @@ func (a *AddressableEndpointState) Init(networkEndpoint NetworkEndpoint) { a.mu.Lock() defer a.mu.Unlock() a.mu.endpoints = make(map[tcpip.Address]*addressState) - a.mu.groups = make(map[tcpip.Address]uint32) -} - -// ReadOnlyAddressableEndpointState provides read-only access to an -// AddressableEndpointState. -type ReadOnlyAddressableEndpointState struct { - inner *AddressableEndpointState } -// AddrOrMatching returns an endpoint for the passed address that is consisdered -// bound to the wrapped AddressableEndpointState. +// GetAddress returns the AddressEndpoint for the passed address. // -// If addr is an exact match with an existing address, that address is returned. -// Otherwise, f is called with each address and the address that f returns true -// for is returned. -// -// Returns nil of no address matches. -func (m ReadOnlyAddressableEndpointState) AddrOrMatching(addr tcpip.Address, spoofingOrPrimiscuous bool, f func(AddressEndpoint) bool) AddressEndpoint { - m.inner.mu.RLock() - defer m.inner.mu.RUnlock() - - if ep, ok := m.inner.mu.endpoints[addr]; ok { - if ep.IsAssigned(spoofingOrPrimiscuous) && ep.IncRef() { - return ep - } - } - - for _, ep := range m.inner.mu.endpoints { - if ep.IsAssigned(spoofingOrPrimiscuous) && f(ep) && ep.IncRef() { - return ep - } - } - - return nil -} - -// Lookup returns the AddressEndpoint for the passed address. +// GetAddress does not increment the address's reference count or check if the +// address is considered bound to the endpoint. // -// Returns nil if the passed address is not associated with the -// AddressableEndpointState. -func (m ReadOnlyAddressableEndpointState) Lookup(addr tcpip.Address) AddressEndpoint { - m.inner.mu.RLock() - defer m.inner.mu.RUnlock() +// Returns nil if the passed address is not associated with the endpoint. +func (a *AddressableEndpointState) GetAddress(addr tcpip.Address) AddressEndpoint { + a.mu.RLock() + defer a.mu.RUnlock() - ep, ok := m.inner.mu.endpoints[addr] + ep, ok := a.mu.endpoints[addr] if !ok { return nil } return ep } -// ForEach calls f for each address pair. +// ForEachEndpoint calls f for each address. // -// If f returns false, f is no longer be called. -func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool) { - m.inner.mu.RLock() - defer m.inner.mu.RUnlock() +// Once f returns false, f will no longer be called. +func (a *AddressableEndpointState) ForEachEndpoint(f func(AddressEndpoint) bool) { + a.mu.RLock() + defer a.mu.RUnlock() - for _, ep := range m.inner.mu.endpoints { + for _, ep := range a.mu.endpoints { if !f(ep) { return } @@ -120,18 +83,16 @@ func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool) // ForEachPrimaryEndpoint calls f for each primary address. // -// If f returns false, f is no longer be called. -func (m ReadOnlyAddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) { - m.inner.mu.RLock() - defer m.inner.mu.RUnlock() - for _, ep := range m.inner.mu.primary { - f(ep) - } -} +// Once f returns false, f will no longer be called. +func (a *AddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint) bool) { + a.mu.RLock() + defer a.mu.RUnlock() -// ReadOnly returns a readonly reference to a. -func (a *AddressableEndpointState) ReadOnly() ReadOnlyAddressableEndpointState { - return ReadOnlyAddressableEndpointState{inner: a} + for _, ep := range a.mu.primary { + if !f(ep) { + return + } + } } func (a *AddressableEndpointState) releaseAddressState(addrState *addressState) { @@ -335,11 +296,6 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { a.mu.Lock() defer a.mu.Unlock() - - if _, ok := a.mu.groups[addr]; ok { - panic(fmt.Sprintf("group address = %s must be removed with LeaveGroup", addr)) - } - return a.removePermanentAddressLocked(addr) } @@ -471,8 +427,19 @@ func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*ad return deprecatedEndpoint } -// AcquireAssignedAddress implements AddressableEndpoint. -func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint { +// AcquireAssignedAddressOrMatching returns an address endpoint that is +// considered assigned to the addressable endpoint. +// +// If the address is an exact match with an existing address, that address is +// returned. Otherwise, if f is provided, f is called with each address and +// the address that f returns true for is returned. +// +// If there is no matching address, a temporary address will be returned if +// allowTemp is true. +// +// Regardless how the address was obtained, it will be acquired before it is +// returned. +func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tcpip.Address, f func(AddressEndpoint) bool, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint { a.mu.Lock() defer a.mu.Unlock() @@ -488,6 +455,14 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres return addrState } + if f != nil { + for _, addrState := range a.mu.endpoints { + if addrState.IsAssigned(allowTemp) && f(addrState) && addrState.IncRef() { + return addrState + } + } + } + if !allowTemp { return nil } @@ -520,6 +495,11 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres return ep } +// AcquireAssignedAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint { + return a.AcquireAssignedAddressOrMatching(localAddr, nil, allowTemp, tempPEB) +} + // AcquireOutgoingPrimaryAddress implements AddressableEndpoint. func (a *AddressableEndpointState) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint { a.mu.RLock() @@ -588,72 +568,11 @@ func (a *AddressableEndpointState) PermanentAddresses() []tcpip.AddressWithPrefi return addrs } -// JoinGroup implements GroupAddressableEndpoint. -func (a *AddressableEndpointState) JoinGroup(group tcpip.Address) (bool, *tcpip.Error) { - a.mu.Lock() - defer a.mu.Unlock() - - joins, ok := a.mu.groups[group] - if !ok { - ep, err := a.addAndAcquireAddressLocked(group.WithPrefix(), NeverPrimaryEndpoint, AddressConfigStatic, false /* deprecated */, true /* permanent */) - if err != nil { - return false, err - } - // We have no need for the address endpoint. - a.decAddressRefLocked(ep) - } - - a.mu.groups[group] = joins + 1 - return !ok, nil -} - -// LeaveGroup implements GroupAddressableEndpoint. -func (a *AddressableEndpointState) LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) { - a.mu.Lock() - defer a.mu.Unlock() - - joins, ok := a.mu.groups[group] - if !ok { - return false, tcpip.ErrBadLocalAddress - } - - if joins == 1 { - a.removeGroupAddressLocked(group) - delete(a.mu.groups, group) - return true, nil - } - - a.mu.groups[group] = joins - 1 - return false, nil -} - -// IsInGroup implements GroupAddressableEndpoint. -func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool { - a.mu.RLock() - defer a.mu.RUnlock() - _, ok := a.mu.groups[group] - return ok -} - -func (a *AddressableEndpointState) removeGroupAddressLocked(group tcpip.Address) { - if err := a.removePermanentAddressLocked(group); err != nil { - // removePermanentEndpointLocked would only return an error if group is - // not bound to the addressable endpoint, but we know it MUST be assigned - // since we have group in our map of groups. - panic(fmt.Sprintf("error removing group address = %s: %s", group, err)) - } -} - // Cleanup forcefully leaves all groups and removes all permanent addresses. func (a *AddressableEndpointState) Cleanup() { a.mu.Lock() defer a.mu.Unlock() - for group := range a.mu.groups { - a.removeGroupAddressLocked(group) - } - a.mu.groups = make(map[tcpip.Address]uint32) - for _, ep := range a.mu.endpoints { // removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is // not a permanent address. diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go index 26787d0a3..140f146f6 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state_test.go +++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go @@ -53,25 +53,9 @@ func TestAddressableEndpointStateCleanup(t *testing.T) { ep.DecRef() } - group := tcpip.Address("\x02") - if added, err := s.JoinGroup(group); err != nil { - t.Fatalf("s.JoinGroup(%s): %s", group, err) - } else if !added { - t.Fatalf("got s.JoinGroup(%s) = false, want = true", group) - } - if !s.IsInGroup(group) { - t.Fatalf("got s.IsInGroup(%s) = false, want = true", group) - } - s.Cleanup() - { - ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint) - if ep != nil { - ep.DecRef() - t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix()) - } - } - if s.IsInGroup(group) { - t.Fatalf("got s.IsInGroup(%s) = true, want = false", group) + if ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint); ep != nil { + ep.DecRef() + t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix()) } } diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 9a17efcba..5e649cca6 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -142,19 +142,19 @@ func (cn *conn) timedOut(now time.Time) bool { // update the connection tracking state. // -// Precondition: ct.mu must be held. -func (ct *conn) updateLocked(tcpHeader header.TCP, hook Hook) { +// Precondition: cn.mu must be held. +func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) { // Update the state of tcb. tcb assumes it's always initialized on the // client. However, we only need to know whether the connection is // established or not, so the client/server distinction isn't important. // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle // other tcp states. - if ct.tcb.IsEmpty() { - ct.tcb.Init(tcpHeader) - } else if hook == ct.tcbHook { - ct.tcb.UpdateStateOutbound(tcpHeader) + if cn.tcb.IsEmpty() { + cn.tcb.Init(tcpHeader) + } else if hook == cn.tcbHook { + cn.tcb.UpdateStateOutbound(tcpHeader) } else { - ct.tcb.UpdateStateInbound(tcpHeader) + cn.tcb.UpdateStateInbound(tcpHeader) } } diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 7a501acdc..93e8e1c51 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -74,8 +74,30 @@ func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { } func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) { - // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) + netHdr := pkt.NetworkHeader().View() + _, dst := f.proto.ParseAddresses(netHdr) + + addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), CanBePrimaryEndpoint) + if addressEndpoint != nil { + addressEndpoint.DecRef() + // Dispatch the packet to the transport protocol. + f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(netHdr[protocolNumberOffset]), pkt) + return + } + + r, err := f.proto.stack.FindRoute(0, "", dst, fwdTestNetNumber, false /* multicastLoop */) + if err != nil { + return + } + defer r.Release() + + vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) + pkt = NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: vv.ToView().ToVectorisedView(), + }) + // TODO(b/143425874) Decrease the TTL field in forwarded packets. + _ = r.WriteHeaderIncludedPacket(pkt) } func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { @@ -106,8 +128,13 @@ func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuf panic("not implemented") } -func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error { - return tcpip.ErrNotSupported +func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error { + // The network header should not already be populated. + if _, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen); !ok { + return tcpip.ErrMalformedHeader + } + + return f.nic.WritePacket(r, nil /* gso */, fwdTestNetNumber, pkt) } func (f *fwdTestNetworkEndpoint) Close() { @@ -117,6 +144,8 @@ func (f *fwdTestNetworkEndpoint) Close() { // fwdTestNetworkProtocol is a network-layer protocol that implements Address // resolution. type fwdTestNetworkProtocol struct { + stack *Stack + addrCache *linkAddrCache neigh *neighborCache addrResolveDelay time.Duration @@ -280,7 +309,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { p := fwdTestPacketInfo{ - RemoteLinkAddress: r.RemoteLinkAddress, + RemoteLinkAddress: r.RemoteLinkAddress(), LocalLinkAddress: r.LocalLinkAddress, Pkt: pkt, } @@ -304,20 +333,6 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer return n, nil } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - p := fwdTestPacketInfo{ - Pkt: NewPacketBuffer(PacketBufferOptions{Data: vv}), - } - - select { - case e.C <- p: - default: - } - - return nil -} - // Wait implements stack.LinkEndpoint.Wait. func (*fwdTestLinkEndpoint) Wait() {} @@ -334,7 +349,10 @@ func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protoco func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) { // Create a stack with the network protocol and two NICs. s := New(Options{ - NetworkProtocols: []NetworkProtocolFactory{func(*Stack) NetworkProtocol { return proto }}, + NetworkProtocols: []NetworkProtocolFactory{func(s *Stack) NetworkProtocol { + proto.stack = s + return proto + }}, UseNeighborCache: useNeighborCache, }) @@ -542,6 +560,38 @@ func TestForwardingWithNoResolver(t *testing.T) { } } +func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { + proto := &fwdTestNetworkProtocol{ + addrResolveDelay: 50 * time.Millisecond, + onLinkAddressResolved: func(*linkAddrCache, *neighborCache, tcpip.Address, tcpip.LinkAddress) { + // Don't resolve the link address. + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto, true /* useNeighborCache */) + + const numPackets int = 5 + // These packets will all be enqueued in the packet queue to wait for link + // address resolution. + for i := 0; i < numPackets; i++ { + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } + + // All packets should fail resolution. + // TODO(gvisor.dev/issue/5141): Use a fake clock. + for i := 0; i < numPackets; i++ { + select { + case got := <-ep2.C: + t.Fatalf("got %#v; packets should have failed resolution and not been forwarded", got) + case <-time.After(100 * time.Millisecond): + } + } +} + func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { tests := []struct { name string diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 2d8c883cd..09c7811fa 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -45,13 +45,13 @@ const reaperDelay = 5 * time.Second func DefaultTables() *IPTables { return &IPTables{ v4Tables: [NumTables]Table{ - NATID: Table{ + NATID: { Rules: []Rule{ - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, }, BuiltinChains: [NumHooks]int{ Prerouting: 0, @@ -68,11 +68,11 @@ func DefaultTables() *IPTables { Postrouting: 3, }, }, - MangleID: Table{ + MangleID: { Rules: []Rule{ - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, }, BuiltinChains: [NumHooks]int{ Prerouting: 0, @@ -86,12 +86,12 @@ func DefaultTables() *IPTables { Postrouting: HookUnset, }, }, - FilterID: Table{ + FilterID: { Rules: []Rule{ - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, }, BuiltinChains: [NumHooks]int{ Prerouting: HookUnset, @@ -110,13 +110,13 @@ func DefaultTables() *IPTables { }, }, v6Tables: [NumTables]Table{ - NATID: Table{ + NATID: { Rules: []Rule{ - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, - Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + {Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, }, BuiltinChains: [NumHooks]int{ Prerouting: 0, @@ -133,11 +133,11 @@ func DefaultTables() *IPTables { Postrouting: 3, }, }, - MangleID: Table{ + MangleID: { Rules: []Rule{ - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, - Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + {Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, }, BuiltinChains: [NumHooks]int{ Prerouting: 0, @@ -151,12 +151,12 @@ func DefaultTables() *IPTables { Postrouting: HookUnset, }, }, - FilterID: Table{ + FilterID: { Rules: []Rule{ - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, - Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + {Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, }, BuiltinChains: [NumHooks]int{ Prerouting: HookUnset, @@ -175,9 +175,9 @@ func DefaultTables() *IPTables { }, }, priorities: [NumHooks][]TableID{ - Prerouting: []TableID{MangleID, NATID}, - Input: []TableID{NATID, FilterID}, - Output: []TableID{MangleID, NATID, FilterID}, + Prerouting: {MangleID, NATID}, + Input: {NATID, FilterID}, + Output: {MangleID, NATID, FilterID}, }, connections: ConnTrack{ seed: generateRandUint32(), diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 4b86c1be9..56a3e7861 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -56,7 +56,7 @@ const ( // Postrouting happens just before a packet goes out on the wire. Postrouting - // The total number of hooks. + // NumHooks is the total number of hooks. NumHooks ) @@ -273,14 +273,12 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) boo return true } - // If the interface name ends with '+', any interface which begins - // with the name should be matched. + // If the interface name ends with '+', any interface which + // begins with the name should be matched. ifName := fl.OutputInterface - matches := true + matches := nicName == ifName if strings.HasSuffix(ifName, "+") { matches = strings.HasPrefix(nicName, ifName[:n-1]) - } else { - matches = nicName == ifName } return fl.OutputInterfaceInvert != matches } diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index c9b13cd0e..792f4f170 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -18,7 +18,6 @@ import ( "fmt" "time" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -58,9 +57,6 @@ const ( incomplete entryState = iota // ready means that the address has been resolved and can be used. ready - // failed means that address resolution timed out and the address - // could not be resolved. - failed ) // String implements Stringer. @@ -70,8 +66,6 @@ func (s entryState) String() string { return "incomplete" case ready: return "ready" - case failed: - return "failed" default: return fmt.Sprintf("unknown(%d)", s) } @@ -80,40 +74,48 @@ func (s entryState) String() string { // A linkAddrEntry is an entry in the linkAddrCache. // This struct is thread-compatible. type linkAddrEntry struct { + // linkAddrEntryEntry access is synchronized by the linkAddrCache lock. linkAddrEntryEntry + // TODO(gvisor.dev/issue/5150): move these fields under mu. + // mu protects the fields below. + mu sync.RWMutex + addr tcpip.FullAddress linkAddr tcpip.LinkAddress expiration time.Time s entryState - // wakers is a set of waiters for address resolution result. Anytime - // state transitions out of incomplete these waiters are notified. - wakers map[*sleep.Waker]struct{} - - // done is used to allow callers to wait on address resolution. It is nil iff - // s is incomplete and resolution is not yet in progress. + // done is closed when address resolution is complete. It is nil iff s is + // incomplete and resolution is not yet in progress. done chan struct{} + + // onResolve is called with the result of address resolution. + onResolve []func(tcpip.LinkAddress, bool) } -// changeState sets the entry's state to ns, notifying any waiters. +func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) { + for _, callback := range e.onResolve { + callback(linkAddr, len(linkAddr) != 0) + } + e.onResolve = nil + if ch := e.done; ch != nil { + close(ch) + e.done = nil + } +} + +// changeStateLocked sets the entry's state to ns. // // The entry's expiration is bumped up to the greater of itself and the passed // expiration; the zero value indicates immediate expiration, and is set // unconditionally - this is an implementation detail that allows for entries // to be reused. -func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) { - // Notify whoever is waiting on address resolution when transitioning - // out of incomplete. - if e.s == incomplete && ns != incomplete { - for w := range e.wakers { - w.Assert() - } - e.wakers = nil - if ch := e.done; ch != nil { - close(ch) - } - e.done = nil +// +// Precondition: e.mu must be locked +func (e *linkAddrEntry) changeStateLocked(ns entryState, expiration time.Time) { + if e.s == incomplete && ns == ready { + e.notifyCompletionLocked(e.linkAddr) } if expiration.IsZero() || expiration.After(e.expiration) { @@ -122,10 +124,6 @@ func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) { e.s = ns } -func (e *linkAddrEntry) removeWaker(w *sleep.Waker) { - delete(e.wakers, w) -} - // add adds a k -> v mapping to the cache. func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) { // Calculate expiration time before acquiring the lock, since expiration is @@ -135,10 +133,12 @@ func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) { c.cache.Lock() entry := c.getOrCreateEntryLocked(k) - entry.linkAddr = v - - entry.changeState(ready, expiration) c.cache.Unlock() + + entry.mu.Lock() + defer entry.mu.Unlock() + entry.linkAddr = v + entry.changeStateLocked(ready, expiration) } // getOrCreateEntryLocked retrieves a cache entry associated with k. The @@ -159,13 +159,14 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt var entry *linkAddrEntry if len(c.cache.table) == linkAddrCacheSize { entry = c.cache.lru.Back() + entry.mu.Lock() delete(c.cache.table, entry.addr) c.cache.lru.Remove(entry) - // Wake waiters and mark the soon-to-be-reused entry as expired. Note - // that the state passed doesn't matter when the zero time is passed. - entry.changeState(failed, time.Time{}) + // Wake waiters and mark the soon-to-be-reused entry as expired. + entry.notifyCompletionLocked("" /* linkAddr */) + entry.mu.Unlock() } else { entry = new(linkAddrEntry) } @@ -180,9 +181,12 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt } // get reports any known link address for k. -func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { if linkRes != nil { if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok { + if onResolve != nil { + onResolve(addr, true) + } return addr, nil, nil } } @@ -190,56 +194,35 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo c.cache.Lock() defer c.cache.Unlock() entry := c.getOrCreateEntryLocked(k) + entry.mu.Lock() + defer entry.mu.Unlock() + switch s := entry.s; s { - case ready, failed: + case ready: if !time.Now().After(entry.expiration) { // Not expired. - switch s { - case ready: - return entry.linkAddr, nil, nil - case failed: - return entry.linkAddr, nil, tcpip.ErrNoLinkAddress - default: - panic(fmt.Sprintf("invalid cache entry state: %s", s)) + if onResolve != nil { + onResolve(entry.linkAddr, true) } + return entry.linkAddr, nil, nil } - entry.changeState(incomplete, time.Time{}) + entry.changeStateLocked(incomplete, time.Time{}) fallthrough case incomplete: - if waker != nil { - if entry.wakers == nil { - entry.wakers = make(map[*sleep.Waker]struct{}) - } - entry.wakers[waker] = struct{}{} + if onResolve != nil { + entry.onResolve = append(entry.onResolve, onResolve) } - if entry.done == nil { - // Address resolution needs to be initiated. - if linkRes == nil { - return entry.linkAddr, nil, tcpip.ErrNoLinkAddress - } - entry.done = make(chan struct{}) go c.startAddressResolution(k, linkRes, localAddr, nic, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. } - return entry.linkAddr, entry.done, tcpip.ErrWouldBlock default: panic(fmt.Sprintf("invalid cache entry state: %s", s)) } } -// removeWaker removes a waker previously added through get(). -func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) { - c.cache.Lock() - defer c.cache.Unlock() - - if entry, ok := c.cache.table[k]; ok { - entry.removeWaker(waker) - } -} - func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) { for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check @@ -257,9 +240,9 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link } } -// checkLinkRequest checks whether previous attempt to resolve address has succeeded -// and mark the entry accordingly, e.g. ready, failed, etc. Return true if request -// can stop, false if another request should be sent. +// checkLinkRequest checks whether previous attempt to resolve address has +// succeeded and mark the entry accordingly. Returns true if request can stop, +// false if another request should be sent. func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool { c.cache.Lock() defer c.cache.Unlock() @@ -268,16 +251,20 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, att // Entry was evicted from the cache. return true } + entry.mu.Lock() + defer entry.mu.Unlock() + switch s := entry.s; s { - case ready, failed: - // Entry was made ready by resolver or failed. Either way we're done. + case ready: + // Entry was made ready by resolver. case incomplete: if attempt+1 < c.resolutionAttempts { // No response yet, need to send another ARP request. return false } - // Max number of retries reached, mark entry as failed. - entry.changeState(failed, now.Add(c.ageLimit)) + // Max number of retries reached, delete entry. + entry.notifyCompletionLocked("" /* linkAddr */) + delete(c.cache.table, k) default: panic(fmt.Sprintf("invalid cache entry state: %s", s)) } diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index d2e37f38d..6883045b5 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -21,7 +21,6 @@ import ( "testing" "time" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -50,6 +49,7 @@ type testLinkAddressResolver struct { } func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { + // TODO(gvisor.dev/issue/5141): Use a fake clock. time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) }) if f := r.onLinkAddressRequest; f != nil { f() @@ -78,16 +78,18 @@ func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe } func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) { - w := sleep.Waker{} - s := sleep.Sleeper{} - s.AddWaker(&w, 123) - defer s.Done() - + var attemptedResolution bool for { - if got, _, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { - return got, err + got, ch, err := c.get(addr, linkRes, "", nil, nil) + if err == tcpip.ErrWouldBlock { + if attemptedResolution { + return got, tcpip.ErrNoLinkAddress + } + attemptedResolution = true + <-ch + continue } - s.Fetch(true) + return got, err } } @@ -116,16 +118,19 @@ func TestCacheOverflow(t *testing.T) { } } // The earliest entries should no longer be in the cache. + c.cache.Lock() + defer c.cache.Unlock() for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- { e := testAddrs[i] - if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { - t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err) + if entry, ok := c.cache.table[e.addr]; ok { + t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry) } } } func TestCacheConcurrent(t *testing.T) { c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) + linkRes := &testLinkAddressResolver{cache: c} var wg sync.WaitGroup for r := 0; r < 16; r++ { @@ -133,7 +138,6 @@ func TestCacheConcurrent(t *testing.T) { go func() { for _, e := range testAddrs { c.add(e.addr, e.linkAddr) - c.get(e.addr, nil, "", nil, nil) // make work for gotsan } wg.Done() }() @@ -144,7 +148,7 @@ func TestCacheConcurrent(t *testing.T) { // can fit in the cache, so our eviction strategy requires that // the last entry be present and the first be missing. e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, nil, "", nil, nil) + got, _, err := c.get(e.addr, linkRes, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) } @@ -153,18 +157,22 @@ func TestCacheConcurrent(t *testing.T) { } e = testAddrs[0] - if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) + c.cache.Lock() + defer c.cache.Unlock() + if entry, ok := c.cache.table[e.addr]; ok { + t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry) } } func TestCacheAgeLimit(t *testing.T) { c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3) + linkRes := &testLinkAddressResolver{cache: c} + e := testAddrs[0] c.add(e.addr, e.linkAddr) time.Sleep(50 * time.Millisecond) - if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) + if _, _, err := c.get(e.addr, linkRes, "", nil, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got c.get(%q) = %s, want = ErrWouldBlock", string(e.addr.Addr), err) } } @@ -282,71 +290,3 @@ func TestStaticResolution(t *testing.T) { t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want)) } } - -// TestCacheWaker verifies that RemoveWaker removes a waker previously added -// through get(). -func TestCacheWaker(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) - - // First, sanity check that wakers are working. - { - linkRes := &testLinkAddressResolver{cache: c} - s := sleep.Sleeper{} - defer s.Done() - - const wakerID = 1 - w := sleep.Waker{} - s.AddWaker(&w, wakerID) - - e := testAddrs[0] - - if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock) - } - id, ok := s.Fetch(true /* block */) - if !ok { - t.Fatal("got s.Fetch(true) = (_, false), want = (_, true)") - } - if id != wakerID { - t.Fatalf("got s.Fetch(true) = (%d, %t), want = (%d, true)", id, ok, wakerID) - } - - if got, _, err := c.get(e.addr, linkRes, "", nil, nil); err != nil { - t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err) - } else if got != e.linkAddr { - t.Fatalf("got c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr) - } - } - - // Check that RemoveWaker works. - { - linkRes := &testLinkAddressResolver{cache: c} - s := sleep.Sleeper{} - defer s.Done() - - const wakerID = 2 // different than the ID used in the sanity check - w := sleep.Waker{} - s.AddWaker(&w, wakerID) - - e := testAddrs[1] - linkRes.onLinkAddressRequest = func() { - // Remove the waker before the linkAddrCache has the opportunity to send - // a notification. - c.removeWaker(e.addr, &w) - } - - if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock) - } - - if got, err := getBlocking(c, e.addr, linkRes); err != nil { - t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err) - } else if got != e.linkAddr { - t.Fatalf("c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr) - } - - if id, ok := s.Fetch(false /* block */); ok { - t.Fatalf("unexpected notification from waker with id %d", id) - } - } -} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 73a01c2dd..61636cae5 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" @@ -352,7 +353,7 @@ func TestDADDisabled(t *testing.T) { } // We should not have sent any NDP NS messages. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 { + if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != 0 { t.Fatalf("got NeighborSolicit = %d, want = 0", got) } } @@ -465,14 +466,18 @@ func TestDADResolve(t *testing.T) { if err != tcpip.ErrNoRoute { t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute) } - r.Release() + if r != nil { + r.Release() + } } { r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false) if err != tcpip.ErrNoRoute { t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute) } - r.Release() + if r != nil { + r.Release() + } } if t.Failed() { @@ -510,7 +515,9 @@ func TestDADResolve(t *testing.T) { } else if r.LocalAddress != addr1 { t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1) } - r.Release() + if r != nil { + r.Release() + } } if t.Failed() { @@ -518,7 +525,7 @@ func TestDADResolve(t *testing.T) { } // Should not have sent any more NS messages. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) { + if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) { t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits) } @@ -563,18 +570,18 @@ func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns := header.NDPNeighborSolicit(pkt.MessageBody()) ns.SetTargetAddress(tgt) snmc := header.SolicitedNodeAddr(tgt) pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{})) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: 255, - SrcAddr: header.IPv6Any, - DstAddr: snmc, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: 255, + SrcAddr: header.IPv6Any, + DstAddr: snmc, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()})) } @@ -605,7 +612,7 @@ func TestDADFail(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) pkt := header.ICMPv6(hdr.Prepend(naSize)) pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.NDPPayload()) + na := header.NDPNeighborAdvert(pkt.MessageBody()) na.SetSolicitedFlag(true) na.SetOverrideFlag(true) na.SetTargetAddress(tgt) @@ -616,11 +623,11 @@ func TestDADFail(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: 255, - SrcAddr: tgt, - DstAddr: header.IPv6AllNodesMulticastAddress, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: 255, + SrcAddr: tgt, + DstAddr: header.IPv6AllNodesMulticastAddress, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()})) }, @@ -666,7 +673,7 @@ func TestDADFail(t *testing.T) { // Receive a packet to simulate an address conflict. test.rxPkt(e, addr1) - stat := test.getStat(s.Stats().ICMP.V6PacketsReceived) + stat := test.getStat(s.Stats().ICMP.V6.PacketsReceived) if got := stat.Value(); got != 1 { t.Fatalf("got stat = %d, want = 1", got) } @@ -803,7 +810,7 @@ func TestDADStop(t *testing.T) { } // Should not have sent more than 1 NS message. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 { + if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got > 1 { t.Errorf("got NeighborSolicit = %d, want <= 1", got) } }) @@ -982,7 +989,7 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo pkt := header.ICMPv6(hdr.Prepend(icmpSize)) pkt.SetType(header.ICMPv6RouterAdvert) pkt.SetCode(0) - raPayload := pkt.NDPPayload() + raPayload := pkt.MessageBody() ra := header.NDPRouterAdvert(raPayload) // Populate the Router Lifetime. binary.BigEndian.PutUint16(raPayload[2:], rl) @@ -1004,11 +1011,11 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo payloadLength := hdr.UsedLength() iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) iph.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: header.NDPHopLimit, - SrcAddr: ip, - DstAddr: header.IPv6AllNodesMulticastAddress, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: header.NDPHopLimit, + SrcAddr: ip, + DstAddr: header.IPv6AllNodesMulticastAddress, }) return stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -2162,8 +2169,8 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { NDPConfigs: ipv6.NDPConfigurations{ AutoGenTempGlobalAddresses: true, }, - NDPDisp: &ndpDisp, - AutoGenIPv6LinkLocal: true, + NDPDisp: &ndpDisp, + AutoGenLinkLocal: true, })}, }) @@ -2843,9 +2850,7 @@ func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) } defer ep.Close() - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) - } + ep.SocketOptions().SetV6Only(true) if err := ep.Connect(addr); err != nil { t.Fatalf("ep.Connect(%+v): %s", addr, err) } @@ -2879,9 +2884,7 @@ func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullA t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) } defer ep.Close() - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) - } + ep.SocketOptions().SetV6Only(true) if err := ep.Bind(addr); err != nil { t.Fatalf("ep.Bind(%+v): %s", addr, err) } @@ -3250,9 +3253,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) } defer ep.Close() - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) - } + ep.SocketOptions().SetV6Only(true) if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute { t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, tcpip.ErrNoRoute) @@ -4044,9 +4045,9 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { ndpConfigs.AutoGenAddressConflictRetries = maxRetries s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, + AutoGenLinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: func(_ tcpip.NICID, nicName string) string { return nicName @@ -4179,9 +4180,9 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, - NDPConfigs: addrType.ndpConfigs, - NDPDisp: &ndpDisp, + AutoGenLinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: addrType.ndpConfigs, + NDPDisp: &ndpDisp, })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -4708,7 +4709,7 @@ func TestCleanupNDPState(t *testing.T) { } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenIPv6LinkLocal: true, + AutoGenLinkLocal: true, NDPConfigs: ipv6.NDPConfigurations{ HandleRAs: true, DiscoverDefaultRouters: true, @@ -5174,113 +5175,99 @@ func TestRouterSolicitation(t *testing.T) { }, } - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + e := channelLinkWithHeaderLength{ + Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), + headerLength: test.linkHeaderLen, + } + e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired + waitForPkt := func(timeout time.Duration) { + t.Helper() - e := channelLinkWithHeaderLength{ - Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), - headerLength: test.linkHeaderLen, + clock.Advance(timeout) + p, ok := e.Read() + if !ok { + t.Fatal("expected router solicitation packet") } - e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired - waitForPkt := func(timeout time.Duration) { - t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - p, ok := e.ReadContext(ctx) - if !ok { - t.Fatal("timed out waiting for packet") - return - } - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } - // Make sure the right remote link address is used. - if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) - } + // Make sure the right remote link address is used. + if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) + } - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(test.expectedSrcAddr), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), - ) + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(test.expectedSrcAddr), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), + ) - if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) - } - } - waitForNothing := func(timeout time.Duration) { - t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { - t.Fatal("unexpectedly got a packet") - } - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - MaxRtrSolicitations: test.maxRtrSolicit, - RtrSolicitationInterval: test.rtrSolicitInt, - MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, - }, - })}, - }) - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) } + } + waitForNothing := func(timeout time.Duration) { + t.Helper() - if addr := test.nicAddr; addr != "" { - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) - } + clock.Advance(timeout) + if p, ok := e.Read(); ok { + t.Fatalf("unexpectedly got a packet = %#v", p) } + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + MaxRtrSolicitations: test.maxRtrSolicit, + RtrSolicitationInterval: test.rtrSolicitInt, + MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, + }, + })}, + Clock: clock, + }) + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - // Make sure each RS is sent at the right time. - remaining := test.maxRtrSolicit - if remaining > 0 { - waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncPositiveEventTimeout) - remaining-- + if addr := test.nicAddr; addr != "" { + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) } + } - for ; remaining > 0; remaining-- { - if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { - waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncNegativeEventTimeout) - waitForPkt(defaultAsyncPositiveEventTimeout) - } else { - waitForPkt(test.effectiveRtrSolicitInt + defaultAsyncPositiveEventTimeout) - } - } + // Make sure each RS is sent at the right time. + remaining := test.maxRtrSolicit + if remaining > 0 { + waitForPkt(test.effectiveMaxRtrSolicitDelay) + remaining-- + } - // Make sure no more RS. - if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { - waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncNegativeEventTimeout) + for ; remaining > 0; remaining-- { + if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { + waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) + waitForPkt(time.Nanosecond) } else { - waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncNegativeEventTimeout) + waitForPkt(test.effectiveRtrSolicitInt) } + } - // Make sure the counter got properly - // incremented. - if got, want := s.Stats().ICMP.V6PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { - t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) - } - }) - } - }) + // Make sure no more RS. + if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { + waitForNothing(test.effectiveRtrSolicitInt) + } else { + waitForNothing(test.effectiveMaxRtrSolicitDelay) + } + + if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { + t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) + } + }) + } } func TestStopStartSolicitingRouters(t *testing.T) { diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 177bf5516..c15f10e76 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -17,16 +17,22 @@ package stack import ( "fmt" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) const neighborCacheSize = 512 // max entries per interface +// NeighborStats holds metrics for the neighbor table. +type NeighborStats struct { + // FailedEntryLookups counts the number of lookups performed on an entry in + // Failed state. + FailedEntryLookups *tcpip.StatCounter +} + // neighborCache maps IP addresses to link addresses. It uses the Least // Recently Used (LRU) eviction strategy to implement a bounded cache for -// dynmically acquired entries. It contains the state machine and configuration +// dynamically acquired entries. It contains the state machine and configuration // for running Neighbor Unreachability Detection (NUD). // // There are two types of entries in the neighbor cache: @@ -92,9 +98,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA n.dynamic.lru.Remove(e) n.dynamic.count-- - e.dispatchRemoveEventLocked() - e.setStateLocked(Unknown) - e.notifyWakersLocked() + e.removeLocked() e.mu.Unlock() } n.cache[remoteAddr] = entry @@ -103,21 +107,27 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA return entry } -// entry looks up the neighbor cache for translating address to link address -// (e.g. IP -> MAC). If the LinkEndpoint requests address resolution and there -// is a LinkAddressResolver registered with the network protocol, the cache -// attempts to resolve the address and returns ErrWouldBlock. If a Waker is -// provided, it will be notified when address resolution is complete (success -// or not). +// entry looks up neighbor information matching the remote address, and returns +// it if readily available. +// +// Returns ErrWouldBlock if the link address is not readily available, along +// with a notification channel for the caller to block on. Triggers address +// resolution asynchronously. +// +// If onResolve is provided, it will be called either immediately, if resolution +// is not required, or when address resolution is complete, with the resolved +// link address and whether resolution succeeded. After any callbacks have been +// called, the returned notification channel is closed. +// +// NB: if a callback is provided, it should not call into the neighbor cache. // // If specified, the local address must be an address local to the interface the // neighbor cache belongs to. The local address is the source address of a // packet prompting NUD/link address resolution. // -// If address resolution is required, ErrNoLinkAddress and a notification -// channel is returned for the top level caller to block. Channel is closed -// once address resolution is complete (success or not). -func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) { +// TODO(gvisor.dev/issue/5151): Don't return the neighbor entry. +func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (NeighborEntry, <-chan struct{}, *tcpip.Error) { + // TODO(gvisor.dev/issue/5149): Handle static resolution in route.Resolve. if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok { e := NeighborEntry{ Addr: remoteAddr, @@ -125,6 +135,9 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA State: Static, UpdatedAtNanos: 0, } + if onResolve != nil { + onResolve(linkAddr, true) + } return e, nil, nil } @@ -142,47 +155,36 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA // of packets to a neighbor. While reasserting a neighbor's reachability, // a node continues sending packets to that neighbor using the cached // link-layer address." + if onResolve != nil { + onResolve(entry.neigh.LinkAddr, true) + } return entry.neigh, nil, nil - case Unknown, Incomplete: - entry.addWakerLocked(w) - + case Unknown, Incomplete, Failed: + if onResolve != nil { + entry.onResolve = append(entry.onResolve, onResolve) + } if entry.done == nil { // Address resolution needs to be initiated. - if linkRes == nil { - return entry.neigh, nil, tcpip.ErrNoLinkAddress - } entry.done = make(chan struct{}) } - entry.handlePacketQueuedLocked(localAddr) return entry.neigh, entry.done, tcpip.ErrWouldBlock - case Failed: - return entry.neigh, nil, tcpip.ErrNoLinkAddress default: panic(fmt.Sprintf("Invalid cache entry state: %s", s)) } } -// removeWaker removes a waker that has been added when link resolution for -// addr was requested. -func (n *neighborCache) removeWaker(addr tcpip.Address, waker *sleep.Waker) { - n.mu.Lock() - if entry, ok := n.cache[addr]; ok { - delete(entry.wakers, waker) - } - n.mu.Unlock() -} - // entries returns all entries in the neighbor cache. func (n *neighborCache) entries() []NeighborEntry { - entries := make([]NeighborEntry, 0, len(n.cache)) n.mu.RLock() + defer n.mu.RUnlock() + + entries := make([]NeighborEntry, 0, len(n.cache)) for _, entry := range n.cache { entry.mu.RLock() entries = append(entries, entry.neigh) entry.mu.RUnlock() } - n.mu.RUnlock() return entries } @@ -214,32 +216,13 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd return } - // Notify that resolution has been interrupted, just in case the entry was - // in the Incomplete or Probe state. - entry.dispatchRemoveEventLocked() - entry.setStateLocked(Unknown) - entry.notifyWakersLocked() + entry.removeLocked() entry.mu.Unlock() } n.cache[addr] = newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) } -// removeEntryLocked removes the specified entry from the neighbor cache. -func (n *neighborCache) removeEntryLocked(entry *neighborEntry) { - if entry.neigh.State != Static { - n.dynamic.lru.Remove(entry) - n.dynamic.count-- - } - if entry.neigh.State != Failed { - entry.dispatchRemoveEventLocked() - } - entry.setStateLocked(Unknown) - entry.notifyWakersLocked() - - delete(n.cache, entry.neigh.Addr) -} - // removeEntry removes a dynamic or static entry by address from the neighbor // cache. Returns true if the entry was found and deleted. func (n *neighborCache) removeEntry(addr tcpip.Address) bool { @@ -254,7 +237,13 @@ func (n *neighborCache) removeEntry(addr tcpip.Address) bool { entry.mu.Lock() defer entry.mu.Unlock() - n.removeEntryLocked(entry) + if entry.neigh.State != Static { + n.dynamic.lru.Remove(entry) + n.dynamic.count-- + } + + entry.removeLocked() + delete(n.cache, entry.neigh.Addr) return true } @@ -265,9 +254,7 @@ func (n *neighborCache) clear() { for _, entry := range n.cache { entry.mu.Lock() - entry.dispatchRemoveEventLocked() - entry.setStateLocked(Unknown) - entry.notifyWakersLocked() + entry.removeLocked() entry.mu.Unlock() } diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index ed33418f3..a2ed6ae2a 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -28,7 +28,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/faketime" ) @@ -80,17 +79,20 @@ func entryDiffOptsWithSort() []cmp.Option { func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache { config.resetInvalidFields() rng := rand.New(rand.NewSource(time.Now().UnixNano())) - return &neighborCache{ + neigh := &neighborCache{ nic: &NIC{ stack: &Stack{ clock: clock, nudDisp: nudDisp, }, - id: 1, + id: 1, + stats: makeNICStats(), }, state: NewNUDState(config, rng), cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), } + neigh.nic.neigh = neigh + return neigh } // testEntryStore contains a set of IP to NeighborEntry mappings. @@ -187,15 +189,18 @@ type testNeighborResolver struct { entries *testEntryStore delay time.Duration onLinkAddressRequest func() + dropReplies bool } var _ LinkAddressResolver = (*testNeighborResolver)(nil) func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { - // Delay handling the request to emulate network latency. - r.clock.AfterFunc(r.delay, func() { - r.fakeRequest(targetAddr) - }) + if !r.dropReplies { + // Delay handling the request to emulate network latency. + r.clock.AfterFunc(r.delay, func() { + r.fakeRequest(targetAddr) + }) + } // Execute post address resolution action, if available. if f := r.onLinkAddressRequest; f != nil { @@ -288,10 +293,10 @@ func TestNeighborCacheEntry(t *testing.T) { entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -324,7 +329,7 @@ func TestNeighborCacheEntry(t *testing.T) { } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) } // No more events should have been dispatched. @@ -351,11 +356,11 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -410,7 +415,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } } @@ -458,7 +463,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error { return fmt.Errorf("c.store.entry(%d) not found", i) } if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - return fmt.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + return fmt.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(c.neigh.config().RetransmitTimer) @@ -510,7 +515,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error { } // Expect to find only the most recent entries. The order of entries reported - // by entries() is undeterministic, so entries have to be sorted before + // by entries() is nondeterministic, so entries have to be sorted before // comparison. wantUnsortedEntries := opts.wantStaticEntries for i := c.store.size() - neighborCacheSize; i < c.store.size(); i++ { @@ -572,10 +577,10 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { // Add a dynamic entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(c.neigh.config().RetransmitTimer) wantEvents := []testEntryEventInfo{ @@ -647,7 +652,7 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { // Add a static entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) @@ -691,7 +696,7 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) // Add a static entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) @@ -753,7 +758,7 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { // Add a static entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) @@ -823,10 +828,10 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { // Add a dynamic entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ @@ -904,150 +909,6 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { } } -func TestNeighborCacheNotifiesWaker(t *testing.T) { - config := DefaultNUDConfigurations() - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } - - w := sleep.Waker{} - s := sleep.Sleeper{} - const wakerID = 1 - s.AddWaker(&w, wakerID) - - entry, ok := store.entry(0) - if !ok { - t.Fatalf("store.entry(0) not found") - } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _ = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) - } - if doneCh == nil { - t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr) - } - clock.Advance(typicalLatency) - - select { - case <-doneCh: - default: - t.Fatal("expected notification from done channel") - } - - id, ok := s.Fetch(false /* block */) - if !ok { - t.Errorf("expected waker to be notified after neigh.entry(%s, '', _, _)", entry.Addr) - } - if id != wakerID { - t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID) - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) - } -} - -func TestNeighborCacheRemoveWaker(t *testing.T) { - config := DefaultNUDConfigurations() - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } - - w := sleep.Waker{} - s := sleep.Sleeper{} - const wakerID = 1 - s.AddWaker(&w, wakerID) - - entry, ok := store.entry(0) - if !ok { - t.Fatalf("store.entry(0) not found") - } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) - } - if doneCh == nil { - t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr) - } - - // Remove the waker before the neighbor cache has the opportunity to send a - // notification. - neigh.removeWaker(entry.Addr, &w) - clock.Advance(typicalLatency) - - select { - case <-doneCh: - default: - t.Fatal("expected notification from done channel") - } - - if id, ok := s.Fetch(false /* block */); ok { - t.Errorf("unexpected notification from waker with id %d", id) - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) - } -} - func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { config := DefaultNUDConfigurations() // Stay in Reachable so the cache can overflow @@ -1059,12 +920,12 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) e, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) if err != nil { - t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err) + t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil, nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1072,7 +933,7 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { State: Static, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("c.neigh.entry(%s, \"\", _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } wantEvents := []testEntryEventInfo{ @@ -1126,10 +987,10 @@ func TestNeighborCacheClear(t *testing.T) { // Add a dynamic entry. entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -1184,7 +1045,7 @@ func TestNeighborCacheClear(t *testing.T) { } } - // Clear shoud remove both dynamic and static entries. + // Clear should remove both dynamic and static entries. neigh.clear() // Remove events dispatched from clear() have no deterministic order so they @@ -1231,10 +1092,10 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { // Add a dynamic entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ @@ -1315,7 +1176,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { frequentlyUsedEntry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } // The following logic is very similar to overflowCache, but @@ -1327,15 +1188,22 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if !ok { + t.Fatal("expected successful address resolution") + } + if linkAddr != entry.LinkAddr { + t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { - case <-doneCh: + case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) } wantEvents := []testEntryEventInfo{ { @@ -1370,7 +1238,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { // Periodically refresh the frequently used entry if i%(neighborCacheSize/2) == 0 { if _, _, err := neigh.entry(frequentlyUsedEntry.Addr, "", linkRes, nil); err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", frequentlyUsedEntry.Addr, err) + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", frequentlyUsedEntry.Addr, err) } } @@ -1378,15 +1246,23 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) + + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if !ok { + t.Fatal("expected successful address resolution") + } + if linkAddr != entry.LinkAddr { + t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { - case <-doneCh: + case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) } // An entry should have been removed, as per the LRU eviction strategy @@ -1432,7 +1308,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { } // Expect to find only the frequently used entry and the most recent entries. - // The order of entries reported by entries() is undeterministic, so entries + // The order of entries reported by entries() is nondeterministic, so entries // have to be sorted before comparison. wantUnsortedEntries := []NeighborEntry{ { @@ -1491,12 +1367,12 @@ func TestNeighborCacheConcurrent(t *testing.T) { go func(entry NeighborEntry) { defer wg.Done() if e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil && err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock) } }(entry) } - // Wait for all gorountines to send a request + // Wait for all goroutines to send a request wg.Wait() // Process all the requests for a single entry concurrently @@ -1506,7 +1382,7 @@ func TestNeighborCacheConcurrent(t *testing.T) { // All goroutines add in the same order and add more values than can fit in // the cache. Our eviction strategy requires that the last entries are // present, up to the size of the neighbor cache, and the rest are missing. - // The order of entries reported by entries() is undeterministic, so entries + // The order of entries reported by entries() is nondeterministic, so entries // have to be sorted before comparison. var wantUnsortedEntries []NeighborEntry for i := store.size() - neighborCacheSize; i < store.size(); i++ { @@ -1544,27 +1420,32 @@ func TestNeighborCacheReplace(t *testing.T) { // Add an entry entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) + + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if !ok { + t.Fatal("expected successful address resolution") + } + if linkAddr != entry.LinkAddr { + t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { - case <-doneCh: + case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) } // Verify the entry exists { - e, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) + e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) - } - if doneCh != nil { - t.Errorf("unexpected done channel from neigh.entry(%s, '', _, nil): %v", entry.Addr, doneCh) + t.Errorf("unexpected error from neigh.entry(%s, '', _, _, nil): %s", entry.Addr, err) } if t.Failed() { t.FailNow() @@ -1575,7 +1456,7 @@ func TestNeighborCacheReplace(t *testing.T) { State: Reachable, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("neigh.entry(%s, '', _, _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } } @@ -1584,7 +1465,7 @@ func TestNeighborCacheReplace(t *testing.T) { { entry, ok := store.entry(1) if !ok { - t.Fatalf("store.entry(1) not found") + t.Fatal("store.entry(1) not found") } updatedLinkAddr = entry.LinkAddr } @@ -1601,7 +1482,7 @@ func TestNeighborCacheReplace(t *testing.T) { { e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Fatalf("neigh.entry(%s, '', _, nil): %s", entry.Addr, err) + t.Fatalf("neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1609,7 +1490,7 @@ func TestNeighborCacheReplace(t *testing.T) { State: Delay, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } clock.Advance(config.DelayFirstProbeTime + typicalLatency) } @@ -1619,7 +1500,7 @@ func TestNeighborCacheReplace(t *testing.T) { e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) clock.Advance(typicalLatency) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1627,7 +1508,7 @@ func TestNeighborCacheReplace(t *testing.T) { State: Reachable, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } } } @@ -1651,18 +1532,35 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { }, } - // First, sanity check that resolution is working entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + + // First, sanity check that resolution is working + { + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if !ok { + t.Fatal("expected successful address resolution") + } + if linkAddr != entry.LinkAddr { + t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + } + clock.Advance(typicalLatency) + select { + case <-ch: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + } } - clock.Advance(typicalLatency) + got, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1670,20 +1568,35 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { State: Reachable, } if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } - // Verify that address resolution for an unknown address returns ErrNoLinkAddress + // Verify address resolution fails for an unknown address. before := atomic.LoadUint32(&requestCount) entry.Addr += "2" - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) - } - waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) - clock.Advance(waitFor) - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress) + { + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if ok { + t.Error("expected unsuccessful address resolution") + } + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) + } + if t.Failed() { + t.FailNow() + } + }) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + } + waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) + clock.Advance(waitFor) + select { + case <-ch: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + } } maxAttempts := neigh.config().MaxUnicastProbes @@ -1711,15 +1624,129 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if ok { + t.Error("expected unsuccessful address resolution") + } + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) + } + if t.Failed() { + t.FailNow() + } + }) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress) + + select { + case <-ch: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + } +} + +// TestNeighborCacheRetryResolution simulates retrying communication after +// failing to perform address resolution. +func TestNeighborCacheRetryResolution(t *testing.T) { + config := DefaultNUDConfigurations() + clock := faketime.NewManualClock() + neigh := newTestNeighborCache(nil, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + // Simulate a faulty link. + dropReplies: true, + } + + entry, ok := store.entry(0) + if !ok { + t.Fatal("store.entry(0) not found") + } + + // Perform address resolution with a faulty link, which will fail. + { + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if ok { + t.Error("expected unsuccessful address resolution") + } + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) + } + if t.Failed() { + t.FailNow() + } + }) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + } + waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) + clock.Advance(waitFor) + + select { + case <-ch: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + } + } + + // Verify the entry is in Failed state. + wantEntries := []NeighborEntry{ + { + Addr: entry.Addr, + LinkAddr: "", + State: Failed, + }, + } + if diff := cmp.Diff(neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" { + t.Fatalf("neighbor entries mismatch (-got, +want):\n%s", diff) + } + + // Retry address resolution with a working link. + linkRes.dropReplies = false + { + incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if linkAddr != entry.LinkAddr { + t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + } + if incompleteEntry.State != Incomplete { + t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete) + } + clock.Advance(typicalLatency) + + select { + case <-ch: + if !ok { + t.Fatal("expected successful address resolution") + } + reachableEntry, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + if err != nil { + t.Fatalf("neigh.entry(%s, '', _, _, nil): %v", entry.Addr, err) + } + if reachableEntry.Addr != entry.Addr { + t.Fatalf("got entry.Addr = %s, want = %s", reachableEntry.Addr, entry.Addr) + } + if reachableEntry.LinkAddr != entry.LinkAddr { + t.Fatalf("got entry.LinkAddr = %s, want = %s", reachableEntry.LinkAddr, entry.LinkAddr) + } + if reachableEntry.State != Reachable { + t.Fatalf("got entry.State = %s, want = %s", reachableEntry.State.String(), Reachable.String()) + } + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + } } } @@ -1739,7 +1766,7 @@ func TestNeighborCacheStaticResolution(t *testing.T) { got, _, err := neigh.entry(testEntryBroadcastAddr, "", linkRes, nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", testEntryBroadcastAddr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", testEntryBroadcastAddr, err) } want := NeighborEntry{ Addr: testEntryBroadcastAddr, @@ -1747,7 +1774,7 @@ func TestNeighborCacheStaticResolution(t *testing.T) { State: Static, } if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff) } } @@ -1772,12 +1799,23 @@ func BenchmarkCacheClear(b *testing.B) { if !ok { b.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) + + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if !ok { + b.Fatal("expected successful address resolution") + } + if linkAddr != entry.LinkAddr { + b.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) if err != tcpip.ErrWouldBlock { - b.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + b.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } - if doneCh != nil { - <-doneCh + + select { + case <-ch: + default: + b.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) } } diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 493e48031..75afb3001 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -19,7 +19,6 @@ import ( "sync" "time" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -67,8 +66,7 @@ const ( // Static describes entries that have been explicitly added by the user. They // do not expire and are not deleted until explicitly removed. Static - // Failed means traffic should not be sent to this neighbor since attempts of - // reachability have returned inconclusive. + // Failed means recent attempts of reachability have returned inconclusive. Failed ) @@ -93,16 +91,13 @@ type neighborEntry struct { neigh NeighborEntry - // wakers is a set of waiters for address resolution result. Anytime state - // transitions out of incomplete these waiters are notified. It is nil iff - // address resolution is ongoing and no clients are waiting for the result. - wakers map[*sleep.Waker]struct{} - - // done is used to allow callers to wait on address resolution. It is nil - // iff nudState is not Reachable and address resolution is not yet in - // progress. + // done is closed when address resolution is complete. It is nil iff s is + // incomplete and resolution is not yet in progress. done chan struct{} + // onResolve is called with the result of address resolution. + onResolve []func(tcpip.LinkAddress, bool) + isRouter bool job *tcpip.Job } @@ -143,25 +138,15 @@ func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAdd } } -// addWaker adds w to the list of wakers waiting for address resolution. -// Assumes the entry has already been appropriately locked. -func (e *neighborEntry) addWakerLocked(w *sleep.Waker) { - if w == nil { - return - } - if e.wakers == nil { - e.wakers = make(map[*sleep.Waker]struct{}) - } - e.wakers[w] = struct{}{} -} - -// notifyWakersLocked notifies those waiting for address resolution, whether it -// succeeded or failed. Assumes the entry has already been appropriately locked. -func (e *neighborEntry) notifyWakersLocked() { - for w := range e.wakers { - w.Assert() +// notifyCompletionLocked notifies those waiting for address resolution, with +// the link address if resolution completed successfully. +// +// Precondition: e.mu MUST be locked. +func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { + for _, callback := range e.onResolve { + callback(e.neigh.LinkAddr, succeeded) } - e.wakers = nil + e.onResolve = nil if ch := e.done; ch != nil { close(ch) e.done = nil @@ -170,6 +155,8 @@ func (e *neighborEntry) notifyWakersLocked() { // dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has // been added. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchAddEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { nudDisp.OnNeighborAdded(e.nic.id, e.neigh) @@ -178,6 +165,8 @@ func (e *neighborEntry) dispatchAddEventLocked() { // dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry // has changed state or link-layer address. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchChangeEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { nudDisp.OnNeighborChanged(e.nic.id, e.neigh) @@ -186,23 +175,41 @@ func (e *neighborEntry) dispatchChangeEventLocked() { // dispatchRemoveEventLocked signals to stack's NUD Dispatcher that the entry // has been removed. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchRemoveEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { nudDisp.OnNeighborRemoved(e.nic.id, e.neigh) } } +// cancelJobLocked cancels the currently scheduled action, if there is one. +// Entries in Unknown, Stale, or Static state do not have a scheduled action. +// +// Precondition: e.mu MUST be locked. +func (e *neighborEntry) cancelJobLocked() { + if job := e.job; job != nil { + job.Cancel() + } +} + +// removeLocked prepares the entry for removal. +// +// Precondition: e.mu MUST be locked. +func (e *neighborEntry) removeLocked() { + e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() + e.dispatchRemoveEventLocked() + e.cancelJobLocked() + e.notifyCompletionLocked(false /* succeeded */) +} + // setStateLocked transitions the entry to the specified state immediately. // // Follows the logic defined in RFC 4861 section 7.3.3. // -// e.mu MUST be locked. +// Precondition: e.mu MUST be locked. func (e *neighborEntry) setStateLocked(next NeighborState) { - // Cancel the previously scheduled action, if there is one. Entries in - // Unknown, Stale, or Static state do not have scheduled actions. - if timer := e.job; timer != nil { - timer.Cancel() - } + e.cancelJobLocked() prev := e.neigh.State e.neigh.State = next @@ -257,11 +264,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { e.job.Schedule(immediateDuration) case Failed: - e.notifyWakersLocked() - e.job = e.nic.stack.newJob(&e.mu, func() { - e.nic.neigh.removeEntryLocked(e) - }) - e.job.Schedule(config.UnreachableTime) + e.notifyCompletionLocked(false /* succeeded */) case Unknown, Stale, Static: // Do nothing @@ -275,8 +278,14 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // being queued for outgoing transmission. // // Follows the logic defined in RFC 4861 section 7.3.3. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { switch e.neigh.State { + case Failed: + e.nic.stats.Neighbor.FailedEntryLookups.Increment() + + fallthrough case Unknown: e.neigh.State = Incomplete e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() @@ -309,7 +318,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // implementation may find it convenient in some cases to return errors // to the sender by taking the offending packet, generating an ICMP // error message, and then delivering it (locally) through the generic - // error-handling routines.' - RFC 4861 section 2.1 + // error-handling routines." - RFC 4861 section 2.1 e.dispatchRemoveEventLocked() e.setStateLocked(Failed) return @@ -347,9 +356,8 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { e.setStateLocked(Delay) e.dispatchChangeEventLocked() - case Incomplete, Reachable, Delay, Probe, Static, Failed: + case Incomplete, Reachable, Delay, Probe, Static: // Do nothing - default: panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) } @@ -359,18 +367,30 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // Neighbor Solicitation for ARP or NDP, respectively). // // Follows the logic defined in RFC 4861 section 7.2.3. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { // Probes MUST be silently discarded if the target address is tentative, does // not exist, or not bound to the NIC as per RFC 4861 section 7.2.3. These // checks MUST be done by the NetworkEndpoint. switch e.neigh.State { - case Unknown, Incomplete, Failed: + case Unknown, Failed: e.neigh.LinkAddr = remoteLinkAddr e.setStateLocked(Stale) - e.notifyWakersLocked() e.dispatchAddEventLocked() + case Incomplete: + // "If an entry already exists, and the cached link-layer address + // differs from the one in the received Source Link-Layer option, the + // cached address should be replaced by the received address, and the + // entry's reachability state MUST be set to STALE." + // - RFC 4861 section 7.2.3 + e.neigh.LinkAddr = remoteLinkAddr + e.setStateLocked(Stale) + e.notifyCompletionLocked(true /* succeeded */) + e.dispatchChangeEventLocked() + case Reachable, Delay, Probe: if e.neigh.LinkAddr != remoteLinkAddr { e.neigh.LinkAddr = remoteLinkAddr @@ -403,6 +423,8 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { // not be possible. SEND uses RSA key pairs to produce Cryptographically // Generated Addresses (CGA), as defined in RFC 3972. This ensures that the // claimed source of an NDP message is the owner of the claimed address. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { switch e.neigh.State { case Incomplete: @@ -421,7 +443,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla } e.dispatchChangeEventLocked() e.isRouter = flags.IsRouter - e.notifyWakersLocked() + e.notifyCompletionLocked(true /* succeeded */) // "Note that the Override flag is ignored if the entry is in the // INCOMPLETE state." - RFC 4861 section 7.2.5 @@ -456,7 +478,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla wasReachable := e.neigh.State == Reachable // Set state to Reachable again to refresh timers. e.setStateLocked(Reachable) - e.notifyWakersLocked() + e.notifyCompletionLocked(true /* succeeded */) if !wasReachable { e.dispatchChangeEventLocked() } @@ -494,6 +516,8 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla // handleUpperLevelConfirmationLocked processes an incoming upper-level protocol // (e.g. TCP acknowledgements) reachability confirmation. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) handleUpperLevelConfirmationLocked() { switch e.neigh.State { case Reachable, Stale, Delay, Probe: diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index c2b763325..ec34ffa5a 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -25,7 +25,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -73,35 +72,36 @@ func eventDiffOptsWithSort() []cmp.Option { // The following unit tests exercise every state transition and verify its // behavior with RFC 4681. // -// | From | To | Cause | Action | Event | -// | ========== | ========== | ========================================== | =============== | ======= | -// | Unknown | Unknown | Confirmation w/ unknown address | | Added | -// | Unknown | Incomplete | Packet queued to unknown address | Send probe | Added | -// | Unknown | Stale | Probe w/ unknown address | | Added | -// | Incomplete | Incomplete | Retransmit timer expired | Send probe | Changed | -// | Incomplete | Reachable | Solicited confirmation | Notify wakers | Changed | -// | Incomplete | Stale | Unsolicited confirmation | Notify wakers | Changed | -// | Incomplete | Failed | Max probes sent without reply | Notify wakers | Removed | -// | Reachable | Reachable | Confirmation w/ different isRouter flag | Update IsRouter | | -// | Reachable | Stale | Reachable timer expired | | Changed | -// | Reachable | Stale | Probe or confirmation w/ different address | | Changed | -// | Stale | Reachable | Solicited override confirmation | Update LinkAddr | Changed | -// | Stale | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | -// | Stale | Stale | Override confirmation | Update LinkAddr | Changed | -// | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed | -// | Stale | Delay | Packet sent | | Changed | -// | Delay | Reachable | Upper-layer confirmation | | Changed | -// | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed | -// | Delay | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | -// | Delay | Stale | Probe or confirmation w/ different address | | Changed | -// | Delay | Probe | Delay timer expired | Send probe | Changed | -// | Probe | Reachable | Solicited override confirmation | Update LinkAddr | Changed | -// | Probe | Reachable | Solicited confirmation w/ same address | Notify wakers | Changed | -// | Probe | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | -// | Probe | Stale | Probe or confirmation w/ different address | | Changed | -// | Probe | Probe | Retransmit timer expired | Send probe | Changed | -// | Probe | Failed | Max probes sent without reply | Notify wakers | Removed | -// | Failed | | Unreachability timer expired | Delete entry | | +// | From | To | Cause | Update | Action | Event | +// | ========== | ========== | ========================================== | ======== | ===========| ======= | +// | Unknown | Unknown | Confirmation w/ unknown address | | | Added | +// | Unknown | Incomplete | Packet queued to unknown address | | Send probe | Added | +// | Unknown | Stale | Probe | | | Added | +// | Incomplete | Incomplete | Retransmit timer expired | | Send probe | Changed | +// | Incomplete | Reachable | Solicited confirmation | LinkAddr | Notify | Changed | +// | Incomplete | Stale | Unsolicited confirmation | LinkAddr | Notify | Changed | +// | Incomplete | Stale | Probe | LinkAddr | Notify | Changed | +// | Incomplete | Failed | Max probes sent without reply | | Notify | Removed | +// | Reachable | Reachable | Confirmation w/ different isRouter flag | IsRouter | | | +// | Reachable | Stale | Reachable timer expired | | | Changed | +// | Reachable | Stale | Probe or confirmation w/ different address | | | Changed | +// | Stale | Reachable | Solicited override confirmation | LinkAddr | | Changed | +// | Stale | Reachable | Solicited confirmation w/o address | | Notify | Changed | +// | Stale | Stale | Override confirmation | LinkAddr | | Changed | +// | Stale | Stale | Probe w/ different address | LinkAddr | | Changed | +// | Stale | Delay | Packet sent | | | Changed | +// | Delay | Reachable | Upper-layer confirmation | | | Changed | +// | Delay | Reachable | Solicited override confirmation | LinkAddr | | Changed | +// | Delay | Reachable | Solicited confirmation w/o address | | Notify | Changed | +// | Delay | Stale | Probe or confirmation w/ different address | | | Changed | +// | Delay | Probe | Delay timer expired | | Send probe | Changed | +// | Probe | Reachable | Solicited override confirmation | LinkAddr | | Changed | +// | Probe | Reachable | Solicited confirmation w/ same address | | Notify | Changed | +// | Probe | Reachable | Solicited confirmation w/o address | | Notify | Changed | +// | Probe | Stale | Probe or confirmation w/ different address | | | Changed | +// | Probe | Probe | Retransmit timer expired | | | Changed | +// | Probe | Failed | Max probes sent without reply | | Notify | Removed | +// | Failed | Incomplete | Packet queued | | Send probe | Added | type testEntryEventType uint8 @@ -228,6 +228,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e clock: clock, nudDisp: &disp, }, + stats: makeNICStats(), } nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil, nil, nil), @@ -256,8 +257,8 @@ func TestEntryInitiallyUnknown(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - if got, want := e.neigh.State, Unknown; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Unknown { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Unknown) } e.mu.Unlock() @@ -289,8 +290,8 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { Override: false, IsRouter: false, }) - if got, want := e.neigh.State, Unknown; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Unknown { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Unknown) } e.mu.Unlock() @@ -318,8 +319,8 @@ func TestEntryUnknownToIncomplete(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } e.mu.Unlock() @@ -365,8 +366,8 @@ func TestEntryUnknownToStale(t *testing.T) { e.mu.Lock() e.handleProbeLocked(entryTestLinkAddr1) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } e.mu.Unlock() @@ -404,8 +405,8 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } updatedAtNanos := e.neigh.UpdatedAtNanos e.mu.Unlock() @@ -558,21 +559,15 @@ func TestEntryIncompleteToReachable(t *testing.T) { nudDisp.mu.Unlock() } -// TestEntryAddsAndClearsWakers verifies that wakers are added when -// addWakerLocked is called and cleared when address resolution finishes. In -// this case, address resolution will finish when transitioning from Incomplete -// to Reachable. -func TestEntryAddsAndClearsWakers(t *testing.T) { +func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, clock := entryTestSetup(c) - w := sleep.Waker{} - s := sleep.Sleeper{} - s.AddWaker(&w, 123) - defer s.Done() - e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) + } e.mu.Unlock() runImmediatelyScheduledJobs(clock) @@ -591,26 +586,16 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { } e.mu.Lock() - if got := e.wakers; got != nil { - t.Errorf("got e.wakers = %v, want = nil", got) - } - e.addWakerLocked(&w) - if got, want := w.IsAsserted(), false; got != want { - t.Errorf("waker.IsAsserted() = %t, want = %t", got, want) - } - if e.wakers == nil { - t.Error("expected e.wakers to be non-nil") - } e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, - IsRouter: false, + IsRouter: true, }) - if e.wakers != nil { - t.Errorf("got e.wakers = %v, want = nil", e.wakers) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } - if got, want := w.IsAsserted(), true; got != want { - t.Errorf("waker.IsAsserted() = %t, want = %t", got, want) + if !e.isRouter { + t.Errorf("got e.isRouter = %t, want = true", e.isRouter) } e.mu.Unlock() @@ -641,7 +626,7 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { nudDisp.mu.Unlock() } -func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { +func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, clock := entryTestSetup(c) @@ -661,22 +646,20 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { }, } linkRes.mu.Lock() - if diff := cmp.Diff(linkRes.probes, wantProbes); diff != "" { + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - linkRes.mu.Unlock() e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, + Solicited: false, Override: false, - IsRouter: true, + IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) - } - if !e.isRouter { - t.Errorf("got e.isRouter = %t, want = true", e.isRouter) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } e.mu.Unlock() @@ -696,7 +679,7 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { Entry: NeighborEntry{ Addr: entryTestAddr1, LinkAddr: entryTestLinkAddr1, - State: Reachable, + State: Stale, }, }, } @@ -707,7 +690,7 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { nudDisp.mu.Unlock() } -func TestEntryIncompleteToStale(t *testing.T) { +func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, clock := entryTestSetup(c) @@ -734,11 +717,7 @@ func TestEntryIncompleteToStale(t *testing.T) { } e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) + e.handleProbeLocked(entryTestLinkAddr1) if e.neigh.State != Stale { t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } @@ -778,8 +757,8 @@ func TestEntryIncompleteToFailed(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } e.mu.Unlock() @@ -839,8 +818,8 @@ func TestEntryIncompleteToFailed(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if got, want := e.neigh.State, Failed; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Failed { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed) } e.mu.Unlock() } @@ -883,8 +862,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { Override: false, IsRouter: true, }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } if got, want := e.isRouter, true; got != want { t.Errorf("got e.isRouter = %t, want = %t", got, want) @@ -930,8 +909,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } e.mu.Unlock() } @@ -1081,8 +1060,8 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } e.mu.Unlock() } @@ -2379,8 +2358,8 @@ func TestEntryDelayToProbe(t *testing.T) { IsRouter: false, }) e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) } e.mu.Unlock() @@ -2445,8 +2424,8 @@ func TestEntryDelayToProbe(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.mu.Unlock() } @@ -2503,12 +2482,12 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleProbeLocked(entryTestLinkAddr2) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } e.mu.Unlock() @@ -2618,16 +2597,16 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } e.mu.Unlock() @@ -2738,16 +2717,16 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, }) - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) @@ -2834,16 +2813,16 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: true, Override: true, IsRouter: false, }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) @@ -2962,16 +2941,16 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: true, Override: true, IsRouter: false, }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) @@ -3099,16 +3078,16 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } e.mu.Unlock() @@ -3433,72 +3412,61 @@ func TestEntryProbeToFailed(t *testing.T) { nudDisp.mu.Unlock() } -func TestEntryFailedGetsDeleted(t *testing.T) { +func TestEntryFailedToIncomplete(t *testing.T) { c := DefaultNUDConfigurations() c.MaxMulticastProbes = 3 - c.MaxUnicastProbes = 3 e, nudDisp, linkRes, clock := entryTestSetup(c) - // Verify the cache contains the entry. - if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok { - t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1) - } - + // TODO(gvisor.dev/issue/4872): Use helper functions to start entry tests in + // their expected state. e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) + } e.mu.Unlock() - runImmediatelyScheduledJobs(clock) - { - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.probes = nil - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } + waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes) + clock.Advance(waitFor) + + wantProbes := []entryTestProbeInfo{ + // The Incomplete-to-Incomplete state transition is tested here by + // verifying that 3 reachability probes were sent. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Failed { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed) + } e.mu.Unlock() - waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime - clock.Advance(waitFor) - { - wantProbes := []entryTestProbeInfo{ - // The next three probe are sent in Probe. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } + e.mu.Lock() + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } + e.mu.Unlock() wantEvents := []testEntryEventInfo{ { @@ -3511,39 +3479,21 @@ func TestEntryFailedGetsDeleted(t *testing.T) { }, }, { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - }, - }, - { - EventType: entryTestChanged, + EventType: entryTestRemoved, NICID: entryTestNICID, Entry: NeighborEntry{ Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, }, }, { - EventType: entryTestRemoved, + EventType: entryTestAdded, NICID: entryTestNICID, Entry: NeighborEntry{ Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, }, }, } @@ -3552,9 +3502,4 @@ func TestEntryFailedGetsDeleted(t *testing.T) { t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - // Verify the cache no longer contains the entry. - if _, ok := e.nic.neigh.cache[entryTestAddr1]; ok { - t.Errorf("entry %q should have been deleted from the neighbor cache", entryTestAddr1) - } } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 60c81a3aa..4a34805b5 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -20,7 +20,6 @@ import ( "reflect" "sync/atomic" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -54,18 +53,20 @@ type NIC struct { sync.RWMutex spoofing bool promiscuous bool - // packetEPs is protected by mu, but the contained PacketEndpoint - // values are not. - packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint + // packetEPs is protected by mu, but the contained packetEndpointList are + // not. + packetEPs map[tcpip.NetworkProtocolNumber]*packetEndpointList } } -// NICStats includes transmitted and received stats. +// NICStats hold statistics for a NIC. type NICStats struct { Tx DirectionStats Rx DirectionStats DisabledRx DirectionStats + + Neighbor NeighborStats } func makeNICStats() NICStats { @@ -80,6 +81,39 @@ type DirectionStats struct { Bytes *tcpip.StatCounter } +type packetEndpointList struct { + mu sync.RWMutex + + // eps is protected by mu, but the contained PacketEndpoint values are not. + eps []PacketEndpoint +} + +func (p *packetEndpointList) add(ep PacketEndpoint) { + p.mu.Lock() + defer p.mu.Unlock() + p.eps = append(p.eps, ep) +} + +func (p *packetEndpointList) remove(ep PacketEndpoint) { + p.mu.Lock() + defer p.mu.Unlock() + for i, epOther := range p.eps { + if epOther == ep { + p.eps = append(p.eps[:i], p.eps[i+1:]...) + break + } + } +} + +// forEach calls fn with each endpoints in p while holding the read lock on p. +func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) { + p.mu.RLock() + defer p.mu.RUnlock() + for _, ep := range p.eps { + fn(ep) + } +} + // newNIC returns a new NIC using the default NDP configurations from stack. func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC { // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For @@ -100,7 +134,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC stats: makeNICStats(), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), } - nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint) + nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) // Check for Neighbor Unreachability Detection support. var nud NUDHandler @@ -123,11 +157,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC // Register supported packet and network endpoint protocols. for _, netProto := range header.Ethertypes { - nic.mu.packetEPs[netProto] = []PacketEndpoint{} + nic.mu.packetEPs[netProto] = new(packetEndpointList) } for _, netProto := range stack.networkProtocols { netNum := netProto.Number() - nic.mu.packetEPs[netNum] = nil + nic.mu.packetEPs[netNum] = new(packetEndpointList) nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic) } @@ -170,7 +204,7 @@ func (n *NIC) disable() { // // n MUST be locked. func (n *NIC) disableLocked() { - if !n.setEnabled(false) { + if !n.Enabled() { return } @@ -182,6 +216,10 @@ func (n *NIC) disableLocked() { for _, ep := range n.networkEndpoints { ep.Disable() } + + if !n.setEnabled(false) { + panic("should have only done work to disable the NIC if it was enabled") + } } // enable enables n. @@ -232,7 +270,8 @@ func (n *NIC) setPromiscuousMode(enable bool) { n.mu.Unlock() } -func (n *NIC) isPromiscuousMode() bool { +// Promiscuous implements NetworkInterface. +func (n *NIC) Promiscuous() bool { n.mu.RLock() rv := n.mu.promiscuous n.mu.RUnlock() @@ -255,16 +294,18 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb // the same unresolved IP address, and transmit the saved // packet when the address has been resolved. // - // RFC 4861 section 5.2 (for IPv6): - // Once the IP address of the next-hop node is known, the sender - // examines the Neighbor Cache for link-layer information about that - // neighbor. If no entry exists, the sender creates one, sets its state - // to INCOMPLETE, initiates Address Resolution, and then queues the data - // packet pending completion of address resolution. + // RFC 4861 section 7.2.2 (for IPv6): + // While waiting for address resolution to complete, the sender MUST, for + // each neighbor, retain a small queue of packets waiting for address + // resolution to complete. The queue MUST hold at least one packet, and MAY + // contain more. However, the number of queued packets per neighbor SHOULD + // be limited to some small value. When a queue overflows, the new arrival + // SHOULD replace the oldest entry. Once address resolution completes, the + // node transmits any queued packets. if ch, err := r.Resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { - r := r.Clone() - n.stack.linkResQueue.enqueue(ch, &r, protocol, pkt) + r.Acquire() + n.stack.linkResQueue.enqueue(ch, r, protocol, pkt) return nil } return err @@ -276,9 +317,11 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb // WritePacketToRemote implements NetworkInterface. func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { r := Route{ - NetProto: protocol, - RemoteLinkAddress: remoteLinkAddr, + routeInfo: routeInfo{ + NetProto: protocol, + }, } + r.ResolveWith(remoteLinkAddr) return n.writePacket(&r, gso, protocol, pkt) } @@ -320,16 +363,21 @@ func (n *NIC) setSpoofing(enable bool) { // primaryAddress returns an address that can be used to communicate with // remoteAddr. func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint { - n.mu.RLock() - spoofing := n.mu.spoofing - n.mu.RUnlock() - ep, ok := n.networkEndpoints[protocol] if !ok { return nil } - return ep.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing) + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + return nil + } + + n.mu.RLock() + spoofing := n.mu.spoofing + n.mu.RUnlock() + + return addressableEndpoint.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing) } type getAddressBehaviour int @@ -388,11 +436,17 @@ func (n *NIC) getAddressOrCreateTemp(protocol tcpip.NetworkProtocolNumber, addre // getAddressOrCreateTempInner is like getAddressEpOrCreateTemp except a boolean // is passed to indicate whether or not we should generate temporary endpoints. func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, createTemp bool, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { - if ep, ok := n.networkEndpoints[protocol]; ok { - return ep.AcquireAssignedAddress(address, createTemp, peb) + ep, ok := n.networkEndpoints[protocol] + if !ok { + return nil } - return nil + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + return nil + } + + return addressableEndpoint.AcquireAssignedAddress(address, createTemp, peb) } // addAddress adds a new address to n, so that it starts accepting packets @@ -403,7 +457,12 @@ func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo return tcpip.ErrUnknownProtocol } - addressEndpoint, err := ep.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */) + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + return tcpip.ErrNotSupported + } + + addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */) if err == nil { // We have no need for the address endpoint. addressEndpoint.DecRef() @@ -416,7 +475,12 @@ func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress { var addrs []tcpip.ProtocolAddress for p, ep := range n.networkEndpoints { - for _, a := range ep.PermanentAddresses() { + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + continue + } + + for _, a := range addressableEndpoint.PermanentAddresses() { addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a}) } } @@ -427,7 +491,12 @@ func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress { func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress { var addrs []tcpip.ProtocolAddress for p, ep := range n.networkEndpoints { - for _, a := range ep.PrimaryAddresses() { + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + continue + } + + for _, a := range addressableEndpoint.PrimaryAddresses() { addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a}) } } @@ -445,13 +514,23 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit return tcpip.AddressWithPrefix{} } - return ep.MainAddress() + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + return tcpip.AddressWithPrefix{} + } + + return addressableEndpoint.MainAddress() } // removeAddress removes an address from n. func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error { for _, ep := range n.networkEndpoints { - if err := ep.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress { + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + continue + } + + if err := addressableEndpoint.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress { continue } else { return err @@ -469,14 +548,6 @@ func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) { return n.neigh.entries(), nil } -func (n *NIC) removeWaker(addr tcpip.Address, w *sleep.Waker) { - if n.neigh == nil { - return - } - - n.neigh.removeWaker(addr, w) -} - func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) *tcpip.Error { if n.neigh == nil { return tcpip.ErrNotSupported @@ -524,8 +595,7 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address return tcpip.ErrNotSupported } - _, err := gep.JoinGroup(addr) - return err + return gep.JoinGroup(addr) } // leaveGroup decrements the count for the given multicast address, and when it @@ -541,11 +611,7 @@ func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Addres return tcpip.ErrNotSupported } - if _, err := gep.LeaveGroup(addr); err != nil { - return err - } - - return nil + return gep.LeaveGroup(addr) } // isInGroup returns true if n has joined the multicast group addr. @@ -564,13 +630,6 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { return false } -func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) { - r := makeRoute(protocol, dst, src, n, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */) - defer r.Release() - r.PopulatePacketInfo(pkt) - n.getNetworkEndpoint(protocol).HandlePacket(pkt) -} - // DeliverNetworkPacket finds the appropriate network protocol endpoint and // hands the packet over for further processing. This function is called when // the NIC receives a packet from the link endpoint. @@ -592,7 +651,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp n.stats.Rx.Packets.Increment() n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size())) - netProto, ok := n.stack.networkProtocols[protocol] + networkEndpoint, ok := n.networkEndpoints[protocol] if !ok { n.mu.RUnlock() n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -607,21 +666,26 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0 // Are any packet type sockets listening for this network protocol? - packetEPs := n.mu.packetEPs[protocol] - // Add any other packet type sockets that may be listening for all protocols. - packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) + protoEPs := n.mu.packetEPs[protocol] + // Other packet type sockets that are listening for all protocols. + anyEPs := n.mu.packetEPs[header.EthernetProtocolAll] n.mu.RUnlock() - for _, ep := range packetEPs { + + // Deliver to interested packet endpoints without holding NIC lock. + deliverPacketEPs := func(ep PacketEndpoint) { p := pkt.Clone() p.PktType = tcpip.PacketHost ep.HandlePacket(n.id, local, protocol, p) } - - if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { - n.stack.stats.IP.PacketsReceived.Increment() + if protoEPs != nil { + protoEPs.forEach(deliverPacketEPs) + } + if anyEPs != nil { + anyEPs.forEach(deliverPacketEPs) } // Parse headers. + netProto := n.stack.NetworkProtocolInstance(protocol) transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) if !ok { // The packet is too small to contain a network header. @@ -636,9 +700,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } - src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View()) - if n.stack.handleLocal && !n.IsLoopback() { + src, _ := netProto.ParseAddresses(pkt.NetworkHeader().View()) if r := n.getAddress(protocol, src); r != nil { r.DecRef() @@ -651,78 +714,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } - // Loopback traffic skips the prerouting chain. - if !n.IsLoopback() { - // iptables filtering. - ipt := n.stack.IPTables() - address := n.primaryAddress(protocol) - if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok { - // iptables is telling us to drop the packet. - n.stack.stats.IP.IPTablesPreroutingDropped.Increment() - return - } - } - - if addressEndpoint := n.getAddress(protocol, dst); addressEndpoint != nil { - n.handlePacket(protocol, dst, src, remote, addressEndpoint, pkt) - return - } - - // This NIC doesn't care about the packet. Find a NIC that cares about the - // packet and forward it to the NIC. - // - // TODO: Should we be forwarding the packet even if promiscuous? - if n.stack.Forwarding(protocol) { - r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */) - if err != nil { - n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() - return - } - - // Found a NIC. - n := r.localAddressNIC - if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil { - if n.isValidForOutgoing(addressEndpoint) { - pkt.NICID = n.ID() - r.RemoteAddress = src - pkt.NetworkPacketInfo = r.networkPacketInfo() - n.getNetworkEndpoint(protocol).HandlePacket(pkt) - addressEndpoint.DecRef() - r.Release() - return - } - - addressEndpoint.DecRef() - } - - // n doesn't have a destination endpoint. - // Send the packet out of n. - // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease - // the TTL field for ipv4/ipv6. - - // pkt may have set its header and may not have enough headroom for - // link-layer header for the other link to prepend. Here we create a new - // packet to forward. - fwdPkt := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()), - // We need to do a deep copy of the IP packet because WritePacket (and - // friends) take ownership of the packet buffer, but we do not own it. - Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(), - }) - - // TODO(b/143425874) Decrease the TTL field in forwarded packets. - if err := n.WritePacket(&r, nil, protocol, fwdPkt); err != nil { - n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() - } - - r.Release() - return - } - - // If a packet socket handled the packet, don't treat it as invalid. - if len(packetEPs) == 0 { - n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() - } + networkEndpoint.HandlePacket(pkt) } // DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket. @@ -731,16 +723,17 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc // We do not deliver to protocol specific packet endpoints as on Linux // only ETH_P_ALL endpoints get outbound packets. // Add any other packet sockets that maybe listening for all protocols. - packetEPs := n.mu.packetEPs[header.EthernetProtocolAll] + eps := n.mu.packetEPs[header.EthernetProtocolAll] n.mu.RUnlock() - for _, ep := range packetEPs { + + eps.forEach(func(ep PacketEndpoint) { p := pkt.Clone() p.PktType = tcpip.PacketOutgoing // Add the link layer header as outgoing packets are intercepted // before the link layer header is created. n.LinkEndpoint.AddHeader(local, remote, protocol, p) ep.HandlePacket(n.id, local, protocol, p) - } + }) } // DeliverTransportPacket delivers the packets to the appropriate transport @@ -893,7 +886,7 @@ func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa if !ok { return tcpip.ErrNotSupported } - n.mu.packetEPs[netProto] = append(eps, ep) + eps.add(ep) return nil } @@ -906,13 +899,7 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep if !ok { return } - - for i, epOther := range eps { - if epOther == ep { - n.mu.packetEPs[netProto] = append(eps[:i], eps[i+1:]...) - return - } - } + eps.remove(ep) } // isValidForOutgoing returns true if the endpoint can be used to send out a diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index ab629b3a4..12d67409a 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -109,14 +109,6 @@ const ( // // Default taken from MAX_NEIGHBOR_ADVERTISEMENT of RFC 4861 section 10. defaultMaxReachbilityConfirmations = 3 - - // defaultUnreachableTime is the default duration for how long an entry will - // remain in the FAILED state before being removed from the neighbor cache. - // - // Note, there is no equivalent protocol constant defined in RFC 4861. It - // leaves the specifics of any garbage collection mechanism up to the - // implementation. - defaultUnreachableTime = 5 * time.Second ) // NUDDispatcher is the interface integrators of netstack must implement to @@ -278,10 +270,6 @@ type NUDConfigurations struct { // TODO(gvisor.dev/issue/2246): Discuss if implementation of this NUD // configuration option is necessary. MaxReachabilityConfirmations uint32 - - // UnreachableTime describes how long an entry will remain in the FAILED - // state before being removed from the neighbor cache. - UnreachableTime time.Duration } // DefaultNUDConfigurations returns a NUDConfigurations populated with default @@ -299,7 +287,6 @@ func DefaultNUDConfigurations() NUDConfigurations { MaxUnicastProbes: defaultMaxUnicastProbes, MaxAnycastDelayTime: defaultMaxAnycastDelayTime, MaxReachabilityConfirmations: defaultMaxReachbilityConfirmations, - UnreachableTime: defaultUnreachableTime, } } @@ -329,9 +316,6 @@ func (c *NUDConfigurations) resetInvalidFields() { if c.MaxUnicastProbes == 0 { c.MaxUnicastProbes = defaultMaxUnicastProbes } - if c.UnreachableTime == 0 { - c.UnreachableTime = defaultUnreachableTime - } } // calcMaxRandomFactor calculates the maximum value of the random factor used @@ -416,7 +400,7 @@ func (s *NUDState) ReachableTime() time.Duration { s.config.BaseReachableTime != s.prevBaseReachableTime || s.config.MinRandomFactor != s.prevMinRandomFactor || s.config.MaxRandomFactor != s.prevMaxRandomFactor { - return s.recomputeReachableTimeLocked() + s.recomputeReachableTimeLocked() } return s.reachableTime } @@ -442,7 +426,7 @@ func (s *NUDState) ReachableTime() time.Duration { // random value gets re-computed at least once every few hours. // // s.mu MUST be locked for writing. -func (s *NUDState) recomputeReachableTimeLocked() time.Duration { +func (s *NUDState) recomputeReachableTimeLocked() { s.prevBaseReachableTime = s.config.BaseReachableTime s.prevMinRandomFactor = s.config.MinRandomFactor s.prevMaxRandomFactor = s.config.MaxRandomFactor @@ -462,5 +446,4 @@ func (s *NUDState) recomputeReachableTimeLocked() time.Duration { } s.expiration = time.Now().Add(2 * time.Hour) - return s.reachableTime } diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index 8cffb9fc6..7bca1373e 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -37,7 +37,6 @@ const ( defaultMaxUnicastProbes = 3 defaultMaxAnycastDelayTime = time.Second defaultMaxReachbilityConfirmations = 3 - defaultUnreachableTime = 5 * time.Second defaultFakeRandomNum = 0.5 ) @@ -565,58 +564,6 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) { } } -func TestNUDConfigurationsUnreachableTime(t *testing.T) { - tests := []struct { - name string - unreachableTime time.Duration - want time.Duration - }{ - // Invalid cases - { - name: "EqualToZero", - unreachableTime: 0, - want: defaultUnreachableTime, - }, - // Valid cases - { - name: "MoreThanZero", - unreachableTime: time.Millisecond, - want: time.Millisecond, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const nicID = 1 - - c := stack.DefaultNUDConfigurations() - c.UnreachableTime = test.unreachableTime - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The - // networking stack will only allocate neighbor caches if a protocol - // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: c, - UseNeighborCache: true, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - sc, err := s.NUDConfigurations(nicID) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) - } - if got := sc.UnreachableTime; got != test.want { - t.Errorf("got UnreachableTime = %q, want = %q", got, test.want) - } - }) - } -} - // TestNUDStateReachableTime verifies the correctness of the ReachableTime // computation. func TestNUDStateReachableTime(t *testing.T) { diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index 5d364a2b0..4a3adcf33 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -103,7 +103,7 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro for _, p := range packets { if cancelled { p.route.Stats().IP.OutgoingPacketErrors.Increment() - } else if _, err := p.route.Resolve(nil); err != nil { + } else if p.route.IsResolutionRequired() { p.route.Stats().IP.OutgoingPacketErrors.Increment() } else { p.route.outgoingNIC.writePacket(p.route, nil /* gso */, p.proto, p.pkt) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 00e9a82ae..4795208b4 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -17,7 +17,6 @@ package stack import ( "fmt" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -65,10 +64,6 @@ const ( // NetworkPacketInfo holds information about a network layer packet. type NetworkPacketInfo struct { - // RemoteAddressBroadcast is true if the packet's remote address is a - // broadcast address. - RemoteAddressBroadcast bool - // LocalAddressBroadcast is true if the packet's local address is a broadcast // address. LocalAddressBroadcast bool @@ -89,7 +84,7 @@ type TransportEndpoint interface { // HandleControlPacket is called by the stack when new control (e.g. // ICMP) packets arrive to this transport endpoint. // HandleControlPacket takes ownership of pkt. - HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) + HandleControlPacket(typ ControlType, extra uint32, pkt *PacketBuffer) // Abort initiates an expedited endpoint teardown. It puts the endpoint // in a closed state and frees all resources associated with it. This @@ -263,15 +258,6 @@ const ( PacketLoop ) -// NetOptions is an interface that allows us to pass network protocol specific -// options through the Stack layer code. -type NetOptions interface { - // AllocationSize returns the amount of memory that must be allocated to - // hold the options given that the value must be rounded up to the next - // multiple of 4 bytes. - AllocationSize() int -} - // NetworkHeaderParams are the header parameters given as input by the // transport endpoint to the network. type NetworkHeaderParams struct { @@ -283,10 +269,6 @@ type NetworkHeaderParams struct { // TOS refers to TypeOfService or TrafficClass field of the IP-header. TOS uint8 - - // Options is a set of options to add to a network header (or nil). - // It will be protocol specific opaque information from higher layers. - Options NetOptions } // GroupAddressableEndpoint is an endpoint that supports group addressing. @@ -295,14 +277,10 @@ type NetworkHeaderParams struct { // endpoints may associate themselves with the same identifier (group address). type GroupAddressableEndpoint interface { // JoinGroup joins the specified group. - // - // Returns true if the group was newly joined. - JoinGroup(group tcpip.Address) (bool, *tcpip.Error) + JoinGroup(group tcpip.Address) *tcpip.Error // LeaveGroup attempts to leave the specified group. - // - // Returns tcpip.ErrBadLocalAddress if the endpoint has not joined the group. - LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) + LeaveGroup(group tcpip.Address) *tcpip.Error // IsInGroup returns true if the endpoint is a member of the specified group. IsInGroup(group tcpip.Address) bool @@ -518,6 +496,9 @@ type NetworkInterface interface { // Enabled returns true if the interface is enabled. Enabled() bool + // Promiscuous returns true if the interface is in promiscuous mode. + Promiscuous() bool + // WritePacketToRemote writes the packet to the given remote link address. WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error } @@ -525,8 +506,6 @@ type NetworkInterface interface { // NetworkEndpoint is the interface that needs to be implemented by endpoints // of network layer protocols (e.g., ipv4, ipv6). type NetworkEndpoint interface { - AddressableEndpoint - // Enable enables the endpoint. // // Must only be called when the stack is in a state that allows the endpoint @@ -742,10 +721,6 @@ type LinkEndpoint interface { // endpoint. Capabilities() LinkEndpointCapabilities - // WriteRawPacket writes a packet directly to the link. The packet - // should already have an ethernet header. It takes ownership of vv. - WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error - // Attach attaches the data link layer endpoint to the network-layer // dispatcher of the stack. // @@ -823,19 +798,26 @@ type LinkAddressCache interface { // AddLinkAddress adds a link address to the cache. AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) - // GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC). - // If the LinkEndpoint requests address resolution and there is a LinkAddressResolver - // registered with the network protocol, the cache attempts to resolve the address - // and returns ErrWouldBlock. Waker is notified when address resolution is - // complete (success or not). + // GetLinkAddress finds the link address corresponding to the remote address + // (e.g. IP -> MAC). // - // If address resolution is required, ErrNoLinkAddress and a notification channel is - // returned for the top level caller to block. Channel is closed once address resolution - // is complete (success or not). - GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) - - // RemoveWaker removes a waker that has been added in GetLinkAddress(). - RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) + // Returns a link address for the remote address, if readily available. + // + // Returns ErrWouldBlock if the link address is not readily available, along + // with a notification channel for the caller to block on. Triggers address + // resolution asynchronously. + // + // If onResolve is provided, it will be called either immediately, if + // resolution is not required, or when address resolution is complete, with + // the resolved link address and whether resolution succeeded. After any + // callbacks have been called, the returned notification channel is closed. + // + // If specified, the local address must be an address local to the interface + // the neighbor cache belongs to. The local address is the source address of + // a packet prompting NUD/link address resolution. + // + // TODO(gvisor.dev/issue/5151): Don't return the link address. + GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) } // RawFactory produces endpoints for writing various types of raw packets. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 15ff437c7..b0251d0b4 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -17,20 +17,53 @@ package stack import ( "fmt" - "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) // Route represents a route through the networking stack to a given destination. +// +// It is safe to call Route's methods from multiple goroutines. +// +// The exported fields are immutable. +// +// TODO(gvisor.dev/issue/4902): Unexpose immutable fields. type Route struct { + routeInfo + + // localAddressNIC is the interface the address is associated with. + // TODO(gvisor.dev/issue/4548): Remove this field once we can query the + // address's assigned status without the NIC. + localAddressNIC *NIC + + mu struct { + sync.RWMutex + + // localAddressEndpoint is the local address this route is associated with. + localAddressEndpoint AssignableAddressEndpoint + + // remoteLinkAddress is the link-layer (MAC) address of the next hop in the + // route. + remoteLinkAddress tcpip.LinkAddress + } + + // outgoingNIC is the interface this route uses to write packets. + outgoingNIC *NIC + + // linkCache is set if link address resolution is enabled for this protocol on + // the route's NIC. + linkCache LinkAddressCache + + // linkRes is set if link address resolution is enabled for this protocol on + // the route's NIC. + linkRes LinkAddressResolver +} + +type routeInfo struct { // RemoteAddress is the final destination of the route. RemoteAddress tcpip.Address - // RemoteLinkAddress is the link-layer (MAC) address of the - // final destination of the route. - RemoteLinkAddress tcpip.LinkAddress - // LocalAddress is the local address where the route starts. LocalAddress tcpip.Address @@ -46,47 +79,48 @@ type Route struct { // Loop controls where WritePacket should send packets. Loop PacketLooping +} - // localAddressNIC is the interface the address is associated with. - // TODO(gvisor.dev/issue/4548): Remove this field once we can query the - // address's assigned status without the NIC. - localAddressNIC *NIC - - // localAddressEndpoint is the local address this route is associated with. - localAddressEndpoint AssignableAddressEndpoint - - // outgoingNIC is the interface this route uses to write packets. - outgoingNIC *NIC +// RouteInfo contains all of Route's exported fields. +type RouteInfo struct { + routeInfo - // linkCache is set if link address resolution is enabled for this protocol on - // the route's NIC. - linkCache LinkAddressCache + // RemoteLinkAddress is the link-layer (MAC) address of the next hop in the + // route. + RemoteLinkAddress tcpip.LinkAddress +} - // linkRes is set if link address resolution is enabled for this protocol on - // the route's NIC. - linkRes LinkAddressResolver +// GetFields returns a RouteInfo with all of r's exported fields. This allows +// callers to store the route's fields without retaining a reference to it. +func (r *Route) GetFields() RouteInfo { + return RouteInfo{ + routeInfo: r.routeInfo, + RemoteLinkAddress: r.RemoteLinkAddress(), + } } // constructAndValidateRoute validates and initializes a route. It takes // ownership of the provided local address. // // Returns an empty route if validation fails. -func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) Route { - addrWithPrefix := addressEndpoint.AddressWithPrefix() +func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) *Route { + if len(localAddr) == 0 { + localAddr = addressEndpoint.AddressWithPrefix().Address + } - if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(addrWithPrefix.Address) { + if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(localAddr) { addressEndpoint.DecRef() - return Route{} + return nil } // If no remote address is provided, use the local address. if len(remoteAddr) == 0 { - remoteAddr = addrWithPrefix.Address + remoteAddr = localAddr } r := makeRoute( netProto, - addrWithPrefix.Address, + localAddr, remoteAddr, outgoingNIC, localAddressNIC, @@ -99,8 +133,8 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp // broadcast it. if len(gateway) > 0 { r.NextHop = gateway - } else if subnet := addrWithPrefix.Subnet(); subnet.IsBroadcast(remoteAddr) { - r.RemoteLinkAddress = header.EthernetBroadcastAddress + } else if subnet := addressEndpoint.Subnet(); subnet.IsBroadcast(remoteAddr) { + r.ResolveWith(header.EthernetBroadcastAddress) } return r @@ -108,11 +142,15 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp // makeRoute initializes a new route. It takes ownership of the provided // AssignableAddressEndpoint. -func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route { +func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) *Route { if localAddressNIC.stack != outgoingNIC.stack { panic(fmt.Sprintf("cannot create a route with NICs from different stacks")) } + if len(localAddr) == 0 { + localAddr = localAddressEndpoint.AddressWithPrefix().Address + } + loop := PacketOut // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the @@ -133,18 +171,23 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop) } -func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) Route { - r := Route{ - NetProto: netProto, - LocalAddress: localAddr, - LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(), - RemoteAddress: remoteAddr, - localAddressNIC: localAddressNIC, - localAddressEndpoint: localAddressEndpoint, - outgoingNIC: outgoingNIC, - Loop: loop, +func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) *Route { + r := &Route{ + routeInfo: routeInfo{ + NetProto: netProto, + LocalAddress: localAddr, + LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(), + RemoteAddress: remoteAddr, + Loop: loop, + }, + localAddressNIC: localAddressNIC, + outgoingNIC: outgoingNIC, } + r.mu.Lock() + r.mu.localAddressEndpoint = localAddressEndpoint + r.mu.Unlock() + if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok { r.linkRes = linkRes @@ -159,7 +202,7 @@ func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr // provided AssignableAddressEndpoint. // // A local route is a route to a destination that is local to the stack. -func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) Route { +func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) *Route { loop := PacketLoop // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the // link endpoint level. We can remove this check once loopback interfaces @@ -170,26 +213,12 @@ func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop) } -// PopulatePacketInfo populates a packet buffer's packet information fields. -// -// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by -// the network layer. -func (r *Route) PopulatePacketInfo(pkt *PacketBuffer) { - if r.local() { - pkt.RXTransportChecksumValidated = true - } - pkt.NetworkPacketInfo = r.networkPacketInfo() -} - -// networkPacketInfo returns the network packet information of the route. -// -// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by -// the network layer. -func (r *Route) networkPacketInfo() NetworkPacketInfo { - return NetworkPacketInfo{ - RemoteAddressBroadcast: r.IsOutboundBroadcast(), - LocalAddressBroadcast: r.isInboundBroadcast(), - } +// RemoteLinkAddress returns the link-layer (MAC) address of the next hop in +// the route. +func (r *Route) RemoteLinkAddress() tcpip.LinkAddress { + r.mu.RLock() + defer r.mu.RUnlock() + return r.mu.remoteLinkAddress } // NICID returns the id of the NIC from which this route originates. @@ -253,22 +282,26 @@ func (r *Route) GSOMaxSize() uint32 { // ResolveWith immediately resolves a route with the specified remote link // address. func (r *Route) ResolveWith(addr tcpip.LinkAddress) { - r.RemoteLinkAddress = addr + r.mu.Lock() + defer r.mu.Unlock() + r.mu.remoteLinkAddress = addr } -// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in -// case address resolution requires blocking, e.g. wait for ARP reply. Waker is -// notified when address resolution is complete (success or not). -// -// If address resolution is required, ErrNoLinkAddress and a notification channel is -// returned for the top level caller to block. Channel is closed once address resolution -// is complete (success or not). +// Resolve attempts to resolve the link address if necessary. // -// The NIC r uses must not be locked. -func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { - if !r.IsResolutionRequired() { +// Returns tcpip.ErrWouldBlock if address resolution requires blocking (e.g. +// waiting for ARP reply). If address resolution is required, a notification +// channel is also returned for the caller to block on. The channel is closed +// once address resolution is complete (successful or not). If a callback is +// provided, it will be called when address resolution is complete, regardless +// of success or failure. +func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) { + r.mu.Lock() + + if !r.isResolutionRequiredRLocked() { // Nothing to do if there is no cache (which does the resolution on cache miss) or // link address is already known. + r.mu.Unlock() return nil, nil } @@ -276,7 +309,8 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { if nextAddr == "" { // Local link address is already known. if r.RemoteAddress == r.LocalAddress { - r.RemoteLinkAddress = r.LocalLinkAddress + r.mu.remoteLinkAddress = r.LocalLinkAddress + r.mu.Unlock() return nil, nil } nextAddr = r.RemoteAddress @@ -289,38 +323,36 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { linkAddressResolutionRequestLocalAddr = r.LocalAddress } + // Increment the route's reference count because finishResolution retains a + // reference to the route and releases it when called. + r.acquireLocked() + r.mu.Unlock() + + finishResolution := func(linkAddress tcpip.LinkAddress, ok bool) { + if ok { + r.ResolveWith(linkAddress) + } + if afterResolve != nil { + afterResolve() + } + r.Release() + } + if neigh := r.outgoingNIC.neigh; neigh != nil { - entry, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, waker) + _, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, finishResolution) if err != nil { return ch, err } - r.RemoteLinkAddress = entry.LinkAddr return nil, nil } - linkAddr, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, waker) + _, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, finishResolution) if err != nil { return ch, err } - r.RemoteLinkAddress = linkAddr return nil, nil } -// RemoveWaker removes a waker that has been added in Resolve(). -func (r *Route) RemoveWaker(waker *sleep.Waker) { - nextAddr := r.NextHop - if nextAddr == "" { - nextAddr = r.RemoteAddress - } - - if neigh := r.outgoingNIC.neigh; neigh != nil { - neigh.removeWaker(nextAddr, waker) - return - } - - r.linkCache.RemoveWaker(r.outgoingNIC.ID(), nextAddr, waker) -} - // local returns true if the route is a local route. func (r *Route) local() bool { return r.Loop == PacketLoop || r.outgoingNIC.IsLoopback() @@ -331,7 +363,13 @@ func (r *Route) local() bool { // // The NICs the route is associated with must not be locked. func (r *Route) IsResolutionRequired() bool { - if !r.isValidForOutgoing() || r.RemoteLinkAddress != "" || r.local() { + r.mu.RLock() + defer r.mu.RUnlock() + return r.isResolutionRequiredRLocked() +} + +func (r *Route) isResolutionRequiredRLocked() bool { + if !r.isValidForOutgoingRLocked() || r.mu.remoteLinkAddress != "" || r.local() { return false } @@ -339,11 +377,18 @@ func (r *Route) IsResolutionRequired() bool { } func (r *Route) isValidForOutgoing() bool { + r.mu.RLock() + defer r.mu.RUnlock() + return r.isValidForOutgoingRLocked() +} + +func (r *Route) isValidForOutgoingRLocked() bool { if !r.outgoingNIC.Enabled() { return false } - if !r.localAddressNIC.isValidForOutgoing(r.localAddressEndpoint) { + localAddressEndpoint := r.mu.localAddressEndpoint + if localAddressEndpoint == nil || !r.localAddressNIC.isValidForOutgoing(localAddressEndpoint) { return false } @@ -395,39 +440,31 @@ func (r *Route) MTU() uint32 { return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MTU() } -// Release frees all resources associated with the route. +// Release decrements the reference counter of the resources associated with the +// route. func (r *Route) Release() { - if r.localAddressEndpoint != nil { - r.localAddressEndpoint.DecRef() - r.localAddressEndpoint = nil + r.mu.Lock() + defer r.mu.Unlock() + + if ep := r.mu.localAddressEndpoint; ep != nil { + ep.DecRef() } } -// Clone clones the route. -func (r *Route) Clone() Route { - if r.localAddressEndpoint != nil { - if !r.localAddressEndpoint.IncRef() { +// Acquire increments the reference counter of the resources associated with the +// route. +func (r *Route) Acquire() { + r.mu.RLock() + defer r.mu.RUnlock() + r.acquireLocked() +} + +func (r *Route) acquireLocked() { + if ep := r.mu.localAddressEndpoint; ep != nil { + if !ep.IncRef() { panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress)) } } - return *r -} - -// MakeLoopedRoute duplicates the given route with special handling for routes -// used for sending multicast or broadcast packets. In those cases the -// multicast/broadcast address is the remote address when sending out, but for -// incoming (looped) packets it becomes the local address. Similarly, the local -// interface address that was the local address going out becomes the remote -// address coming in. This is different to unicast routes where local and -// remote addresses remain the same as they identify location (local vs remote) -// not direction (source vs destination). -func (r *Route) MakeLoopedRoute() Route { - l := r.Clone() - if r.RemoteAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) { - l.RemoteAddress, l.LocalAddress = l.LocalAddress, l.RemoteAddress - l.RemoteLinkAddress = l.LocalLinkAddress - } - return l } // Stack returns the instance of the Stack that owns this route. @@ -440,7 +477,14 @@ func (r *Route) isV4Broadcast(addr tcpip.Address) bool { return true } - subnet := r.localAddressEndpoint.Subnet() + r.mu.RLock() + localAddressEndpoint := r.mu.localAddressEndpoint + r.mu.RUnlock() + if localAddressEndpoint == nil { + return false + } + + subnet := localAddressEndpoint.Subnet() return subnet.IsBroadcast(addr) } @@ -450,27 +494,3 @@ func (r *Route) IsOutboundBroadcast() bool { // Only IPv4 has a notion of broadcast. return r.isV4Broadcast(r.RemoteAddress) } - -// isInboundBroadcast returns true if the route is for an inbound broadcast -// packet. -func (r *Route) isInboundBroadcast() bool { - // Only IPv4 has a notion of broadcast. - return r.isV4Broadcast(r.LocalAddress) -} - -// ReverseRoute returns new route with given source and destination address. -func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route { - return Route{ - NetProto: r.NetProto, - LocalAddress: dst, - LocalLinkAddress: r.RemoteLinkAddress, - RemoteAddress: src, - RemoteLinkAddress: r.LocalLinkAddress, - Loop: r.Loop, - localAddressNIC: r.localAddressNIC, - localAddressEndpoint: r.localAddressEndpoint, - outgoingNIC: r.outgoingNIC, - linkCache: r.linkCache, - linkRes: r.linkRes, - } -} diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 0fe157128..114643b03 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -29,7 +29,6 @@ import ( "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/rand" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -82,6 +81,7 @@ type TCPRACKState struct { FACK seqnum.Value RTT time.Duration Reord bool + DSACKSeen bool } // TCPEndpointID is the unique 4 tuple that identifies a given endpoint. @@ -170,6 +170,9 @@ type TCPSenderState struct { // Outstanding is the number of packets in flight. Outstanding int + // SackedOut is the number of packets which have been selectively acked. + SackedOut int + // SndWnd is the send window size in bytes. SndWnd seqnum.Size @@ -1080,7 +1083,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { flags := NICStateFlags{ Up: true, // Netstack interfaces are always up. Running: nic.Enabled(), - Promiscuous: nic.isPromiscuousMode(), + Promiscuous: nic.Promiscuous(), Loopback: nic.IsLoopback(), } nics[id] = NICInfo{ @@ -1117,6 +1120,16 @@ func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint) } +// AddAddressWithPrefix is the same as AddAddress, but allows you to specify +// the address prefix. +func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) *tcpip.Error { + ap := tcpip.ProtocolAddress{ + Protocol: protocol, + AddressWithPrefix: addr, + } + return s.AddProtocolAddressWithOptions(id, ap, CanBePrimaryEndpoint) +} + // AddProtocolAddress adds a new network-layer protocol address to the // specified NIC. func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) *tcpip.Error { @@ -1207,10 +1220,10 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP // from the specified NIC. // // Precondition: s.mu must be read locked. -func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) { +func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route { localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint) if localAddressEndpoint == nil { - return Route{}, false + return nil } var outgoingNIC *NIC @@ -1234,12 +1247,12 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re // route. if outgoingNIC == nil { localAddressEndpoint.DecRef() - return Route{}, false + return nil } r := makeLocalRoute( netProto, - localAddressEndpoint.AddressWithPrefix().Address, + localAddr, remoteAddr, outgoingNIC, localAddressNIC, @@ -1248,10 +1261,10 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re if r.IsOutboundBroadcast() { r.Release() - return Route{}, false + return nil } - return r, true + return r } // findLocalRouteRLocked returns a local route. @@ -1260,26 +1273,26 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re // is, a local route is a route where packets never have to leave the stack. // // Precondition: s.mu must be read locked. -func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) { +func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route { if len(localAddr) == 0 { localAddr = remoteAddr } if localAddressNICID == 0 { for _, localAddressNIC := range s.nics { - if r, ok := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); ok { - return r, true + if r := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); r != nil { + return r } } - return Route{}, false + return nil } if localAddressNIC, ok := s.nics[localAddressNICID]; ok { return s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto) } - return Route{}, false + return nil } // FindRoute creates a route to the given destination address, leaving through @@ -1293,7 +1306,7 @@ func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, // If no local address is provided, the stack will select a local address. If no // remote address is provided, the stack wil use a remote address equal to the // local address. -func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) { +func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (*Route, *tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() @@ -1304,7 +1317,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || isLoopback) if s.handleLocal && !isMulticast && !isLocalBroadcast { - if r, ok := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); ok { + if r := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); r != nil { return r, nil } } @@ -1316,7 +1329,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { return makeRoute( netProto, - addressEndpoint.AddressWithPrefix().Address, + localAddr, remoteAddr, nic, /* outboundNIC */ nic, /* localAddressNIC*/ @@ -1328,9 +1341,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } if isLoopback { - return Route{}, tcpip.ErrBadLocalAddress + return nil, tcpip.ErrBadLocalAddress } - return Route{}, tcpip.ErrNetworkUnreachable + return nil, tcpip.ErrNetworkUnreachable } canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal @@ -1353,8 +1366,8 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if needRoute { gateway = route.Gateway } - r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop) - if r == (Route{}) { + r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop) + if r == nil { panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr)) } return r, nil @@ -1390,13 +1403,13 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if id != 0 { if aNIC, ok := s.nics[id]; ok { if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil { - if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil { return r, nil } } } - return Route{}, tcpip.ErrNoRoute + return nil, tcpip.ErrNoRoute } if id == 0 { @@ -1408,7 +1421,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n continue } - if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil { return r, nil } } @@ -1416,12 +1429,12 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } if needRoute { - return Route{}, tcpip.ErrNoRoute + return nil, tcpip.ErrNoRoute } if header.IsV6LoopbackAddress(remoteAddr) { - return Route{}, tcpip.ErrBadLocalAddress + return nil, tcpip.ErrBadLocalAddress } - return Route{}, tcpip.ErrNetworkUnreachable + return nil, tcpip.ErrNetworkUnreachable } // CheckNetworkProtocol checks if a given network protocol is enabled in the @@ -1506,7 +1519,7 @@ func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr t } // GetLinkAddress implements LinkAddressCache.GetLinkAddress. -func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { s.mu.RLock() nic := s.nics[nicID] if nic == nil { @@ -1517,7 +1530,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} linkRes := s.linkAddrResolvers[protocol] - return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, waker) + return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, onResolve) } // Neighbors returns all IP to MAC address associations. @@ -1533,29 +1546,6 @@ func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, *tcpip.Error) { return nic.neighbors() } -// RemoveWaker removes a waker that has been added when link resolution for -// addr was requested. -func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) { - if s.useNeighborCache { - s.mu.RLock() - nic, ok := s.nics[nicID] - s.mu.RUnlock() - - if ok { - nic.removeWaker(addr, waker) - } - return - } - - s.mu.RLock() - defer s.mu.RUnlock() - - if nic := s.nics[nicID]; nic == nil { - fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} - s.linkAddrCache.removeWaker(fullAddr, waker) - } -} - // AddStaticNeighbor statically associates an IP address to a MAC address. func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) *tcpip.Error { s.mu.RLock() @@ -1809,49 +1799,20 @@ func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip nic.unregisterPacketEndpoint(netProto, ep) } -// WritePacket writes data directly to the specified NIC. It adds an ethernet -// header based on the arguments. -func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error { +// WritePacketToRemote writes a payload on the specified NIC using the provided +// network protocol and remote link address. +func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error { s.mu.Lock() nic, ok := s.nics[nicID] s.mu.Unlock() if !ok { return tcpip.ErrUnknownDevice } - - // Add our own fake ethernet header. - ethFields := header.EthernetFields{ - SrcAddr: nic.LinkEndpoint.LinkAddress(), - DstAddr: dst, - Type: netProto, - } - fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) - fakeHeader.Encode(ðFields) - vv := buffer.View(fakeHeader).ToVectorisedView() - vv.Append(payload) - - if err := nic.LinkEndpoint.WriteRawPacket(vv); err != nil { - return err - } - - return nil -} - -// WriteRawPacket writes data directly to the specified NIC without adding any -// headers. -func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error { - s.mu.Lock() - nic, ok := s.nics[nicID] - s.mu.Unlock() - if !ok { - return tcpip.ErrUnknownDevice - } - - if err := nic.LinkEndpoint.WriteRawPacket(payload); err != nil { - return err - } - - return nil + pkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(nic.MaxHeaderLength()), + Data: payload, + }) + return nic.WritePacketToRemote(remote, nil, netProto, pkt) } // NetworkProtocolInstance returns the protocol instance in the stack for the @@ -1911,7 +1872,6 @@ func (s *Stack) RemoveTCPProbe() { // JoinGroup joins the given multicast group on the given NIC. func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error { - // TODO: notify network of subscription via igmp protocol. s.mu.RLock() defer s.mu.RUnlock() @@ -2158,3 +2118,43 @@ func (s *Stack) networkProtocolNumbers() []tcpip.NetworkProtocolNumber { } return protos } + +func isSubnetBroadcastOnNIC(nic *NIC, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + addressEndpoint := nic.getAddressOrCreateTempInner(protocol, addr, false /* createTemp */, NeverPrimaryEndpoint) + if addressEndpoint == nil { + return false + } + + subnet := addressEndpoint.Subnet() + addressEndpoint.DecRef() + return subnet.IsBroadcast(addr) +} + +// IsSubnetBroadcast returns true if the provided address is a subnet-local +// broadcast address on the specified NIC and protocol. +// +// Returns false if the NIC is unknown or if the protocol is unknown or does +// not support addressing. +// +// If the NIC is not specified, the stack will check all NICs. +func (s *Stack) IsSubnetBroadcast(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + s.mu.RLock() + defer s.mu.RUnlock() + + if nicID != 0 { + nic, ok := s.nics[nicID] + if !ok { + return false + } + + return isSubnetBroadcastOnNIC(nic, protocol, addr) + } + + for _, nic := range s.nics { + if isSubnetBroadcastOnNIC(nic, protocol, addr) { + return true + } + } + + return false +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index dedfdd435..856ebf6d4 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -27,7 +27,6 @@ import ( "time" "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -112,7 +111,15 @@ func (*fakeNetworkEndpoint) DefaultTTL() uint8 { func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Increment the received packet count in the protocol descriptor. netHdr := pkt.NetworkHeader().View() - f.proto.packetCount[int(netHdr[dstAddrOffset])%len(f.proto.packetCount)]++ + + dst := tcpip.Address(netHdr[dstAddrOffset:][:1]) + addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), stack.CanBePrimaryEndpoint) + if addressEndpoint == nil { + return + } + addressEndpoint.DecRef() + + f.proto.packetCount[int(dst[0])%len(f.proto.packetCount)]++ // Handle control packets. if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { @@ -159,9 +166,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params hdr[protocolNumberOffset] = byte(params.Protocol) if r.Loop&stack.PacketLoop != 0 { - pkt := pkt.Clone() - r.PopulatePacketInfo(pkt) - f.HandlePacket(pkt) + f.HandlePacket(pkt.Clone()) } if r.Loop&stack.PacketOut == 0 { return nil @@ -401,7 +406,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro return send(r, payload) } -func send(r stack.Route, payload buffer.View) *tcpip.Error { +func send(r *stack.Route, payload buffer.View) *tcpip.Error { return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: payload.ToVectorisedView(), @@ -419,7 +424,7 @@ func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.En } } -func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) { +func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View) { t.Helper() ep.Drain() if err := send(r, payload); err != nil { @@ -430,7 +435,7 @@ func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer. } } -func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { +func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { t.Helper() if gotErr := send(r, payload); gotErr != wantErr { t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) @@ -1557,15 +1562,15 @@ func TestSpoofingNoAddress(t *testing.T) { // testSendTo(t, s, remoteAddr, ep, nil) } -func verifyRoute(gotRoute, wantRoute stack.Route) error { +func verifyRoute(gotRoute, wantRoute *stack.Route) error { if gotRoute.LocalAddress != wantRoute.LocalAddress { return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress) } if gotRoute.RemoteAddress != wantRoute.RemoteAddress { return fmt.Errorf("bad remote address: got %s, want = %s", gotRoute.RemoteAddress, wantRoute.RemoteAddress) } - if gotRoute.RemoteLinkAddress != wantRoute.RemoteLinkAddress { - return fmt.Errorf("bad remote link address: got %s, want = %s", gotRoute.RemoteLinkAddress, wantRoute.RemoteLinkAddress) + if got, want := gotRoute.RemoteLinkAddress(), wantRoute.RemoteLinkAddress(); got != want { + return fmt.Errorf("bad remote link address: got %s, want = %s", got, want) } if gotRoute.NextHop != wantRoute.NextHop { return fmt.Errorf("bad next-hop address: got %s, want = %s", gotRoute.NextHop, wantRoute.NextHop) @@ -1597,7 +1602,10 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil { + var wantRoute stack.Route + wantRoute.LocalAddress = header.IPv4Any + wantRoute.RemoteAddress = header.IPv4Broadcast + if err := verifyRoute(r, &wantRoute); err != nil { t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } @@ -1651,7 +1659,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + var wantRoute stack.Route + wantRoute.LocalAddress = nic1Addr.Address + wantRoute.RemoteAddress = header.IPv4Broadcast + if err := verifyRoute(r, &wantRoute); err != nil { t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } @@ -1661,7 +1672,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + wantRoute = stack.Route{} + wantRoute.LocalAddress = nic2Addr.Address + wantRoute.RemoteAddress = header.IPv4Broadcast + if err := verifyRoute(r, &wantRoute); err != nil { t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) } @@ -1677,7 +1691,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + wantRoute = stack.Route{} + wantRoute.LocalAddress = nic1Addr.Address + wantRoute.RemoteAddress = header.IPv4Broadcast + if err := verifyRoute(r, &wantRoute); err != nil { t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) } } @@ -2214,88 +2231,6 @@ func TestNICStats(t *testing.T) { } } -func TestNICForwarding(t *testing.T) { - const nicID1 = 1 - const nicID2 = 2 - const dstAddr = tcpip.Address("\x03") - - tests := []struct { - name string - headerLen uint16 - }{ - { - name: "Zero header length", - }, - { - name: "Non-zero header length", - headerLen: 16, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - s.SetForwarding(fakeNetNumber, true) - - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID1, ep1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err) - } - - ep2 := channelLinkWithHeaderLength{ - Endpoint: channel.New(10, defaultMTU, ""), - headerLength: test.headerLen, - } - if err := s.CreateNIC(nicID2, &ep2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) - } - if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err) - } - - // Route all packets to dstAddr to NIC 2. - { - subnet, err := tcpip.NewSubnet(dstAddr, "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}}) - } - - // Send a packet to dstAddr. - buf := buffer.NewView(30) - buf[dstAddrOffset] = dstAddr[0] - ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - pkt, ok := ep2.Read() - if !ok { - t.Fatal("packet not forwarded") - } - - // Test that the link's MaxHeaderLength is honoured. - if capacity, want := pkt.Pkt.AvailableHeaderBytes(), int(test.headerLen); capacity != want { - t.Errorf("got LinkHeader.AvailableLength() = %d, want = %d", capacity, want) - } - - // Test that forwarding increments Tx stats correctly. - if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want) - } - - if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) - } - }) - } -} - // TestNICContextPreservation tests that you can read out via stack.NICInfo the // Context data you pass via NICContext.Context in stack.CreateNICWithOptions. func TestNICContextPreservation(t *testing.T) { @@ -2483,9 +2418,9 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { } opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenIPv6LinkLocal: test.autoGen, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: test.iidOpts, + AutoGenLinkLocal: test.autoGen, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: test.iidOpts, })}, } @@ -2578,8 +2513,8 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { t.Run(test.name, func(t *testing.T) { opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenIPv6LinkLocal: true, - OpaqueIIDOpts: test.opaqueIIDOpts, + AutoGenLinkLocal: true, + OpaqueIIDOpts: test.opaqueIIDOpts, })}, } @@ -2612,9 +2547,9 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { ndpConfigs := ipv6.DefaultNDPConfigurations() opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ndpConfigs, - AutoGenIPv6LinkLocal: true, - NDPDisp: &ndpDisp, + NDPConfigs: ndpConfigs, + AutoGenLinkLocal: true, + NDPDisp: &ndpDisp, })}, } @@ -2803,8 +2738,16 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - nicID = 1 - lifetimeSeconds = 9999 + globalAddr3 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") + ipv4MappedIPv6Addr1 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x01") + ipv4MappedIPv6Addr2 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x02") + toredoAddr1 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + toredoAddr2 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + ipv6ToIPv4Addr1 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + ipv6ToIPv4Addr2 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + + nicID = 1 + lifetimeSeconds = 9999 ) prefix1, _, stableGlobalAddr1 := prefixSubnetAddr(0, linkAddr1) @@ -2821,139 +2764,191 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { slaacPrefixForTempAddrBeforeNICAddrAdd tcpip.AddressWithPrefix nicAddrs []tcpip.Address slaacPrefixForTempAddrAfterNICAddrAdd tcpip.AddressWithPrefix - connectAddr tcpip.Address + remoteAddr tcpip.Address expectedLocalAddr tcpip.Address }{ - // Test Rule 1 of RFC 6724 section 5. + // Test Rule 1 of RFC 6724 section 5 (prefer same address). { name: "Same Global most preferred (last address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: globalAddr1, + nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, + remoteAddr: globalAddr1, expectedLocalAddr: globalAddr1, }, { name: "Same Global most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, - connectAddr: globalAddr1, + nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1}, + remoteAddr: globalAddr1, expectedLocalAddr: globalAddr1, }, { name: "Same Link Local most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, - connectAddr: linkLocalAddr1, + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, + remoteAddr: linkLocalAddr1, expectedLocalAddr: linkLocalAddr1, }, { name: "Same Link Local most preferred (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: linkLocalAddr1, + nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, + remoteAddr: linkLocalAddr1, expectedLocalAddr: linkLocalAddr1, }, { name: "Same Unique Local most preferred (last address)", - nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1}, - connectAddr: uniqueLocalAddr1, + nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1}, + remoteAddr: uniqueLocalAddr1, expectedLocalAddr: uniqueLocalAddr1, }, { name: "Same Unique Local most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, - connectAddr: uniqueLocalAddr1, + nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1}, + remoteAddr: uniqueLocalAddr1, expectedLocalAddr: uniqueLocalAddr1, }, - // Test Rule 2 of RFC 6724 section 5. + // Test Rule 2 of RFC 6724 section 5 (prefer appropriate scope). { name: "Global most preferred (last address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: globalAddr2, + nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, + remoteAddr: globalAddr2, expectedLocalAddr: globalAddr1, }, { name: "Global most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, - connectAddr: globalAddr2, + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, + remoteAddr: globalAddr2, expectedLocalAddr: globalAddr1, }, { name: "Link Local most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, - connectAddr: linkLocalAddr2, + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, + remoteAddr: linkLocalAddr2, expectedLocalAddr: linkLocalAddr1, }, { name: "Link Local most preferred (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: linkLocalAddr2, + nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, + remoteAddr: linkLocalAddr2, expectedLocalAddr: linkLocalAddr1, }, { name: "Link Local most preferred for link local multicast (last address)", - nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, - connectAddr: linkLocalMulticastAddr, + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, + remoteAddr: linkLocalMulticastAddr, expectedLocalAddr: linkLocalAddr1, }, { name: "Link Local most preferred for link local multicast (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: linkLocalMulticastAddr, + nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, + remoteAddr: linkLocalMulticastAddr, expectedLocalAddr: linkLocalAddr1, }, + + // Test Rule 6 of 6724 section 5 (prefer matching label). { name: "Unique Local most preferred (last address)", - nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1}, - connectAddr: uniqueLocalAddr2, + nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1, toredoAddr1, ipv6ToIPv4Addr1}, + remoteAddr: uniqueLocalAddr2, expectedLocalAddr: uniqueLocalAddr1, }, { name: "Unique Local most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, - connectAddr: uniqueLocalAddr2, + nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, toredoAddr1, ipv6ToIPv4Addr1, uniqueLocalAddr1}, + remoteAddr: uniqueLocalAddr2, expectedLocalAddr: uniqueLocalAddr1, }, + { + name: "Toredo most preferred (first address)", + nicAddrs: []tcpip.Address{toredoAddr1, uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1}, + remoteAddr: toredoAddr2, + expectedLocalAddr: toredoAddr1, + }, + { + name: "Toredo most preferred (last address)", + nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1, uniqueLocalAddr1, toredoAddr1}, + remoteAddr: toredoAddr2, + expectedLocalAddr: toredoAddr1, + }, + { + name: "6To4 most preferred (first address)", + nicAddrs: []tcpip.Address{ipv6ToIPv4Addr1, toredoAddr1, uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1}, + remoteAddr: ipv6ToIPv4Addr2, + expectedLocalAddr: ipv6ToIPv4Addr1, + }, + { + name: "6To4 most preferred (last address)", + nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, uniqueLocalAddr1, toredoAddr1, ipv6ToIPv4Addr1}, + remoteAddr: ipv6ToIPv4Addr2, + expectedLocalAddr: ipv6ToIPv4Addr1, + }, + { + name: "IPv4 mapped IPv6 most preferred (first address)", + nicAddrs: []tcpip.Address{ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1, toredoAddr1, uniqueLocalAddr1, globalAddr1}, + remoteAddr: ipv4MappedIPv6Addr2, + expectedLocalAddr: ipv4MappedIPv6Addr1, + }, + { + name: "IPv4 mapped IPv6 most preferred (last address)", + nicAddrs: []tcpip.Address{globalAddr1, ipv6ToIPv4Addr1, uniqueLocalAddr1, toredoAddr1, ipv4MappedIPv6Addr1}, + remoteAddr: ipv4MappedIPv6Addr2, + expectedLocalAddr: ipv4MappedIPv6Addr1, + }, - // Test Rule 7 of RFC 6724 section 5. + // Test Rule 7 of RFC 6724 section 5 (prefer temporary addresses). { name: "Temp Global most preferred (last address)", slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1, nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: globalAddr2, + remoteAddr: globalAddr2, expectedLocalAddr: tempGlobalAddr1, }, { name: "Temp Global most preferred (first address)", nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, slaacPrefixForTempAddrAfterNICAddrAdd: prefix1, - connectAddr: globalAddr2, + remoteAddr: globalAddr2, expectedLocalAddr: tempGlobalAddr1, }, + // Test Rule 8 of RFC 6724 section 5 (use longest matching prefix). + { + name: "Longest prefix matched most preferred (first address)", + nicAddrs: []tcpip.Address{globalAddr2, globalAddr1}, + remoteAddr: globalAddr3, + expectedLocalAddr: globalAddr2, + }, + { + name: "Longest prefix matched most preferred (last address)", + nicAddrs: []tcpip.Address{globalAddr1, globalAddr2}, + remoteAddr: globalAddr3, + expectedLocalAddr: globalAddr2, + }, + // Test returning the endpoint that is closest to the front when // candidate addresses are "equal" from the perspective of RFC 6724 // section 5. { name: "Unique Local for Global", nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, uniqueLocalAddr2}, - connectAddr: globalAddr2, + remoteAddr: globalAddr2, expectedLocalAddr: uniqueLocalAddr1, }, { name: "Link Local for Global", nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2}, - connectAddr: globalAddr2, + remoteAddr: globalAddr2, expectedLocalAddr: linkLocalAddr1, }, { name: "Link Local for Unique Local", nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2}, - connectAddr: uniqueLocalAddr2, + remoteAddr: uniqueLocalAddr2, expectedLocalAddr: linkLocalAddr1, }, { name: "Temp Global for Global", slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1, slaacPrefixForTempAddrAfterNICAddrAdd: prefix2, - connectAddr: globalAddr1, + remoteAddr: globalAddr1, expectedLocalAddr: tempGlobalAddr2, }, } @@ -2975,12 +2970,6 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv6EmptySubnet, - Gateway: llAddr3, - NIC: nicID, - }}) - s.AddLinkAddress(nicID, llAddr3, linkAddr3) if test.slaacPrefixForTempAddrBeforeNICAddrAdd != (tcpip.AddressWithPrefix{}) { e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, test.slaacPrefixForTempAddrBeforeNICAddrAdd, true, true, lifetimeSeconds, lifetimeSeconds)) @@ -3000,7 +2989,23 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { t.FailNow() } - if got := addrForNewConnectionTo(t, s, tcpip.FullAddress{Addr: test.connectAddr, NIC: nicID, Port: 1234}); got != test.expectedLocalAddr { + netEP, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } + + addressableEndpoint, ok := netEP.(stack.AddressableEndpoint) + if !ok { + t.Fatal("network endpoint is not addressable") + } + + addressEP := addressableEndpoint.AcquireOutgoingPrimaryAddress(test.remoteAddr, false /* allowExpired */) + if addressEP == nil { + t.Fatal("expected a non-nil address endpoint") + } + defer addressEP.DecRef() + + if got := addressEP.AddressWithPrefix().Address; got != test.expectedLocalAddr { t.Errorf("got local address = %s, want = %s", got, test.expectedLocalAddr) } }) @@ -3427,11 +3432,16 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { remNetSubnetBcast := remNetSubnet.Broadcast() tests := []struct { - name string - nicAddr tcpip.ProtocolAddress - routes []tcpip.Route - remoteAddr tcpip.Address - expectedRoute stack.Route + name string + nicAddr tcpip.ProtocolAddress + routes []tcpip.Route + remoteAddr tcpip.Address + expectedLocalAddress tcpip.Address + expectedRemoteAddress tcpip.Address + expectedRemoteLinkAddress tcpip.LinkAddress + expectedNextHop tcpip.Address + expectedNetProto tcpip.NetworkProtocolNumber + expectedLoop stack.PacketLooping }{ // Broadcast to a locally attached subnet populates the broadcast MAC. { @@ -3446,14 +3456,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: ipv4SubnetBcast, - expectedRoute: stack.Route{ - LocalAddress: ipv4Addr.Address, - RemoteAddress: ipv4SubnetBcast, - RemoteLinkAddress: header.EthernetBroadcastAddress, - NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut | stack.PacketLoop, - }, + remoteAddr: ipv4SubnetBcast, + expectedLocalAddress: ipv4Addr.Address, + expectedRemoteAddress: ipv4SubnetBcast, + expectedRemoteLinkAddress: header.EthernetBroadcastAddress, + expectedNetProto: header.IPv4ProtocolNumber, + expectedLoop: stack.PacketOut | stack.PacketLoop, }, // Broadcast to a locally attached /31 subnet does not populate the // broadcast MAC. @@ -3469,13 +3477,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: ipv4Subnet31Bcast, - expectedRoute: stack.Route{ - LocalAddress: ipv4AddrPrefix31.Address, - RemoteAddress: ipv4Subnet31Bcast, - NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut, - }, + remoteAddr: ipv4Subnet31Bcast, + expectedLocalAddress: ipv4AddrPrefix31.Address, + expectedRemoteAddress: ipv4Subnet31Bcast, + expectedNetProto: header.IPv4ProtocolNumber, + expectedLoop: stack.PacketOut, }, // Broadcast to a locally attached /32 subnet does not populate the // broadcast MAC. @@ -3491,13 +3497,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: ipv4Subnet32Bcast, - expectedRoute: stack.Route{ - LocalAddress: ipv4AddrPrefix32.Address, - RemoteAddress: ipv4Subnet32Bcast, - NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut, - }, + remoteAddr: ipv4Subnet32Bcast, + expectedLocalAddress: ipv4AddrPrefix32.Address, + expectedRemoteAddress: ipv4Subnet32Bcast, + expectedNetProto: header.IPv4ProtocolNumber, + expectedLoop: stack.PacketOut, }, // IPv6 has no notion of a broadcast. { @@ -3512,13 +3516,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: ipv6SubnetBcast, - expectedRoute: stack.Route{ - LocalAddress: ipv6Addr.Address, - RemoteAddress: ipv6SubnetBcast, - NetProto: header.IPv6ProtocolNumber, - Loop: stack.PacketOut, - }, + remoteAddr: ipv6SubnetBcast, + expectedLocalAddress: ipv6Addr.Address, + expectedRemoteAddress: ipv6SubnetBcast, + expectedNetProto: header.IPv6ProtocolNumber, + expectedLoop: stack.PacketOut, }, // Broadcast to a remote subnet in the route table is send to the next-hop // gateway. @@ -3535,14 +3537,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: remNetSubnetBcast, - expectedRoute: stack.Route{ - LocalAddress: ipv4Addr.Address, - RemoteAddress: remNetSubnetBcast, - NextHop: ipv4Gateway, - NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut, - }, + remoteAddr: remNetSubnetBcast, + expectedLocalAddress: ipv4Addr.Address, + expectedRemoteAddress: remNetSubnetBcast, + expectedNextHop: ipv4Gateway, + expectedNetProto: header.IPv4ProtocolNumber, + expectedLoop: stack.PacketOut, }, // Broadcast to an unknown subnet follows the default route. Note that this // is essentially just routing an unknown destination IP, because w/o any @@ -3560,14 +3560,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: remNetSubnetBcast, - expectedRoute: stack.Route{ - LocalAddress: ipv4Addr.Address, - RemoteAddress: remNetSubnetBcast, - NextHop: ipv4Gateway, - NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut, - }, + remoteAddr: remNetSubnetBcast, + expectedLocalAddress: ipv4Addr.Address, + expectedRemoteAddress: remNetSubnetBcast, + expectedNextHop: ipv4Gateway, + expectedNetProto: header.IPv4ProtocolNumber, + expectedLoop: stack.PacketOut, }, } @@ -3596,10 +3594,27 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { t.Fatalf("got unexpected address length = %d bytes", l) } - if r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */); err != nil { + r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */) + if err != nil { t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, test.remoteAddr, netProto, err) - } else if diff := cmp.Diff(r, test.expectedRoute, cmpopts.IgnoreUnexported(r)); diff != "" { - t.Errorf("route mismatch (-want +got):\n%s", diff) + } + if r.LocalAddress != test.expectedLocalAddress { + t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, test.expectedLocalAddress) + } + if r.RemoteAddress != test.expectedRemoteAddress { + t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, test.expectedRemoteAddress) + } + if got := r.RemoteLinkAddress(); got != test.expectedRemoteLinkAddress { + t.Errorf("got r.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddress) + } + if r.NextHop != test.expectedNextHop { + t.Errorf("got r.NextHop = %s, want = %s", r.NextHop, test.expectedNextHop) + } + if r.NetProto != test.expectedNetProto { + t.Errorf("got r.NetProto = %d, want = %d", r.NetProto, test.expectedNetProto) + } + if r.Loop != test.expectedLoop { + t.Errorf("got r.Loop = %x, want = %x", r.Loop, test.expectedLoop) } }) } @@ -4167,10 +4182,12 @@ func TestFindRouteWithForwarding(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) + if r != nil { + defer r.Release() + } if err != test.findRouteErr { t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr) } - defer r.Release() if test.findRouteErr != nil { return @@ -4228,3 +4245,63 @@ func TestFindRouteWithForwarding(t *testing.T) { }) } } + +func TestWritePacketToRemote(t *testing.T) { + const nicID = 1 + const MTU = 1280 + e := channel.New(1, MTU, linkAddr1) + s := stack.New(stack.Options{}) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("CreateNIC(%d) = %s", nicID, err) + } + tests := []struct { + name string + protocol tcpip.NetworkProtocolNumber + payload []byte + }{ + { + name: "SuccessIPv4", + protocol: header.IPv4ProtocolNumber, + payload: []byte{1, 2, 3, 4}, + }, + { + name: "SuccessIPv6", + protocol: header.IPv6ProtocolNumber, + payload: []byte{5, 6, 7, 8}, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if err := s.WritePacketToRemote(nicID, linkAddr2, test.protocol, buffer.View(test.payload).ToVectorisedView()); err != nil { + t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s", err) + } + + pkt, ok := e.Read() + if got, want := ok, true; got != want { + t.Fatalf("e.Read() = %t, want %t", got, want) + } + if got, want := pkt.Proto, test.protocol; got != want { + t.Fatalf("pkt.Proto = %d, want %d", got, want) + } + if pkt.Route.RemoteLinkAddress != linkAddr2 { + t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", pkt.Route.RemoteLinkAddress, linkAddr2) + } + if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" { + t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff) + } + }) + } + + t.Run("InvalidNICID", func(t *testing.T) { + if got, want := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView()), tcpip.ErrUnknownDevice; got != want { + t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", got, want) + } + pkt, ok := e.Read() + if got, want := ok, false; got != want { + t.Fatalf("e.Read() = %t, %v; want %t", got, pkt, want) + } + }) +} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index f183ec6e4..07b2818d2 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -182,7 +182,8 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. } -// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. +// handleControlPacket delivers a control packet to the transport endpoint +// identified by id. func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) { epsByNIC.mu.RLock() defer epsByNIC.mu.RUnlock() @@ -199,7 +200,7 @@ func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpoint // broadcast like we are doing with handlePacket above? // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNIC.seed).HandleControlPacket(id, typ, extra, pkt) + selectEndpoint(id, mpep, epsByNIC.seed).HandleControlPacket(typ, extra, pkt) } // registerEndpoint returns true if it succeeds. It fails and returns diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 41a8e5ad0..859278f0b 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -15,6 +15,7 @@ package stack_test import ( + "io/ioutil" "math" "math/rand" "testing" @@ -141,11 +142,11 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI // Initialize the IP header. ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.UDPMinimumSize + len(payload)), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 65, - SrcAddr: testSrcAddrV6, - DstAddr: testDstAddrV6, + PayloadLength: uint16(header.UDPMinimumSize + len(payload)), + TransportProtocol: udp.ProtocolNumber, + HopLimit: 65, + SrcAddr: testSrcAddrV6, + DstAddr: testDstAddrV6, }) // Initialize the UDP header. @@ -307,12 +308,9 @@ func TestBindToDeviceDistribution(t *testing.T) { }(ep) defer ep.Close() - if err := ep.SetSockOptBool(tcpip.ReusePortOption, endpoint.reuse); err != nil { - t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err) - } - bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) - if err := ep.SetSockOpt(&bindToDeviceOption); err != nil { - t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", bindToDeviceOption, bindToDeviceOption, i, err) + ep.SocketOptions().SetReusePort(endpoint.reuse) + if err := ep.SocketOptions().SetBindToDevice(int32(endpoint.bindToDevice)); err != nil { + t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", endpoint.bindToDevice, endpoint.bindToDevice, i, err) } var dstAddr tcpip.Address @@ -354,7 +352,7 @@ func TestBindToDeviceDistribution(t *testing.T) { } ep := <-pollChannel - if _, _, err := ep.Read(nil); err != nil { + if _, err := ep.Read(ioutil.Discard, math.MaxUint16, tcpip.ReadOptions{}); err != nil { t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err) } stats[ep]++ diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index c457b67a2..0ff32c6ea 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -15,12 +15,12 @@ package stack_test import ( + "io" "testing" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" @@ -39,14 +39,18 @@ const ( // use it. type fakeTransportEndpoint struct { stack.TransportEndpointInfo + tcpip.DefaultSocketOptionsHandler proto *fakeTransportProtocol peerAddr tcpip.Address - route stack.Route + route *stack.Route uniqueID uint64 // acceptQueue is non-nil iff bound. - acceptQueue []fakeTransportEndpoint + acceptQueue []*fakeTransportEndpoint + + // ops is used to set and get socket options. + ops tcpip.SocketOptions } func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo { @@ -59,8 +63,14 @@ func (*fakeTransportEndpoint) Stats() tcpip.EndpointStats { func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {} +func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions { + return &f.ops +} + func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { - return &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} + ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} + ep.ops.InitHandler(ep) + return ep } func (f *fakeTransportEndpoint) Abort() { @@ -68,6 +78,7 @@ func (f *fakeTransportEndpoint) Abort() { } func (f *fakeTransportEndpoint) Close() { + // TODO(gvisor.dev/issue/5153): Consider retaining the route. f.route.Release() } @@ -75,8 +86,8 @@ func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask return mask } -func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - return buffer.View{}, tcpip.ControlMessages{}, nil +func (*fakeTransportEndpoint) Read(io.Writer, int, tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { + return tcpip.ReadResult{}, nil } func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { @@ -100,30 +111,16 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions return int64(len(v)), nil, nil } -func (*fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil -} - // SetSockOpt sets a socket option. Currently not supported. func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error { return tcpip.ErrInvalidEndpointState } -// SetSockOptBool sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOptBool(tcpip.SockOptBool, bool) *tcpip.Error { - return tcpip.ErrInvalidEndpointState -} - // SetSockOptInt sets a socket option. Currently not supported. func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error { return tcpip.ErrInvalidEndpointState } -// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (*fakeTransportEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - return false, tcpip.ErrUnknownProtocolOption -} - // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return -1, tcpip.ErrUnknownProtocolOption @@ -147,16 +144,16 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if err != nil { return tcpip.ErrNoRoute } - defer r.Release() // Try to register so that we can start receiving packets. f.ID.RemoteAddress = addr.Addr err = f.proto.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) if err != nil { + r.Release() return err } - f.route = r.Clone() + f.route = r return nil } @@ -186,7 +183,7 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai } a := f.acceptQueue[0] f.acceptQueue = f.acceptQueue[1:] - return &a, nil, nil + return a, nil, nil } func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { @@ -201,7 +198,7 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { ); err != nil { return err } - f.acceptQueue = []fakeTransportEndpoint{} + f.acceptQueue = []*fakeTransportEndpoint{} return nil } @@ -227,7 +224,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt * } route.ResolveWith(pkt.SourceLinkAddress()) - f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ + ep := &fakeTransportEndpoint{ TransportEndpointInfo: stack.TransportEndpointInfo{ ID: f.ID, NetProto: f.NetProto, @@ -235,10 +232,12 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt * proto: f.proto, peerAddr: route.RemoteAddress, route: route, - }) + } + ep.ops.InitHandler(ep) + f.acceptQueue = append(f.acceptQueue, ep) } -func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) { +func (f *fakeTransportEndpoint) HandleControlPacket(stack.ControlType, uint32, *stack.PacketBuffer) { // Increment the number of received control packets. f.proto.controlCount++ } @@ -553,87 +552,3 @@ func TestTransportOptions(t *testing.T) { t.Fatalf("got tcpip.TCPModerateReceiveBufferOption = false, want = true") } } - -func TestTransportForwarding(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, - }) - s.SetForwarding(fakeNetNumber, true) - - // TODO(b/123449044): Change this to a channel NIC. - ep1 := loopback.New() - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatalf("CreateNIC #1 failed: %v", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress #1 failed: %v", err) - } - - ep2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatalf("CreateNIC #2 failed: %v", err) - } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress #2 failed: %v", err) - } - - // Route all packets to address 3 to NIC 2 and all packets to address - // 1 to NIC 1. - { - subnet0, err := tcpip.NewSubnet("\x03", "\xff") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet0, Gateway: "\x00", NIC: 2}, - {Destination: subnet1, Gateway: "\x00", NIC: 1}, - }) - } - - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{Addr: "\x01", NIC: 1}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Send a packet to address 1 from address 3. - req := buffer.NewView(30) - req[0] = 1 - req[1] = 3 - req[2] = byte(fakeTransNumber) - ep2.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: req.ToVectorisedView(), - })) - - aep, _, err := ep.Accept(nil) - if err != nil || aep == nil { - t.Fatalf("Accept failed: %v, %v", aep, err) - } - - resp := buffer.NewView(30) - if _, _, err := aep.Write(tcpip.SlicePayload(resp), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - p, ok := ep2.Read() - if !ok { - t.Fatal("Response packet not forwarded") - } - - nh := stack.PayloadSince(p.Pkt.NetworkHeader()) - if dst := nh[0]; dst != 3 { - t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) - } - if src := nh[1]; src != 1 { - t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) - } -} |