summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/BUILD6
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go175
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state_test.go22
-rw-r--r--pkg/tcpip/stack/forwarding_test.go16
-rw-r--r--pkg/tcpip/stack/ndp_test.go229
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go16
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go7
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go27
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go144
-rw-r--r--pkg/tcpip/stack/nic.go19
-rw-r--r--pkg/tcpip/stack/registration.go25
-rw-r--r--pkg/tcpip/stack/route.go200
-rw-r--r--pkg/tcpip/stack/stack.go144
-rw-r--r--pkg/tcpip/stack/stack_test.go217
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go4
-rw-r--r--pkg/tcpip/stack/transport_test.go30
16 files changed, 713 insertions, 568 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index d09ebe7fa..9cc6074da 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test", "most_shards")
load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
@@ -112,7 +112,7 @@ go_test(
"transport_demuxer_test.go",
"transport_test.go",
],
- shard_count = 20,
+ shard_count = most_shards,
deps = [
":stack",
"//pkg/rand",
@@ -120,6 +120,7 @@ go_test(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
@@ -131,7 +132,6 @@ go_test(
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
- "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
)
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index 9478f3fb7..6e4f5fa46 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
-var _ GroupAddressableEndpoint = (*AddressableEndpointState)(nil)
var _ AddressableEndpoint = (*AddressableEndpointState)(nil)
// AddressableEndpointState is an implementation of an AddressableEndpoint.
@@ -37,10 +36,6 @@ type AddressableEndpointState struct {
endpoints map[tcpip.Address]*addressState
primary []*addressState
-
- // groups holds the mapping between group addresses and the number of times
- // they have been joined.
- groups map[tcpip.Address]uint32
}
}
@@ -53,65 +48,33 @@ func (a *AddressableEndpointState) Init(networkEndpoint NetworkEndpoint) {
a.mu.Lock()
defer a.mu.Unlock()
a.mu.endpoints = make(map[tcpip.Address]*addressState)
- a.mu.groups = make(map[tcpip.Address]uint32)
-}
-
-// ReadOnlyAddressableEndpointState provides read-only access to an
-// AddressableEndpointState.
-type ReadOnlyAddressableEndpointState struct {
- inner *AddressableEndpointState
}
-// AddrOrMatching returns an endpoint for the passed address that is consisdered
-// bound to the wrapped AddressableEndpointState.
+// GetAddress returns the AddressEndpoint for the passed address.
//
-// If addr is an exact match with an existing address, that address is returned.
-// Otherwise, f is called with each address and the address that f returns true
-// for is returned.
-//
-// Returns nil of no address matches.
-func (m ReadOnlyAddressableEndpointState) AddrOrMatching(addr tcpip.Address, spoofingOrPrimiscuous bool, f func(AddressEndpoint) bool) AddressEndpoint {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
-
- if ep, ok := m.inner.mu.endpoints[addr]; ok {
- if ep.IsAssigned(spoofingOrPrimiscuous) && ep.IncRef() {
- return ep
- }
- }
-
- for _, ep := range m.inner.mu.endpoints {
- if ep.IsAssigned(spoofingOrPrimiscuous) && f(ep) && ep.IncRef() {
- return ep
- }
- }
-
- return nil
-}
-
-// Lookup returns the AddressEndpoint for the passed address.
+// GetAddress does not increment the address's reference count or check if the
+// address is considered bound to the endpoint.
//
-// Returns nil if the passed address is not associated with the
-// AddressableEndpointState.
-func (m ReadOnlyAddressableEndpointState) Lookup(addr tcpip.Address) AddressEndpoint {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
+// Returns nil if the passed address is not associated with the endpoint.
+func (a *AddressableEndpointState) GetAddress(addr tcpip.Address) AddressEndpoint {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
- ep, ok := m.inner.mu.endpoints[addr]
+ ep, ok := a.mu.endpoints[addr]
if !ok {
return nil
}
return ep
}
-// ForEach calls f for each address pair.
+// ForEachEndpoint calls f for each address.
//
-// If f returns false, f is no longer be called.
-func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool) {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
+// Once f returns false, f will no longer be called.
+func (a *AddressableEndpointState) ForEachEndpoint(f func(AddressEndpoint) bool) {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
- for _, ep := range m.inner.mu.endpoints {
+ for _, ep := range a.mu.endpoints {
if !f(ep) {
return
}
@@ -119,21 +82,15 @@ func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool)
}
// ForEachPrimaryEndpoint calls f for each primary address.
-//
-// If f returns false, f is no longer be called.
-func (m ReadOnlyAddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) {
- m.inner.mu.RLock()
- defer m.inner.mu.RUnlock()
- for _, ep := range m.inner.mu.primary {
+func (a *AddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+
+ for _, ep := range a.mu.primary {
f(ep)
}
}
-// ReadOnly returns a readonly reference to a.
-func (a *AddressableEndpointState) ReadOnly() ReadOnlyAddressableEndpointState {
- return ReadOnlyAddressableEndpointState{inner: a}
-}
-
func (a *AddressableEndpointState) releaseAddressState(addrState *addressState) {
a.mu.Lock()
defer a.mu.Unlock()
@@ -335,11 +292,6 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
a.mu.Lock()
defer a.mu.Unlock()
-
- if _, ok := a.mu.groups[addr]; ok {
- panic(fmt.Sprintf("group address = %s must be removed with LeaveGroup", addr))
- }
-
return a.removePermanentAddressLocked(addr)
}
@@ -471,8 +423,19 @@ func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*ad
return deprecatedEndpoint
}
-// AcquireAssignedAddress implements AddressableEndpoint.
-func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
+// AcquireAssignedAddressOrMatching returns an address endpoint that is
+// considered assigned to the addressable endpoint.
+//
+// If the address is an exact match with an existing address, that address is
+// returned. Otherwise, if f is provided, f is called with each address and
+// the address that f returns true for is returned.
+//
+// If there is no matching address, a temporary address will be returned if
+// allowTemp is true.
+//
+// Regardless how the address was obtained, it will be acquired before it is
+// returned.
+func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tcpip.Address, f func(AddressEndpoint) bool, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
a.mu.Lock()
defer a.mu.Unlock()
@@ -488,6 +451,14 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres
return addrState
}
+ if f != nil {
+ for _, addrState := range a.mu.endpoints {
+ if addrState.IsAssigned(allowTemp) && f(addrState) && addrState.IncRef() {
+ return addrState
+ }
+ }
+ }
+
if !allowTemp {
return nil
}
@@ -520,6 +491,11 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres
return ep
}
+// AcquireAssignedAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
+ return a.AcquireAssignedAddressOrMatching(localAddr, nil, allowTemp, tempPEB)
+}
+
// AcquireOutgoingPrimaryAddress implements AddressableEndpoint.
func (a *AddressableEndpointState) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint {
a.mu.RLock()
@@ -588,72 +564,11 @@ func (a *AddressableEndpointState) PermanentAddresses() []tcpip.AddressWithPrefi
return addrs
}
-// JoinGroup implements GroupAddressableEndpoint.
-func (a *AddressableEndpointState) JoinGroup(group tcpip.Address) (bool, *tcpip.Error) {
- a.mu.Lock()
- defer a.mu.Unlock()
-
- joins, ok := a.mu.groups[group]
- if !ok {
- ep, err := a.addAndAcquireAddressLocked(group.WithPrefix(), NeverPrimaryEndpoint, AddressConfigStatic, false /* deprecated */, true /* permanent */)
- if err != nil {
- return false, err
- }
- // We have no need for the address endpoint.
- a.decAddressRefLocked(ep)
- }
-
- a.mu.groups[group] = joins + 1
- return !ok, nil
-}
-
-// LeaveGroup implements GroupAddressableEndpoint.
-func (a *AddressableEndpointState) LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) {
- a.mu.Lock()
- defer a.mu.Unlock()
-
- joins, ok := a.mu.groups[group]
- if !ok {
- return false, tcpip.ErrBadLocalAddress
- }
-
- if joins == 1 {
- a.removeGroupAddressLocked(group)
- delete(a.mu.groups, group)
- return true, nil
- }
-
- a.mu.groups[group] = joins - 1
- return false, nil
-}
-
-// IsInGroup implements GroupAddressableEndpoint.
-func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool {
- a.mu.RLock()
- defer a.mu.RUnlock()
- _, ok := a.mu.groups[group]
- return ok
-}
-
-func (a *AddressableEndpointState) removeGroupAddressLocked(group tcpip.Address) {
- if err := a.removePermanentAddressLocked(group); err != nil {
- // removePermanentEndpointLocked would only return an error if group is
- // not bound to the addressable endpoint, but we know it MUST be assigned
- // since we have group in our map of groups.
- panic(fmt.Sprintf("error removing group address = %s: %s", group, err))
- }
-}
-
// Cleanup forcefully leaves all groups and removes all permanent addresses.
func (a *AddressableEndpointState) Cleanup() {
a.mu.Lock()
defer a.mu.Unlock()
- for group := range a.mu.groups {
- a.removeGroupAddressLocked(group)
- }
- a.mu.groups = make(map[tcpip.Address]uint32)
-
for _, ep := range a.mu.endpoints {
// removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is
// not a permanent address.
diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go
index 26787d0a3..140f146f6 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state_test.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go
@@ -53,25 +53,9 @@ func TestAddressableEndpointStateCleanup(t *testing.T) {
ep.DecRef()
}
- group := tcpip.Address("\x02")
- if added, err := s.JoinGroup(group); err != nil {
- t.Fatalf("s.JoinGroup(%s): %s", group, err)
- } else if !added {
- t.Fatalf("got s.JoinGroup(%s) = false, want = true", group)
- }
- if !s.IsInGroup(group) {
- t.Fatalf("got s.IsInGroup(%s) = false, want = true", group)
- }
-
s.Cleanup()
- {
- ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint)
- if ep != nil {
- ep.DecRef()
- t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix())
- }
- }
- if s.IsInGroup(group) {
- t.Fatalf("got s.IsInGroup(%s) = true, want = false", group)
+ if ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint); ep != nil {
+ ep.DecRef()
+ t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix())
}
}
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index 6dc9e7859..5ec9b3411 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -309,7 +309,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress {
func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
p := fwdTestPacketInfo{
- RemoteLinkAddress: r.RemoteLinkAddress,
+ RemoteLinkAddress: r.RemoteLinkAddress(),
LocalLinkAddress: r.LocalLinkAddress,
Pkt: pkt,
}
@@ -333,20 +333,6 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer
return n, nil
}
-// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- p := fwdTestPacketInfo{
- Pkt: NewPacketBuffer(PacketBufferOptions{Data: vv}),
- }
-
- select {
- case e.C <- p:
- default:
- }
-
- return nil
-}
-
// Wait implements stack.LinkEndpoint.Wait.
func (*fwdTestLinkEndpoint) Wait() {}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 73a01c2dd..31b67b987 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
@@ -352,7 +353,7 @@ func TestDADDisabled(t *testing.T) {
}
// We should not have sent any NDP NS messages.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 {
+ if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != 0 {
t.Fatalf("got NeighborSolicit = %d, want = 0", got)
}
}
@@ -465,14 +466,18 @@ func TestDADResolve(t *testing.T) {
if err != tcpip.ErrNoRoute {
t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
}
- r.Release()
+ if r != nil {
+ r.Release()
+ }
}
{
r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false)
if err != tcpip.ErrNoRoute {
t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
}
- r.Release()
+ if r != nil {
+ r.Release()
+ }
}
if t.Failed() {
@@ -510,7 +515,9 @@ func TestDADResolve(t *testing.T) {
} else if r.LocalAddress != addr1 {
t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1)
}
- r.Release()
+ if r != nil {
+ r.Release()
+ }
}
if t.Failed() {
@@ -518,7 +525,7 @@ func TestDADResolve(t *testing.T) {
}
// Should not have sent any more NS messages.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) {
+ if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) {
t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits)
}
@@ -533,8 +540,8 @@ func TestDADResolve(t *testing.T) {
// Make sure the right remote link address is used.
snmc := header.SolicitedNodeAddr(addr1)
- if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want {
- t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(snmc); got != want {
+ t.Errorf("got remote link address = %s, want = %s", got, want)
}
// Check NDP NS packet.
@@ -563,7 +570,7 @@ func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns := header.NDPNeighborSolicit(pkt.MessageBody())
ns.SetTargetAddress(tgt)
snmc := header.SolicitedNodeAddr(tgt)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{}))
@@ -605,7 +612,7 @@ func TestDADFail(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
pkt := header.ICMPv6(hdr.Prepend(naSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na := header.NDPNeighborAdvert(pkt.MessageBody())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(true)
na.SetTargetAddress(tgt)
@@ -666,7 +673,7 @@ func TestDADFail(t *testing.T) {
// Receive a packet to simulate an address conflict.
test.rxPkt(e, addr1)
- stat := test.getStat(s.Stats().ICMP.V6PacketsReceived)
+ stat := test.getStat(s.Stats().ICMP.V6.PacketsReceived)
if got := stat.Value(); got != 1 {
t.Fatalf("got stat = %d, want = 1", got)
}
@@ -803,7 +810,7 @@ func TestDADStop(t *testing.T) {
}
// Should not have sent more than 1 NS message.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 {
+ if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got > 1 {
t.Errorf("got NeighborSolicit = %d, want <= 1", got)
}
})
@@ -982,7 +989,7 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
pkt.SetType(header.ICMPv6RouterAdvert)
pkt.SetCode(0)
- raPayload := pkt.NDPPayload()
+ raPayload := pkt.MessageBody()
ra := header.NDPRouterAdvert(raPayload)
// Populate the Router Lifetime.
binary.BigEndian.PutUint16(raPayload[2:], rl)
@@ -2162,8 +2169,8 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
NDPConfigs: ipv6.NDPConfigurations{
AutoGenTempGlobalAddresses: true,
},
- NDPDisp: &ndpDisp,
- AutoGenIPv6LinkLocal: true,
+ NDPDisp: &ndpDisp,
+ AutoGenLinkLocal: true,
})},
})
@@ -2843,9 +2850,7 @@ func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress
t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
if err := ep.Connect(addr); err != nil {
t.Fatalf("ep.Connect(%+v): %s", addr, err)
}
@@ -2879,9 +2884,7 @@ func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullA
t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
if err := ep.Bind(addr); err != nil {
t.Fatalf("ep.Bind(%+v): %s", addr, err)
}
@@ -3250,9 +3253,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
}
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
+ ep.SocketOptions().SetV6Only(true)
if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute {
t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, tcpip.ErrNoRoute)
@@ -4044,9 +4045,9 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
ndpConfigs.AutoGenAddressConflictRetries = maxRetries
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
+ AutoGenLinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: func(_ tcpip.NICID, nicName string) string {
return nicName
@@ -4179,9 +4180,9 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
- NDPConfigs: addrType.ndpConfigs,
- NDPDisp: &ndpDisp,
+ AutoGenLinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: addrType.ndpConfigs,
+ NDPDisp: &ndpDisp,
})},
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -4708,7 +4709,7 @@ func TestCleanupNDPState(t *testing.T) {
}
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: true,
+ AutoGenLinkLocal: true,
NDPConfigs: ipv6.NDPConfigurations{
HandleRAs: true,
DiscoverDefaultRouters: true,
@@ -5174,113 +5175,99 @@ func TestRouterSolicitation(t *testing.T) {
},
}
- // This Run will not return until the parallel tests finish.
- //
- // We need this because we need to do some teardown work after the
- // parallel tests complete.
- //
- // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
- // more details.
- t.Run("group", func(t *testing.T) {
- for _, test := range tests {
- test := test
-
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ e := channelLinkWithHeaderLength{
+ Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
+ headerLength: test.linkHeaderLen,
+ }
+ e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ waitForPkt := func(timeout time.Duration) {
+ t.Helper()
- e := channelLinkWithHeaderLength{
- Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
- headerLength: test.linkHeaderLen,
+ clock.Advance(timeout)
+ p, ok := e.Read()
+ if !ok {
+ t.Fatal("expected router solicitation packet")
}
- e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- waitForPkt := func(timeout time.Duration) {
- t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- p, ok := e.ReadContext(ctx)
- if !ok {
- t.Fatal("timed out waiting for packet")
- return
- }
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
- // Make sure the right remote link address is used.
- if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want {
- t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
- }
+ // Make sure the right remote link address is used.
+ if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); got != want {
+ t.Errorf("got remote link address = %s, want = %s", got, want)
+ }
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(test.expectedSrcAddr),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
- )
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(test.expectedSrcAddr),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
+ )
- if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
- t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
- }
- }
- waitForNothing := func(timeout time.Duration) {
- t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- if _, ok := e.ReadContext(ctx); ok {
- t.Fatal("unexpectedly got a packet")
- }
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- MaxRtrSolicitations: test.maxRtrSolicit,
- RtrSolicitationInterval: test.rtrSolicitInt,
- MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
- },
- })},
- })
- if err := s.CreateNIC(nicID, &e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
}
+ }
+ waitForNothing := func(timeout time.Duration) {
+ t.Helper()
- if addr := test.nicAddr; addr != "" {
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
- }
+ clock.Advance(timeout)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("unexpectedly got a packet = %#v", p)
}
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ MaxRtrSolicitations: test.maxRtrSolicit,
+ RtrSolicitationInterval: test.rtrSolicitInt,
+ MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
+ },
+ })},
+ Clock: clock,
+ })
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
- // Make sure each RS is sent at the right time.
- remaining := test.maxRtrSolicit
- if remaining > 0 {
- waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncPositiveEventTimeout)
- remaining--
+ if addr := test.nicAddr; addr != "" {
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
}
+ }
- for ; remaining > 0; remaining-- {
- if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
- waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncNegativeEventTimeout)
- waitForPkt(defaultAsyncPositiveEventTimeout)
- } else {
- waitForPkt(test.effectiveRtrSolicitInt + defaultAsyncPositiveEventTimeout)
- }
- }
+ // Make sure each RS is sent at the right time.
+ remaining := test.maxRtrSolicit
+ if remaining > 0 {
+ waitForPkt(test.effectiveMaxRtrSolicitDelay)
+ remaining--
+ }
- // Make sure no more RS.
- if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
- waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncNegativeEventTimeout)
+ for ; remaining > 0; remaining-- {
+ if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
+ waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond)
+ waitForPkt(time.Nanosecond)
} else {
- waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncNegativeEventTimeout)
+ waitForPkt(test.effectiveRtrSolicitInt)
}
+ }
- // Make sure the counter got properly
- // incremented.
- if got, want := s.Stats().ICMP.V6PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
- t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
- }
- })
- }
- })
+ // Make sure no more RS.
+ if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
+ waitForNothing(test.effectiveRtrSolicitInt)
+ } else {
+ waitForNothing(test.effectiveMaxRtrSolicitDelay)
+ }
+
+ if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
+ t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
+ }
+ })
+ }
}
func TestStopStartSolicitingRouters(t *testing.T) {
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
index 177bf5516..317f6871d 100644
--- a/pkg/tcpip/stack/neighbor_cache.go
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -24,9 +24,16 @@ import (
const neighborCacheSize = 512 // max entries per interface
+// NeighborStats holds metrics for the neighbor table.
+type NeighborStats struct {
+ // FailedEntryLookups counts the number of lookups performed on an entry in
+ // Failed state.
+ FailedEntryLookups *tcpip.StatCounter
+}
+
// neighborCache maps IP addresses to link addresses. It uses the Least
// Recently Used (LRU) eviction strategy to implement a bounded cache for
-// dynmically acquired entries. It contains the state machine and configuration
+// dynamically acquired entries. It contains the state machine and configuration
// for running Neighbor Unreachability Detection (NUD).
//
// There are two types of entries in the neighbor cache:
@@ -175,14 +182,15 @@ func (n *neighborCache) removeWaker(addr tcpip.Address, waker *sleep.Waker) {
// entries returns all entries in the neighbor cache.
func (n *neighborCache) entries() []NeighborEntry {
- entries := make([]NeighborEntry, 0, len(n.cache))
n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ entries := make([]NeighborEntry, 0, len(n.cache))
for _, entry := range n.cache {
entry.mu.RLock()
entries = append(entries, entry.neigh)
entry.mu.RUnlock()
}
- n.mu.RUnlock()
return entries
}
@@ -226,6 +234,8 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd
}
// removeEntryLocked removes the specified entry from the neighbor cache.
+//
+// Prerequisite: n.mu and entry.mu MUST be locked.
func (n *neighborCache) removeEntryLocked(entry *neighborEntry) {
if entry.neigh.State != Static {
n.dynamic.lru.Remove(entry)
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index ed33418f3..732a299f7 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -80,17 +80,20 @@ func entryDiffOptsWithSort() []cmp.Option {
func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache {
config.resetInvalidFields()
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
- return &neighborCache{
+ neigh := &neighborCache{
nic: &NIC{
stack: &Stack{
clock: clock,
nudDisp: nudDisp,
},
- id: 1,
+ id: 1,
+ stats: makeNICStats(),
},
state: NewNUDState(config, rng),
cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
}
+ neigh.nic.neigh = neigh
+ return neigh
}
// testEntryStore contains a set of IP to NeighborEntry mappings.
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 493e48031..32399b4f5 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -258,7 +258,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
case Failed:
e.notifyWakersLocked()
- e.job = e.nic.stack.newJob(&e.mu, func() {
+ e.job = e.nic.stack.newJob(&doubleLock{first: &e.nic.neigh.mu, second: &e.mu}, func() {
e.nic.neigh.removeEntryLocked(e)
})
e.job.Schedule(config.UnreachableTime)
@@ -347,9 +347,10 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
e.setStateLocked(Delay)
e.dispatchChangeEventLocked()
- case Incomplete, Reachable, Delay, Probe, Static, Failed:
+ case Incomplete, Reachable, Delay, Probe, Static:
// Do nothing
-
+ case Failed:
+ e.nic.stats.Neighbor.FailedEntryLookups.Increment()
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
}
@@ -511,3 +512,23 @@ func (e *neighborEntry) handleUpperLevelConfirmationLocked() {
panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
}
}
+
+// doubleLock combines two locks into one while maintaining lock ordering.
+//
+// TODO(gvisor.dev/issue/4796): Remove this once subsequent traffic to a Failed
+// neighbor is allowed.
+type doubleLock struct {
+ first, second sync.Locker
+}
+
+// Lock locks both locks in order: first then second.
+func (l *doubleLock) Lock() {
+ l.first.Lock()
+ l.second.Lock()
+}
+
+// Unlock unlocks both locks in reverse order: second then first.
+func (l *doubleLock) Unlock() {
+ l.second.Unlock()
+ l.first.Unlock()
+}
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index c2b763325..c497d3932 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -89,7 +89,7 @@ func eventDiffOptsWithSort() []cmp.Option {
// | Stale | Reachable | Solicited confirmation w/o address | Notify wakers | Changed |
// | Stale | Stale | Override confirmation | Update LinkAddr | Changed |
// | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed |
-// | Stale | Delay | Packet sent | | Changed |
+// | Stale | Delay | Packet queued | | Changed |
// | Delay | Reachable | Upper-layer confirmation | | Changed |
// | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
// | Delay | Reachable | Solicited confirmation w/o address | Notify wakers | Changed |
@@ -101,6 +101,7 @@ func eventDiffOptsWithSort() []cmp.Option {
// | Probe | Stale | Probe or confirmation w/ different address | | Changed |
// | Probe | Probe | Retransmit timer expired | Send probe | Changed |
// | Probe | Failed | Max probes sent without reply | Notify wakers | Removed |
+// | Failed | Failed | Packet queued | | |
// | Failed | | Unreachability timer expired | Delete entry | |
type testEntryEventType uint8
@@ -228,6 +229,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e
clock: clock,
nudDisp: &disp,
},
+ stats: makeNICStats(),
}
nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{
header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil, nil, nil),
@@ -3433,6 +3435,146 @@ func TestEntryProbeToFailed(t *testing.T) {
nudDisp.mu.Unlock()
}
+func TestEntryFailedToFailed(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = 3
+ c.MaxUnicastProbes = 3
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ // Verify the cache contains the entry.
+ if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok {
+ t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1)
+ }
+
+ // TODO(gvisor.dev/issue/4872): Use helper functions to start entry tests in
+ // their expected state.
+ e.mu.Lock()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes)
+ clock.Advance(waitFor)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ },
+ {
+ EventType: entryTestRemoved,
+ NICID: entryTestNICID,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ failedLookups := e.nic.stats.Neighbor.FailedEntryLookups
+ if got := failedLookups.Value(); got != 0 {
+ t.Errorf("got Neighbor.FailedEntryLookups = %d, want = 0", got)
+ }
+
+ e.mu.Lock()
+ // Verify queuing a packet to the entry immediately fails.
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ state := e.neigh.State
+ e.mu.Unlock()
+ if state != Failed {
+ t.Errorf("got e.neigh.State = %q, want = %q", state, Failed)
+ }
+
+ if got := failedLookups.Value(); got != 1 {
+ t.Errorf("got Neighbor.FailedEntryLookups = %d, want = 1", got)
+ }
+}
+
func TestEntryFailedGetsDeleted(t *testing.T) {
c := DefaultNUDConfigurations()
c.MaxMulticastProbes = 3
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 3e6ceff28..5887aa1ed 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -60,12 +60,14 @@ type NIC struct {
}
}
-// NICStats includes transmitted and received stats.
+// NICStats hold statistics for a NIC.
type NICStats struct {
Tx DirectionStats
Rx DirectionStats
DisabledRx DirectionStats
+
+ Neighbor NeighborStats
}
func makeNICStats() NICStats {
@@ -265,7 +267,7 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb
if ch, err := r.Resolve(nil); err != nil {
if err == tcpip.ErrWouldBlock {
r := r.Clone()
- n.stack.linkResQueue.enqueue(ch, &r, protocol, pkt)
+ n.stack.linkResQueue.enqueue(ch, r, protocol, pkt)
return nil
}
return err
@@ -277,9 +279,9 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb
// WritePacketToRemote implements NetworkInterface.
func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
r := Route{
- NetProto: protocol,
- RemoteLinkAddress: remoteLinkAddr,
+ NetProto: protocol,
}
+ r.ResolveWith(remoteLinkAddr)
return n.writePacket(&r, gso, protocol, pkt)
}
@@ -561,8 +563,7 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
return tcpip.ErrNotSupported
}
- _, err := gep.JoinGroup(addr)
- return err
+ return gep.JoinGroup(addr)
}
// leaveGroup decrements the count for the given multicast address, and when it
@@ -578,11 +579,7 @@ func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Addres
return tcpip.ErrNotSupported
}
- if _, err := gep.LeaveGroup(addr); err != nil {
- return err
- }
-
- return nil
+ return gep.LeaveGroup(addr)
}
// isInGroup returns true if n has joined the multicast group addr.
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 2cb13c6fa..b334e27c4 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -259,15 +259,6 @@ const (
PacketLoop
)
-// NetOptions is an interface that allows us to pass network protocol specific
-// options through the Stack layer code.
-type NetOptions interface {
- // SizeWithPadding returns the amount of memory that must be allocated to
- // hold the options given that the value must be rounded up to the next
- // multiple of 4 bytes.
- SizeWithPadding() int
-}
-
// NetworkHeaderParams are the header parameters given as input by the
// transport endpoint to the network.
type NetworkHeaderParams struct {
@@ -279,10 +270,6 @@ type NetworkHeaderParams struct {
// TOS refers to TypeOfService or TrafficClass field of the IP-header.
TOS uint8
-
- // Options is a set of options to add to a network header (or nil).
- // It will be protocol specific opaque information from higher layers.
- Options NetOptions
}
// GroupAddressableEndpoint is an endpoint that supports group addressing.
@@ -291,14 +278,10 @@ type NetworkHeaderParams struct {
// endpoints may associate themselves with the same identifier (group address).
type GroupAddressableEndpoint interface {
// JoinGroup joins the specified group.
- //
- // Returns true if the group was newly joined.
- JoinGroup(group tcpip.Address) (bool, *tcpip.Error)
+ JoinGroup(group tcpip.Address) *tcpip.Error
// LeaveGroup attempts to leave the specified group.
- //
- // Returns tcpip.ErrBadLocalAddress if the endpoint has not joined the group.
- LeaveGroup(group tcpip.Address) (bool, *tcpip.Error)
+ LeaveGroup(group tcpip.Address) *tcpip.Error
// IsInGroup returns true if the endpoint is a member of the specified group.
IsInGroup(group tcpip.Address) bool
@@ -739,10 +722,6 @@ type LinkEndpoint interface {
// endpoint.
Capabilities() LinkEndpointCapabilities
- // WriteRawPacket writes a packet directly to the link. The packet
- // should already have an ethernet header. It takes ownership of vv.
- WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error
-
// Attach attaches the data link layer endpoint to the network-layer
// dispatcher of the stack.
//
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 53cb6694f..de5fe6ffe 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -18,19 +18,22 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
// Route represents a route through the networking stack to a given destination.
+//
+// It is safe to call Route's methods from multiple goroutines.
+//
+// The exported fields are immutable.
+//
+// TODO(gvisor.dev/issue/4902): Unexpose immutable fields.
type Route struct {
// RemoteAddress is the final destination of the route.
RemoteAddress tcpip.Address
- // RemoteLinkAddress is the link-layer (MAC) address of the
- // final destination of the route.
- RemoteLinkAddress tcpip.LinkAddress
-
// LocalAddress is the local address where the route starts.
LocalAddress tcpip.Address
@@ -52,8 +55,16 @@ type Route struct {
// address's assigned status without the NIC.
localAddressNIC *NIC
- // localAddressEndpoint is the local address this route is associated with.
- localAddressEndpoint AssignableAddressEndpoint
+ mu struct {
+ sync.RWMutex
+
+ // localAddressEndpoint is the local address this route is associated with.
+ localAddressEndpoint AssignableAddressEndpoint
+
+ // remoteLinkAddress is the link-layer (MAC) address of the next hop in the
+ // route.
+ remoteLinkAddress tcpip.LinkAddress
+ }
// outgoingNIC is the interface this route uses to write packets.
outgoingNIC *NIC
@@ -71,22 +82,24 @@ type Route struct {
// ownership of the provided local address.
//
// Returns an empty route if validation fails.
-func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) Route {
- addrWithPrefix := addressEndpoint.AddressWithPrefix()
+func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) *Route {
+ if len(localAddr) == 0 {
+ localAddr = addressEndpoint.AddressWithPrefix().Address
+ }
- if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(addrWithPrefix.Address) {
+ if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(localAddr) {
addressEndpoint.DecRef()
- return Route{}
+ return nil
}
// If no remote address is provided, use the local address.
if len(remoteAddr) == 0 {
- remoteAddr = addrWithPrefix.Address
+ remoteAddr = localAddr
}
r := makeRoute(
netProto,
- addrWithPrefix.Address,
+ localAddr,
remoteAddr,
outgoingNIC,
localAddressNIC,
@@ -99,8 +112,8 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp
// broadcast it.
if len(gateway) > 0 {
r.NextHop = gateway
- } else if subnet := addrWithPrefix.Subnet(); subnet.IsBroadcast(remoteAddr) {
- r.RemoteLinkAddress = header.EthernetBroadcastAddress
+ } else if subnet := addressEndpoint.Subnet(); subnet.IsBroadcast(remoteAddr) {
+ r.ResolveWith(header.EthernetBroadcastAddress)
}
return r
@@ -108,11 +121,15 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp
// makeRoute initializes a new route. It takes ownership of the provided
// AssignableAddressEndpoint.
-func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route {
+func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) *Route {
if localAddressNIC.stack != outgoingNIC.stack {
panic(fmt.Sprintf("cannot create a route with NICs from different stacks"))
}
+ if len(localAddr) == 0 {
+ localAddr = localAddressEndpoint.AddressWithPrefix().Address
+ }
+
loop := PacketOut
// TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the
@@ -133,18 +150,21 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop)
}
-func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) Route {
- r := Route{
- NetProto: netProto,
- LocalAddress: localAddr,
- LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(),
- RemoteAddress: remoteAddr,
- localAddressNIC: localAddressNIC,
- localAddressEndpoint: localAddressEndpoint,
- outgoingNIC: outgoingNIC,
- Loop: loop,
+func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) *Route {
+ r := &Route{
+ NetProto: netProto,
+ LocalAddress: localAddr,
+ LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(),
+ RemoteAddress: remoteAddr,
+ localAddressNIC: localAddressNIC,
+ outgoingNIC: outgoingNIC,
+ Loop: loop,
}
+ r.mu.Lock()
+ r.mu.localAddressEndpoint = localAddressEndpoint
+ r.mu.Unlock()
+
if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 {
if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok {
r.linkRes = linkRes
@@ -159,7 +179,7 @@ func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr
// provided AssignableAddressEndpoint.
//
// A local route is a route to a destination that is local to the stack.
-func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) Route {
+func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) *Route {
loop := PacketLoop
// TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the
// link endpoint level. We can remove this check once loopback interfaces
@@ -170,6 +190,14 @@ func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr
return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop)
}
+// RemoteLinkAddress returns the link-layer (MAC) address of the next hop in
+// the route.
+func (r *Route) RemoteLinkAddress() tcpip.LinkAddress {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.mu.remoteLinkAddress
+}
+
// NICID returns the id of the NIC from which this route originates.
func (r *Route) NICID() tcpip.NICID {
return r.outgoingNIC.ID()
@@ -231,7 +259,9 @@ func (r *Route) GSOMaxSize() uint32 {
// ResolveWith immediately resolves a route with the specified remote link
// address.
func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
- r.RemoteLinkAddress = addr
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.mu.remoteLinkAddress = addr
}
// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in
@@ -244,7 +274,10 @@ func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
//
// The NIC r uses must not be locked.
func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
- if !r.IsResolutionRequired() {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if !r.isResolutionRequiredRLocked() {
// Nothing to do if there is no cache (which does the resolution on cache miss) or
// link address is already known.
return nil, nil
@@ -254,7 +287,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
if nextAddr == "" {
// Local link address is already known.
if r.RemoteAddress == r.LocalAddress {
- r.RemoteLinkAddress = r.LocalLinkAddress
+ r.mu.remoteLinkAddress = r.LocalLinkAddress
return nil, nil
}
nextAddr = r.RemoteAddress
@@ -272,7 +305,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
if err != nil {
return ch, err
}
- r.RemoteLinkAddress = entry.LinkAddr
+ r.mu.remoteLinkAddress = entry.LinkAddr
return nil, nil
}
@@ -280,7 +313,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
if err != nil {
return ch, err
}
- r.RemoteLinkAddress = linkAddr
+ r.mu.remoteLinkAddress = linkAddr
return nil, nil
}
@@ -309,7 +342,13 @@ func (r *Route) local() bool {
//
// The NICs the route is associated with must not be locked.
func (r *Route) IsResolutionRequired() bool {
- if !r.isValidForOutgoing() || r.RemoteLinkAddress != "" || r.local() {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.isResolutionRequiredRLocked()
+}
+
+func (r *Route) isResolutionRequiredRLocked() bool {
+ if !r.isValidForOutgoingRLocked() || r.mu.remoteLinkAddress != "" || r.local() {
return false
}
@@ -317,11 +356,18 @@ func (r *Route) IsResolutionRequired() bool {
}
func (r *Route) isValidForOutgoing() bool {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.isValidForOutgoingRLocked()
+}
+
+func (r *Route) isValidForOutgoingRLocked() bool {
if !r.outgoingNIC.Enabled() {
return false
}
- if !r.localAddressNIC.isValidForOutgoing(r.localAddressEndpoint) {
+ localAddressEndpoint := r.mu.localAddressEndpoint
+ if localAddressEndpoint == nil || !r.localAddressNIC.isValidForOutgoing(localAddressEndpoint) {
return false
}
@@ -375,37 +421,44 @@ func (r *Route) MTU() uint32 {
// Release frees all resources associated with the route.
func (r *Route) Release() {
- if r.localAddressEndpoint != nil {
- r.localAddressEndpoint.DecRef()
- r.localAddressEndpoint = nil
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.mu.localAddressEndpoint != nil {
+ r.mu.localAddressEndpoint.DecRef()
+ r.mu.localAddressEndpoint = nil
}
}
// Clone clones the route.
-func (r *Route) Clone() Route {
- if r.localAddressEndpoint != nil {
- if !r.localAddressEndpoint.IncRef() {
- panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress))
- }
+func (r *Route) Clone() *Route {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+
+ newRoute := &Route{
+ RemoteAddress: r.RemoteAddress,
+ LocalAddress: r.LocalAddress,
+ LocalLinkAddress: r.LocalLinkAddress,
+ NextHop: r.NextHop,
+ NetProto: r.NetProto,
+ Loop: r.Loop,
+ localAddressNIC: r.localAddressNIC,
+ outgoingNIC: r.outgoingNIC,
+ linkCache: r.linkCache,
+ linkRes: r.linkRes,
}
- return *r
-}
-// MakeLoopedRoute duplicates the given route with special handling for routes
-// used for sending multicast or broadcast packets. In those cases the
-// multicast/broadcast address is the remote address when sending out, but for
-// incoming (looped) packets it becomes the local address. Similarly, the local
-// interface address that was the local address going out becomes the remote
-// address coming in. This is different to unicast routes where local and
-// remote addresses remain the same as they identify location (local vs remote)
-// not direction (source vs destination).
-func (r *Route) MakeLoopedRoute() Route {
- l := r.Clone()
- if r.RemoteAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) {
- l.RemoteAddress, l.LocalAddress = l.LocalAddress, l.RemoteAddress
- l.RemoteLinkAddress = l.LocalLinkAddress
+ newRoute.mu.Lock()
+ defer newRoute.mu.Unlock()
+ newRoute.mu.localAddressEndpoint = r.mu.localAddressEndpoint
+ if newRoute.mu.localAddressEndpoint != nil {
+ if !newRoute.mu.localAddressEndpoint.IncRef() {
+ panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", newRoute.LocalAddress))
+ }
}
- return l
+ newRoute.mu.remoteLinkAddress = r.mu.remoteLinkAddress
+
+ return newRoute
}
// Stack returns the instance of the Stack that owns this route.
@@ -418,7 +471,14 @@ func (r *Route) isV4Broadcast(addr tcpip.Address) bool {
return true
}
- subnet := r.localAddressEndpoint.Subnet()
+ r.mu.RLock()
+ localAddressEndpoint := r.mu.localAddressEndpoint
+ r.mu.RUnlock()
+ if localAddressEndpoint == nil {
+ return false
+ }
+
+ subnet := localAddressEndpoint.Subnet()
return subnet.IsBroadcast(addr)
}
@@ -428,27 +488,3 @@ func (r *Route) IsOutboundBroadcast() bool {
// Only IPv4 has a notion of broadcast.
return r.isV4Broadcast(r.RemoteAddress)
}
-
-// isInboundBroadcast returns true if the route is for an inbound broadcast
-// packet.
-func (r *Route) isInboundBroadcast() bool {
- // Only IPv4 has a notion of broadcast.
- return r.isV4Broadcast(r.LocalAddress)
-}
-
-// ReverseRoute returns new route with given source and destination address.
-func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route {
- return Route{
- NetProto: r.NetProto,
- LocalAddress: dst,
- LocalLinkAddress: r.RemoteLinkAddress,
- RemoteAddress: src,
- RemoteLinkAddress: r.LocalLinkAddress,
- Loop: r.Loop,
- localAddressNIC: r.localAddressNIC,
- localAddressEndpoint: r.localAddressEndpoint,
- outgoingNIC: r.outgoingNIC,
- linkCache: r.linkCache,
- linkRes: r.linkRes,
- }
-}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index e0025e0a9..dc4f5b3e7 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -1118,6 +1118,16 @@ func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber,
return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint)
}
+// AddAddressWithPrefix is the same as AddAddress, but allows you to specify
+// the address prefix.
+func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) *tcpip.Error {
+ ap := tcpip.ProtocolAddress{
+ Protocol: protocol,
+ AddressWithPrefix: addr,
+ }
+ return s.AddProtocolAddressWithOptions(id, ap, CanBePrimaryEndpoint)
+}
+
// AddProtocolAddress adds a new network-layer protocol address to the
// specified NIC.
func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) *tcpip.Error {
@@ -1208,10 +1218,10 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP
// from the specified NIC.
//
// Precondition: s.mu must be read locked.
-func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) {
+func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route {
localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint)
if localAddressEndpoint == nil {
- return Route{}, false
+ return nil
}
var outgoingNIC *NIC
@@ -1235,12 +1245,12 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re
// route.
if outgoingNIC == nil {
localAddressEndpoint.DecRef()
- return Route{}, false
+ return nil
}
r := makeLocalRoute(
netProto,
- localAddressEndpoint.AddressWithPrefix().Address,
+ localAddr,
remoteAddr,
outgoingNIC,
localAddressNIC,
@@ -1249,10 +1259,10 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re
if r.IsOutboundBroadcast() {
r.Release()
- return Route{}, false
+ return nil
}
- return r, true
+ return r
}
// findLocalRouteRLocked returns a local route.
@@ -1261,26 +1271,26 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re
// is, a local route is a route where packets never have to leave the stack.
//
// Precondition: s.mu must be read locked.
-func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) {
+func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route {
if len(localAddr) == 0 {
localAddr = remoteAddr
}
if localAddressNICID == 0 {
for _, localAddressNIC := range s.nics {
- if r, ok := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); ok {
- return r, true
+ if r := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); r != nil {
+ return r
}
}
- return Route{}, false
+ return nil
}
if localAddressNIC, ok := s.nics[localAddressNICID]; ok {
return s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto)
}
- return Route{}, false
+ return nil
}
// FindRoute creates a route to the given destination address, leaving through
@@ -1294,7 +1304,7 @@ func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr,
// If no local address is provided, the stack will select a local address. If no
// remote address is provided, the stack wil use a remote address equal to the
// local address.
-func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) {
+func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (*Route, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -1305,7 +1315,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || isLoopback)
if s.handleLocal && !isMulticast && !isLocalBroadcast {
- if r, ok := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); ok {
+ if r := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); r != nil {
return r, nil
}
}
@@ -1317,7 +1327,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
return makeRoute(
netProto,
- addressEndpoint.AddressWithPrefix().Address,
+ localAddr,
remoteAddr,
nic, /* outboundNIC */
nic, /* localAddressNIC*/
@@ -1329,9 +1339,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
if isLoopback {
- return Route{}, tcpip.ErrBadLocalAddress
+ return nil, tcpip.ErrBadLocalAddress
}
- return Route{}, tcpip.ErrNetworkUnreachable
+ return nil, tcpip.ErrNetworkUnreachable
}
canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal
@@ -1354,8 +1364,8 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if needRoute {
gateway = route.Gateway
}
- r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop)
- if r == (Route{}) {
+ r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop)
+ if r == nil {
panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr))
}
return r, nil
@@ -1391,13 +1401,13 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if id != 0 {
if aNIC, ok := s.nics[id]; ok {
if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil {
- if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) {
+ if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil {
return r, nil
}
}
}
- return Route{}, tcpip.ErrNoRoute
+ return nil, tcpip.ErrNoRoute
}
if id == 0 {
@@ -1409,7 +1419,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
continue
}
- if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) {
+ if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil {
return r, nil
}
}
@@ -1417,12 +1427,12 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
if needRoute {
- return Route{}, tcpip.ErrNoRoute
+ return nil, tcpip.ErrNoRoute
}
if header.IsV6LoopbackAddress(remoteAddr) {
- return Route{}, tcpip.ErrBadLocalAddress
+ return nil, tcpip.ErrBadLocalAddress
}
- return Route{}, tcpip.ErrNetworkUnreachable
+ return nil, tcpip.ErrNetworkUnreachable
}
// CheckNetworkProtocol checks if a given network protocol is enabled in the
@@ -1810,49 +1820,20 @@ func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip
nic.unregisterPacketEndpoint(netProto, ep)
}
-// WritePacket writes data directly to the specified NIC. It adds an ethernet
-// header based on the arguments.
-func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error {
+// WritePacketToRemote writes a payload on the specified NIC using the provided
+// network protocol and remote link address.
+func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error {
s.mu.Lock()
nic, ok := s.nics[nicID]
s.mu.Unlock()
if !ok {
return tcpip.ErrUnknownDevice
}
-
- // Add our own fake ethernet header.
- ethFields := header.EthernetFields{
- SrcAddr: nic.LinkEndpoint.LinkAddress(),
- DstAddr: dst,
- Type: netProto,
- }
- fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
- fakeHeader.Encode(&ethFields)
- vv := buffer.View(fakeHeader).ToVectorisedView()
- vv.Append(payload)
-
- if err := nic.LinkEndpoint.WriteRawPacket(vv); err != nil {
- return err
- }
-
- return nil
-}
-
-// WriteRawPacket writes data directly to the specified NIC without adding any
-// headers.
-func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error {
- s.mu.Lock()
- nic, ok := s.nics[nicID]
- s.mu.Unlock()
- if !ok {
- return tcpip.ErrUnknownDevice
- }
-
- if err := nic.LinkEndpoint.WriteRawPacket(payload); err != nil {
- return err
- }
-
- return nil
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: int(nic.MaxHeaderLength()),
+ Data: payload,
+ })
+ return nic.WritePacketToRemote(remote, nil, netProto, pkt)
}
// NetworkProtocolInstance returns the protocol instance in the stack for the
@@ -1912,7 +1893,6 @@ func (s *Stack) RemoveTCPProbe() {
// JoinGroup joins the given multicast group on the given NIC.
func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
- // TODO: notify network of subscription via igmp protocol.
s.mu.RLock()
defer s.mu.RUnlock()
@@ -2159,3 +2139,43 @@ func (s *Stack) networkProtocolNumbers() []tcpip.NetworkProtocolNumber {
}
return protos
}
+
+func isSubnetBroadcastOnNIC(nic *NIC, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ addressEndpoint := nic.getAddressOrCreateTempInner(protocol, addr, false /* createTemp */, NeverPrimaryEndpoint)
+ if addressEndpoint == nil {
+ return false
+ }
+
+ subnet := addressEndpoint.Subnet()
+ addressEndpoint.DecRef()
+ return subnet.IsBroadcast(addr)
+}
+
+// IsSubnetBroadcast returns true if the provided address is a subnet-local
+// broadcast address on the specified NIC and protocol.
+//
+// Returns false if the NIC is unknown or if the protocol is unknown or does
+// not support addressing.
+//
+// If the NIC is not specified, the stack will check all NICs.
+func (s *Stack) IsSubnetBroadcast(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nicID != 0 {
+ nic, ok := s.nics[nicID]
+ if !ok {
+ return false
+ }
+
+ return isSubnetBroadcastOnNIC(nic, protocol, addr)
+ }
+
+ for _, nic := range s.nics {
+ if isSubnetBroadcastOnNIC(nic, protocol, addr) {
+ return true
+ }
+ }
+
+ return false
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 61db3164b..457990945 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -27,7 +27,6 @@ import (
"time"
"github.com/google/go-cmp/cmp"
- "github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -407,7 +406,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro
return send(r, payload)
}
-func send(r stack.Route, payload buffer.View) *tcpip.Error {
+func send(r *stack.Route, payload buffer.View) *tcpip.Error {
return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: payload.ToVectorisedView(),
@@ -425,7 +424,7 @@ func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.En
}
}
-func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) {
+func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View) {
t.Helper()
ep.Drain()
if err := send(r, payload); err != nil {
@@ -436,7 +435,7 @@ func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.
}
}
-func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
t.Helper()
if gotErr := send(r, payload); gotErr != wantErr {
t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr)
@@ -1563,15 +1562,15 @@ func TestSpoofingNoAddress(t *testing.T) {
// testSendTo(t, s, remoteAddr, ep, nil)
}
-func verifyRoute(gotRoute, wantRoute stack.Route) error {
+func verifyRoute(gotRoute, wantRoute *stack.Route) error {
if gotRoute.LocalAddress != wantRoute.LocalAddress {
return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress)
}
if gotRoute.RemoteAddress != wantRoute.RemoteAddress {
return fmt.Errorf("bad remote address: got %s, want = %s", gotRoute.RemoteAddress, wantRoute.RemoteAddress)
}
- if gotRoute.RemoteLinkAddress != wantRoute.RemoteLinkAddress {
- return fmt.Errorf("bad remote link address: got %s, want = %s", gotRoute.RemoteLinkAddress, wantRoute.RemoteLinkAddress)
+ if got, want := gotRoute.RemoteLinkAddress(), wantRoute.RemoteLinkAddress(); got != want {
+ return fmt.Errorf("bad remote link address: got %s, want = %s", got, want)
}
if gotRoute.NextHop != wantRoute.NextHop {
return fmt.Errorf("bad next-hop address: got %s, want = %s", gotRoute.NextHop, wantRoute.NextHop)
@@ -1603,7 +1602,7 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ if err := verifyRoute(r, &stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil {
t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1657,7 +1656,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ if err := verifyRoute(r, &stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1667,7 +1666,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ if err := verifyRoute(r, &stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1683,7 +1682,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ if err := verifyRoute(r, &stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err)
}
}
@@ -2407,9 +2406,9 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
}
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: test.autoGen,
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: test.iidOpts,
+ AutoGenLinkLocal: test.autoGen,
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: test.iidOpts,
})},
}
@@ -2502,8 +2501,8 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenIPv6LinkLocal: true,
- OpaqueIIDOpts: test.opaqueIIDOpts,
+ AutoGenLinkLocal: true,
+ OpaqueIIDOpts: test.opaqueIIDOpts,
})},
}
@@ -2536,9 +2535,9 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
ndpConfigs := ipv6.DefaultNDPConfigurations()
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ndpConfigs,
- AutoGenIPv6LinkLocal: true,
- NDPDisp: &ndpDisp,
+ NDPConfigs: ndpConfigs,
+ AutoGenLinkLocal: true,
+ NDPDisp: &ndpDisp,
})},
}
@@ -3351,11 +3350,16 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
remNetSubnetBcast := remNetSubnet.Broadcast()
tests := []struct {
- name string
- nicAddr tcpip.ProtocolAddress
- routes []tcpip.Route
- remoteAddr tcpip.Address
- expectedRoute stack.Route
+ name string
+ nicAddr tcpip.ProtocolAddress
+ routes []tcpip.Route
+ remoteAddr tcpip.Address
+ expectedLocalAddress tcpip.Address
+ expectedRemoteAddress tcpip.Address
+ expectedRemoteLinkAddress tcpip.LinkAddress
+ expectedNextHop tcpip.Address
+ expectedNetProto tcpip.NetworkProtocolNumber
+ expectedLoop stack.PacketLooping
}{
// Broadcast to a locally attached subnet populates the broadcast MAC.
{
@@ -3370,14 +3374,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv4SubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4Addr.Address,
- RemoteAddress: ipv4SubnetBcast,
- RemoteLinkAddress: header.EthernetBroadcastAddress,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut | stack.PacketLoop,
- },
+ remoteAddr: ipv4SubnetBcast,
+ expectedLocalAddress: ipv4Addr.Address,
+ expectedRemoteAddress: ipv4SubnetBcast,
+ expectedRemoteLinkAddress: header.EthernetBroadcastAddress,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut | stack.PacketLoop,
},
// Broadcast to a locally attached /31 subnet does not populate the
// broadcast MAC.
@@ -3393,13 +3395,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv4Subnet31Bcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4AddrPrefix31.Address,
- RemoteAddress: ipv4Subnet31Bcast,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: ipv4Subnet31Bcast,
+ expectedLocalAddress: ipv4AddrPrefix31.Address,
+ expectedRemoteAddress: ipv4Subnet31Bcast,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// Broadcast to a locally attached /32 subnet does not populate the
// broadcast MAC.
@@ -3415,13 +3415,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv4Subnet32Bcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4AddrPrefix32.Address,
- RemoteAddress: ipv4Subnet32Bcast,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: ipv4Subnet32Bcast,
+ expectedLocalAddress: ipv4AddrPrefix32.Address,
+ expectedRemoteAddress: ipv4Subnet32Bcast,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// IPv6 has no notion of a broadcast.
{
@@ -3436,13 +3434,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: ipv6SubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv6Addr.Address,
- RemoteAddress: ipv6SubnetBcast,
- NetProto: header.IPv6ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: ipv6SubnetBcast,
+ expectedLocalAddress: ipv6Addr.Address,
+ expectedRemoteAddress: ipv6SubnetBcast,
+ expectedNetProto: header.IPv6ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// Broadcast to a remote subnet in the route table is send to the next-hop
// gateway.
@@ -3459,14 +3455,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: remNetSubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4Addr.Address,
- RemoteAddress: remNetSubnetBcast,
- NextHop: ipv4Gateway,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: remNetSubnetBcast,
+ expectedLocalAddress: ipv4Addr.Address,
+ expectedRemoteAddress: remNetSubnetBcast,
+ expectedNextHop: ipv4Gateway,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
// Broadcast to an unknown subnet follows the default route. Note that this
// is essentially just routing an unknown destination IP, because w/o any
@@ -3484,14 +3478,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: remNetSubnetBcast,
- expectedRoute: stack.Route{
- LocalAddress: ipv4Addr.Address,
- RemoteAddress: remNetSubnetBcast,
- NextHop: ipv4Gateway,
- NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
- },
+ remoteAddr: remNetSubnetBcast,
+ expectedLocalAddress: ipv4Addr.Address,
+ expectedRemoteAddress: remNetSubnetBcast,
+ expectedNextHop: ipv4Gateway,
+ expectedNetProto: header.IPv4ProtocolNumber,
+ expectedLoop: stack.PacketOut,
},
}
@@ -3520,10 +3512,27 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
t.Fatalf("got unexpected address length = %d bytes", l)
}
- if r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */); err != nil {
+ r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */)
+ if err != nil {
t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, test.remoteAddr, netProto, err)
- } else if diff := cmp.Diff(r, test.expectedRoute, cmpopts.IgnoreUnexported(r)); diff != "" {
- t.Errorf("route mismatch (-want +got):\n%s", diff)
+ }
+ if r.LocalAddress != test.expectedLocalAddress {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, test.expectedLocalAddress)
+ }
+ if r.RemoteAddress != test.expectedRemoteAddress {
+ t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, test.expectedRemoteAddress)
+ }
+ if got := r.RemoteLinkAddress(); got != test.expectedRemoteLinkAddress {
+ t.Errorf("got r.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddress)
+ }
+ if r.NextHop != test.expectedNextHop {
+ t.Errorf("got r.NextHop = %s, want = %s", r.NextHop, test.expectedNextHop)
+ }
+ if r.NetProto != test.expectedNetProto {
+ t.Errorf("got r.NetProto = %d, want = %d", r.NetProto, test.expectedNetProto)
+ }
+ if r.Loop != test.expectedLoop {
+ t.Errorf("got r.Loop = %x, want = %x", r.Loop, test.expectedLoop)
}
})
}
@@ -4091,10 +4100,12 @@ func TestFindRouteWithForwarding(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}})
r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
+ if r != nil {
+ defer r.Release()
+ }
if err != test.findRouteErr {
t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr)
}
- defer r.Release()
if test.findRouteErr != nil {
return
@@ -4152,3 +4163,63 @@ func TestFindRouteWithForwarding(t *testing.T) {
})
}
}
+
+func TestWritePacketToRemote(t *testing.T) {
+ const nicID = 1
+ const MTU = 1280
+ e := channel.New(1, MTU, linkAddr1)
+ s := stack.New(stack.Options{})
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("CreateNIC(%d) = %s", nicID, err)
+ }
+ tests := []struct {
+ name string
+ protocol tcpip.NetworkProtocolNumber
+ payload []byte
+ }{
+ {
+ name: "SuccessIPv4",
+ protocol: header.IPv4ProtocolNumber,
+ payload: []byte{1, 2, 3, 4},
+ },
+ {
+ name: "SuccessIPv6",
+ protocol: header.IPv6ProtocolNumber,
+ payload: []byte{5, 6, 7, 8},
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if err := s.WritePacketToRemote(nicID, linkAddr2, test.protocol, buffer.View(test.payload).ToVectorisedView()); err != nil {
+ t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s", err)
+ }
+
+ pkt, ok := e.Read()
+ if got, want := ok, true; got != want {
+ t.Fatalf("e.Read() = %t, want %t", got, want)
+ }
+ if got, want := pkt.Proto, test.protocol; got != want {
+ t.Fatalf("pkt.Proto = %d, want %d", got, want)
+ }
+ if got, want := pkt.Route.RemoteLinkAddress(), linkAddr2; got != want {
+ t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", got, want)
+ }
+ if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" {
+ t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+
+ t.Run("InvalidNICID", func(t *testing.T) {
+ if got, want := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView()), tcpip.ErrUnknownDevice; got != want {
+ t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", got, want)
+ }
+ pkt, ok := e.Read()
+ if got, want := ok, false; got != want {
+ t.Fatalf("e.Read() = %t, %v; want %t", got, pkt, want)
+ }
+ })
+}
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 41a8e5ad0..2cdb5ca79 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -307,9 +307,7 @@ func TestBindToDeviceDistribution(t *testing.T) {
}(ep)
defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.ReusePortOption, endpoint.reuse); err != nil {
- t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err)
- }
+ ep.SocketOptions().SetReusePort(endpoint.reuse)
bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
if err := ep.SetSockOpt(&bindToDeviceOption); err != nil {
t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", bindToDeviceOption, bindToDeviceOption, i, err)
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 5b9043d85..d9769e47d 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -38,14 +38,15 @@ const (
// use it.
type fakeTransportEndpoint struct {
stack.TransportEndpointInfo
+ tcpip.DefaultSocketOptionsHandler
proto *fakeTransportProtocol
peerAddr tcpip.Address
- route stack.Route
+ route *stack.Route
uniqueID uint64
// acceptQueue is non-nil iff bound.
- acceptQueue []fakeTransportEndpoint
+ acceptQueue []*fakeTransportEndpoint
// ops is used to set and get socket options.
ops tcpip.SocketOptions
@@ -64,8 +65,11 @@ func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {}
func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions {
return &f.ops
}
+
func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
- return &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
+ ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
+ ep.ops.InitHandler(ep)
+ return ep
}
func (f *fakeTransportEndpoint) Abort() {
@@ -114,21 +118,11 @@ func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Erro
return tcpip.ErrInvalidEndpointState
}
-// SetSockOptBool sets a socket option. Currently not supported.
-func (*fakeTransportEndpoint) SetSockOptBool(tcpip.SockOptBool, bool) *tcpip.Error {
- return tcpip.ErrInvalidEndpointState
-}
-
// SetSockOptInt sets a socket option. Currently not supported.
func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (*fakeTransportEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- return false, tcpip.ErrUnknownProtocolOption
-}
-
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
return -1, tcpip.ErrUnknownProtocolOption
@@ -189,7 +183,7 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai
if len(f.acceptQueue) == 0 {
return nil, nil, nil
}
- a := &f.acceptQueue[0]
+ a := f.acceptQueue[0]
f.acceptQueue = f.acceptQueue[1:]
return a, nil, nil
}
@@ -206,7 +200,7 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
); err != nil {
return err
}
- f.acceptQueue = []fakeTransportEndpoint{}
+ f.acceptQueue = []*fakeTransportEndpoint{}
return nil
}
@@ -232,7 +226,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *
}
route.ResolveWith(pkt.SourceLinkAddress())
- f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{
+ ep := &fakeTransportEndpoint{
TransportEndpointInfo: stack.TransportEndpointInfo{
ID: f.ID,
NetProto: f.NetProto,
@@ -240,7 +234,9 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *
proto: f.proto,
peerAddr: route.RemoteAddress,
route: route,
- })
+ }
+ ep.ops.InitHandler(ep)
+ f.acceptQueue = append(f.acceptQueue, ep)
}
func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) {