summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/BUILD7
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go179
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state_test.go22
-rw-r--r--pkg/tcpip/stack/conntrack.go14
-rw-r--r--pkg/tcpip/stack/forwarding_test.go90
-rw-r--r--pkg/tcpip/stack/iptables.go66
-rw-r--r--pkg/tcpip/stack/iptables_types.go10
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go135
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go110
-rw-r--r--pkg/tcpip/stack/ndp_test.go255
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go107
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go498
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go118
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go347
-rw-r--r--pkg/tcpip/stack/nic.go283
-rw-r--r--pkg/tcpip/stack/nud.go21
-rw-r--r--pkg/tcpip/stack/nud_test.go53
-rw-r--r--pkg/tcpip/stack/pending_packets.go2
-rw-r--r--pkg/tcpip/stack/registration.go68
-rw-r--r--pkg/tcpip/stack/route.go310
-rw-r--r--pkg/tcpip/stack/stack.go178
-rw-r--r--pkg/tcpip/stack/stack_test.go489
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go5
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go22
-rw-r--r--pkg/tcpip/stack/transport_test.go137
25 files changed, 1647 insertions, 1879 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index d09ebe7fa..bb30556cf 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test", "most_shards")
load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
@@ -112,7 +112,7 @@ go_test(
"transport_demuxer_test.go",
"transport_test.go",
],
- shard_count = 20,
+ shard_count = most_shards,
deps = [
":stack",
"//pkg/rand",
@@ -120,6 +120,7 @@ go_test(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
@@ -131,7 +132,6 @@ go_test(
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
- "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
)
@@ -148,7 +148,6 @@ go_test(
],
library = ":stack",
deps = [
- "//pkg/sleep",
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index 9478f3fb7..cd423bf71 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
-var _ GroupAddressableEndpoint = (*AddressableEndpointState)(nil)
var _ AddressableEndpoint = (*AddressableEndpointState)(nil)
// AddressableEndpointState is an implementation of an AddressableEndpoint.
@@ -37,10 +36,6 @@ type AddressableEndpointState struct {
endpoints map[tcpip.Address]*addressState
primary []*addressState
-
- // groups holds the mapping between group addresses and the number of times
- // they have been joined.
- groups map[tcpip.Address]uint32
}
}
@@ -53,65 +48,33 @@ func (a *AddressableEndpointState) Init(networkEndpoint NetworkEndpoint) {
a.mu.Lock()
defer a.mu.Unlock()
a.mu.endpoints = make(map[tcpip.Address]*addressState)
- a.mu.groups = make(map[tcpip.Address]uint32)
-}
-
-// ReadOnlyAddressableEndpointState provides read-only access to an
-// AddressableEndpointState.
-type ReadOnlyAddressableEndpointState struct {
- inner *AddressableEndpointState
}
-// AddrOrMatching returns an endpoint for the passed address that is consisdered
-// bound to the wrapped AddressableEndpointState.
+// GetAddress returns the AddressEndpoint for the passed address.
//
-// If addr is an exact match with an existing address, that address is returned.
-// Otherwise, f is called with each address and the address that f returns true
-// for is returned.
-//
-// Returns nil of no address matches.
-func (m ReadOnlyAddressableEndpointState) AddrOrMatching(addr tcpip.Address, spoofingOrPrimiscuous bool, f func(AddressEndpoint) bool) AddressEndpoint {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
-
- if ep, ok := m.inner.mu.endpoints[addr]; ok {
- if ep.IsAssigned(spoofingOrPrimiscuous) && ep.IncRef() {
- return ep
- }
- }
-
- for _, ep := range m.inner.mu.endpoints {
- if ep.IsAssigned(spoofingOrPrimiscuous) && f(ep) && ep.IncRef() {
- return ep
- }
- }
-
- return nil
-}
-
-// Lookup returns the AddressEndpoint for the passed address.
+// GetAddress does not increment the address's reference count or check if the
+// address is considered bound to the endpoint.
//
-// Returns nil if the passed address is not associated with the
-// AddressableEndpointState.
-func (m ReadOnlyAddressableEndpointState) Lookup(addr tcpip.Address) AddressEndpoint {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
+// Returns nil if the passed address is not associated with the endpoint.
+func (a *AddressableEndpointState) GetAddress(addr tcpip.Address) AddressEndpoint {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
- ep, ok := m.inner.mu.endpoints[addr]
+ ep, ok := a.mu.endpoints[addr]
if !ok {
return nil
}
return ep
}
-// ForEach calls f for each address pair.
+// ForEachEndpoint calls f for each address.
//
-// If f returns false, f is no longer be called.
-func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool) {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
+// Once f returns false, f will no longer be called.
+func (a *AddressableEndpointState) ForEachEndpoint(f func(AddressEndpoint) bool) {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
- for _, ep := range m.inner.mu.endpoints {
+ for _, ep := range a.mu.endpoints {
if !f(ep) {
return
}
@@ -120,18 +83,16 @@ func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool)
// ForEachPrimaryEndpoint calls f for each primary address.
//
-// If f returns false, f is no longer be called.
-func (m ReadOnlyAddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
- for _, ep := range m.inner.mu.primary {
- f(ep)
- }
-}
+// Once f returns false, f will no longer be called.
+func (a *AddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint) bool) {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
-// ReadOnly returns a readonly reference to a.
-func (a *AddressableEndpointState) ReadOnly() ReadOnlyAddressableEndpointState {
- return ReadOnlyAddressableEndpointState{inner: a}
+ for _, ep := range a.mu.primary {
+ if !f(ep) {
+ return
+ }
+ }
}
func (a *AddressableEndpointState) releaseAddressState(addrState *addressState) {
@@ -335,11 +296,6 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
a.mu.Lock()
defer a.mu.Unlock()
-
- if _, ok := a.mu.groups[addr]; ok {
- panic(fmt.Sprintf("group address = %s must be removed with LeaveGroup", addr))
- }
-
return a.removePermanentAddressLocked(addr)
}
@@ -471,8 +427,19 @@ func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*ad
return deprecatedEndpoint
}
-// AcquireAssignedAddress implements AddressableEndpoint.
-func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
+// AcquireAssignedAddressOrMatching returns an address endpoint that is
+// considered assigned to the addressable endpoint.
+//
+// If the address is an exact match with an existing address, that address is
+// returned. Otherwise, if f is provided, f is called with each address and
+// the address that f returns true for is returned.
+//
+// If there is no matching address, a temporary address will be returned if
+// allowTemp is true.
+//
+// Regardless how the address was obtained, it will be acquired before it is
+// returned.
+func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tcpip.Address, f func(AddressEndpoint) bool, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
a.mu.Lock()
defer a.mu.Unlock()
@@ -488,6 +455,14 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres
return addrState
}
+ if f != nil {
+ for _, addrState := range a.mu.endpoints {
+ if addrState.IsAssigned(allowTemp) && f(addrState) && addrState.IncRef() {
+ return addrState
+ }
+ }
+ }
+
if !allowTemp {
return nil
}
@@ -520,6 +495,11 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres
return ep
}
+// AcquireAssignedAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
+ return a.AcquireAssignedAddressOrMatching(localAddr, nil, allowTemp, tempPEB)
+}
+
// AcquireOutgoingPrimaryAddress implements AddressableEndpoint.
func (a *AddressableEndpointState) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint {
a.mu.RLock()
@@ -588,72 +568,11 @@ func (a *AddressableEndpointState) PermanentAddresses() []tcpip.AddressWithPrefi
return addrs
}
-// JoinGroup implements GroupAddressableEndpoint.
-func (a *AddressableEndpointState) JoinGroup(group tcpip.Address) (bool, *tcpip.Error) {
- a.mu.Lock()
- defer a.mu.Unlock()
-
- joins, ok := a.mu.groups[group]
- if !ok {
- ep, err := a.addAndAcquireAddressLocked(group.WithPrefix(), NeverPrimaryEndpoint, AddressConfigStatic, false /* deprecated */, true /* permanent */)
- if err != nil {
- return false, err
- }
- // We have no need for the address endpoint.
- a.decAddressRefLocked(ep)
- }
-
- a.mu.groups[group] = joins + 1
- return !ok, nil
-}
-
-// LeaveGroup implements GroupAddressableEndpoint.
-func (a *AddressableEndpointState) LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) {
- a.mu.Lock()
- defer a.mu.Unlock()
-
- joins, ok := a.mu.groups[group]
- if !ok {
- return false, tcpip.ErrBadLocalAddress
- }
-
- if joins == 1 {
- a.removeGroupAddressLocked(group)
- delete(a.mu.groups, group)
- return true, nil
- }
-
- a.mu.groups[group] = joins - 1
- return false, nil
-}
-
-// IsInGroup implements GroupAddressableEndpoint.
-func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool {
- a.mu.RLock()
- defer a.mu.RUnlock()
- _, ok := a.mu.groups[group]
- return ok
-}
-
-func (a *AddressableEndpointState) removeGroupAddressLocked(group tcpip.Address) {
- if err := a.removePermanentAddressLocked(group); err != nil {
- // removePermanentEndpointLocked would only return an error if group is
- // not bound to the addressable endpoint, but we know it MUST be assigned
- // since we have group in our map of groups.
- panic(fmt.Sprintf("error removing group address = %s: %s", group, err))
- }
-}
-
// Cleanup forcefully leaves all groups and removes all permanent addresses.
func (a *AddressableEndpointState) Cleanup() {
a.mu.Lock()
defer a.mu.Unlock()
- for group := range a.mu.groups {
- a.removeGroupAddressLocked(group)
- }
- a.mu.groups = make(map[tcpip.Address]uint32)
-
for _, ep := range a.mu.endpoints {
// removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is
// not a permanent address.
diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go
index 26787d0a3..140f146f6 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state_test.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go
@@ -53,25 +53,9 @@ func TestAddressableEndpointStateCleanup(t *testing.T) {
ep.DecRef()
}
- group := tcpip.Address("\x02")
- if added, err := s.JoinGroup(group); err != nil {
- t.Fatalf("s.JoinGroup(%s): %s", group, err)
- } else if !added {
- t.Fatalf("got s.JoinGroup(%s) = false, want = true", group)
- }
- if !s.IsInGroup(group) {
- t.Fatalf("got s.IsInGroup(%s) = false, want = true", group)
- }
-
s.Cleanup()
- {
- ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint)
- if ep != nil {
- ep.DecRef()
- t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix())
- }
- }
- if s.IsInGroup(group) {
- t.Fatalf("got s.IsInGroup(%s) = true, want = false", group)
+ if ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint); ep != nil {
+ ep.DecRef()
+ t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix())
}
}
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 9a17efcba..5e649cca6 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -142,19 +142,19 @@ func (cn *conn) timedOut(now time.Time) bool {
// update the connection tracking state.
//
-// Precondition: ct.mu must be held.
-func (ct *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
+// Precondition: cn.mu must be held.
+func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
// Update the state of tcb. tcb assumes it's always initialized on the
// client. However, we only need to know whether the connection is
// established or not, so the client/server distinction isn't important.
// TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle
// other tcp states.
- if ct.tcb.IsEmpty() {
- ct.tcb.Init(tcpHeader)
- } else if hook == ct.tcbHook {
- ct.tcb.UpdateStateOutbound(tcpHeader)
+ if cn.tcb.IsEmpty() {
+ cn.tcb.Init(tcpHeader)
+ } else if hook == cn.tcbHook {
+ cn.tcb.UpdateStateOutbound(tcpHeader)
} else {
- ct.tcb.UpdateStateInbound(tcpHeader)
+ cn.tcb.UpdateStateInbound(tcpHeader)
}
}
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index 7a501acdc..93e8e1c51 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -74,8 +74,30 @@ func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
}
func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) {
- // Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
+ netHdr := pkt.NetworkHeader().View()
+ _, dst := f.proto.ParseAddresses(netHdr)
+
+ addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), CanBePrimaryEndpoint)
+ if addressEndpoint != nil {
+ addressEndpoint.DecRef()
+ // Dispatch the packet to the transport protocol.
+ f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(netHdr[protocolNumberOffset]), pkt)
+ return
+ }
+
+ r, err := f.proto.stack.FindRoute(0, "", dst, fwdTestNetNumber, false /* multicastLoop */)
+ if err != nil {
+ return
+ }
+ defer r.Release()
+
+ vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
+ pkt = NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: vv.ToView().ToVectorisedView(),
+ })
+ // TODO(b/143425874) Decrease the TTL field in forwarded packets.
+ _ = r.WriteHeaderIncludedPacket(pkt)
}
func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -106,8 +128,13 @@ func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuf
panic("not implemented")
}
-func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error {
+ // The network header should not already be populated.
+ if _, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen); !ok {
+ return tcpip.ErrMalformedHeader
+ }
+
+ return f.nic.WritePacket(r, nil /* gso */, fwdTestNetNumber, pkt)
}
func (f *fwdTestNetworkEndpoint) Close() {
@@ -117,6 +144,8 @@ func (f *fwdTestNetworkEndpoint) Close() {
// fwdTestNetworkProtocol is a network-layer protocol that implements Address
// resolution.
type fwdTestNetworkProtocol struct {
+ stack *Stack
+
addrCache *linkAddrCache
neigh *neighborCache
addrResolveDelay time.Duration
@@ -280,7 +309,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress {
func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
p := fwdTestPacketInfo{
- RemoteLinkAddress: r.RemoteLinkAddress,
+ RemoteLinkAddress: r.RemoteLinkAddress(),
LocalLinkAddress: r.LocalLinkAddress,
Pkt: pkt,
}
@@ -304,20 +333,6 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer
return n, nil
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- p := fwdTestPacketInfo{
- Pkt: NewPacketBuffer(PacketBufferOptions{Data: vv}),
- }
-
- select {
- case e.C <- p:
- default:
- }
-
- return nil
-}
-
// Wait implements stack.LinkEndpoint.Wait.
func (*fwdTestLinkEndpoint) Wait() {}
@@ -334,7 +349,10 @@ func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protoco
func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) {
// Create a stack with the network protocol and two NICs.
s := New(Options{
- NetworkProtocols: []NetworkProtocolFactory{func(*Stack) NetworkProtocol { return proto }},
+ NetworkProtocols: []NetworkProtocolFactory{func(s *Stack) NetworkProtocol {
+ proto.stack = s
+ return proto
+ }},
UseNeighborCache: useNeighborCache,
})
@@ -542,6 +560,38 @@ func TestForwardingWithNoResolver(t *testing.T) {
}
}
+func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) {
+ proto := &fwdTestNetworkProtocol{
+ addrResolveDelay: 50 * time.Millisecond,
+ onLinkAddressResolved: func(*linkAddrCache, *neighborCache, tcpip.Address, tcpip.LinkAddress) {
+ // Don't resolve the link address.
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto, true /* useNeighborCache */)
+
+ const numPackets int = 5
+ // These packets will all be enqueued in the packet queue to wait for link
+ // address resolution.
+ for i := 0; i < numPackets; i++ {
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+ }
+
+ // All packets should fail resolution.
+ // TODO(gvisor.dev/issue/5141): Use a fake clock.
+ for i := 0; i < numPackets; i++ {
+ select {
+ case got := <-ep2.C:
+ t.Fatalf("got %#v; packets should have failed resolution and not been forwarded", got)
+ case <-time.After(100 * time.Millisecond):
+ }
+ }
+}
+
func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
tests := []struct {
name string
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 2d8c883cd..09c7811fa 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -45,13 +45,13 @@ const reaperDelay = 5 * time.Second
func DefaultTables() *IPTables {
return &IPTables{
v4Tables: [NumTables]Table{
- NATID: Table{
+ NATID: {
Rules: []Rule{
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
- Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: 0,
@@ -68,11 +68,11 @@ func DefaultTables() *IPTables {
Postrouting: 3,
},
},
- MangleID: Table{
+ MangleID: {
Rules: []Rule{
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
- Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: 0,
@@ -86,12 +86,12 @@ func DefaultTables() *IPTables {
Postrouting: HookUnset,
},
},
- FilterID: Table{
+ FilterID: {
Rules: []Rule{
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
- Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: HookUnset,
@@ -110,13 +110,13 @@ func DefaultTables() *IPTables {
},
},
v6Tables: [NumTables]Table{
- NATID: Table{
+ NATID: {
Rules: []Rule{
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
- Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ {Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: 0,
@@ -133,11 +133,11 @@ func DefaultTables() *IPTables {
Postrouting: 3,
},
},
- MangleID: Table{
+ MangleID: {
Rules: []Rule{
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
- Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ {Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: 0,
@@ -151,12 +151,12 @@ func DefaultTables() *IPTables {
Postrouting: HookUnset,
},
},
- FilterID: Table{
+ FilterID: {
Rules: []Rule{
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
- Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
- Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
+ {Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
},
BuiltinChains: [NumHooks]int{
Prerouting: HookUnset,
@@ -175,9 +175,9 @@ func DefaultTables() *IPTables {
},
},
priorities: [NumHooks][]TableID{
- Prerouting: []TableID{MangleID, NATID},
- Input: []TableID{NATID, FilterID},
- Output: []TableID{MangleID, NATID, FilterID},
+ Prerouting: {MangleID, NATID},
+ Input: {NATID, FilterID},
+ Output: {MangleID, NATID, FilterID},
},
connections: ConnTrack{
seed: generateRandUint32(),
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 4b86c1be9..56a3e7861 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -56,7 +56,7 @@ const (
// Postrouting happens just before a packet goes out on the wire.
Postrouting
- // The total number of hooks.
+ // NumHooks is the total number of hooks.
NumHooks
)
@@ -273,14 +273,12 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) boo
return true
}
- // If the interface name ends with '+', any interface which begins
- // with the name should be matched.
+ // If the interface name ends with '+', any interface which
+ // begins with the name should be matched.
ifName := fl.OutputInterface
- matches := true
+ matches := nicName == ifName
if strings.HasSuffix(ifName, "+") {
matches = strings.HasPrefix(nicName, ifName[:n-1])
- } else {
- matches = nicName == ifName
}
return fl.OutputInterfaceInvert != matches
}
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index c9b13cd0e..792f4f170 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -18,7 +18,6 @@ import (
"fmt"
"time"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -58,9 +57,6 @@ const (
incomplete entryState = iota
// ready means that the address has been resolved and can be used.
ready
- // failed means that address resolution timed out and the address
- // could not be resolved.
- failed
)
// String implements Stringer.
@@ -70,8 +66,6 @@ func (s entryState) String() string {
return "incomplete"
case ready:
return "ready"
- case failed:
- return "failed"
default:
return fmt.Sprintf("unknown(%d)", s)
}
@@ -80,40 +74,48 @@ func (s entryState) String() string {
// A linkAddrEntry is an entry in the linkAddrCache.
// This struct is thread-compatible.
type linkAddrEntry struct {
+ // linkAddrEntryEntry access is synchronized by the linkAddrCache lock.
linkAddrEntryEntry
+ // TODO(gvisor.dev/issue/5150): move these fields under mu.
+ // mu protects the fields below.
+ mu sync.RWMutex
+
addr tcpip.FullAddress
linkAddr tcpip.LinkAddress
expiration time.Time
s entryState
- // wakers is a set of waiters for address resolution result. Anytime
- // state transitions out of incomplete these waiters are notified.
- wakers map[*sleep.Waker]struct{}
-
- // done is used to allow callers to wait on address resolution. It is nil iff
- // s is incomplete and resolution is not yet in progress.
+ // done is closed when address resolution is complete. It is nil iff s is
+ // incomplete and resolution is not yet in progress.
done chan struct{}
+
+ // onResolve is called with the result of address resolution.
+ onResolve []func(tcpip.LinkAddress, bool)
}
-// changeState sets the entry's state to ns, notifying any waiters.
+func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) {
+ for _, callback := range e.onResolve {
+ callback(linkAddr, len(linkAddr) != 0)
+ }
+ e.onResolve = nil
+ if ch := e.done; ch != nil {
+ close(ch)
+ e.done = nil
+ }
+}
+
+// changeStateLocked sets the entry's state to ns.
//
// The entry's expiration is bumped up to the greater of itself and the passed
// expiration; the zero value indicates immediate expiration, and is set
// unconditionally - this is an implementation detail that allows for entries
// to be reused.
-func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) {
- // Notify whoever is waiting on address resolution when transitioning
- // out of incomplete.
- if e.s == incomplete && ns != incomplete {
- for w := range e.wakers {
- w.Assert()
- }
- e.wakers = nil
- if ch := e.done; ch != nil {
- close(ch)
- }
- e.done = nil
+//
+// Precondition: e.mu must be locked
+func (e *linkAddrEntry) changeStateLocked(ns entryState, expiration time.Time) {
+ if e.s == incomplete && ns == ready {
+ e.notifyCompletionLocked(e.linkAddr)
}
if expiration.IsZero() || expiration.After(e.expiration) {
@@ -122,10 +124,6 @@ func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) {
e.s = ns
}
-func (e *linkAddrEntry) removeWaker(w *sleep.Waker) {
- delete(e.wakers, w)
-}
-
// add adds a k -> v mapping to the cache.
func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
// Calculate expiration time before acquiring the lock, since expiration is
@@ -135,10 +133,12 @@ func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
c.cache.Lock()
entry := c.getOrCreateEntryLocked(k)
- entry.linkAddr = v
-
- entry.changeState(ready, expiration)
c.cache.Unlock()
+
+ entry.mu.Lock()
+ defer entry.mu.Unlock()
+ entry.linkAddr = v
+ entry.changeStateLocked(ready, expiration)
}
// getOrCreateEntryLocked retrieves a cache entry associated with k. The
@@ -159,13 +159,14 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt
var entry *linkAddrEntry
if len(c.cache.table) == linkAddrCacheSize {
entry = c.cache.lru.Back()
+ entry.mu.Lock()
delete(c.cache.table, entry.addr)
c.cache.lru.Remove(entry)
- // Wake waiters and mark the soon-to-be-reused entry as expired. Note
- // that the state passed doesn't matter when the zero time is passed.
- entry.changeState(failed, time.Time{})
+ // Wake waiters and mark the soon-to-be-reused entry as expired.
+ entry.notifyCompletionLocked("" /* linkAddr */)
+ entry.mu.Unlock()
} else {
entry = new(linkAddrEntry)
}
@@ -180,9 +181,12 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt
}
// get reports any known link address for k.
-func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
if linkRes != nil {
if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok {
+ if onResolve != nil {
+ onResolve(addr, true)
+ }
return addr, nil, nil
}
}
@@ -190,56 +194,35 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo
c.cache.Lock()
defer c.cache.Unlock()
entry := c.getOrCreateEntryLocked(k)
+ entry.mu.Lock()
+ defer entry.mu.Unlock()
+
switch s := entry.s; s {
- case ready, failed:
+ case ready:
if !time.Now().After(entry.expiration) {
// Not expired.
- switch s {
- case ready:
- return entry.linkAddr, nil, nil
- case failed:
- return entry.linkAddr, nil, tcpip.ErrNoLinkAddress
- default:
- panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ if onResolve != nil {
+ onResolve(entry.linkAddr, true)
}
+ return entry.linkAddr, nil, nil
}
- entry.changeState(incomplete, time.Time{})
+ entry.changeStateLocked(incomplete, time.Time{})
fallthrough
case incomplete:
- if waker != nil {
- if entry.wakers == nil {
- entry.wakers = make(map[*sleep.Waker]struct{})
- }
- entry.wakers[waker] = struct{}{}
+ if onResolve != nil {
+ entry.onResolve = append(entry.onResolve, onResolve)
}
-
if entry.done == nil {
- // Address resolution needs to be initiated.
- if linkRes == nil {
- return entry.linkAddr, nil, tcpip.ErrNoLinkAddress
- }
-
entry.done = make(chan struct{})
go c.startAddressResolution(k, linkRes, localAddr, nic, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
}
-
return entry.linkAddr, entry.done, tcpip.ErrWouldBlock
default:
panic(fmt.Sprintf("invalid cache entry state: %s", s))
}
}
-// removeWaker removes a waker previously added through get().
-func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) {
- c.cache.Lock()
- defer c.cache.Unlock()
-
- if entry, ok := c.cache.table[k]; ok {
- entry.removeWaker(waker)
- }
-}
-
func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) {
for i := 0; ; i++ {
// Send link request, then wait for the timeout limit and check
@@ -257,9 +240,9 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
}
}
-// checkLinkRequest checks whether previous attempt to resolve address has succeeded
-// and mark the entry accordingly, e.g. ready, failed, etc. Return true if request
-// can stop, false if another request should be sent.
+// checkLinkRequest checks whether previous attempt to resolve address has
+// succeeded and mark the entry accordingly. Returns true if request can stop,
+// false if another request should be sent.
func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool {
c.cache.Lock()
defer c.cache.Unlock()
@@ -268,16 +251,20 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, att
// Entry was evicted from the cache.
return true
}
+ entry.mu.Lock()
+ defer entry.mu.Unlock()
+
switch s := entry.s; s {
- case ready, failed:
- // Entry was made ready by resolver or failed. Either way we're done.
+ case ready:
+ // Entry was made ready by resolver.
case incomplete:
if attempt+1 < c.resolutionAttempts {
// No response yet, need to send another ARP request.
return false
}
- // Max number of retries reached, mark entry as failed.
- entry.changeState(failed, now.Add(c.ageLimit))
+ // Max number of retries reached, delete entry.
+ entry.notifyCompletionLocked("" /* linkAddr */)
+ delete(c.cache.table, k)
default:
panic(fmt.Sprintf("invalid cache entry state: %s", s))
}
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index d2e37f38d..6883045b5 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -21,7 +21,6 @@ import (
"testing"
"time"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -50,6 +49,7 @@ type testLinkAddressResolver struct {
}
func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
+ // TODO(gvisor.dev/issue/5141): Use a fake clock.
time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) })
if f := r.onLinkAddressRequest; f != nil {
f()
@@ -78,16 +78,18 @@ func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe
}
func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) {
- w := sleep.Waker{}
- s := sleep.Sleeper{}
- s.AddWaker(&w, 123)
- defer s.Done()
-
+ var attemptedResolution bool
for {
- if got, _, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
- return got, err
+ got, ch, err := c.get(addr, linkRes, "", nil, nil)
+ if err == tcpip.ErrWouldBlock {
+ if attemptedResolution {
+ return got, tcpip.ErrNoLinkAddress
+ }
+ attemptedResolution = true
+ <-ch
+ continue
}
- s.Fetch(true)
+ return got, err
}
}
@@ -116,16 +118,19 @@ func TestCacheOverflow(t *testing.T) {
}
}
// The earliest entries should no longer be in the cache.
+ c.cache.Lock()
+ defer c.cache.Unlock()
for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- {
e := testAddrs[i]
- if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
- t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err)
+ if entry, ok := c.cache.table[e.addr]; ok {
+ t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry)
}
}
}
func TestCacheConcurrent(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+ linkRes := &testLinkAddressResolver{cache: c}
var wg sync.WaitGroup
for r := 0; r < 16; r++ {
@@ -133,7 +138,6 @@ func TestCacheConcurrent(t *testing.T) {
go func() {
for _, e := range testAddrs {
c.add(e.addr, e.linkAddr)
- c.get(e.addr, nil, "", nil, nil) // make work for gotsan
}
wg.Done()
}()
@@ -144,7 +148,7 @@ func TestCacheConcurrent(t *testing.T) {
// can fit in the cache, so our eviction strategy requires that
// the last entry be present and the first be missing.
e := testAddrs[len(testAddrs)-1]
- got, _, err := c.get(e.addr, nil, "", nil, nil)
+ got, _, err := c.get(e.addr, linkRes, "", nil, nil)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
}
@@ -153,18 +157,22 @@ func TestCacheConcurrent(t *testing.T) {
}
e = testAddrs[0]
- if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
- t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
+ c.cache.Lock()
+ defer c.cache.Unlock()
+ if entry, ok := c.cache.table[e.addr]; ok {
+ t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry)
}
}
func TestCacheAgeLimit(t *testing.T) {
c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3)
+ linkRes := &testLinkAddressResolver{cache: c}
+
e := testAddrs[0]
c.add(e.addr, e.linkAddr)
time.Sleep(50 * time.Millisecond)
- if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
- t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
+ if _, _, err := c.get(e.addr, linkRes, "", nil, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.get(%q) = %s, want = ErrWouldBlock", string(e.addr.Addr), err)
}
}
@@ -282,71 +290,3 @@ func TestStaticResolution(t *testing.T) {
t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want))
}
}
-
-// TestCacheWaker verifies that RemoveWaker removes a waker previously added
-// through get().
-func TestCacheWaker(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
-
- // First, sanity check that wakers are working.
- {
- linkRes := &testLinkAddressResolver{cache: c}
- s := sleep.Sleeper{}
- defer s.Done()
-
- const wakerID = 1
- w := sleep.Waker{}
- s.AddWaker(&w, wakerID)
-
- e := testAddrs[0]
-
- if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock)
- }
- id, ok := s.Fetch(true /* block */)
- if !ok {
- t.Fatal("got s.Fetch(true) = (_, false), want = (_, true)")
- }
- if id != wakerID {
- t.Fatalf("got s.Fetch(true) = (%d, %t), want = (%d, true)", id, ok, wakerID)
- }
-
- if got, _, err := c.get(e.addr, linkRes, "", nil, nil); err != nil {
- t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err)
- } else if got != e.linkAddr {
- t.Fatalf("got c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr)
- }
- }
-
- // Check that RemoveWaker works.
- {
- linkRes := &testLinkAddressResolver{cache: c}
- s := sleep.Sleeper{}
- defer s.Done()
-
- const wakerID = 2 // different than the ID used in the sanity check
- w := sleep.Waker{}
- s.AddWaker(&w, wakerID)
-
- e := testAddrs[1]
- linkRes.onLinkAddressRequest = func() {
- // Remove the waker before the linkAddrCache has the opportunity to send
- // a notification.
- c.removeWaker(e.addr, &w)
- }
-
- if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock)
- }
-
- if got, err := getBlocking(c, e.addr, linkRes); err != nil {
- t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err)
- } else if got != e.linkAddr {
- t.Fatalf("c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr)
- }
-
- if id, ok := s.Fetch(false /* block */); ok {
- t.Fatalf("unexpected notification from waker with id %d", id)
- }
- }
-}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 73a01c2dd..61636cae5 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
@@ -352,7 +353,7 @@ func TestDADDisabled(t *testing.T) {
}
// We should not have sent any NDP NS messages.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 {
+ if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != 0 {
t.Fatalf("got NeighborSolicit = %d, want = 0", got)
}
}
@@ -465,14 +466,18 @@ func TestDADResolve(t *testing.T) {
if err != tcpip.ErrNoRoute {
t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
}
- r.Release()
+ if r != nil {
+ r.Release()
+ }
}
{
r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false)
if err != tcpip.ErrNoRoute {
t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
}
- r.Release()
+ if r != nil {
+ r.Release()
+ }
}
if t.Failed() {
@@ -510,7 +515,9 @@ func TestDADResolve(t *testing.T) {
} else if r.LocalAddress != addr1 {
t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1)
}
- r.Release()
+ if r != nil {
+ r.Release()
+ }
}
if t.Failed() {
@@ -518,7 +525,7 @@ func TestDADResolve(t *testing.T) {
}
// Should not have sent any more NS messages.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) {
+ if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) {
t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits)
}
@@ -563,18 +570,18 @@ func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(tgt)
snmc := header.SolicitedNodeAddr(tgt)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{}))
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: 255,
- SrcAddr: header.IPv6Any,
- DstAddr: snmc,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: 255,
+ SrcAddr: header.IPv6Any,
+ DstAddr: snmc,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()}))
}
@@ -605,7 +612,7 @@ func TestDADFail(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
pkt := header.ICMPv6(hdr.Prepend(naSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na := header.NDPNeighborAdvert(pkt.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(true)
na.SetTargetAddress(tgt)
@@ -616,11 +623,11 @@ func TestDADFail(t *testing.T) {
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: 255,
- SrcAddr: tgt,
- DstAddr: header.IPv6AllNodesMulticastAddress,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: 255,
+ SrcAddr: tgt,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
})
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()}))
},
@@ -666,7 +673,7 @@ func TestDADFail(t *testing.T) {
// Receive a packet to simulate an address conflict.
test.rxPkt(e, addr1)
- stat := test.getStat(s.Stats().ICMP.V6PacketsReceived)
+ stat := test.getStat(s.Stats().ICMP.V6.PacketsReceived)
if got := stat.Value(); got != 1 {
t.Fatalf("got stat = %d, want = 1", got)
}
@@ -803,7 +810,7 @@ func TestDADStop(t *testing.T) {
}
// Should not have sent more than 1 NS message.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 {
+ if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got > 1 {
t.Errorf("got NeighborSolicit = %d, want <= 1", got)
}
})
@@ -982,7 +989,7 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
pkt.SetType(header.ICMPv6RouterAdvert)
pkt.SetCode(0)
- raPayload := pkt.NDPPayload()
+ raPayload := pkt.MessageBody()
ra := header.NDPRouterAdvert(raPayload)
// Populate the Router Lifetime.
binary.BigEndian.PutUint16(raPayload[2:], rl)
@@ -1004,11 +1011,11 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
payloadLength := hdr.UsedLength()
iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
iph.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: header.NDPHopLimit,
- SrcAddr: ip,
- DstAddr: header.IPv6AllNodesMulticastAddress,
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: ip,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
})
return stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -2162,8 +2169,8 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
NDPConfigs: ipv6.NDPConfigurations{
AutoGenTempGlobalAddresses: true,
},
- NDPDisp: &ndpDisp,
- AutoGenIPv6LinkLocal: true,
+ NDPDisp: &ndpDisp,
+ AutoGenLinkLocal: true,
})},
})
@@ -2843,9 +2850,7 @@ func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress
t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
if err := ep.Connect(addr); err != nil {
t.Fatalf("ep.Connect(%+v): %s", addr, err)
}
@@ -2879,9 +2884,7 @@ func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullA
t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
if err := ep.Bind(addr); err != nil {
t.Fatalf("ep.Bind(%+v): %s", addr, err)
}
@@ -3250,9 +3253,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute {
t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, tcpip.ErrNoRoute)
@@ -4044,9 +4045,9 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
ndpConfigs.AutoGenAddressConflictRetries = maxRetries
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
+ AutoGenLinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: func(_ tcpip.NICID, nicName string) string {
return nicName
@@ -4179,9 +4180,9 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
- NDPConfigs: addrType.ndpConfigs,
- NDPDisp: &ndpDisp,
+ AutoGenLinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: addrType.ndpConfigs,
+ NDPDisp: &ndpDisp,
})},
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -4708,7 +4709,7 @@ func TestCleanupNDPState(t *testing.T) {
}
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: true,
+ AutoGenLinkLocal: true,
NDPConfigs: ipv6.NDPConfigurations{
HandleRAs: true,
DiscoverDefaultRouters: true,
@@ -5174,113 +5175,99 @@ func TestRouterSolicitation(t *testing.T) {
},
}
- // This Run will not return until the parallel tests finish.
- //
- // We need this because we need to do some teardown work after the
- // parallel tests complete.
- //
- // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
- // more details.
- t.Run("group", func(t *testing.T) {
- for _, test := range tests {
- test := test
-
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ e := channelLinkWithHeaderLength{
+ Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
+ headerLength: test.linkHeaderLen,
+ }
+ e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ waitForPkt := func(timeout time.Duration) {
+ t.Helper()
- e := channelLinkWithHeaderLength{
- Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
- headerLength: test.linkHeaderLen,
+ clock.Advance(timeout)
+ p, ok := e.Read()
+ if !ok {
+ t.Fatal("expected router solicitation packet")
}
- e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- waitForPkt := func(timeout time.Duration) {
- t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- p, ok := e.ReadContext(ctx)
- if !ok {
- t.Fatal("timed out waiting for packet")
- return
- }
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
- // Make sure the right remote link address is used.
- if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want {
- t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
- }
+ // Make sure the right remote link address is used.
+ if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ }
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(test.expectedSrcAddr),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
- )
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(test.expectedSrcAddr),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
+ )
- if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
- t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
- }
- }
- waitForNothing := func(timeout time.Duration) {
- t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- if _, ok := e.ReadContext(ctx); ok {
- t.Fatal("unexpectedly got a packet")
- }
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- MaxRtrSolicitations: test.maxRtrSolicit,
- RtrSolicitationInterval: test.rtrSolicitInt,
- MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
- },
- })},
- })
- if err := s.CreateNIC(nicID, &e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
}
+ }
+ waitForNothing := func(timeout time.Duration) {
+ t.Helper()
- if addr := test.nicAddr; addr != "" {
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
- }
+ clock.Advance(timeout)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("unexpectedly got a packet = %#v", p)
}
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ MaxRtrSolicitations: test.maxRtrSolicit,
+ RtrSolicitationInterval: test.rtrSolicitInt,
+ MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
+ },
+ })},
+ Clock: clock,
+ })
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
- // Make sure each RS is sent at the right time.
- remaining := test.maxRtrSolicit
- if remaining > 0 {
- waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncPositiveEventTimeout)
- remaining--
+ if addr := test.nicAddr; addr != "" {
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
}
+ }
- for ; remaining > 0; remaining-- {
- if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
- waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncNegativeEventTimeout)
- waitForPkt(defaultAsyncPositiveEventTimeout)
- } else {
- waitForPkt(test.effectiveRtrSolicitInt + defaultAsyncPositiveEventTimeout)
- }
- }
+ // Make sure each RS is sent at the right time.
+ remaining := test.maxRtrSolicit
+ if remaining > 0 {
+ waitForPkt(test.effectiveMaxRtrSolicitDelay)
+ remaining--
+ }
- // Make sure no more RS.
- if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
- waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncNegativeEventTimeout)
+ for ; remaining > 0; remaining-- {
+ if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
+ waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond)
+ waitForPkt(time.Nanosecond)
} else {
- waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncNegativeEventTimeout)
+ waitForPkt(test.effectiveRtrSolicitInt)
}
+ }
- // Make sure the counter got properly
- // incremented.
- if got, want := s.Stats().ICMP.V6PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
- t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
- }
- })
- }
- })
+ // Make sure no more RS.
+ if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
+ waitForNothing(test.effectiveRtrSolicitInt)
+ } else {
+ waitForNothing(test.effectiveMaxRtrSolicitDelay)
+ }
+
+ if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
+ t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
+ }
+ })
+ }
}
func TestStopStartSolicitingRouters(t *testing.T) {
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
index 177bf5516..c15f10e76 100644
--- a/pkg/tcpip/stack/neighbor_cache.go
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -17,16 +17,22 @@ package stack
import (
"fmt"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
const neighborCacheSize = 512 // max entries per interface
+// NeighborStats holds metrics for the neighbor table.
+type NeighborStats struct {
+ // FailedEntryLookups counts the number of lookups performed on an entry in
+ // Failed state.
+ FailedEntryLookups *tcpip.StatCounter
+}
+
// neighborCache maps IP addresses to link addresses. It uses the Least
// Recently Used (LRU) eviction strategy to implement a bounded cache for
-// dynmically acquired entries. It contains the state machine and configuration
+// dynamically acquired entries. It contains the state machine and configuration
// for running Neighbor Unreachability Detection (NUD).
//
// There are two types of entries in the neighbor cache:
@@ -92,9 +98,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA
n.dynamic.lru.Remove(e)
n.dynamic.count--
- e.dispatchRemoveEventLocked()
- e.setStateLocked(Unknown)
- e.notifyWakersLocked()
+ e.removeLocked()
e.mu.Unlock()
}
n.cache[remoteAddr] = entry
@@ -103,21 +107,27 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA
return entry
}
-// entry looks up the neighbor cache for translating address to link address
-// (e.g. IP -> MAC). If the LinkEndpoint requests address resolution and there
-// is a LinkAddressResolver registered with the network protocol, the cache
-// attempts to resolve the address and returns ErrWouldBlock. If a Waker is
-// provided, it will be notified when address resolution is complete (success
-// or not).
+// entry looks up neighbor information matching the remote address, and returns
+// it if readily available.
+//
+// Returns ErrWouldBlock if the link address is not readily available, along
+// with a notification channel for the caller to block on. Triggers address
+// resolution asynchronously.
+//
+// If onResolve is provided, it will be called either immediately, if resolution
+// is not required, or when address resolution is complete, with the resolved
+// link address and whether resolution succeeded. After any callbacks have been
+// called, the returned notification channel is closed.
+//
+// NB: if a callback is provided, it should not call into the neighbor cache.
//
// If specified, the local address must be an address local to the interface the
// neighbor cache belongs to. The local address is the source address of a
// packet prompting NUD/link address resolution.
//
-// If address resolution is required, ErrNoLinkAddress and a notification
-// channel is returned for the top level caller to block. Channel is closed
-// once address resolution is complete (success or not).
-func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) {
+// TODO(gvisor.dev/issue/5151): Don't return the neighbor entry.
+func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (NeighborEntry, <-chan struct{}, *tcpip.Error) {
+ // TODO(gvisor.dev/issue/5149): Handle static resolution in route.Resolve.
if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok {
e := NeighborEntry{
Addr: remoteAddr,
@@ -125,6 +135,9 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
State: Static,
UpdatedAtNanos: 0,
}
+ if onResolve != nil {
+ onResolve(linkAddr, true)
+ }
return e, nil, nil
}
@@ -142,47 +155,36 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
// of packets to a neighbor. While reasserting a neighbor's reachability,
// a node continues sending packets to that neighbor using the cached
// link-layer address."
+ if onResolve != nil {
+ onResolve(entry.neigh.LinkAddr, true)
+ }
return entry.neigh, nil, nil
- case Unknown, Incomplete:
- entry.addWakerLocked(w)
-
+ case Unknown, Incomplete, Failed:
+ if onResolve != nil {
+ entry.onResolve = append(entry.onResolve, onResolve)
+ }
if entry.done == nil {
// Address resolution needs to be initiated.
- if linkRes == nil {
- return entry.neigh, nil, tcpip.ErrNoLinkAddress
- }
entry.done = make(chan struct{})
}
-
entry.handlePacketQueuedLocked(localAddr)
return entry.neigh, entry.done, tcpip.ErrWouldBlock
- case Failed:
- return entry.neigh, nil, tcpip.ErrNoLinkAddress
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", s))
}
}
-// removeWaker removes a waker that has been added when link resolution for
-// addr was requested.
-func (n *neighborCache) removeWaker(addr tcpip.Address, waker *sleep.Waker) {
- n.mu.Lock()
- if entry, ok := n.cache[addr]; ok {
- delete(entry.wakers, waker)
- }
- n.mu.Unlock()
-}
-
// entries returns all entries in the neighbor cache.
func (n *neighborCache) entries() []NeighborEntry {
- entries := make([]NeighborEntry, 0, len(n.cache))
n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ entries := make([]NeighborEntry, 0, len(n.cache))
for _, entry := range n.cache {
entry.mu.RLock()
entries = append(entries, entry.neigh)
entry.mu.RUnlock()
}
- n.mu.RUnlock()
return entries
}
@@ -214,32 +216,13 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd
return
}
- // Notify that resolution has been interrupted, just in case the entry was
- // in the Incomplete or Probe state.
- entry.dispatchRemoveEventLocked()
- entry.setStateLocked(Unknown)
- entry.notifyWakersLocked()
+ entry.removeLocked()
entry.mu.Unlock()
}
n.cache[addr] = newStaticNeighborEntry(n.nic, addr, linkAddr, n.state)
}
-// removeEntryLocked removes the specified entry from the neighbor cache.
-func (n *neighborCache) removeEntryLocked(entry *neighborEntry) {
- if entry.neigh.State != Static {
- n.dynamic.lru.Remove(entry)
- n.dynamic.count--
- }
- if entry.neigh.State != Failed {
- entry.dispatchRemoveEventLocked()
- }
- entry.setStateLocked(Unknown)
- entry.notifyWakersLocked()
-
- delete(n.cache, entry.neigh.Addr)
-}
-
// removeEntry removes a dynamic or static entry by address from the neighbor
// cache. Returns true if the entry was found and deleted.
func (n *neighborCache) removeEntry(addr tcpip.Address) bool {
@@ -254,7 +237,13 @@ func (n *neighborCache) removeEntry(addr tcpip.Address) bool {
entry.mu.Lock()
defer entry.mu.Unlock()
- n.removeEntryLocked(entry)
+ if entry.neigh.State != Static {
+ n.dynamic.lru.Remove(entry)
+ n.dynamic.count--
+ }
+
+ entry.removeLocked()
+ delete(n.cache, entry.neigh.Addr)
return true
}
@@ -265,9 +254,7 @@ func (n *neighborCache) clear() {
for _, entry := range n.cache {
entry.mu.Lock()
- entry.dispatchRemoveEventLocked()
- entry.setStateLocked(Unknown)
- entry.notifyWakersLocked()
+ entry.removeLocked()
entry.mu.Unlock()
}
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index ed33418f3..a2ed6ae2a 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -28,7 +28,6 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
)
@@ -80,17 +79,20 @@ func entryDiffOptsWithSort() []cmp.Option {
func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache {
config.resetInvalidFields()
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
- return &neighborCache{
+ neigh := &neighborCache{
nic: &NIC{
stack: &Stack{
clock: clock,
nudDisp: nudDisp,
},
- id: 1,
+ id: 1,
+ stats: makeNICStats(),
},
state: NewNUDState(config, rng),
cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
}
+ neigh.nic.neigh = neigh
+ return neigh
}
// testEntryStore contains a set of IP to NeighborEntry mappings.
@@ -187,15 +189,18 @@ type testNeighborResolver struct {
entries *testEntryStore
delay time.Duration
onLinkAddressRequest func()
+ dropReplies bool
}
var _ LinkAddressResolver = (*testNeighborResolver)(nil)
func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
- // Delay handling the request to emulate network latency.
- r.clock.AfterFunc(r.delay, func() {
- r.fakeRequest(targetAddr)
- })
+ if !r.dropReplies {
+ // Delay handling the request to emulate network latency.
+ r.clock.AfterFunc(r.delay, func() {
+ r.fakeRequest(targetAddr)
+ })
+ }
// Execute post address resolution action, if available.
if f := r.onLinkAddressRequest; f != nil {
@@ -288,10 +293,10 @@ func TestNeighborCacheEntry(t *testing.T) {
entry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
@@ -324,7 +329,7 @@ func TestNeighborCacheEntry(t *testing.T) {
}
if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil {
- t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
+ t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err)
}
// No more events should have been dispatched.
@@ -351,11 +356,11 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
entry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
@@ -410,7 +415,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
}
if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
}
@@ -458,7 +463,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
return fmt.Errorf("c.store.entry(%d) not found", i)
}
if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
- return fmt.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ return fmt.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(c.neigh.config().RetransmitTimer)
@@ -510,7 +515,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
}
// Expect to find only the most recent entries. The order of entries reported
- // by entries() is undeterministic, so entries have to be sorted before
+ // by entries() is nondeterministic, so entries have to be sorted before
// comparison.
wantUnsortedEntries := opts.wantStaticEntries
for i := c.store.size() - neighborCacheSize; i < c.store.size(); i++ {
@@ -572,10 +577,10 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
// Add a dynamic entry
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(c.neigh.config().RetransmitTimer)
wantEvents := []testEntryEventInfo{
@@ -647,7 +652,7 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) {
// Add a static entry
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
staticLinkAddr := entry.LinkAddr + "static"
c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
@@ -691,7 +696,7 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T)
// Add a static entry
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
staticLinkAddr := entry.LinkAddr + "static"
c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
@@ -753,7 +758,7 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
// Add a static entry
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
staticLinkAddr := entry.LinkAddr + "static"
c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
@@ -823,10 +828,10 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
// Add a dynamic entry
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
@@ -904,150 +909,6 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
}
}
-func TestNeighborCacheNotifiesWaker(t *testing.T) {
- config := DefaultNUDConfigurations()
-
- nudDisp := testNUDDispatcher{}
- clock := faketime.NewManualClock()
- neigh := newTestNeighborCache(&nudDisp, config, clock)
- store := newTestEntryStore()
- linkRes := &testNeighborResolver{
- clock: clock,
- neigh: neigh,
- entries: store,
- delay: typicalLatency,
- }
-
- w := sleep.Waker{}
- s := sleep.Sleeper{}
- const wakerID = 1
- s.AddWaker(&w, wakerID)
-
- entry, ok := store.entry(0)
- if !ok {
- t.Fatalf("store.entry(0) not found")
- }
- _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w)
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, _ = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
- }
- if doneCh == nil {
- t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr)
- }
- clock.Advance(typicalLatency)
-
- select {
- case <-doneCh:
- default:
- t.Fatal("expected notification from done channel")
- }
-
- id, ok := s.Fetch(false /* block */)
- if !ok {
- t.Errorf("expected waker to be notified after neigh.entry(%s, '', _, _)", entry.Addr)
- }
- if id != wakerID {
- t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID)
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- State: Incomplete,
- },
- },
- {
- EventType: entryTestChanged,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- },
- },
- }
- nudDisp.mu.Lock()
- defer nudDisp.mu.Unlock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
- }
-}
-
-func TestNeighborCacheRemoveWaker(t *testing.T) {
- config := DefaultNUDConfigurations()
-
- nudDisp := testNUDDispatcher{}
- clock := faketime.NewManualClock()
- neigh := newTestNeighborCache(&nudDisp, config, clock)
- store := newTestEntryStore()
- linkRes := &testNeighborResolver{
- clock: clock,
- neigh: neigh,
- entries: store,
- delay: typicalLatency,
- }
-
- w := sleep.Waker{}
- s := sleep.Sleeper{}
- const wakerID = 1
- s.AddWaker(&w, wakerID)
-
- entry, ok := store.entry(0)
- if !ok {
- t.Fatalf("store.entry(0) not found")
- }
- _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w)
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, _) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
- }
- if doneCh == nil {
- t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr)
- }
-
- // Remove the waker before the neighbor cache has the opportunity to send a
- // notification.
- neigh.removeWaker(entry.Addr, &w)
- clock.Advance(typicalLatency)
-
- select {
- case <-doneCh:
- default:
- t.Fatal("expected notification from done channel")
- }
-
- if id, ok := s.Fetch(false /* block */); ok {
- t.Errorf("unexpected notification from waker with id %d", id)
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- State: Incomplete,
- },
- },
- {
- EventType: entryTestChanged,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- },
- },
- }
- nudDisp.mu.Lock()
- defer nudDisp.mu.Unlock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
- }
-}
-
func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
config := DefaultNUDConfigurations()
// Stay in Reachable so the cache can overflow
@@ -1059,12 +920,12 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr)
e, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil)
if err != nil {
- t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err)
+ t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
Addr: entry.Addr,
@@ -1072,7 +933,7 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
State: Static,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ t.Errorf("c.neigh.entry(%s, \"\", _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
wantEvents := []testEntryEventInfo{
@@ -1126,10 +987,10 @@ func TestNeighborCacheClear(t *testing.T) {
// Add a dynamic entry.
entry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
@@ -1184,7 +1045,7 @@ func TestNeighborCacheClear(t *testing.T) {
}
}
- // Clear shoud remove both dynamic and static entries.
+ // Clear should remove both dynamic and static entries.
neigh.clear()
// Remove events dispatched from clear() have no deterministic order so they
@@ -1231,10 +1092,10 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) {
// Add a dynamic entry
entry, ok := c.store.entry(0)
if !ok {
- t.Fatalf("c.store.entry(0) not found")
+ t.Fatal("c.store.entry(0) not found")
}
if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
@@ -1315,7 +1176,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
frequentlyUsedEntry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
// The following logic is very similar to overflowCache, but
@@ -1327,15 +1188,22 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if !ok {
t.Fatalf("store.entry(%d) not found", i)
}
- _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if !ok {
+ t.Fatal("expected successful address resolution")
+ }
+ if linkAddr != entry.LinkAddr {
+ t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ }
+ })
if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
select {
- case <-doneCh:
+ case <-ch:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr)
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
}
wantEvents := []testEntryEventInfo{
{
@@ -1370,7 +1238,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
// Periodically refresh the frequently used entry
if i%(neighborCacheSize/2) == 0 {
if _, _, err := neigh.entry(frequentlyUsedEntry.Addr, "", linkRes, nil); err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", frequentlyUsedEntry.Addr, err)
+ t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", frequentlyUsedEntry.Addr, err)
}
}
@@ -1378,15 +1246,23 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if !ok {
t.Fatalf("store.entry(%d) not found", i)
}
- _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
+
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if !ok {
+ t.Fatal("expected successful address resolution")
+ }
+ if linkAddr != entry.LinkAddr {
+ t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ }
+ })
if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
select {
- case <-doneCh:
+ case <-ch:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr)
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
}
// An entry should have been removed, as per the LRU eviction strategy
@@ -1432,7 +1308,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
}
// Expect to find only the frequently used entry and the most recent entries.
- // The order of entries reported by entries() is undeterministic, so entries
+ // The order of entries reported by entries() is nondeterministic, so entries
// have to be sorted before comparison.
wantUnsortedEntries := []NeighborEntry{
{
@@ -1491,12 +1367,12 @@ func TestNeighborCacheConcurrent(t *testing.T) {
go func(entry NeighborEntry) {
defer wg.Done()
if e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil && err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock)
}
}(entry)
}
- // Wait for all gorountines to send a request
+ // Wait for all goroutines to send a request
wg.Wait()
// Process all the requests for a single entry concurrently
@@ -1506,7 +1382,7 @@ func TestNeighborCacheConcurrent(t *testing.T) {
// All goroutines add in the same order and add more values than can fit in
// the cache. Our eviction strategy requires that the last entries are
// present, up to the size of the neighbor cache, and the rest are missing.
- // The order of entries reported by entries() is undeterministic, so entries
+ // The order of entries reported by entries() is nondeterministic, so entries
// have to be sorted before comparison.
var wantUnsortedEntries []NeighborEntry
for i := store.size() - neighborCacheSize; i < store.size(); i++ {
@@ -1544,27 +1420,32 @@ func TestNeighborCacheReplace(t *testing.T) {
// Add an entry
entry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
- _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
+
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if !ok {
+ t.Fatal("expected successful address resolution")
+ }
+ if linkAddr != entry.LinkAddr {
+ t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ }
+ })
if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
select {
- case <-doneCh:
+ case <-ch:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr)
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
}
// Verify the entry exists
{
- e, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ e, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
- }
- if doneCh != nil {
- t.Errorf("unexpected done channel from neigh.entry(%s, '', _, nil): %v", entry.Addr, doneCh)
+ t.Errorf("unexpected error from neigh.entry(%s, '', _, _, nil): %s", entry.Addr, err)
}
if t.Failed() {
t.FailNow()
@@ -1575,7 +1456,7 @@ func TestNeighborCacheReplace(t *testing.T) {
State: Reachable,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ t.Errorf("neigh.entry(%s, '', _, _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
}
@@ -1584,7 +1465,7 @@ func TestNeighborCacheReplace(t *testing.T) {
{
entry, ok := store.entry(1)
if !ok {
- t.Fatalf("store.entry(1) not found")
+ t.Fatal("store.entry(1) not found")
}
updatedLinkAddr = entry.LinkAddr
}
@@ -1601,7 +1482,7 @@ func TestNeighborCacheReplace(t *testing.T) {
{
e, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != nil {
- t.Fatalf("neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
+ t.Fatalf("neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
Addr: entry.Addr,
@@ -1609,7 +1490,7 @@ func TestNeighborCacheReplace(t *testing.T) {
State: Delay,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
clock.Advance(config.DelayFirstProbeTime + typicalLatency)
}
@@ -1619,7 +1500,7 @@ func TestNeighborCacheReplace(t *testing.T) {
e, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
clock.Advance(typicalLatency)
if err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
+ t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
Addr: entry.Addr,
@@ -1627,7 +1508,7 @@ func TestNeighborCacheReplace(t *testing.T) {
State: Reachable,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
}
}
@@ -1651,18 +1532,35 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
},
}
- // First, sanity check that resolution is working
entry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+
+ // First, sanity check that resolution is working
+ {
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if !ok {
+ t.Fatal("expected successful address resolution")
+ }
+ if linkAddr != entry.LinkAddr {
+ t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ }
+ })
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ }
+ clock.Advance(typicalLatency)
+ select {
+ case <-ch:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ }
}
- clock.Advance(typicalLatency)
+
got, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != nil {
- t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
+ t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
Addr: entry.Addr,
@@ -1670,20 +1568,35 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
State: Reachable,
}
if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
- // Verify that address resolution for an unknown address returns ErrNoLinkAddress
+ // Verify address resolution fails for an unknown address.
before := atomic.LoadUint32(&requestCount)
entry.Addr += "2"
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
- }
- waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes)
- clock.Advance(waitFor)
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress)
+ {
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if ok {
+ t.Error("expected unsuccessful address resolution")
+ }
+ if len(linkAddr) != 0 {
+ t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ })
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ }
+ waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes)
+ clock.Advance(waitFor)
+ select {
+ case <-ch:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ }
}
maxAttempts := neigh.config().MaxUnicastProbes
@@ -1711,15 +1624,129 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
entry, ok := store.entry(0)
if !ok {
- t.Fatalf("store.entry(0) not found")
+ t.Fatal("store.entry(0) not found")
}
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if ok {
+ t.Error("expected unsuccessful address resolution")
+ }
+ if len(linkAddr) != 0 {
+ t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ })
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
clock.Advance(waitFor)
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress)
+
+ select {
+ case <-ch:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ }
+}
+
+// TestNeighborCacheRetryResolution simulates retrying communication after
+// failing to perform address resolution.
+func TestNeighborCacheRetryResolution(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ clock := faketime.NewManualClock()
+ neigh := newTestNeighborCache(nil, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ // Simulate a faulty link.
+ dropReplies: true,
+ }
+
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatal("store.entry(0) not found")
+ }
+
+ // Perform address resolution with a faulty link, which will fail.
+ {
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if ok {
+ t.Error("expected unsuccessful address resolution")
+ }
+ if len(linkAddr) != 0 {
+ t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ })
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ }
+ waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
+ clock.Advance(waitFor)
+
+ select {
+ case <-ch:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ }
+ }
+
+ // Verify the entry is in Failed state.
+ wantEntries := []NeighborEntry{
+ {
+ Addr: entry.Addr,
+ LinkAddr: "",
+ State: Failed,
+ },
+ }
+ if diff := cmp.Diff(neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" {
+ t.Fatalf("neighbor entries mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Retry address resolution with a working link.
+ linkRes.dropReplies = false
+ {
+ incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if linkAddr != entry.LinkAddr {
+ t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ }
+ })
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ }
+ if incompleteEntry.State != Incomplete {
+ t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete)
+ }
+ clock.Advance(typicalLatency)
+
+ select {
+ case <-ch:
+ if !ok {
+ t.Fatal("expected successful address resolution")
+ }
+ reachableEntry, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ if err != nil {
+ t.Fatalf("neigh.entry(%s, '', _, _, nil): %v", entry.Addr, err)
+ }
+ if reachableEntry.Addr != entry.Addr {
+ t.Fatalf("got entry.Addr = %s, want = %s", reachableEntry.Addr, entry.Addr)
+ }
+ if reachableEntry.LinkAddr != entry.LinkAddr {
+ t.Fatalf("got entry.LinkAddr = %s, want = %s", reachableEntry.LinkAddr, entry.LinkAddr)
+ }
+ if reachableEntry.State != Reachable {
+ t.Fatalf("got entry.State = %s, want = %s", reachableEntry.State.String(), Reachable.String())
+ }
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
+ }
}
}
@@ -1739,7 +1766,7 @@ func TestNeighborCacheStaticResolution(t *testing.T) {
got, _, err := neigh.entry(testEntryBroadcastAddr, "", linkRes, nil)
if err != nil {
- t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", testEntryBroadcastAddr, err)
+ t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", testEntryBroadcastAddr, err)
}
want := NeighborEntry{
Addr: testEntryBroadcastAddr,
@@ -1747,7 +1774,7 @@ func TestNeighborCacheStaticResolution(t *testing.T) {
State: Static,
}
if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff)
}
}
@@ -1772,12 +1799,23 @@ func BenchmarkCacheClear(b *testing.B) {
if !ok {
b.Fatalf("store.entry(%d) not found", i)
}
- _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
+
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
+ if !ok {
+ b.Fatal("expected successful address resolution")
+ }
+ if linkAddr != entry.LinkAddr {
+ b.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ }
+ })
if err != tcpip.ErrWouldBlock {
- b.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ b.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
- if doneCh != nil {
- <-doneCh
+
+ select {
+ case <-ch:
+ default:
+ b.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr)
}
}
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 493e48031..75afb3001 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -19,7 +19,6 @@ import (
"sync"
"time"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -67,8 +66,7 @@ const (
// Static describes entries that have been explicitly added by the user. They
// do not expire and are not deleted until explicitly removed.
Static
- // Failed means traffic should not be sent to this neighbor since attempts of
- // reachability have returned inconclusive.
+ // Failed means recent attempts of reachability have returned inconclusive.
Failed
)
@@ -93,16 +91,13 @@ type neighborEntry struct {
neigh NeighborEntry
- // wakers is a set of waiters for address resolution result. Anytime state
- // transitions out of incomplete these waiters are notified. It is nil iff
- // address resolution is ongoing and no clients are waiting for the result.
- wakers map[*sleep.Waker]struct{}
-
- // done is used to allow callers to wait on address resolution. It is nil
- // iff nudState is not Reachable and address resolution is not yet in
- // progress.
+ // done is closed when address resolution is complete. It is nil iff s is
+ // incomplete and resolution is not yet in progress.
done chan struct{}
+ // onResolve is called with the result of address resolution.
+ onResolve []func(tcpip.LinkAddress, bool)
+
isRouter bool
job *tcpip.Job
}
@@ -143,25 +138,15 @@ func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAdd
}
}
-// addWaker adds w to the list of wakers waiting for address resolution.
-// Assumes the entry has already been appropriately locked.
-func (e *neighborEntry) addWakerLocked(w *sleep.Waker) {
- if w == nil {
- return
- }
- if e.wakers == nil {
- e.wakers = make(map[*sleep.Waker]struct{})
- }
- e.wakers[w] = struct{}{}
-}
-
-// notifyWakersLocked notifies those waiting for address resolution, whether it
-// succeeded or failed. Assumes the entry has already been appropriately locked.
-func (e *neighborEntry) notifyWakersLocked() {
- for w := range e.wakers {
- w.Assert()
+// notifyCompletionLocked notifies those waiting for address resolution, with
+// the link address if resolution completed successfully.
+//
+// Precondition: e.mu MUST be locked.
+func (e *neighborEntry) notifyCompletionLocked(succeeded bool) {
+ for _, callback := range e.onResolve {
+ callback(e.neigh.LinkAddr, succeeded)
}
- e.wakers = nil
+ e.onResolve = nil
if ch := e.done; ch != nil {
close(ch)
e.done = nil
@@ -170,6 +155,8 @@ func (e *neighborEntry) notifyWakersLocked() {
// dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has
// been added.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchAddEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
nudDisp.OnNeighborAdded(e.nic.id, e.neigh)
@@ -178,6 +165,8 @@ func (e *neighborEntry) dispatchAddEventLocked() {
// dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry
// has changed state or link-layer address.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchChangeEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
nudDisp.OnNeighborChanged(e.nic.id, e.neigh)
@@ -186,23 +175,41 @@ func (e *neighborEntry) dispatchChangeEventLocked() {
// dispatchRemoveEventLocked signals to stack's NUD Dispatcher that the entry
// has been removed.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchRemoveEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
nudDisp.OnNeighborRemoved(e.nic.id, e.neigh)
}
}
+// cancelJobLocked cancels the currently scheduled action, if there is one.
+// Entries in Unknown, Stale, or Static state do not have a scheduled action.
+//
+// Precondition: e.mu MUST be locked.
+func (e *neighborEntry) cancelJobLocked() {
+ if job := e.job; job != nil {
+ job.Cancel()
+ }
+}
+
+// removeLocked prepares the entry for removal.
+//
+// Precondition: e.mu MUST be locked.
+func (e *neighborEntry) removeLocked() {
+ e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds()
+ e.dispatchRemoveEventLocked()
+ e.cancelJobLocked()
+ e.notifyCompletionLocked(false /* succeeded */)
+}
+
// setStateLocked transitions the entry to the specified state immediately.
//
// Follows the logic defined in RFC 4861 section 7.3.3.
//
-// e.mu MUST be locked.
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) setStateLocked(next NeighborState) {
- // Cancel the previously scheduled action, if there is one. Entries in
- // Unknown, Stale, or Static state do not have scheduled actions.
- if timer := e.job; timer != nil {
- timer.Cancel()
- }
+ e.cancelJobLocked()
prev := e.neigh.State
e.neigh.State = next
@@ -257,11 +264,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
e.job.Schedule(immediateDuration)
case Failed:
- e.notifyWakersLocked()
- e.job = e.nic.stack.newJob(&e.mu, func() {
- e.nic.neigh.removeEntryLocked(e)
- })
- e.job.Schedule(config.UnreachableTime)
+ e.notifyCompletionLocked(false /* succeeded */)
case Unknown, Stale, Static:
// Do nothing
@@ -275,8 +278,14 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
// being queued for outgoing transmission.
//
// Follows the logic defined in RFC 4861 section 7.3.3.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
switch e.neigh.State {
+ case Failed:
+ e.nic.stats.Neighbor.FailedEntryLookups.Increment()
+
+ fallthrough
case Unknown:
e.neigh.State = Incomplete
e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds()
@@ -309,7 +318,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
// implementation may find it convenient in some cases to return errors
// to the sender by taking the offending packet, generating an ICMP
// error message, and then delivering it (locally) through the generic
- // error-handling routines.' - RFC 4861 section 2.1
+ // error-handling routines." - RFC 4861 section 2.1
e.dispatchRemoveEventLocked()
e.setStateLocked(Failed)
return
@@ -347,9 +356,8 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
e.setStateLocked(Delay)
e.dispatchChangeEventLocked()
- case Incomplete, Reachable, Delay, Probe, Static, Failed:
+ case Incomplete, Reachable, Delay, Probe, Static:
// Do nothing
-
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
}
@@ -359,18 +367,30 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
// Neighbor Solicitation for ARP or NDP, respectively).
//
// Follows the logic defined in RFC 4861 section 7.2.3.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) {
// Probes MUST be silently discarded if the target address is tentative, does
// not exist, or not bound to the NIC as per RFC 4861 section 7.2.3. These
// checks MUST be done by the NetworkEndpoint.
switch e.neigh.State {
- case Unknown, Incomplete, Failed:
+ case Unknown, Failed:
e.neigh.LinkAddr = remoteLinkAddr
e.setStateLocked(Stale)
- e.notifyWakersLocked()
e.dispatchAddEventLocked()
+ case Incomplete:
+ // "If an entry already exists, and the cached link-layer address
+ // differs from the one in the received Source Link-Layer option, the
+ // cached address should be replaced by the received address, and the
+ // entry's reachability state MUST be set to STALE."
+ // - RFC 4861 section 7.2.3
+ e.neigh.LinkAddr = remoteLinkAddr
+ e.setStateLocked(Stale)
+ e.notifyCompletionLocked(true /* succeeded */)
+ e.dispatchChangeEventLocked()
+
case Reachable, Delay, Probe:
if e.neigh.LinkAddr != remoteLinkAddr {
e.neigh.LinkAddr = remoteLinkAddr
@@ -403,6 +423,8 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) {
// not be possible. SEND uses RSA key pairs to produce Cryptographically
// Generated Addresses (CGA), as defined in RFC 3972. This ensures that the
// claimed source of an NDP message is the owner of the claimed address.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) {
switch e.neigh.State {
case Incomplete:
@@ -421,7 +443,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
}
e.dispatchChangeEventLocked()
e.isRouter = flags.IsRouter
- e.notifyWakersLocked()
+ e.notifyCompletionLocked(true /* succeeded */)
// "Note that the Override flag is ignored if the entry is in the
// INCOMPLETE state." - RFC 4861 section 7.2.5
@@ -456,7 +478,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
wasReachable := e.neigh.State == Reachable
// Set state to Reachable again to refresh timers.
e.setStateLocked(Reachable)
- e.notifyWakersLocked()
+ e.notifyCompletionLocked(true /* succeeded */)
if !wasReachable {
e.dispatchChangeEventLocked()
}
@@ -494,6 +516,8 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
// handleUpperLevelConfirmationLocked processes an incoming upper-level protocol
// (e.g. TCP acknowledgements) reachability confirmation.
+//
+// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handleUpperLevelConfirmationLocked() {
switch e.neigh.State {
case Reachable, Stale, Delay, Probe:
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index c2b763325..ec34ffa5a 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -25,7 +25,6 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -73,35 +72,36 @@ func eventDiffOptsWithSort() []cmp.Option {
// The following unit tests exercise every state transition and verify its
// behavior with RFC 4681.
//
-// | From | To | Cause | Action | Event |
-// | ========== | ========== | ========================================== | =============== | ======= |
-// | Unknown | Unknown | Confirmation w/ unknown address | | Added |
-// | Unknown | Incomplete | Packet queued to unknown address | Send probe | Added |
-// | Unknown | Stale | Probe w/ unknown address | | Added |
-// | Incomplete | Incomplete | Retransmit timer expired | Send probe | Changed |
-// | Incomplete | Reachable | Solicited confirmation | Notify wakers | Changed |
-// | Incomplete | Stale | Unsolicited confirmation | Notify wakers | Changed |
-// | Incomplete | Failed | Max probes sent without reply | Notify wakers | Removed |
-// | Reachable | Reachable | Confirmation w/ different isRouter flag | Update IsRouter | |
-// | Reachable | Stale | Reachable timer expired | | Changed |
-// | Reachable | Stale | Probe or confirmation w/ different address | | Changed |
-// | Stale | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
-// | Stale | Reachable | Solicited confirmation w/o address | Notify wakers | Changed |
-// | Stale | Stale | Override confirmation | Update LinkAddr | Changed |
-// | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed |
-// | Stale | Delay | Packet sent | | Changed |
-// | Delay | Reachable | Upper-layer confirmation | | Changed |
-// | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
-// | Delay | Reachable | Solicited confirmation w/o address | Notify wakers | Changed |
-// | Delay | Stale | Probe or confirmation w/ different address | | Changed |
-// | Delay | Probe | Delay timer expired | Send probe | Changed |
-// | Probe | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
-// | Probe | Reachable | Solicited confirmation w/ same address | Notify wakers | Changed |
-// | Probe | Reachable | Solicited confirmation w/o address | Notify wakers | Changed |
-// | Probe | Stale | Probe or confirmation w/ different address | | Changed |
-// | Probe | Probe | Retransmit timer expired | Send probe | Changed |
-// | Probe | Failed | Max probes sent without reply | Notify wakers | Removed |
-// | Failed | | Unreachability timer expired | Delete entry | |
+// | From | To | Cause | Update | Action | Event |
+// | ========== | ========== | ========================================== | ======== | ===========| ======= |
+// | Unknown | Unknown | Confirmation w/ unknown address | | | Added |
+// | Unknown | Incomplete | Packet queued to unknown address | | Send probe | Added |
+// | Unknown | Stale | Probe | | | Added |
+// | Incomplete | Incomplete | Retransmit timer expired | | Send probe | Changed |
+// | Incomplete | Reachable | Solicited confirmation | LinkAddr | Notify | Changed |
+// | Incomplete | Stale | Unsolicited confirmation | LinkAddr | Notify | Changed |
+// | Incomplete | Stale | Probe | LinkAddr | Notify | Changed |
+// | Incomplete | Failed | Max probes sent without reply | | Notify | Removed |
+// | Reachable | Reachable | Confirmation w/ different isRouter flag | IsRouter | | |
+// | Reachable | Stale | Reachable timer expired | | | Changed |
+// | Reachable | Stale | Probe or confirmation w/ different address | | | Changed |
+// | Stale | Reachable | Solicited override confirmation | LinkAddr | | Changed |
+// | Stale | Reachable | Solicited confirmation w/o address | | Notify | Changed |
+// | Stale | Stale | Override confirmation | LinkAddr | | Changed |
+// | Stale | Stale | Probe w/ different address | LinkAddr | | Changed |
+// | Stale | Delay | Packet sent | | | Changed |
+// | Delay | Reachable | Upper-layer confirmation | | | Changed |
+// | Delay | Reachable | Solicited override confirmation | LinkAddr | | Changed |
+// | Delay | Reachable | Solicited confirmation w/o address | | Notify | Changed |
+// | Delay | Stale | Probe or confirmation w/ different address | | | Changed |
+// | Delay | Probe | Delay timer expired | | Send probe | Changed |
+// | Probe | Reachable | Solicited override confirmation | LinkAddr | | Changed |
+// | Probe | Reachable | Solicited confirmation w/ same address | | Notify | Changed |
+// | Probe | Reachable | Solicited confirmation w/o address | | Notify | Changed |
+// | Probe | Stale | Probe or confirmation w/ different address | | | Changed |
+// | Probe | Probe | Retransmit timer expired | | | Changed |
+// | Probe | Failed | Max probes sent without reply | | Notify | Removed |
+// | Failed | Incomplete | Packet queued | | Send probe | Added |
type testEntryEventType uint8
@@ -228,6 +228,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e
clock: clock,
nudDisp: &disp,
},
+ stats: makeNICStats(),
}
nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{
header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil, nil, nil),
@@ -256,8 +257,8 @@ func TestEntryInitiallyUnknown(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- if got, want := e.neigh.State, Unknown; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Unknown {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Unknown)
}
e.mu.Unlock()
@@ -289,8 +290,8 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) {
Override: false,
IsRouter: false,
})
- if got, want := e.neigh.State, Unknown; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Unknown {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Unknown)
}
e.mu.Unlock()
@@ -318,8 +319,8 @@ func TestEntryUnknownToIncomplete(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
e.mu.Unlock()
@@ -365,8 +366,8 @@ func TestEntryUnknownToStale(t *testing.T) {
e.mu.Lock()
e.handleProbeLocked(entryTestLinkAddr1)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
}
e.mu.Unlock()
@@ -404,8 +405,8 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
updatedAtNanos := e.neigh.UpdatedAtNanos
e.mu.Unlock()
@@ -558,21 +559,15 @@ func TestEntryIncompleteToReachable(t *testing.T) {
nudDisp.mu.Unlock()
}
-// TestEntryAddsAndClearsWakers verifies that wakers are added when
-// addWakerLocked is called and cleared when address resolution finishes. In
-// this case, address resolution will finish when transitioning from Incomplete
-// to Reachable.
-func TestEntryAddsAndClearsWakers(t *testing.T) {
+func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
c := DefaultNUDConfigurations()
e, nudDisp, linkRes, clock := entryTestSetup(c)
- w := sleep.Waker{}
- s := sleep.Sleeper{}
- s.AddWaker(&w, 123)
- defer s.Done()
-
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
+ }
e.mu.Unlock()
runImmediatelyScheduledJobs(clock)
@@ -591,26 +586,16 @@ func TestEntryAddsAndClearsWakers(t *testing.T) {
}
e.mu.Lock()
- if got := e.wakers; got != nil {
- t.Errorf("got e.wakers = %v, want = nil", got)
- }
- e.addWakerLocked(&w)
- if got, want := w.IsAsserted(), false; got != want {
- t.Errorf("waker.IsAsserted() = %t, want = %t", got, want)
- }
- if e.wakers == nil {
- t.Error("expected e.wakers to be non-nil")
- }
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: true,
Override: false,
- IsRouter: false,
+ IsRouter: true,
})
- if e.wakers != nil {
- t.Errorf("got e.wakers = %v, want = nil", e.wakers)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
}
- if got, want := w.IsAsserted(), true; got != want {
- t.Errorf("waker.IsAsserted() = %t, want = %t", got, want)
+ if !e.isRouter {
+ t.Errorf("got e.isRouter = %t, want = true", e.isRouter)
}
e.mu.Unlock()
@@ -641,7 +626,7 @@ func TestEntryAddsAndClearsWakers(t *testing.T) {
nudDisp.mu.Unlock()
}
-func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
+func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) {
c := DefaultNUDConfigurations()
e, nudDisp, linkRes, clock := entryTestSetup(c)
@@ -661,22 +646,20 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
},
}
linkRes.mu.Lock()
- if diff := cmp.Diff(linkRes.probes, wantProbes); diff != "" {
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
- linkRes.mu.Unlock()
e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
+ Solicited: false,
Override: false,
- IsRouter: true,
+ IsRouter: false,
})
- if e.neigh.State != Reachable {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
- }
- if !e.isRouter {
- t.Errorf("got e.isRouter = %t, want = true", e.isRouter)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
}
e.mu.Unlock()
@@ -696,7 +679,7 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
Entry: NeighborEntry{
Addr: entryTestAddr1,
LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ State: Stale,
},
},
}
@@ -707,7 +690,7 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
nudDisp.mu.Unlock()
}
-func TestEntryIncompleteToStale(t *testing.T) {
+func TestEntryIncompleteToStaleWhenProbe(t *testing.T) {
c := DefaultNUDConfigurations()
e, nudDisp, linkRes, clock := entryTestSetup(c)
@@ -734,11 +717,7 @@ func TestEntryIncompleteToStale(t *testing.T) {
}
e.mu.Lock()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
+ e.handleProbeLocked(entryTestLinkAddr1)
if e.neigh.State != Stale {
t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
}
@@ -778,8 +757,8 @@ func TestEntryIncompleteToFailed(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
e.mu.Unlock()
@@ -839,8 +818,8 @@ func TestEntryIncompleteToFailed(t *testing.T) {
nudDisp.mu.Unlock()
e.mu.Lock()
- if got, want := e.neigh.State, Failed; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Failed {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed)
}
e.mu.Unlock()
}
@@ -883,8 +862,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
Override: false,
IsRouter: true,
})
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
}
if got, want := e.isRouter, true; got != want {
t.Errorf("got e.isRouter = %t, want = %t", got, want)
@@ -930,8 +909,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
nudDisp.mu.Unlock()
e.mu.Lock()
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
}
e.mu.Unlock()
}
@@ -1081,8 +1060,8 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
nudDisp.mu.Unlock()
e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
}
e.mu.Unlock()
}
@@ -2379,8 +2358,8 @@ func TestEntryDelayToProbe(t *testing.T) {
IsRouter: false,
})
e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
}
e.mu.Unlock()
@@ -2445,8 +2424,8 @@ func TestEntryDelayToProbe(t *testing.T) {
nudDisp.mu.Unlock()
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.mu.Unlock()
}
@@ -2503,12 +2482,12 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
}
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.handleProbeLocked(entryTestLinkAddr2)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
}
e.mu.Unlock()
@@ -2618,16 +2597,16 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
}
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
Solicited: false,
Override: true,
IsRouter: false,
})
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
}
e.mu.Unlock()
@@ -2738,16 +2717,16 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
}
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: true,
IsRouter: false,
})
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
@@ -2834,16 +2813,16 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
}
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
Solicited: true,
Override: true,
IsRouter: false,
})
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
}
if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
@@ -2962,16 +2941,16 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
}
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
Solicited: true,
Override: true,
IsRouter: false,
})
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
}
if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
@@ -3099,16 +3078,16 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
}
e.mu.Lock()
- if got, want := e.neigh.State, Probe; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
}
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: true,
Override: false,
IsRouter: false,
})
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
}
e.mu.Unlock()
@@ -3433,72 +3412,61 @@ func TestEntryProbeToFailed(t *testing.T) {
nudDisp.mu.Unlock()
}
-func TestEntryFailedGetsDeleted(t *testing.T) {
+func TestEntryFailedToIncomplete(t *testing.T) {
c := DefaultNUDConfigurations()
c.MaxMulticastProbes = 3
- c.MaxUnicastProbes = 3
e, nudDisp, linkRes, clock := entryTestSetup(c)
- // Verify the cache contains the entry.
- if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok {
- t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1)
- }
-
+ // TODO(gvisor.dev/issue/4872): Use helper functions to start entry tests in
+ // their expected state.
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
+ }
e.mu.Unlock()
- runImmediatelyScheduledJobs(clock)
- {
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.probes = nil
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
+ waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes)
+ clock.Advance(waitFor)
+
+ wantProbes := []entryTestProbeInfo{
+ // The Incomplete-to-Incomplete state transition is tested here by
+ // verifying that 3 reachability probes were sent.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
e.mu.Lock()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Failed {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed)
+ }
e.mu.Unlock()
- waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime
- clock.Advance(waitFor)
- {
- wantProbes := []entryTestProbeInfo{
- // The next three probe are sent in Probe.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
+ e.mu.Lock()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
+ e.mu.Unlock()
wantEvents := []testEntryEventInfo{
{
@@ -3511,39 +3479,21 @@ func TestEntryFailedGetsDeleted(t *testing.T) {
},
},
{
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
- },
- },
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
- },
- },
- {
- EventType: entryTestChanged,
+ EventType: entryTestRemoved,
NICID: entryTestNICID,
Entry: NeighborEntry{
Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
},
},
{
- EventType: entryTestRemoved,
+ EventType: entryTestAdded,
NICID: entryTestNICID,
Entry: NeighborEntry{
Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
},
},
}
@@ -3552,9 +3502,4 @@ func TestEntryFailedGetsDeleted(t *testing.T) {
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- // Verify the cache no longer contains the entry.
- if _, ok := e.nic.neigh.cache[entryTestAddr1]; ok {
- t.Errorf("entry %q should have been deleted from the neighbor cache", entryTestAddr1)
- }
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 60c81a3aa..4a34805b5 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -20,7 +20,6 @@ import (
"reflect"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -54,18 +53,20 @@ type NIC struct {
sync.RWMutex
spoofing bool
promiscuous bool
- // packetEPs is protected by mu, but the contained PacketEndpoint
- // values are not.
- packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint
+ // packetEPs is protected by mu, but the contained packetEndpointList are
+ // not.
+ packetEPs map[tcpip.NetworkProtocolNumber]*packetEndpointList
}
}
-// NICStats includes transmitted and received stats.
+// NICStats hold statistics for a NIC.
type NICStats struct {
Tx DirectionStats
Rx DirectionStats
DisabledRx DirectionStats
+
+ Neighbor NeighborStats
}
func makeNICStats() NICStats {
@@ -80,6 +81,39 @@ type DirectionStats struct {
Bytes *tcpip.StatCounter
}
+type packetEndpointList struct {
+ mu sync.RWMutex
+
+ // eps is protected by mu, but the contained PacketEndpoint values are not.
+ eps []PacketEndpoint
+}
+
+func (p *packetEndpointList) add(ep PacketEndpoint) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.eps = append(p.eps, ep)
+}
+
+func (p *packetEndpointList) remove(ep PacketEndpoint) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ for i, epOther := range p.eps {
+ if epOther == ep {
+ p.eps = append(p.eps[:i], p.eps[i+1:]...)
+ break
+ }
+ }
+}
+
+// forEach calls fn with each endpoints in p while holding the read lock on p.
+func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ for _, ep := range p.eps {
+ fn(ep)
+ }
+}
+
// newNIC returns a new NIC using the default NDP configurations from stack.
func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC {
// TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For
@@ -100,7 +134,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
stats: makeNICStats(),
networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
}
- nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint)
+ nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList)
// Check for Neighbor Unreachability Detection support.
var nud NUDHandler
@@ -123,11 +157,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
// Register supported packet and network endpoint protocols.
for _, netProto := range header.Ethertypes {
- nic.mu.packetEPs[netProto] = []PacketEndpoint{}
+ nic.mu.packetEPs[netProto] = new(packetEndpointList)
}
for _, netProto := range stack.networkProtocols {
netNum := netProto.Number()
- nic.mu.packetEPs[netNum] = nil
+ nic.mu.packetEPs[netNum] = new(packetEndpointList)
nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic)
}
@@ -170,7 +204,7 @@ func (n *NIC) disable() {
//
// n MUST be locked.
func (n *NIC) disableLocked() {
- if !n.setEnabled(false) {
+ if !n.Enabled() {
return
}
@@ -182,6 +216,10 @@ func (n *NIC) disableLocked() {
for _, ep := range n.networkEndpoints {
ep.Disable()
}
+
+ if !n.setEnabled(false) {
+ panic("should have only done work to disable the NIC if it was enabled")
+ }
}
// enable enables n.
@@ -232,7 +270,8 @@ func (n *NIC) setPromiscuousMode(enable bool) {
n.mu.Unlock()
}
-func (n *NIC) isPromiscuousMode() bool {
+// Promiscuous implements NetworkInterface.
+func (n *NIC) Promiscuous() bool {
n.mu.RLock()
rv := n.mu.promiscuous
n.mu.RUnlock()
@@ -255,16 +294,18 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb
// the same unresolved IP address, and transmit the saved
// packet when the address has been resolved.
//
- // RFC 4861 section 5.2 (for IPv6):
- // Once the IP address of the next-hop node is known, the sender
- // examines the Neighbor Cache for link-layer information about that
- // neighbor. If no entry exists, the sender creates one, sets its state
- // to INCOMPLETE, initiates Address Resolution, and then queues the data
- // packet pending completion of address resolution.
+ // RFC 4861 section 7.2.2 (for IPv6):
+ // While waiting for address resolution to complete, the sender MUST, for
+ // each neighbor, retain a small queue of packets waiting for address
+ // resolution to complete. The queue MUST hold at least one packet, and MAY
+ // contain more. However, the number of queued packets per neighbor SHOULD
+ // be limited to some small value. When a queue overflows, the new arrival
+ // SHOULD replace the oldest entry. Once address resolution completes, the
+ // node transmits any queued packets.
if ch, err := r.Resolve(nil); err != nil {
if err == tcpip.ErrWouldBlock {
- r := r.Clone()
- n.stack.linkResQueue.enqueue(ch, &r, protocol, pkt)
+ r.Acquire()
+ n.stack.linkResQueue.enqueue(ch, r, protocol, pkt)
return nil
}
return err
@@ -276,9 +317,11 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb
// WritePacketToRemote implements NetworkInterface.
func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
r := Route{
- NetProto: protocol,
- RemoteLinkAddress: remoteLinkAddr,
+ routeInfo: routeInfo{
+ NetProto: protocol,
+ },
}
+ r.ResolveWith(remoteLinkAddr)
return n.writePacket(&r, gso, protocol, pkt)
}
@@ -320,16 +363,21 @@ func (n *NIC) setSpoofing(enable bool) {
// primaryAddress returns an address that can be used to communicate with
// remoteAddr.
func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint {
- n.mu.RLock()
- spoofing := n.mu.spoofing
- n.mu.RUnlock()
-
ep, ok := n.networkEndpoints[protocol]
if !ok {
return nil
}
- return ep.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing)
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ return nil
+ }
+
+ n.mu.RLock()
+ spoofing := n.mu.spoofing
+ n.mu.RUnlock()
+
+ return addressableEndpoint.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing)
}
type getAddressBehaviour int
@@ -388,11 +436,17 @@ func (n *NIC) getAddressOrCreateTemp(protocol tcpip.NetworkProtocolNumber, addre
// getAddressOrCreateTempInner is like getAddressEpOrCreateTemp except a boolean
// is passed to indicate whether or not we should generate temporary endpoints.
func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, createTemp bool, peb PrimaryEndpointBehavior) AssignableAddressEndpoint {
- if ep, ok := n.networkEndpoints[protocol]; ok {
- return ep.AcquireAssignedAddress(address, createTemp, peb)
+ ep, ok := n.networkEndpoints[protocol]
+ if !ok {
+ return nil
}
- return nil
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ return nil
+ }
+
+ return addressableEndpoint.AcquireAssignedAddress(address, createTemp, peb)
}
// addAddress adds a new address to n, so that it starts accepting packets
@@ -403,7 +457,12 @@ func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo
return tcpip.ErrUnknownProtocol
}
- addressEndpoint, err := ep.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */)
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ return tcpip.ErrNotSupported
+ }
+
+ addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */)
if err == nil {
// We have no need for the address endpoint.
addressEndpoint.DecRef()
@@ -416,7 +475,12 @@ func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo
func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress {
var addrs []tcpip.ProtocolAddress
for p, ep := range n.networkEndpoints {
- for _, a := range ep.PermanentAddresses() {
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ continue
+ }
+
+ for _, a := range addressableEndpoint.PermanentAddresses() {
addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a})
}
}
@@ -427,7 +491,12 @@ func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress {
func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress {
var addrs []tcpip.ProtocolAddress
for p, ep := range n.networkEndpoints {
- for _, a := range ep.PrimaryAddresses() {
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ continue
+ }
+
+ for _, a := range addressableEndpoint.PrimaryAddresses() {
addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a})
}
}
@@ -445,13 +514,23 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit
return tcpip.AddressWithPrefix{}
}
- return ep.MainAddress()
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ return tcpip.AddressWithPrefix{}
+ }
+
+ return addressableEndpoint.MainAddress()
}
// removeAddress removes an address from n.
func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error {
for _, ep := range n.networkEndpoints {
- if err := ep.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress {
+ addressableEndpoint, ok := ep.(AddressableEndpoint)
+ if !ok {
+ continue
+ }
+
+ if err := addressableEndpoint.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress {
continue
} else {
return err
@@ -469,14 +548,6 @@ func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) {
return n.neigh.entries(), nil
}
-func (n *NIC) removeWaker(addr tcpip.Address, w *sleep.Waker) {
- if n.neigh == nil {
- return
- }
-
- n.neigh.removeWaker(addr, w)
-}
-
func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) *tcpip.Error {
if n.neigh == nil {
return tcpip.ErrNotSupported
@@ -524,8 +595,7 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
return tcpip.ErrNotSupported
}
- _, err := gep.JoinGroup(addr)
- return err
+ return gep.JoinGroup(addr)
}
// leaveGroup decrements the count for the given multicast address, and when it
@@ -541,11 +611,7 @@ func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Addres
return tcpip.ErrNotSupported
}
- if _, err := gep.LeaveGroup(addr); err != nil {
- return err
- }
-
- return nil
+ return gep.LeaveGroup(addr)
}
// isInGroup returns true if n has joined the multicast group addr.
@@ -564,13 +630,6 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool {
return false
}
-func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) {
- r := makeRoute(protocol, dst, src, n, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */)
- defer r.Release()
- r.PopulatePacketInfo(pkt)
- n.getNetworkEndpoint(protocol).HandlePacket(pkt)
-}
-
// DeliverNetworkPacket finds the appropriate network protocol endpoint and
// hands the packet over for further processing. This function is called when
// the NIC receives a packet from the link endpoint.
@@ -592,7 +651,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
n.stats.Rx.Packets.Increment()
n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size()))
- netProto, ok := n.stack.networkProtocols[protocol]
+ networkEndpoint, ok := n.networkEndpoints[protocol]
if !ok {
n.mu.RUnlock()
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
@@ -607,21 +666,26 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0
// Are any packet type sockets listening for this network protocol?
- packetEPs := n.mu.packetEPs[protocol]
- // Add any other packet type sockets that may be listening for all protocols.
- packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
+ protoEPs := n.mu.packetEPs[protocol]
+ // Other packet type sockets that are listening for all protocols.
+ anyEPs := n.mu.packetEPs[header.EthernetProtocolAll]
n.mu.RUnlock()
- for _, ep := range packetEPs {
+
+ // Deliver to interested packet endpoints without holding NIC lock.
+ deliverPacketEPs := func(ep PacketEndpoint) {
p := pkt.Clone()
p.PktType = tcpip.PacketHost
ep.HandlePacket(n.id, local, protocol, p)
}
-
- if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
- n.stack.stats.IP.PacketsReceived.Increment()
+ if protoEPs != nil {
+ protoEPs.forEach(deliverPacketEPs)
+ }
+ if anyEPs != nil {
+ anyEPs.forEach(deliverPacketEPs)
}
// Parse headers.
+ netProto := n.stack.NetworkProtocolInstance(protocol)
transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt)
if !ok {
// The packet is too small to contain a network header.
@@ -636,9 +700,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
}
- src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View())
-
if n.stack.handleLocal && !n.IsLoopback() {
+ src, _ := netProto.ParseAddresses(pkt.NetworkHeader().View())
if r := n.getAddress(protocol, src); r != nil {
r.DecRef()
@@ -651,78 +714,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
}
- // Loopback traffic skips the prerouting chain.
- if !n.IsLoopback() {
- // iptables filtering.
- ipt := n.stack.IPTables()
- address := n.primaryAddress(protocol)
- if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok {
- // iptables is telling us to drop the packet.
- n.stack.stats.IP.IPTablesPreroutingDropped.Increment()
- return
- }
- }
-
- if addressEndpoint := n.getAddress(protocol, dst); addressEndpoint != nil {
- n.handlePacket(protocol, dst, src, remote, addressEndpoint, pkt)
- return
- }
-
- // This NIC doesn't care about the packet. Find a NIC that cares about the
- // packet and forward it to the NIC.
- //
- // TODO: Should we be forwarding the packet even if promiscuous?
- if n.stack.Forwarding(protocol) {
- r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */)
- if err != nil {
- n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
- return
- }
-
- // Found a NIC.
- n := r.localAddressNIC
- if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil {
- if n.isValidForOutgoing(addressEndpoint) {
- pkt.NICID = n.ID()
- r.RemoteAddress = src
- pkt.NetworkPacketInfo = r.networkPacketInfo()
- n.getNetworkEndpoint(protocol).HandlePacket(pkt)
- addressEndpoint.DecRef()
- r.Release()
- return
- }
-
- addressEndpoint.DecRef()
- }
-
- // n doesn't have a destination endpoint.
- // Send the packet out of n.
- // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease
- // the TTL field for ipv4/ipv6.
-
- // pkt may have set its header and may not have enough headroom for
- // link-layer header for the other link to prepend. Here we create a new
- // packet to forward.
- fwdPkt := NewPacketBuffer(PacketBufferOptions{
- ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()),
- // We need to do a deep copy of the IP packet because WritePacket (and
- // friends) take ownership of the packet buffer, but we do not own it.
- Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(),
- })
-
- // TODO(b/143425874) Decrease the TTL field in forwarded packets.
- if err := n.WritePacket(&r, nil, protocol, fwdPkt); err != nil {
- n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
- }
-
- r.Release()
- return
- }
-
- // If a packet socket handled the packet, don't treat it as invalid.
- if len(packetEPs) == 0 {
- n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
- }
+ networkEndpoint.HandlePacket(pkt)
}
// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket.
@@ -731,16 +723,17 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc
// We do not deliver to protocol specific packet endpoints as on Linux
// only ETH_P_ALL endpoints get outbound packets.
// Add any other packet sockets that maybe listening for all protocols.
- packetEPs := n.mu.packetEPs[header.EthernetProtocolAll]
+ eps := n.mu.packetEPs[header.EthernetProtocolAll]
n.mu.RUnlock()
- for _, ep := range packetEPs {
+
+ eps.forEach(func(ep PacketEndpoint) {
p := pkt.Clone()
p.PktType = tcpip.PacketOutgoing
// Add the link layer header as outgoing packets are intercepted
// before the link layer header is created.
n.LinkEndpoint.AddHeader(local, remote, protocol, p)
ep.HandlePacket(n.id, local, protocol, p)
- }
+ })
}
// DeliverTransportPacket delivers the packets to the appropriate transport
@@ -893,7 +886,7 @@ func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa
if !ok {
return tcpip.ErrNotSupported
}
- n.mu.packetEPs[netProto] = append(eps, ep)
+ eps.add(ep)
return nil
}
@@ -906,13 +899,7 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep
if !ok {
return
}
-
- for i, epOther := range eps {
- if epOther == ep {
- n.mu.packetEPs[netProto] = append(eps[:i], eps[i+1:]...)
- return
- }
- }
+ eps.remove(ep)
}
// isValidForOutgoing returns true if the endpoint can be used to send out a
diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go
index ab629b3a4..12d67409a 100644
--- a/pkg/tcpip/stack/nud.go
+++ b/pkg/tcpip/stack/nud.go
@@ -109,14 +109,6 @@ const (
//
// Default taken from MAX_NEIGHBOR_ADVERTISEMENT of RFC 4861 section 10.
defaultMaxReachbilityConfirmations = 3
-
- // defaultUnreachableTime is the default duration for how long an entry will
- // remain in the FAILED state before being removed from the neighbor cache.
- //
- // Note, there is no equivalent protocol constant defined in RFC 4861. It
- // leaves the specifics of any garbage collection mechanism up to the
- // implementation.
- defaultUnreachableTime = 5 * time.Second
)
// NUDDispatcher is the interface integrators of netstack must implement to
@@ -278,10 +270,6 @@ type NUDConfigurations struct {
// TODO(gvisor.dev/issue/2246): Discuss if implementation of this NUD
// configuration option is necessary.
MaxReachabilityConfirmations uint32
-
- // UnreachableTime describes how long an entry will remain in the FAILED
- // state before being removed from the neighbor cache.
- UnreachableTime time.Duration
}
// DefaultNUDConfigurations returns a NUDConfigurations populated with default
@@ -299,7 +287,6 @@ func DefaultNUDConfigurations() NUDConfigurations {
MaxUnicastProbes: defaultMaxUnicastProbes,
MaxAnycastDelayTime: defaultMaxAnycastDelayTime,
MaxReachabilityConfirmations: defaultMaxReachbilityConfirmations,
- UnreachableTime: defaultUnreachableTime,
}
}
@@ -329,9 +316,6 @@ func (c *NUDConfigurations) resetInvalidFields() {
if c.MaxUnicastProbes == 0 {
c.MaxUnicastProbes = defaultMaxUnicastProbes
}
- if c.UnreachableTime == 0 {
- c.UnreachableTime = defaultUnreachableTime
- }
}
// calcMaxRandomFactor calculates the maximum value of the random factor used
@@ -416,7 +400,7 @@ func (s *NUDState) ReachableTime() time.Duration {
s.config.BaseReachableTime != s.prevBaseReachableTime ||
s.config.MinRandomFactor != s.prevMinRandomFactor ||
s.config.MaxRandomFactor != s.prevMaxRandomFactor {
- return s.recomputeReachableTimeLocked()
+ s.recomputeReachableTimeLocked()
}
return s.reachableTime
}
@@ -442,7 +426,7 @@ func (s *NUDState) ReachableTime() time.Duration {
// random value gets re-computed at least once every few hours.
//
// s.mu MUST be locked for writing.
-func (s *NUDState) recomputeReachableTimeLocked() time.Duration {
+func (s *NUDState) recomputeReachableTimeLocked() {
s.prevBaseReachableTime = s.config.BaseReachableTime
s.prevMinRandomFactor = s.config.MinRandomFactor
s.prevMaxRandomFactor = s.config.MaxRandomFactor
@@ -462,5 +446,4 @@ func (s *NUDState) recomputeReachableTimeLocked() time.Duration {
}
s.expiration = time.Now().Add(2 * time.Hour)
- return s.reachableTime
}
diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go
index 8cffb9fc6..7bca1373e 100644
--- a/pkg/tcpip/stack/nud_test.go
+++ b/pkg/tcpip/stack/nud_test.go
@@ -37,7 +37,6 @@ const (
defaultMaxUnicastProbes = 3
defaultMaxAnycastDelayTime = time.Second
defaultMaxReachbilityConfirmations = 3
- defaultUnreachableTime = 5 * time.Second
defaultFakeRandomNum = 0.5
)
@@ -565,58 +564,6 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) {
}
}
-func TestNUDConfigurationsUnreachableTime(t *testing.T) {
- tests := []struct {
- name string
- unreachableTime time.Duration
- want time.Duration
- }{
- // Invalid cases
- {
- name: "EqualToZero",
- unreachableTime: 0,
- want: defaultUnreachableTime,
- },
- // Valid cases
- {
- name: "MoreThanZero",
- unreachableTime: time.Millisecond,
- want: time.Millisecond,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- const nicID = 1
-
- c := stack.DefaultNUDConfigurations()
- c.UnreachableTime = test.unreachableTime
-
- e := channel.New(0, 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
-
- s := stack.New(stack.Options{
- // A neighbor cache is required to store NUDConfigurations. The
- // networking stack will only allocate neighbor caches if a protocol
- // providing link address resolution is specified (e.g. ARP or IPv6).
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NUDConfigs: c,
- UseNeighborCache: true,
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- sc, err := s.NUDConfigurations(nicID)
- if err != nil {
- t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
- }
- if got := sc.UnreachableTime; got != test.want {
- t.Errorf("got UnreachableTime = %q, want = %q", got, test.want)
- }
- })
- }
-}
-
// TestNUDStateReachableTime verifies the correctness of the ReachableTime
// computation.
func TestNUDStateReachableTime(t *testing.T) {
diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go
index 5d364a2b0..4a3adcf33 100644
--- a/pkg/tcpip/stack/pending_packets.go
+++ b/pkg/tcpip/stack/pending_packets.go
@@ -103,7 +103,7 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro
for _, p := range packets {
if cancelled {
p.route.Stats().IP.OutgoingPacketErrors.Increment()
- } else if _, err := p.route.Resolve(nil); err != nil {
+ } else if p.route.IsResolutionRequired() {
p.route.Stats().IP.OutgoingPacketErrors.Increment()
} else {
p.route.outgoingNIC.writePacket(p.route, nil /* gso */, p.proto, p.pkt)
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 00e9a82ae..4795208b4 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -17,7 +17,6 @@ package stack
import (
"fmt"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -65,10 +64,6 @@ const (
// NetworkPacketInfo holds information about a network layer packet.
type NetworkPacketInfo struct {
- // RemoteAddressBroadcast is true if the packet's remote address is a
- // broadcast address.
- RemoteAddressBroadcast bool
-
// LocalAddressBroadcast is true if the packet's local address is a broadcast
// address.
LocalAddressBroadcast bool
@@ -89,7 +84,7 @@ type TransportEndpoint interface {
// HandleControlPacket is called by the stack when new control (e.g.
// ICMP) packets arrive to this transport endpoint.
// HandleControlPacket takes ownership of pkt.
- HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer)
+ HandleControlPacket(typ ControlType, extra uint32, pkt *PacketBuffer)
// Abort initiates an expedited endpoint teardown. It puts the endpoint
// in a closed state and frees all resources associated with it. This
@@ -263,15 +258,6 @@ const (
PacketLoop
)
-// NetOptions is an interface that allows us to pass network protocol specific
-// options through the Stack layer code.
-type NetOptions interface {
- // AllocationSize returns the amount of memory that must be allocated to
- // hold the options given that the value must be rounded up to the next
- // multiple of 4 bytes.
- AllocationSize() int
-}
-
// NetworkHeaderParams are the header parameters given as input by the
// transport endpoint to the network.
type NetworkHeaderParams struct {
@@ -283,10 +269,6 @@ type NetworkHeaderParams struct {
// TOS refers to TypeOfService or TrafficClass field of the IP-header.
TOS uint8
-
- // Options is a set of options to add to a network header (or nil).
- // It will be protocol specific opaque information from higher layers.
- Options NetOptions
}
// GroupAddressableEndpoint is an endpoint that supports group addressing.
@@ -295,14 +277,10 @@ type NetworkHeaderParams struct {
// endpoints may associate themselves with the same identifier (group address).
type GroupAddressableEndpoint interface {
// JoinGroup joins the specified group.
- //
- // Returns true if the group was newly joined.
- JoinGroup(group tcpip.Address) (bool, *tcpip.Error)
+ JoinGroup(group tcpip.Address) *tcpip.Error
// LeaveGroup attempts to leave the specified group.
- //
- // Returns tcpip.ErrBadLocalAddress if the endpoint has not joined the group.
- LeaveGroup(group tcpip.Address) (bool, *tcpip.Error)
+ LeaveGroup(group tcpip.Address) *tcpip.Error
// IsInGroup returns true if the endpoint is a member of the specified group.
IsInGroup(group tcpip.Address) bool
@@ -518,6 +496,9 @@ type NetworkInterface interface {
// Enabled returns true if the interface is enabled.
Enabled() bool
+ // Promiscuous returns true if the interface is in promiscuous mode.
+ Promiscuous() bool
+
// WritePacketToRemote writes the packet to the given remote link address.
WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error
}
@@ -525,8 +506,6 @@ type NetworkInterface interface {
// NetworkEndpoint is the interface that needs to be implemented by endpoints
// of network layer protocols (e.g., ipv4, ipv6).
type NetworkEndpoint interface {
- AddressableEndpoint
-
// Enable enables the endpoint.
//
// Must only be called when the stack is in a state that allows the endpoint
@@ -742,10 +721,6 @@ type LinkEndpoint interface {
// endpoint.
Capabilities() LinkEndpointCapabilities
- // WriteRawPacket writes a packet directly to the link. The packet
- // should already have an ethernet header. It takes ownership of vv.
- WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error
-
// Attach attaches the data link layer endpoint to the network-layer
// dispatcher of the stack.
//
@@ -823,19 +798,26 @@ type LinkAddressCache interface {
// AddLinkAddress adds a link address to the cache.
AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress)
- // GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC).
- // If the LinkEndpoint requests address resolution and there is a LinkAddressResolver
- // registered with the network protocol, the cache attempts to resolve the address
- // and returns ErrWouldBlock. Waker is notified when address resolution is
- // complete (success or not).
+ // GetLinkAddress finds the link address corresponding to the remote address
+ // (e.g. IP -> MAC).
//
- // If address resolution is required, ErrNoLinkAddress and a notification channel is
- // returned for the top level caller to block. Channel is closed once address resolution
- // is complete (success or not).
- GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error)
-
- // RemoveWaker removes a waker that has been added in GetLinkAddress().
- RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
+ // Returns a link address for the remote address, if readily available.
+ //
+ // Returns ErrWouldBlock if the link address is not readily available, along
+ // with a notification channel for the caller to block on. Triggers address
+ // resolution asynchronously.
+ //
+ // If onResolve is provided, it will be called either immediately, if
+ // resolution is not required, or when address resolution is complete, with
+ // the resolved link address and whether resolution succeeded. After any
+ // callbacks have been called, the returned notification channel is closed.
+ //
+ // If specified, the local address must be an address local to the interface
+ // the neighbor cache belongs to. The local address is the source address of
+ // a packet prompting NUD/link address resolution.
+ //
+ // TODO(gvisor.dev/issue/5151): Don't return the link address.
+ GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error)
}
// RawFactory produces endpoints for writing various types of raw packets.
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 15ff437c7..b0251d0b4 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -17,20 +17,53 @@ package stack
import (
"fmt"
- "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
// Route represents a route through the networking stack to a given destination.
+//
+// It is safe to call Route's methods from multiple goroutines.
+//
+// The exported fields are immutable.
+//
+// TODO(gvisor.dev/issue/4902): Unexpose immutable fields.
type Route struct {
+ routeInfo
+
+ // localAddressNIC is the interface the address is associated with.
+ // TODO(gvisor.dev/issue/4548): Remove this field once we can query the
+ // address's assigned status without the NIC.
+ localAddressNIC *NIC
+
+ mu struct {
+ sync.RWMutex
+
+ // localAddressEndpoint is the local address this route is associated with.
+ localAddressEndpoint AssignableAddressEndpoint
+
+ // remoteLinkAddress is the link-layer (MAC) address of the next hop in the
+ // route.
+ remoteLinkAddress tcpip.LinkAddress
+ }
+
+ // outgoingNIC is the interface this route uses to write packets.
+ outgoingNIC *NIC
+
+ // linkCache is set if link address resolution is enabled for this protocol on
+ // the route's NIC.
+ linkCache LinkAddressCache
+
+ // linkRes is set if link address resolution is enabled for this protocol on
+ // the route's NIC.
+ linkRes LinkAddressResolver
+}
+
+type routeInfo struct {
// RemoteAddress is the final destination of the route.
RemoteAddress tcpip.Address
- // RemoteLinkAddress is the link-layer (MAC) address of the
- // final destination of the route.
- RemoteLinkAddress tcpip.LinkAddress
-
// LocalAddress is the local address where the route starts.
LocalAddress tcpip.Address
@@ -46,47 +79,48 @@ type Route struct {
// Loop controls where WritePacket should send packets.
Loop PacketLooping
+}
- // localAddressNIC is the interface the address is associated with.
- // TODO(gvisor.dev/issue/4548): Remove this field once we can query the
- // address's assigned status without the NIC.
- localAddressNIC *NIC
-
- // localAddressEndpoint is the local address this route is associated with.
- localAddressEndpoint AssignableAddressEndpoint
-
- // outgoingNIC is the interface this route uses to write packets.
- outgoingNIC *NIC
+// RouteInfo contains all of Route's exported fields.
+type RouteInfo struct {
+ routeInfo
- // linkCache is set if link address resolution is enabled for this protocol on
- // the route's NIC.
- linkCache LinkAddressCache
+ // RemoteLinkAddress is the link-layer (MAC) address of the next hop in the
+ // route.
+ RemoteLinkAddress tcpip.LinkAddress
+}
- // linkRes is set if link address resolution is enabled for this protocol on
- // the route's NIC.
- linkRes LinkAddressResolver
+// GetFields returns a RouteInfo with all of r's exported fields. This allows
+// callers to store the route's fields without retaining a reference to it.
+func (r *Route) GetFields() RouteInfo {
+ return RouteInfo{
+ routeInfo: r.routeInfo,
+ RemoteLinkAddress: r.RemoteLinkAddress(),
+ }
}
// constructAndValidateRoute validates and initializes a route. It takes
// ownership of the provided local address.
//
// Returns an empty route if validation fails.
-func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) Route {
- addrWithPrefix := addressEndpoint.AddressWithPrefix()
+func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) *Route {
+ if len(localAddr) == 0 {
+ localAddr = addressEndpoint.AddressWithPrefix().Address
+ }
- if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(addrWithPrefix.Address) {
+ if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(localAddr) {
addressEndpoint.DecRef()
- return Route{}
+ return nil
}
// If no remote address is provided, use the local address.
if len(remoteAddr) == 0 {
- remoteAddr = addrWithPrefix.Address
+ remoteAddr = localAddr
}
r := makeRoute(
netProto,
- addrWithPrefix.Address,
+ localAddr,
remoteAddr,
outgoingNIC,
localAddressNIC,
@@ -99,8 +133,8 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp
// broadcast it.
if len(gateway) > 0 {
r.NextHop = gateway
- } else if subnet := addrWithPrefix.Subnet(); subnet.IsBroadcast(remoteAddr) {
- r.RemoteLinkAddress = header.EthernetBroadcastAddress
+ } else if subnet := addressEndpoint.Subnet(); subnet.IsBroadcast(remoteAddr) {
+ r.ResolveWith(header.EthernetBroadcastAddress)
}
return r
@@ -108,11 +142,15 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp
// makeRoute initializes a new route. It takes ownership of the provided
// AssignableAddressEndpoint.
-func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route {
+func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) *Route {
if localAddressNIC.stack != outgoingNIC.stack {
panic(fmt.Sprintf("cannot create a route with NICs from different stacks"))
}
+ if len(localAddr) == 0 {
+ localAddr = localAddressEndpoint.AddressWithPrefix().Address
+ }
+
loop := PacketOut
// TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the
@@ -133,18 +171,23 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop)
}
-func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) Route {
- r := Route{
- NetProto: netProto,
- LocalAddress: localAddr,
- LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(),
- RemoteAddress: remoteAddr,
- localAddressNIC: localAddressNIC,
- localAddressEndpoint: localAddressEndpoint,
- outgoingNIC: outgoingNIC,
- Loop: loop,
+func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) *Route {
+ r := &Route{
+ routeInfo: routeInfo{
+ NetProto: netProto,
+ LocalAddress: localAddr,
+ LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(),
+ RemoteAddress: remoteAddr,
+ Loop: loop,
+ },
+ localAddressNIC: localAddressNIC,
+ outgoingNIC: outgoingNIC,
}
+ r.mu.Lock()
+ r.mu.localAddressEndpoint = localAddressEndpoint
+ r.mu.Unlock()
+
if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 {
if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok {
r.linkRes = linkRes
@@ -159,7 +202,7 @@ func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr
// provided AssignableAddressEndpoint.
//
// A local route is a route to a destination that is local to the stack.
-func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) Route {
+func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) *Route {
loop := PacketLoop
// TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the
// link endpoint level. We can remove this check once loopback interfaces
@@ -170,26 +213,12 @@ func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr
return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop)
}
-// PopulatePacketInfo populates a packet buffer's packet information fields.
-//
-// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by
-// the network layer.
-func (r *Route) PopulatePacketInfo(pkt *PacketBuffer) {
- if r.local() {
- pkt.RXTransportChecksumValidated = true
- }
- pkt.NetworkPacketInfo = r.networkPacketInfo()
-}
-
-// networkPacketInfo returns the network packet information of the route.
-//
-// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by
-// the network layer.
-func (r *Route) networkPacketInfo() NetworkPacketInfo {
- return NetworkPacketInfo{
- RemoteAddressBroadcast: r.IsOutboundBroadcast(),
- LocalAddressBroadcast: r.isInboundBroadcast(),
- }
+// RemoteLinkAddress returns the link-layer (MAC) address of the next hop in
+// the route.
+func (r *Route) RemoteLinkAddress() tcpip.LinkAddress {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.mu.remoteLinkAddress
}
// NICID returns the id of the NIC from which this route originates.
@@ -253,22 +282,26 @@ func (r *Route) GSOMaxSize() uint32 {
// ResolveWith immediately resolves a route with the specified remote link
// address.
func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
- r.RemoteLinkAddress = addr
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.mu.remoteLinkAddress = addr
}
-// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in
-// case address resolution requires blocking, e.g. wait for ARP reply. Waker is
-// notified when address resolution is complete (success or not).
-//
-// If address resolution is required, ErrNoLinkAddress and a notification channel is
-// returned for the top level caller to block. Channel is closed once address resolution
-// is complete (success or not).
+// Resolve attempts to resolve the link address if necessary.
//
-// The NIC r uses must not be locked.
-func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
- if !r.IsResolutionRequired() {
+// Returns tcpip.ErrWouldBlock if address resolution requires blocking (e.g.
+// waiting for ARP reply). If address resolution is required, a notification
+// channel is also returned for the caller to block on. The channel is closed
+// once address resolution is complete (successful or not). If a callback is
+// provided, it will be called when address resolution is complete, regardless
+// of success or failure.
+func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) {
+ r.mu.Lock()
+
+ if !r.isResolutionRequiredRLocked() {
// Nothing to do if there is no cache (which does the resolution on cache miss) or
// link address is already known.
+ r.mu.Unlock()
return nil, nil
}
@@ -276,7 +309,8 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
if nextAddr == "" {
// Local link address is already known.
if r.RemoteAddress == r.LocalAddress {
- r.RemoteLinkAddress = r.LocalLinkAddress
+ r.mu.remoteLinkAddress = r.LocalLinkAddress
+ r.mu.Unlock()
return nil, nil
}
nextAddr = r.RemoteAddress
@@ -289,38 +323,36 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
linkAddressResolutionRequestLocalAddr = r.LocalAddress
}
+ // Increment the route's reference count because finishResolution retains a
+ // reference to the route and releases it when called.
+ r.acquireLocked()
+ r.mu.Unlock()
+
+ finishResolution := func(linkAddress tcpip.LinkAddress, ok bool) {
+ if ok {
+ r.ResolveWith(linkAddress)
+ }
+ if afterResolve != nil {
+ afterResolve()
+ }
+ r.Release()
+ }
+
if neigh := r.outgoingNIC.neigh; neigh != nil {
- entry, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, waker)
+ _, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, finishResolution)
if err != nil {
return ch, err
}
- r.RemoteLinkAddress = entry.LinkAddr
return nil, nil
}
- linkAddr, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, waker)
+ _, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, finishResolution)
if err != nil {
return ch, err
}
- r.RemoteLinkAddress = linkAddr
return nil, nil
}
-// RemoveWaker removes a waker that has been added in Resolve().
-func (r *Route) RemoveWaker(waker *sleep.Waker) {
- nextAddr := r.NextHop
- if nextAddr == "" {
- nextAddr = r.RemoteAddress
- }
-
- if neigh := r.outgoingNIC.neigh; neigh != nil {
- neigh.removeWaker(nextAddr, waker)
- return
- }
-
- r.linkCache.RemoveWaker(r.outgoingNIC.ID(), nextAddr, waker)
-}
-
// local returns true if the route is a local route.
func (r *Route) local() bool {
return r.Loop == PacketLoop || r.outgoingNIC.IsLoopback()
@@ -331,7 +363,13 @@ func (r *Route) local() bool {
//
// The NICs the route is associated with must not be locked.
func (r *Route) IsResolutionRequired() bool {
- if !r.isValidForOutgoing() || r.RemoteLinkAddress != "" || r.local() {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.isResolutionRequiredRLocked()
+}
+
+func (r *Route) isResolutionRequiredRLocked() bool {
+ if !r.isValidForOutgoingRLocked() || r.mu.remoteLinkAddress != "" || r.local() {
return false
}
@@ -339,11 +377,18 @@ func (r *Route) IsResolutionRequired() bool {
}
func (r *Route) isValidForOutgoing() bool {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.isValidForOutgoingRLocked()
+}
+
+func (r *Route) isValidForOutgoingRLocked() bool {
if !r.outgoingNIC.Enabled() {
return false
}
- if !r.localAddressNIC.isValidForOutgoing(r.localAddressEndpoint) {
+ localAddressEndpoint := r.mu.localAddressEndpoint
+ if localAddressEndpoint == nil || !r.localAddressNIC.isValidForOutgoing(localAddressEndpoint) {
return false
}
@@ -395,39 +440,31 @@ func (r *Route) MTU() uint32 {
return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MTU()
}
-// Release frees all resources associated with the route.
+// Release decrements the reference counter of the resources associated with the
+// route.
func (r *Route) Release() {
- if r.localAddressEndpoint != nil {
- r.localAddressEndpoint.DecRef()
- r.localAddressEndpoint = nil
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if ep := r.mu.localAddressEndpoint; ep != nil {
+ ep.DecRef()
}
}
-// Clone clones the route.
-func (r *Route) Clone() Route {
- if r.localAddressEndpoint != nil {
- if !r.localAddressEndpoint.IncRef() {
+// Acquire increments the reference counter of the resources associated with the
+// route.
+func (r *Route) Acquire() {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ r.acquireLocked()
+}
+
+func (r *Route) acquireLocked() {
+ if ep := r.mu.localAddressEndpoint; ep != nil {
+ if !ep.IncRef() {
panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress))
}
}
- return *r
-}
-
-// MakeLoopedRoute duplicates the given route with special handling for routes
-// used for sending multicast or broadcast packets. In those cases the
-// multicast/broadcast address is the remote address when sending out, but for
-// incoming (looped) packets it becomes the local address. Similarly, the local
-// interface address that was the local address going out becomes the remote
-// address coming in. This is different to unicast routes where local and
-// remote addresses remain the same as they identify location (local vs remote)
-// not direction (source vs destination).
-func (r *Route) MakeLoopedRoute() Route {
- l := r.Clone()
- if r.RemoteAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) {
- l.RemoteAddress, l.LocalAddress = l.LocalAddress, l.RemoteAddress
- l.RemoteLinkAddress = l.LocalLinkAddress
- }
- return l
}
// Stack returns the instance of the Stack that owns this route.
@@ -440,7 +477,14 @@ func (r *Route) isV4Broadcast(addr tcpip.Address) bool {
return true
}
- subnet := r.localAddressEndpoint.Subnet()
+ r.mu.RLock()
+ localAddressEndpoint := r.mu.localAddressEndpoint
+ r.mu.RUnlock()
+ if localAddressEndpoint == nil {
+ return false
+ }
+
+ subnet := localAddressEndpoint.Subnet()
return subnet.IsBroadcast(addr)
}
@@ -450,27 +494,3 @@ func (r *Route) IsOutboundBroadcast() bool {
// Only IPv4 has a notion of broadcast.
return r.isV4Broadcast(r.RemoteAddress)
}
-
-// isInboundBroadcast returns true if the route is for an inbound broadcast
-// packet.
-func (r *Route) isInboundBroadcast() bool {
- // Only IPv4 has a notion of broadcast.
- return r.isV4Broadcast(r.LocalAddress)
-}
-
-// ReverseRoute returns new route with given source and destination address.
-func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route {
- return Route{
- NetProto: r.NetProto,
- LocalAddress: dst,
- LocalLinkAddress: r.RemoteLinkAddress,
- RemoteAddress: src,
- RemoteLinkAddress: r.LocalLinkAddress,
- Loop: r.Loop,
- localAddressNIC: r.localAddressNIC,
- localAddressEndpoint: r.localAddressEndpoint,
- outgoingNIC: r.outgoingNIC,
- linkCache: r.linkCache,
- linkRes: r.linkRes,
- }
-}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 0fe157128..114643b03 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -29,7 +29,6 @@ import (
"golang.org/x/time/rate"
"gvisor.dev/gvisor/pkg/rand"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -82,6 +81,7 @@ type TCPRACKState struct {
FACK seqnum.Value
RTT time.Duration
Reord bool
+ DSACKSeen bool
}
// TCPEndpointID is the unique 4 tuple that identifies a given endpoint.
@@ -170,6 +170,9 @@ type TCPSenderState struct {
// Outstanding is the number of packets in flight.
Outstanding int
+ // SackedOut is the number of packets which have been selectively acked.
+ SackedOut int
+
// SndWnd is the send window size in bytes.
SndWnd seqnum.Size
@@ -1080,7 +1083,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
flags := NICStateFlags{
Up: true, // Netstack interfaces are always up.
Running: nic.Enabled(),
- Promiscuous: nic.isPromiscuousMode(),
+ Promiscuous: nic.Promiscuous(),
Loopback: nic.IsLoopback(),
}
nics[id] = NICInfo{
@@ -1117,6 +1120,16 @@ func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber,
return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint)
}
+// AddAddressWithPrefix is the same as AddAddress, but allows you to specify
+// the address prefix.
+func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) *tcpip.Error {
+ ap := tcpip.ProtocolAddress{
+ Protocol: protocol,
+ AddressWithPrefix: addr,
+ }
+ return s.AddProtocolAddressWithOptions(id, ap, CanBePrimaryEndpoint)
+}
+
// AddProtocolAddress adds a new network-layer protocol address to the
// specified NIC.
func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) *tcpip.Error {
@@ -1207,10 +1220,10 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP
// from the specified NIC.
//
// Precondition: s.mu must be read locked.
-func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) {
+func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route {
localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint)
if localAddressEndpoint == nil {
- return Route{}, false
+ return nil
}
var outgoingNIC *NIC
@@ -1234,12 +1247,12 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re
// route.
if outgoingNIC == nil {
localAddressEndpoint.DecRef()
- return Route{}, false
+ return nil
}
r := makeLocalRoute(
netProto,
- localAddressEndpoint.AddressWithPrefix().Address,
+ localAddr,
remoteAddr,
outgoingNIC,
localAddressNIC,
@@ -1248,10 +1261,10 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re
if r.IsOutboundBroadcast() {
r.Release()
- return Route{}, false
+ return nil
}
- return r, true
+ return r
}
// findLocalRouteRLocked returns a local route.
@@ -1260,26 +1273,26 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re
// is, a local route is a route where packets never have to leave the stack.
//
// Precondition: s.mu must be read locked.
-func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) {
+func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route {
if len(localAddr) == 0 {
localAddr = remoteAddr
}
if localAddressNICID == 0 {
for _, localAddressNIC := range s.nics {
- if r, ok := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); ok {
- return r, true
+ if r := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); r != nil {
+ return r
}
}
- return Route{}, false
+ return nil
}
if localAddressNIC, ok := s.nics[localAddressNICID]; ok {
return s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto)
}
- return Route{}, false
+ return nil
}
// FindRoute creates a route to the given destination address, leaving through
@@ -1293,7 +1306,7 @@ func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr,
// If no local address is provided, the stack will select a local address. If no
// remote address is provided, the stack wil use a remote address equal to the
// local address.
-func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) {
+func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (*Route, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -1304,7 +1317,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || isLoopback)
if s.handleLocal && !isMulticast && !isLocalBroadcast {
- if r, ok := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); ok {
+ if r := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); r != nil {
return r, nil
}
}
@@ -1316,7 +1329,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
return makeRoute(
netProto,
- addressEndpoint.AddressWithPrefix().Address,
+ localAddr,
remoteAddr,
nic, /* outboundNIC */
nic, /* localAddressNIC*/
@@ -1328,9 +1341,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
if isLoopback {
- return Route{}, tcpip.ErrBadLocalAddress
+ return nil, tcpip.ErrBadLocalAddress
}
- return Route{}, tcpip.ErrNetworkUnreachable
+ return nil, tcpip.ErrNetworkUnreachable
}
canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal
@@ -1353,8 +1366,8 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if needRoute {
gateway = route.Gateway
}
- r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop)
- if r == (Route{}) {
+ r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop)
+ if r == nil {
panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr))
}
return r, nil
@@ -1390,13 +1403,13 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if id != 0 {
if aNIC, ok := s.nics[id]; ok {
if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil {
- if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) {
+ if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil {
return r, nil
}
}
}
- return Route{}, tcpip.ErrNoRoute
+ return nil, tcpip.ErrNoRoute
}
if id == 0 {
@@ -1408,7 +1421,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
continue
}
- if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) {
+ if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil {
return r, nil
}
}
@@ -1416,12 +1429,12 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
if needRoute {
- return Route{}, tcpip.ErrNoRoute
+ return nil, tcpip.ErrNoRoute
}
if header.IsV6LoopbackAddress(remoteAddr) {
- return Route{}, tcpip.ErrBadLocalAddress
+ return nil, tcpip.ErrBadLocalAddress
}
- return Route{}, tcpip.ErrNetworkUnreachable
+ return nil, tcpip.ErrNetworkUnreachable
}
// CheckNetworkProtocol checks if a given network protocol is enabled in the
@@ -1506,7 +1519,7 @@ func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr t
}
// GetLinkAddress implements LinkAddressCache.GetLinkAddress.
-func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
s.mu.RLock()
nic := s.nics[nicID]
if nic == nil {
@@ -1517,7 +1530,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address,
fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
linkRes := s.linkAddrResolvers[protocol]
- return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, waker)
+ return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, onResolve)
}
// Neighbors returns all IP to MAC address associations.
@@ -1533,29 +1546,6 @@ func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, *tcpip.Error) {
return nic.neighbors()
}
-// RemoveWaker removes a waker that has been added when link resolution for
-// addr was requested.
-func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) {
- if s.useNeighborCache {
- s.mu.RLock()
- nic, ok := s.nics[nicID]
- s.mu.RUnlock()
-
- if ok {
- nic.removeWaker(addr, waker)
- }
- return
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- if nic := s.nics[nicID]; nic == nil {
- fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
- s.linkAddrCache.removeWaker(fullAddr, waker)
- }
-}
-
// AddStaticNeighbor statically associates an IP address to a MAC address.
func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) *tcpip.Error {
s.mu.RLock()
@@ -1809,49 +1799,20 @@ func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip
nic.unregisterPacketEndpoint(netProto, ep)
}
-// WritePacket writes data directly to the specified NIC. It adds an ethernet
-// header based on the arguments.
-func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error {
+// WritePacketToRemote writes a payload on the specified NIC using the provided
+// network protocol and remote link address.
+func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error {
s.mu.Lock()
nic, ok := s.nics[nicID]
s.mu.Unlock()
if !ok {
return tcpip.ErrUnknownDevice
}
-
- // Add our own fake ethernet header.
- ethFields := header.EthernetFields{
- SrcAddr: nic.LinkEndpoint.LinkAddress(),
- DstAddr: dst,
- Type: netProto,
- }
- fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
- fakeHeader.Encode(&ethFields)
- vv := buffer.View(fakeHeader).ToVectorisedView()
- vv.Append(payload)
-
- if err := nic.LinkEndpoint.WriteRawPacket(vv); err != nil {
- return err
- }
-
- return nil
-}
-
-// WriteRawPacket writes data directly to the specified NIC without adding any
-// headers.
-func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error {
- s.mu.Lock()
- nic, ok := s.nics[nicID]
- s.mu.Unlock()
- if !ok {
- return tcpip.ErrUnknownDevice
- }
-
- if err := nic.LinkEndpoint.WriteRawPacket(payload); err != nil {
- return err
- }
-
- return nil
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: int(nic.MaxHeaderLength()),
+ Data: payload,
+ })
+ return nic.WritePacketToRemote(remote, nil, netProto, pkt)
}
// NetworkProtocolInstance returns the protocol instance in the stack for the
@@ -1911,7 +1872,6 @@ func (s *Stack) RemoveTCPProbe() {
// JoinGroup joins the given multicast group on the given NIC.
func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
- // TODO: notify network of subscription via igmp protocol.
s.mu.RLock()
defer s.mu.RUnlock()
@@ -2158,3 +2118,43 @@ func (s *Stack) networkProtocolNumbers() []tcpip.NetworkProtocolNumber {
}
return protos
}
+
+func isSubnetBroadcastOnNIC(nic *NIC, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ addressEndpoint := nic.getAddressOrCreateTempInner(protocol, addr, false /* createTemp */, NeverPrimaryEndpoint)
+ if addressEndpoint == nil {
+ return false
+ }
+
+ subnet := addressEndpoint.Subnet()
+ addressEndpoint.DecRef()
+ return subnet.IsBroadcast(addr)
+}
+
+// IsSubnetBroadcast returns true if the provided address is a subnet-local
+// broadcast address on the specified NIC and protocol.
+//
+// Returns false if the NIC is unknown or if the protocol is unknown or does
+// not support addressing.
+//
+// If the NIC is not specified, the stack will check all NICs.
+func (s *Stack) IsSubnetBroadcast(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nicID != 0 {
+ nic, ok := s.nics[nicID]
+ if !ok {
+ return false
+ }
+
+ return isSubnetBroadcastOnNIC(nic, protocol, addr)
+ }
+
+ for _, nic := range s.nics {
+ if isSubnetBroadcastOnNIC(nic, protocol, addr) {
+ return true
+ }
+ }
+
+ return false
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index dedfdd435..856ebf6d4 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -27,7 +27,6 @@ import (
"time"
"github.com/google/go-cmp/cmp"
- "github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -112,7 +111,15 @@ func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Increment the received packet count in the protocol descriptor.
netHdr := pkt.NetworkHeader().View()
- f.proto.packetCount[int(netHdr[dstAddrOffset])%len(f.proto.packetCount)]++
+
+ dst := tcpip.Address(netHdr[dstAddrOffset:][:1])
+ addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
+ if addressEndpoint == nil {
+ return
+ }
+ addressEndpoint.DecRef()
+
+ f.proto.packetCount[int(dst[0])%len(f.proto.packetCount)]++
// Handle control packets.
if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) {
@@ -159,9 +166,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
hdr[protocolNumberOffset] = byte(params.Protocol)
if r.Loop&stack.PacketLoop != 0 {
- pkt := pkt.Clone()
- r.PopulatePacketInfo(pkt)
- f.HandlePacket(pkt)
+ f.HandlePacket(pkt.Clone())
}
if r.Loop&stack.PacketOut == 0 {
return nil
@@ -401,7 +406,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro
return send(r, payload)
}
-func send(r stack.Route, payload buffer.View) *tcpip.Error {
+func send(r *stack.Route, payload buffer.View) *tcpip.Error {
return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: payload.ToVectorisedView(),
@@ -419,7 +424,7 @@ func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.En
}
}
-func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) {
+func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View) {
t.Helper()
ep.Drain()
if err := send(r, payload); err != nil {
@@ -430,7 +435,7 @@ func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.
}
}
-func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
t.Helper()
if gotErr := send(r, payload); gotErr != wantErr {
t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr)
@@ -1557,15 +1562,15 @@ func TestSpoofingNoAddress(t *testing.T) {
// testSendTo(t, s, remoteAddr, ep, nil)
}
-func verifyRoute(gotRoute, wantRoute stack.Route) error {
+func verifyRoute(gotRoute, wantRoute *stack.Route) error {
if gotRoute.LocalAddress != wantRoute.LocalAddress {
return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress)
}
if gotRoute.RemoteAddress != wantRoute.RemoteAddress {
return fmt.Errorf("bad remote address: got %s, want = %s", gotRoute.RemoteAddress, wantRoute.RemoteAddress)
}
- if gotRoute.RemoteLinkAddress != wantRoute.RemoteLinkAddress {
- return fmt.Errorf("bad remote link address: got %s, want = %s", gotRoute.RemoteLinkAddress, wantRoute.RemoteLinkAddress)
+ if got, want := gotRoute.RemoteLinkAddress(), wantRoute.RemoteLinkAddress(); got != want {
+ return fmt.Errorf("bad remote link address: got %s, want = %s", got, want)
}
if gotRoute.NextHop != wantRoute.NextHop {
return fmt.Errorf("bad next-hop address: got %s, want = %s", gotRoute.NextHop, wantRoute.NextHop)
@@ -1597,7 +1602,10 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ var wantRoute stack.Route
+ wantRoute.LocalAddress = header.IPv4Any
+ wantRoute.RemoteAddress = header.IPv4Broadcast
+ if err := verifyRoute(r, &wantRoute); err != nil {
t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1651,7 +1659,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ var wantRoute stack.Route
+ wantRoute.LocalAddress = nic1Addr.Address
+ wantRoute.RemoteAddress = header.IPv4Broadcast
+ if err := verifyRoute(r, &wantRoute); err != nil {
t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1661,7 +1672,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ wantRoute = stack.Route{}
+ wantRoute.LocalAddress = nic2Addr.Address
+ wantRoute.RemoteAddress = header.IPv4Broadcast
+ if err := verifyRoute(r, &wantRoute); err != nil {
t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1677,7 +1691,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ wantRoute = stack.Route{}
+ wantRoute.LocalAddress = nic1Addr.Address
+ wantRoute.RemoteAddress = header.IPv4Broadcast
+ if err := verifyRoute(r, &wantRoute); err != nil {
t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err)
}
}
@@ -2214,88 +2231,6 @@ func TestNICStats(t *testing.T) {
}
}
-func TestNICForwarding(t *testing.T) {
- const nicID1 = 1
- const nicID2 = 2
- const dstAddr = tcpip.Address("\x03")
-
- tests := []struct {
- name string
- headerLen uint16
- }{
- {
- name: "Zero header length",
- },
- {
- name: "Non-zero header length",
- headerLen: 16,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- s.SetForwarding(fakeNetNumber, true)
-
- ep1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID1, ep1); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
- }
- if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err)
- }
-
- ep2 := channelLinkWithHeaderLength{
- Endpoint: channel.New(10, defaultMTU, ""),
- headerLength: test.headerLen,
- }
- if err := s.CreateNIC(nicID2, &ep2); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
- }
- if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil {
- t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err)
- }
-
- // Route all packets to dstAddr to NIC 2.
- {
- subnet, err := tcpip.NewSubnet(dstAddr, "\xff")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}})
- }
-
- // Send a packet to dstAddr.
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = dstAddr[0]
- ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- pkt, ok := ep2.Read()
- if !ok {
- t.Fatal("packet not forwarded")
- }
-
- // Test that the link's MaxHeaderLength is honoured.
- if capacity, want := pkt.Pkt.AvailableHeaderBytes(), int(test.headerLen); capacity != want {
- t.Errorf("got LinkHeader.AvailableLength() = %d, want = %d", capacity, want)
- }
-
- // Test that forwarding increments Tx stats correctly.
- if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want {
- t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want {
- t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
- }
- })
- }
-}
-
// TestNICContextPreservation tests that you can read out via stack.NICInfo the
// Context data you pass via NICContext.Context in stack.CreateNICWithOptions.
func TestNICContextPreservation(t *testing.T) {
@@ -2483,9 +2418,9 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
}
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: test.autoGen,
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: test.iidOpts,
+ AutoGenLinkLocal: test.autoGen,
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: test.iidOpts,
})},
}
@@ -2578,8 +2513,8 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: true,
- OpaqueIIDOpts: test.opaqueIIDOpts,
+ AutoGenLinkLocal: true,
+ OpaqueIIDOpts: test.opaqueIIDOpts,
})},
}
@@ -2612,9 +2547,9 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
ndpConfigs := ipv6.DefaultNDPConfigurations()
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ndpConfigs,
- AutoGenIPv6LinkLocal: true,
- NDPDisp: &ndpDisp,
+ NDPConfigs: ndpConfigs,
+ AutoGenLinkLocal: true,
+ NDPDisp: &ndpDisp,
})},
}
@@ -2803,8 +2738,16 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- nicID = 1
- lifetimeSeconds = 9999
+ globalAddr3 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
+ ipv4MappedIPv6Addr1 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x01")
+ ipv4MappedIPv6Addr2 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x02")
+ toredoAddr1 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ toredoAddr2 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ ipv6ToIPv4Addr1 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ ipv6ToIPv4Addr2 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+
+ nicID = 1
+ lifetimeSeconds = 9999
)
prefix1, _, stableGlobalAddr1 := prefixSubnetAddr(0, linkAddr1)
@@ -2821,139 +2764,191 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
slaacPrefixForTempAddrBeforeNICAddrAdd tcpip.AddressWithPrefix
nicAddrs []tcpip.Address
slaacPrefixForTempAddrAfterNICAddrAdd tcpip.AddressWithPrefix
- connectAddr tcpip.Address
+ remoteAddr tcpip.Address
expectedLocalAddr tcpip.Address
}{
- // Test Rule 1 of RFC 6724 section 5.
+ // Test Rule 1 of RFC 6724 section 5 (prefer same address).
{
name: "Same Global most preferred (last address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: globalAddr1,
+ nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1},
+ remoteAddr: globalAddr1,
expectedLocalAddr: globalAddr1,
},
{
name: "Same Global most preferred (first address)",
- nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
- connectAddr: globalAddr1,
+ nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1},
+ remoteAddr: globalAddr1,
expectedLocalAddr: globalAddr1,
},
{
name: "Same Link Local most preferred (last address)",
- nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1},
- connectAddr: linkLocalAddr1,
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1},
+ remoteAddr: linkLocalAddr1,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Same Link Local most preferred (first address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: linkLocalAddr1,
+ nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1},
+ remoteAddr: linkLocalAddr1,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Same Unique Local most preferred (last address)",
- nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1},
- connectAddr: uniqueLocalAddr1,
+ nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1},
+ remoteAddr: uniqueLocalAddr1,
expectedLocalAddr: uniqueLocalAddr1,
},
{
name: "Same Unique Local most preferred (first address)",
- nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
- connectAddr: uniqueLocalAddr1,
+ nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1},
+ remoteAddr: uniqueLocalAddr1,
expectedLocalAddr: uniqueLocalAddr1,
},
- // Test Rule 2 of RFC 6724 section 5.
+ // Test Rule 2 of RFC 6724 section 5 (prefer appropriate scope).
{
name: "Global most preferred (last address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: globalAddr2,
+ nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1},
+ remoteAddr: globalAddr2,
expectedLocalAddr: globalAddr1,
},
{
name: "Global most preferred (first address)",
- nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
- connectAddr: globalAddr2,
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1},
+ remoteAddr: globalAddr2,
expectedLocalAddr: globalAddr1,
},
{
name: "Link Local most preferred (last address)",
- nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1},
- connectAddr: linkLocalAddr2,
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1},
+ remoteAddr: linkLocalAddr2,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Link Local most preferred (first address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: linkLocalAddr2,
+ nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1},
+ remoteAddr: linkLocalAddr2,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Link Local most preferred for link local multicast (last address)",
- nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1},
- connectAddr: linkLocalMulticastAddr,
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1},
+ remoteAddr: linkLocalMulticastAddr,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Link Local most preferred for link local multicast (first address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: linkLocalMulticastAddr,
+ nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1},
+ remoteAddr: linkLocalMulticastAddr,
expectedLocalAddr: linkLocalAddr1,
},
+
+ // Test Rule 6 of 6724 section 5 (prefer matching label).
{
name: "Unique Local most preferred (last address)",
- nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1},
- connectAddr: uniqueLocalAddr2,
+ nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1, toredoAddr1, ipv6ToIPv4Addr1},
+ remoteAddr: uniqueLocalAddr2,
expectedLocalAddr: uniqueLocalAddr1,
},
{
name: "Unique Local most preferred (first address)",
- nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
- connectAddr: uniqueLocalAddr2,
+ nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, toredoAddr1, ipv6ToIPv4Addr1, uniqueLocalAddr1},
+ remoteAddr: uniqueLocalAddr2,
expectedLocalAddr: uniqueLocalAddr1,
},
+ {
+ name: "Toredo most preferred (first address)",
+ nicAddrs: []tcpip.Address{toredoAddr1, uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1},
+ remoteAddr: toredoAddr2,
+ expectedLocalAddr: toredoAddr1,
+ },
+ {
+ name: "Toredo most preferred (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1, uniqueLocalAddr1, toredoAddr1},
+ remoteAddr: toredoAddr2,
+ expectedLocalAddr: toredoAddr1,
+ },
+ {
+ name: "6To4 most preferred (first address)",
+ nicAddrs: []tcpip.Address{ipv6ToIPv4Addr1, toredoAddr1, uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1},
+ remoteAddr: ipv6ToIPv4Addr2,
+ expectedLocalAddr: ipv6ToIPv4Addr1,
+ },
+ {
+ name: "6To4 most preferred (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, uniqueLocalAddr1, toredoAddr1, ipv6ToIPv4Addr1},
+ remoteAddr: ipv6ToIPv4Addr2,
+ expectedLocalAddr: ipv6ToIPv4Addr1,
+ },
+ {
+ name: "IPv4 mapped IPv6 most preferred (first address)",
+ nicAddrs: []tcpip.Address{ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1, toredoAddr1, uniqueLocalAddr1, globalAddr1},
+ remoteAddr: ipv4MappedIPv6Addr2,
+ expectedLocalAddr: ipv4MappedIPv6Addr1,
+ },
+ {
+ name: "IPv4 mapped IPv6 most preferred (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, ipv6ToIPv4Addr1, uniqueLocalAddr1, toredoAddr1, ipv4MappedIPv6Addr1},
+ remoteAddr: ipv4MappedIPv6Addr2,
+ expectedLocalAddr: ipv4MappedIPv6Addr1,
+ },
- // Test Rule 7 of RFC 6724 section 5.
+ // Test Rule 7 of RFC 6724 section 5 (prefer temporary addresses).
{
name: "Temp Global most preferred (last address)",
slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1,
nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: globalAddr2,
+ remoteAddr: globalAddr2,
expectedLocalAddr: tempGlobalAddr1,
},
{
name: "Temp Global most preferred (first address)",
nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
slaacPrefixForTempAddrAfterNICAddrAdd: prefix1,
- connectAddr: globalAddr2,
+ remoteAddr: globalAddr2,
expectedLocalAddr: tempGlobalAddr1,
},
+ // Test Rule 8 of RFC 6724 section 5 (use longest matching prefix).
+ {
+ name: "Longest prefix matched most preferred (first address)",
+ nicAddrs: []tcpip.Address{globalAddr2, globalAddr1},
+ remoteAddr: globalAddr3,
+ expectedLocalAddr: globalAddr2,
+ },
+ {
+ name: "Longest prefix matched most preferred (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, globalAddr2},
+ remoteAddr: globalAddr3,
+ expectedLocalAddr: globalAddr2,
+ },
+
// Test returning the endpoint that is closest to the front when
// candidate addresses are "equal" from the perspective of RFC 6724
// section 5.
{
name: "Unique Local for Global",
nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, uniqueLocalAddr2},
- connectAddr: globalAddr2,
+ remoteAddr: globalAddr2,
expectedLocalAddr: uniqueLocalAddr1,
},
{
name: "Link Local for Global",
nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2},
- connectAddr: globalAddr2,
+ remoteAddr: globalAddr2,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Link Local for Unique Local",
nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2},
- connectAddr: uniqueLocalAddr2,
+ remoteAddr: uniqueLocalAddr2,
expectedLocalAddr: linkLocalAddr1,
},
{
name: "Temp Global for Global",
slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1,
slaacPrefixForTempAddrAfterNICAddrAdd: prefix2,
- connectAddr: globalAddr1,
+ remoteAddr: globalAddr1,
expectedLocalAddr: tempGlobalAddr2,
},
}
@@ -2975,12 +2970,6 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv6EmptySubnet,
- Gateway: llAddr3,
- NIC: nicID,
- }})
- s.AddLinkAddress(nicID, llAddr3, linkAddr3)
if test.slaacPrefixForTempAddrBeforeNICAddrAdd != (tcpip.AddressWithPrefix{}) {
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, test.slaacPrefixForTempAddrBeforeNICAddrAdd, true, true, lifetimeSeconds, lifetimeSeconds))
@@ -3000,7 +2989,23 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
t.FailNow()
}
- if got := addrForNewConnectionTo(t, s, tcpip.FullAddress{Addr: test.connectAddr, NIC: nicID, Port: 1234}); got != test.expectedLocalAddr {
+ netEP, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ }
+
+ addressableEndpoint, ok := netEP.(stack.AddressableEndpoint)
+ if !ok {
+ t.Fatal("network endpoint is not addressable")
+ }
+
+ addressEP := addressableEndpoint.AcquireOutgoingPrimaryAddress(test.remoteAddr, false /* allowExpired */)
+ if addressEP == nil {
+ t.Fatal("expected a non-nil address endpoint")
+ }
+ defer addressEP.DecRef()
+
+ if got := addressEP.AddressWithPrefix().Address; got != test.expectedLocalAddr {
t.Errorf("got local address = %s, want = %s", got, test.expectedLocalAddr)
}
})
@@ -3427,11 +3432,16 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
remNetSubnetBcast := remNetSubnet.Broadcast()
tests := []struct {
- name string
- nicAddr tcpip.ProtocolAddress
- routes []tcpip.Route
- remoteAddr tcpip.Address
- expectedRoute stack.Route
+ name string
+ nicAddr tcpip.ProtocolAddress
+ routes []tcpip.Route
+ remoteAddr tcpip.Address
+ expectedLocalAddress tcpip.Address
+ expectedRemoteAddress tcpip.Address
+ expectedRemoteLinkAddress tcpip.LinkAddress
+ expectedNextHop tcpip.Address
+ expectedNetProto tcpip.NetworkProtocolNumber
+ expectedLoop stack.PacketLooping
}{
// Broadcast to a locally attached subnet populates the broadcast MAC.
{
@@ -3446,14 +3456,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv4SubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4Addr.Address,
- RemoteAddress: ipv4SubnetBcast,
- RemoteLinkAddress: header.EthernetBroadcastAddress,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut | stack.PacketLoop,
- },
+ remoteAddr: ipv4SubnetBcast,
+ expectedLocalAddress: ipv4Addr.Address,
+ expectedRemoteAddress: ipv4SubnetBcast,
+ expectedRemoteLinkAddress: header.EthernetBroadcastAddress,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut | stack.PacketLoop,
},
// Broadcast to a locally attached /31 subnet does not populate the
// broadcast MAC.
@@ -3469,13 +3477,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv4Subnet31Bcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4AddrPrefix31.Address,
- RemoteAddress: ipv4Subnet31Bcast,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: ipv4Subnet31Bcast,
+ expectedLocalAddress: ipv4AddrPrefix31.Address,
+ expectedRemoteAddress: ipv4Subnet31Bcast,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// Broadcast to a locally attached /32 subnet does not populate the
// broadcast MAC.
@@ -3491,13 +3497,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv4Subnet32Bcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4AddrPrefix32.Address,
- RemoteAddress: ipv4Subnet32Bcast,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: ipv4Subnet32Bcast,
+ expectedLocalAddress: ipv4AddrPrefix32.Address,
+ expectedRemoteAddress: ipv4Subnet32Bcast,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// IPv6 has no notion of a broadcast.
{
@@ -3512,13 +3516,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv6SubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv6Addr.Address,
- RemoteAddress: ipv6SubnetBcast,
- NetProto: header.IPv6ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: ipv6SubnetBcast,
+ expectedLocalAddress: ipv6Addr.Address,
+ expectedRemoteAddress: ipv6SubnetBcast,
+ expectedNetProto: header.IPv6ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// Broadcast to a remote subnet in the route table is send to the next-hop
// gateway.
@@ -3535,14 +3537,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: remNetSubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4Addr.Address,
- RemoteAddress: remNetSubnetBcast,
- NextHop: ipv4Gateway,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: remNetSubnetBcast,
+ expectedLocalAddress: ipv4Addr.Address,
+ expectedRemoteAddress: remNetSubnetBcast,
+ expectedNextHop: ipv4Gateway,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// Broadcast to an unknown subnet follows the default route. Note that this
// is essentially just routing an unknown destination IP, because w/o any
@@ -3560,14 +3560,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: remNetSubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4Addr.Address,
- RemoteAddress: remNetSubnetBcast,
- NextHop: ipv4Gateway,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: remNetSubnetBcast,
+ expectedLocalAddress: ipv4Addr.Address,
+ expectedRemoteAddress: remNetSubnetBcast,
+ expectedNextHop: ipv4Gateway,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
}
@@ -3596,10 +3594,27 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
t.Fatalf("got unexpected address length = %d bytes", l)
}
- if r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */); err != nil {
+ r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */)
+ if err != nil {
t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, test.remoteAddr, netProto, err)
- } else if diff := cmp.Diff(r, test.expectedRoute, cmpopts.IgnoreUnexported(r)); diff != "" {
- t.Errorf("route mismatch (-want +got):\n%s", diff)
+ }
+ if r.LocalAddress != test.expectedLocalAddress {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, test.expectedLocalAddress)
+ }
+ if r.RemoteAddress != test.expectedRemoteAddress {
+ t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, test.expectedRemoteAddress)
+ }
+ if got := r.RemoteLinkAddress(); got != test.expectedRemoteLinkAddress {
+ t.Errorf("got r.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddress)
+ }
+ if r.NextHop != test.expectedNextHop {
+ t.Errorf("got r.NextHop = %s, want = %s", r.NextHop, test.expectedNextHop)
+ }
+ if r.NetProto != test.expectedNetProto {
+ t.Errorf("got r.NetProto = %d, want = %d", r.NetProto, test.expectedNetProto)
+ }
+ if r.Loop != test.expectedLoop {
+ t.Errorf("got r.Loop = %x, want = %x", r.Loop, test.expectedLoop)
}
})
}
@@ -4167,10 +4182,12 @@ func TestFindRouteWithForwarding(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}})
r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
+ if r != nil {
+ defer r.Release()
+ }
if err != test.findRouteErr {
t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr)
}
- defer r.Release()
if test.findRouteErr != nil {
return
@@ -4228,3 +4245,63 @@ func TestFindRouteWithForwarding(t *testing.T) {
})
}
}
+
+func TestWritePacketToRemote(t *testing.T) {
+ const nicID = 1
+ const MTU = 1280
+ e := channel.New(1, MTU, linkAddr1)
+ s := stack.New(stack.Options{})
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("CreateNIC(%d) = %s", nicID, err)
+ }
+ tests := []struct {
+ name string
+ protocol tcpip.NetworkProtocolNumber
+ payload []byte
+ }{
+ {
+ name: "SuccessIPv4",
+ protocol: header.IPv4ProtocolNumber,
+ payload: []byte{1, 2, 3, 4},
+ },
+ {
+ name: "SuccessIPv6",
+ protocol: header.IPv6ProtocolNumber,
+ payload: []byte{5, 6, 7, 8},
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if err := s.WritePacketToRemote(nicID, linkAddr2, test.protocol, buffer.View(test.payload).ToVectorisedView()); err != nil {
+ t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s", err)
+ }
+
+ pkt, ok := e.Read()
+ if got, want := ok, true; got != want {
+ t.Fatalf("e.Read() = %t, want %t", got, want)
+ }
+ if got, want := pkt.Proto, test.protocol; got != want {
+ t.Fatalf("pkt.Proto = %d, want %d", got, want)
+ }
+ if pkt.Route.RemoteLinkAddress != linkAddr2 {
+ t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", pkt.Route.RemoteLinkAddress, linkAddr2)
+ }
+ if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" {
+ t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+
+ t.Run("InvalidNICID", func(t *testing.T) {
+ if got, want := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView()), tcpip.ErrUnknownDevice; got != want {
+ t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", got, want)
+ }
+ pkt, ok := e.Read()
+ if got, want := ok, false; got != want {
+ t.Fatalf("e.Read() = %t, %v; want %t", got, pkt, want)
+ }
+ })
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index f183ec6e4..07b2818d2 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -182,7 +182,8 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
}
-// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+// handleControlPacket delivers a control packet to the transport endpoint
+// identified by id.
func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) {
epsByNIC.mu.RLock()
defer epsByNIC.mu.RUnlock()
@@ -199,7 +200,7 @@ func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpoint
// broadcast like we are doing with handlePacket above?
// multiPortEndpoints are guaranteed to have at least one element.
- selectEndpoint(id, mpep, epsByNIC.seed).HandleControlPacket(id, typ, extra, pkt)
+ selectEndpoint(id, mpep, epsByNIC.seed).HandleControlPacket(typ, extra, pkt)
}
// registerEndpoint returns true if it succeeds. It fails and returns
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 41a8e5ad0..859278f0b 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -15,6 +15,7 @@
package stack_test
import (
+ "io/ioutil"
"math"
"math/rand"
"testing"
@@ -141,11 +142,11 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: 65,
- SrcAddr: testSrcAddrV6,
- DstAddr: testDstAddrV6,
+ PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
+ TransportProtocol: udp.ProtocolNumber,
+ HopLimit: 65,
+ SrcAddr: testSrcAddrV6,
+ DstAddr: testDstAddrV6,
})
// Initialize the UDP header.
@@ -307,12 +308,9 @@ func TestBindToDeviceDistribution(t *testing.T) {
}(ep)
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.ReusePortOption, endpoint.reuse); err != nil {
- t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err)
- }
- bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
- if err := ep.SetSockOpt(&bindToDeviceOption); err != nil {
- t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", bindToDeviceOption, bindToDeviceOption, i, err)
+ ep.SocketOptions().SetReusePort(endpoint.reuse)
+ if err := ep.SocketOptions().SetBindToDevice(int32(endpoint.bindToDevice)); err != nil {
+ t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", endpoint.bindToDevice, endpoint.bindToDevice, i, err)
}
var dstAddr tcpip.Address
@@ -354,7 +352,7 @@ func TestBindToDeviceDistribution(t *testing.T) {
}
ep := <-pollChannel
- if _, _, err := ep.Read(nil); err != nil {
+ if _, err := ep.Read(ioutil.Discard, math.MaxUint16, tcpip.ReadOptions{}); err != nil {
t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err)
}
stats[ep]++
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index c457b67a2..0ff32c6ea 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -15,12 +15,12 @@
package stack_test
import (
+ "io"
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
@@ -39,14 +39,18 @@ const (
// use it.
type fakeTransportEndpoint struct {
stack.TransportEndpointInfo
+ tcpip.DefaultSocketOptionsHandler
proto *fakeTransportProtocol
peerAddr tcpip.Address
- route stack.Route
+ route *stack.Route
uniqueID uint64
// acceptQueue is non-nil iff bound.
- acceptQueue []fakeTransportEndpoint
+ acceptQueue []*fakeTransportEndpoint
+
+ // ops is used to set and get socket options.
+ ops tcpip.SocketOptions
}
func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo {
@@ -59,8 +63,14 @@ func (*fakeTransportEndpoint) Stats() tcpip.EndpointStats {
func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {}
+func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions {
+ return &f.ops
+}
+
func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
- return &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
+ ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
+ ep.ops.InitHandler(ep)
+ return ep
}
func (f *fakeTransportEndpoint) Abort() {
@@ -68,6 +78,7 @@ func (f *fakeTransportEndpoint) Abort() {
}
func (f *fakeTransportEndpoint) Close() {
+ // TODO(gvisor.dev/issue/5153): Consider retaining the route.
f.route.Release()
}
@@ -75,8 +86,8 @@ func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask
return mask
}
-func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- return buffer.View{}, tcpip.ControlMessages{}, nil
+func (*fakeTransportEndpoint) Read(io.Writer, int, tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+ return tcpip.ReadResult{}, nil
}
func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
@@ -100,30 +111,16 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
return int64(len(v)), nil, nil
}
-func (*fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
-}
-
// SetSockOpt sets a socket option. Currently not supported.
func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
-// SetSockOptBool sets a socket option. Currently not supported.
-func (*fakeTransportEndpoint) SetSockOptBool(tcpip.SockOptBool, bool) *tcpip.Error {
- return tcpip.ErrInvalidEndpointState
-}
-
// SetSockOptInt sets a socket option. Currently not supported.
func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (*fakeTransportEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- return false, tcpip.ErrUnknownProtocolOption
-}
-
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
return -1, tcpip.ErrUnknownProtocolOption
@@ -147,16 +144,16 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if err != nil {
return tcpip.ErrNoRoute
}
- defer r.Release()
// Try to register so that we can start receiving packets.
f.ID.RemoteAddress = addr.Addr
err = f.proto.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */)
if err != nil {
+ r.Release()
return err
}
- f.route = r.Clone()
+ f.route = r
return nil
}
@@ -186,7 +183,7 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai
}
a := f.acceptQueue[0]
f.acceptQueue = f.acceptQueue[1:]
- return &a, nil, nil
+ return a, nil, nil
}
func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
@@ -201,7 +198,7 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
); err != nil {
return err
}
- f.acceptQueue = []fakeTransportEndpoint{}
+ f.acceptQueue = []*fakeTransportEndpoint{}
return nil
}
@@ -227,7 +224,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *
}
route.ResolveWith(pkt.SourceLinkAddress())
- f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{
+ ep := &fakeTransportEndpoint{
TransportEndpointInfo: stack.TransportEndpointInfo{
ID: f.ID,
NetProto: f.NetProto,
@@ -235,10 +232,12 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *
proto: f.proto,
peerAddr: route.RemoteAddress,
route: route,
- })
+ }
+ ep.ops.InitHandler(ep)
+ f.acceptQueue = append(f.acceptQueue, ep)
}
-func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) {
+func (f *fakeTransportEndpoint) HandleControlPacket(stack.ControlType, uint32, *stack.PacketBuffer) {
// Increment the number of received control packets.
f.proto.controlCount++
}
@@ -553,87 +552,3 @@ func TestTransportOptions(t *testing.T) {
t.Fatalf("got tcpip.TCPModerateReceiveBufferOption = false, want = true")
}
}
-
-func TestTransportForwarding(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory},
- })
- s.SetForwarding(fakeNetNumber, true)
-
- // TODO(b/123449044): Change this to a channel NIC.
- ep1 := loopback.New()
- if err := s.CreateNIC(1, ep1); err != nil {
- t.Fatalf("CreateNIC #1 failed: %v", err)
- }
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress #1 failed: %v", err)
- }
-
- ep2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, ep2); err != nil {
- t.Fatalf("CreateNIC #2 failed: %v", err)
- }
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatalf("AddAddress #2 failed: %v", err)
- }
-
- // Route all packets to address 3 to NIC 2 and all packets to address
- // 1 to NIC 1.
- {
- subnet0, err := tcpip.NewSubnet("\x03", "\xff")
- if err != nil {
- t.Fatal(err)
- }
- subnet1, err := tcpip.NewSubnet("\x01", "\xff")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{
- {Destination: subnet0, Gateway: "\x00", NIC: 2},
- {Destination: subnet1, Gateway: "\x00", NIC: 1},
- })
- }
-
- wq := waiter.Queue{}
- ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if err := ep.Bind(tcpip.FullAddress{Addr: "\x01", NIC: 1}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Send a packet to address 1 from address 3.
- req := buffer.NewView(30)
- req[0] = 1
- req[1] = 3
- req[2] = byte(fakeTransNumber)
- ep2.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: req.ToVectorisedView(),
- }))
-
- aep, _, err := ep.Accept(nil)
- if err != nil || aep == nil {
- t.Fatalf("Accept failed: %v, %v", aep, err)
- }
-
- resp := buffer.NewView(30)
- if _, _, err := aep.Write(tcpip.SlicePayload(resp), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- p, ok := ep2.Read()
- if !ok {
- t.Fatal("Response packet not forwarded")
- }
-
- nh := stack.PayloadSince(p.Pkt.NetworkHeader())
- if dst := nh[0]; dst != 3 {
- t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst)
- }
- if src := nh[1]; src != 1 {
- t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src)
- }
-}