diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/BUILD | 6 | ||||
-rw-r--r-- | pkg/tcpip/stack/addressable_endpoint_state.go | 153 | ||||
-rw-r--r-- | pkg/tcpip/stack/addressable_endpoint_state_test.go | 22 | ||||
-rw-r--r-- | pkg/tcpip/stack/forwarding_test.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp_test.go | 211 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_cache.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 15 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 21 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 145 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 42 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 143 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 13 |
12 files changed, 369 insertions, 409 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index d09ebe7fa..9cc6074da 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test", "most_shards") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -112,7 +112,7 @@ go_test( "transport_demuxer_test.go", "transport_test.go", ], - shard_count = 20, + shard_count = most_shards, deps = [ ":stack", "//pkg/rand", @@ -120,6 +120,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", @@ -131,7 +132,6 @@ go_test( "//pkg/tcpip/transport/udp", "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", - "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index adeebfe37..6e4f5fa46 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" ) -var _ GroupAddressableEndpoint = (*AddressableEndpointState)(nil) var _ AddressableEndpoint = (*AddressableEndpointState)(nil) // AddressableEndpointState is an implementation of an AddressableEndpoint. @@ -37,10 +36,6 @@ type AddressableEndpointState struct { endpoints map[tcpip.Address]*addressState primary []*addressState - - // groups holds the mapping between group addresses and the number of times - // they have been joined. - groups map[tcpip.Address]uint32 } } @@ -53,65 +48,33 @@ func (a *AddressableEndpointState) Init(networkEndpoint NetworkEndpoint) { a.mu.Lock() defer a.mu.Unlock() a.mu.endpoints = make(map[tcpip.Address]*addressState) - a.mu.groups = make(map[tcpip.Address]uint32) -} - -// ReadOnlyAddressableEndpointState provides read-only access to an -// AddressableEndpointState. -type ReadOnlyAddressableEndpointState struct { - inner *AddressableEndpointState } -// AddrOrMatching returns an endpoint for the passed address that is consisdered -// bound to the wrapped AddressableEndpointState. -// -// If addr is an exact match with an existing address, that address is returned. -// Otherwise, f is called with each address and the address that f returns true -// for is returned. +// GetAddress returns the AddressEndpoint for the passed address. // -// Returns nil of no address matches. -func (m ReadOnlyAddressableEndpointState) AddrOrMatching(addr tcpip.Address, spoofingOrPrimiscuous bool, f func(AddressEndpoint) bool) AddressEndpoint { - m.inner.mu.RLock() - defer m.inner.mu.RUnlock() - - if ep, ok := m.inner.mu.endpoints[addr]; ok { - if ep.IsAssigned(spoofingOrPrimiscuous) && ep.IncRef() { - return ep - } - } - - for _, ep := range m.inner.mu.endpoints { - if ep.IsAssigned(spoofingOrPrimiscuous) && f(ep) && ep.IncRef() { - return ep - } - } - - return nil -} - -// Lookup returns the AddressEndpoint for the passed address. +// GetAddress does not increment the address's reference count or check if the +// address is considered bound to the endpoint. // -// Returns nil if the passed address is not associated with the -// AddressableEndpointState. -func (m ReadOnlyAddressableEndpointState) Lookup(addr tcpip.Address) AddressEndpoint { - m.inner.mu.RLock() - defer m.inner.mu.RUnlock() +// Returns nil if the passed address is not associated with the endpoint. +func (a *AddressableEndpointState) GetAddress(addr tcpip.Address) AddressEndpoint { + a.mu.RLock() + defer a.mu.RUnlock() - ep, ok := m.inner.mu.endpoints[addr] + ep, ok := a.mu.endpoints[addr] if !ok { return nil } return ep } -// ForEach calls f for each address pair. +// ForEachEndpoint calls f for each address. // -// If f returns false, f is no longer be called. -func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool) { - m.inner.mu.RLock() - defer m.inner.mu.RUnlock() +// Once f returns false, f will no longer be called. +func (a *AddressableEndpointState) ForEachEndpoint(f func(AddressEndpoint) bool) { + a.mu.RLock() + defer a.mu.RUnlock() - for _, ep := range m.inner.mu.endpoints { + for _, ep := range a.mu.endpoints { if !f(ep) { return } @@ -119,21 +82,15 @@ func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool) } // ForEachPrimaryEndpoint calls f for each primary address. -// -// If f returns false, f is no longer be called. -func (m ReadOnlyAddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) { - m.inner.mu.RLock() - defer m.inner.mu.RUnlock() - for _, ep := range m.inner.mu.primary { +func (a *AddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) { + a.mu.RLock() + defer a.mu.RUnlock() + + for _, ep := range a.mu.primary { f(ep) } } -// ReadOnly returns a readonly reference to a. -func (a *AddressableEndpointState) ReadOnly() ReadOnlyAddressableEndpointState { - return ReadOnlyAddressableEndpointState{inner: a} -} - func (a *AddressableEndpointState) releaseAddressState(addrState *addressState) { a.mu.Lock() defer a.mu.Unlock() @@ -335,11 +292,6 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { a.mu.Lock() defer a.mu.Unlock() - - if _, ok := a.mu.groups[addr]; ok { - panic(fmt.Sprintf("group address = %s must be removed with LeaveGroup", addr)) - } - return a.removePermanentAddressLocked(addr) } @@ -471,8 +423,19 @@ func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*ad return deprecatedEndpoint } -// AcquireAssignedAddress implements AddressableEndpoint. -func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint { +// AcquireAssignedAddressOrMatching returns an address endpoint that is +// considered assigned to the addressable endpoint. +// +// If the address is an exact match with an existing address, that address is +// returned. Otherwise, if f is provided, f is called with each address and +// the address that f returns true for is returned. +// +// If there is no matching address, a temporary address will be returned if +// allowTemp is true. +// +// Regardless how the address was obtained, it will be acquired before it is +// returned. +func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tcpip.Address, f func(AddressEndpoint) bool, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint { a.mu.Lock() defer a.mu.Unlock() @@ -488,6 +451,14 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres return addrState } + if f != nil { + for _, addrState := range a.mu.endpoints { + if addrState.IsAssigned(allowTemp) && f(addrState) && addrState.IncRef() { + return addrState + } + } + } + if !allowTemp { return nil } @@ -520,6 +491,11 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres return ep } +// AcquireAssignedAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint { + return a.AcquireAssignedAddressOrMatching(localAddr, nil, allowTemp, tempPEB) +} + // AcquireOutgoingPrimaryAddress implements AddressableEndpoint. func (a *AddressableEndpointState) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint { a.mu.RLock() @@ -588,50 +564,11 @@ func (a *AddressableEndpointState) PermanentAddresses() []tcpip.AddressWithPrefi return addrs } -// JoinGroup implements GroupAddressableEndpoint. -func (a *AddressableEndpointState) JoinGroup(group tcpip.Address) (bool, *tcpip.Error) { - a.mu.Lock() - defer a.mu.Unlock() - - joins, ok := a.mu.groups[group] - a.mu.groups[group] = joins + 1 - return !ok, nil -} - -// LeaveGroup implements GroupAddressableEndpoint. -func (a *AddressableEndpointState) LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) { - a.mu.Lock() - defer a.mu.Unlock() - - joins, ok := a.mu.groups[group] - if !ok { - return false, tcpip.ErrBadLocalAddress - } - - if joins == 1 { - delete(a.mu.groups, group) - return true, nil - } - - a.mu.groups[group] = joins - 1 - return false, nil -} - -// IsInGroup implements GroupAddressableEndpoint. -func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool { - a.mu.RLock() - defer a.mu.RUnlock() - _, ok := a.mu.groups[group] - return ok -} - // Cleanup forcefully leaves all groups and removes all permanent addresses. func (a *AddressableEndpointState) Cleanup() { a.mu.Lock() defer a.mu.Unlock() - a.mu.groups = make(map[tcpip.Address]uint32) - for _, ep := range a.mu.endpoints { // removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is // not a permanent address. diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go index 26787d0a3..140f146f6 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state_test.go +++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go @@ -53,25 +53,9 @@ func TestAddressableEndpointStateCleanup(t *testing.T) { ep.DecRef() } - group := tcpip.Address("\x02") - if added, err := s.JoinGroup(group); err != nil { - t.Fatalf("s.JoinGroup(%s): %s", group, err) - } else if !added { - t.Fatalf("got s.JoinGroup(%s) = false, want = true", group) - } - if !s.IsInGroup(group) { - t.Fatalf("got s.IsInGroup(%s) = false, want = true", group) - } - s.Cleanup() - { - ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint) - if ep != nil { - ep.DecRef() - t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix()) - } - } - if s.IsInGroup(group) { - t.Fatalf("got s.IsInGroup(%s) = true, want = false", group) + if ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint); ep != nil { + ep.DecRef() + t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix()) } } diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index cb7dec1ea..5ec9b3411 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -309,7 +309,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { p := fwdTestPacketInfo{ - RemoteLinkAddress: r.RemoteLinkAddress, + RemoteLinkAddress: r.RemoteLinkAddress(), LocalLinkAddress: r.LocalLinkAddress, Pkt: pkt, } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 7b0cd58f6..31b67b987 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" @@ -352,7 +353,7 @@ func TestDADDisabled(t *testing.T) { } // We should not have sent any NDP NS messages. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 { + if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != 0 { t.Fatalf("got NeighborSolicit = %d, want = 0", got) } } @@ -465,14 +466,18 @@ func TestDADResolve(t *testing.T) { if err != tcpip.ErrNoRoute { t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute) } - r.Release() + if r != nil { + r.Release() + } } { r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false) if err != tcpip.ErrNoRoute { t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute) } - r.Release() + if r != nil { + r.Release() + } } if t.Failed() { @@ -510,7 +515,9 @@ func TestDADResolve(t *testing.T) { } else if r.LocalAddress != addr1 { t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1) } - r.Release() + if r != nil { + r.Release() + } } if t.Failed() { @@ -518,7 +525,7 @@ func TestDADResolve(t *testing.T) { } // Should not have sent any more NS messages. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) { + if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) { t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits) } @@ -533,8 +540,8 @@ func TestDADResolve(t *testing.T) { // Make sure the right remote link address is used. snmc := header.SolicitedNodeAddr(addr1) - if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) + if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(snmc); got != want { + t.Errorf("got remote link address = %s, want = %s", got, want) } // Check NDP NS packet. @@ -563,7 +570,7 @@ func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns := header.NDPNeighborSolicit(pkt.MessageBody()) ns.SetTargetAddress(tgt) snmc := header.SolicitedNodeAddr(tgt) pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{})) @@ -605,7 +612,7 @@ func TestDADFail(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) pkt := header.ICMPv6(hdr.Prepend(naSize)) pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.NDPPayload()) + na := header.NDPNeighborAdvert(pkt.MessageBody()) na.SetSolicitedFlag(true) na.SetOverrideFlag(true) na.SetTargetAddress(tgt) @@ -666,7 +673,7 @@ func TestDADFail(t *testing.T) { // Receive a packet to simulate an address conflict. test.rxPkt(e, addr1) - stat := test.getStat(s.Stats().ICMP.V6PacketsReceived) + stat := test.getStat(s.Stats().ICMP.V6.PacketsReceived) if got := stat.Value(); got != 1 { t.Fatalf("got stat = %d, want = 1", got) } @@ -803,7 +810,7 @@ func TestDADStop(t *testing.T) { } // Should not have sent more than 1 NS message. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 { + if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got > 1 { t.Errorf("got NeighborSolicit = %d, want <= 1", got) } }) @@ -982,7 +989,7 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo pkt := header.ICMPv6(hdr.Prepend(icmpSize)) pkt.SetType(header.ICMPv6RouterAdvert) pkt.SetCode(0) - raPayload := pkt.NDPPayload() + raPayload := pkt.MessageBody() ra := header.NDPRouterAdvert(raPayload) // Populate the Router Lifetime. binary.BigEndian.PutUint16(raPayload[2:], rl) @@ -2843,9 +2850,7 @@ func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) } defer ep.Close() - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) - } + ep.SocketOptions().SetV6Only(true) if err := ep.Connect(addr); err != nil { t.Fatalf("ep.Connect(%+v): %s", addr, err) } @@ -2879,9 +2884,7 @@ func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullA t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) } defer ep.Close() - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) - } + ep.SocketOptions().SetV6Only(true) if err := ep.Bind(addr); err != nil { t.Fatalf("ep.Bind(%+v): %s", addr, err) } @@ -3250,9 +3253,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) } defer ep.Close() - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) - } + ep.SocketOptions().SetV6Only(true) if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute { t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, tcpip.ErrNoRoute) @@ -5174,113 +5175,99 @@ func TestRouterSolicitation(t *testing.T) { }, } - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + e := channelLinkWithHeaderLength{ + Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), + headerLength: test.linkHeaderLen, + } + e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired + waitForPkt := func(timeout time.Duration) { + t.Helper() - e := channelLinkWithHeaderLength{ - Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), - headerLength: test.linkHeaderLen, + clock.Advance(timeout) + p, ok := e.Read() + if !ok { + t.Fatal("expected router solicitation packet") } - e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired - waitForPkt := func(timeout time.Duration) { - t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - p, ok := e.ReadContext(ctx) - if !ok { - t.Fatal("timed out waiting for packet") - return - } - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } - // Make sure the right remote link address is used. - if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) - } + // Make sure the right remote link address is used. + if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); got != want { + t.Errorf("got remote link address = %s, want = %s", got, want) + } - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(test.expectedSrcAddr), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), - ) + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(test.expectedSrcAddr), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), + ) - if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) - } - } - waitForNothing := func(timeout time.Duration) { - t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { - t.Fatal("unexpectedly got a packet") - } - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - MaxRtrSolicitations: test.maxRtrSolicit, - RtrSolicitationInterval: test.rtrSolicitInt, - MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, - }, - })}, - }) - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) } + } + waitForNothing := func(timeout time.Duration) { + t.Helper() - if addr := test.nicAddr; addr != "" { - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) - } + clock.Advance(timeout) + if p, ok := e.Read(); ok { + t.Fatalf("unexpectedly got a packet = %#v", p) } + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + MaxRtrSolicitations: test.maxRtrSolicit, + RtrSolicitationInterval: test.rtrSolicitInt, + MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, + }, + })}, + Clock: clock, + }) + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - // Make sure each RS is sent at the right time. - remaining := test.maxRtrSolicit - if remaining > 0 { - waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncPositiveEventTimeout) - remaining-- + if addr := test.nicAddr; addr != "" { + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) } + } - for ; remaining > 0; remaining-- { - if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { - waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncNegativeEventTimeout) - waitForPkt(defaultAsyncPositiveEventTimeout) - } else { - waitForPkt(test.effectiveRtrSolicitInt + defaultAsyncPositiveEventTimeout) - } - } + // Make sure each RS is sent at the right time. + remaining := test.maxRtrSolicit + if remaining > 0 { + waitForPkt(test.effectiveMaxRtrSolicitDelay) + remaining-- + } - // Make sure no more RS. - if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { - waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncNegativeEventTimeout) + for ; remaining > 0; remaining-- { + if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { + waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) + waitForPkt(time.Nanosecond) } else { - waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncNegativeEventTimeout) + waitForPkt(test.effectiveRtrSolicitInt) } + } - // Make sure the counter got properly - // incremented. - if got, want := s.Stats().ICMP.V6PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { - t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) - } - }) - } - }) + // Make sure no more RS. + if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { + waitForNothing(test.effectiveRtrSolicitInt) + } else { + waitForNothing(test.effectiveMaxRtrSolicitDelay) + } + + if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { + t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) + } + }) + } } func TestStopStartSolicitingRouters(t *testing.T) { diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 0d3f626cf..317f6871d 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -182,14 +182,15 @@ func (n *neighborCache) removeWaker(addr tcpip.Address, waker *sleep.Waker) { // entries returns all entries in the neighbor cache. func (n *neighborCache) entries() []NeighborEntry { - entries := make([]NeighborEntry, 0, len(n.cache)) n.mu.RLock() + defer n.mu.RUnlock() + + entries := make([]NeighborEntry, 0, len(n.cache)) for _, entry := range n.cache { entry.mu.RLock() entries = append(entries, entry.neigh) entry.mu.RUnlock() } - n.mu.RUnlock() return entries } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 43696ba14..5887aa1ed 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -267,7 +267,7 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb if ch, err := r.Resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { r := r.Clone() - n.stack.linkResQueue.enqueue(ch, &r, protocol, pkt) + n.stack.linkResQueue.enqueue(ch, r, protocol, pkt) return nil } return err @@ -279,9 +279,9 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb // WritePacketToRemote implements NetworkInterface. func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { r := Route{ - NetProto: protocol, - RemoteLinkAddress: remoteLinkAddr, + NetProto: protocol, } + r.ResolveWith(remoteLinkAddr) return n.writePacket(&r, gso, protocol, pkt) } @@ -563,8 +563,7 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address return tcpip.ErrNotSupported } - _, err := gep.JoinGroup(addr) - return err + return gep.JoinGroup(addr) } // leaveGroup decrements the count for the given multicast address, and when it @@ -580,11 +579,7 @@ func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Addres return tcpip.ErrNotSupported } - if _, err := gep.LeaveGroup(addr); err != nil { - return err - } - - return nil + return gep.LeaveGroup(addr) } // isInGroup returns true if n has joined the multicast group addr. diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 43ca03ada..b334e27c4 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -259,15 +259,6 @@ const ( PacketLoop ) -// NetOptions is an interface that allows us to pass network protocol specific -// options through the Stack layer code. -type NetOptions interface { - // SizeWithPadding returns the amount of memory that must be allocated to - // hold the options given that the value must be rounded up to the next - // multiple of 4 bytes. - SizeWithPadding() int -} - // NetworkHeaderParams are the header parameters given as input by the // transport endpoint to the network. type NetworkHeaderParams struct { @@ -279,10 +270,6 @@ type NetworkHeaderParams struct { // TOS refers to TypeOfService or TrafficClass field of the IP-header. TOS uint8 - - // Options is a set of options to add to a network header (or nil). - // It will be protocol specific opaque information from higher layers. - Options NetOptions } // GroupAddressableEndpoint is an endpoint that supports group addressing. @@ -291,14 +278,10 @@ type NetworkHeaderParams struct { // endpoints may associate themselves with the same identifier (group address). type GroupAddressableEndpoint interface { // JoinGroup joins the specified group. - // - // Returns true if the group was newly joined. - JoinGroup(group tcpip.Address) (bool, *tcpip.Error) + JoinGroup(group tcpip.Address) *tcpip.Error // LeaveGroup attempts to leave the specified group. - // - // Returns tcpip.ErrBadLocalAddress if the endpoint has not joined the group. - LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) + LeaveGroup(group tcpip.Address) *tcpip.Error // IsInGroup returns true if the endpoint is a member of the specified group. IsInGroup(group tcpip.Address) bool diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index f0b256507..de5fe6ffe 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -18,19 +18,22 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) // Route represents a route through the networking stack to a given destination. +// +// It is safe to call Route's methods from multiple goroutines. +// +// The exported fields are immutable. +// +// TODO(gvisor.dev/issue/4902): Unexpose immutable fields. type Route struct { // RemoteAddress is the final destination of the route. RemoteAddress tcpip.Address - // RemoteLinkAddress is the link-layer (MAC) address of the - // final destination of the route. - RemoteLinkAddress tcpip.LinkAddress - // LocalAddress is the local address where the route starts. LocalAddress tcpip.Address @@ -52,8 +55,16 @@ type Route struct { // address's assigned status without the NIC. localAddressNIC *NIC - // localAddressEndpoint is the local address this route is associated with. - localAddressEndpoint AssignableAddressEndpoint + mu struct { + sync.RWMutex + + // localAddressEndpoint is the local address this route is associated with. + localAddressEndpoint AssignableAddressEndpoint + + // remoteLinkAddress is the link-layer (MAC) address of the next hop in the + // route. + remoteLinkAddress tcpip.LinkAddress + } // outgoingNIC is the interface this route uses to write packets. outgoingNIC *NIC @@ -71,14 +82,14 @@ type Route struct { // ownership of the provided local address. // // Returns an empty route if validation fails. -func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) Route { +func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) *Route { if len(localAddr) == 0 { localAddr = addressEndpoint.AddressWithPrefix().Address } if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(localAddr) { addressEndpoint.DecRef() - return Route{} + return nil } // If no remote address is provided, use the local address. @@ -102,7 +113,7 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp if len(gateway) > 0 { r.NextHop = gateway } else if subnet := addressEndpoint.Subnet(); subnet.IsBroadcast(remoteAddr) { - r.RemoteLinkAddress = header.EthernetBroadcastAddress + r.ResolveWith(header.EthernetBroadcastAddress) } return r @@ -110,7 +121,7 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp // makeRoute initializes a new route. It takes ownership of the provided // AssignableAddressEndpoint. -func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route { +func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) *Route { if localAddressNIC.stack != outgoingNIC.stack { panic(fmt.Sprintf("cannot create a route with NICs from different stacks")) } @@ -139,18 +150,21 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop) } -func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) Route { - r := Route{ - NetProto: netProto, - LocalAddress: localAddr, - LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(), - RemoteAddress: remoteAddr, - localAddressNIC: localAddressNIC, - localAddressEndpoint: localAddressEndpoint, - outgoingNIC: outgoingNIC, - Loop: loop, +func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) *Route { + r := &Route{ + NetProto: netProto, + LocalAddress: localAddr, + LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(), + RemoteAddress: remoteAddr, + localAddressNIC: localAddressNIC, + outgoingNIC: outgoingNIC, + Loop: loop, } + r.mu.Lock() + r.mu.localAddressEndpoint = localAddressEndpoint + r.mu.Unlock() + if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok { r.linkRes = linkRes @@ -165,7 +179,7 @@ func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr // provided AssignableAddressEndpoint. // // A local route is a route to a destination that is local to the stack. -func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) Route { +func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) *Route { loop := PacketLoop // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the // link endpoint level. We can remove this check once loopback interfaces @@ -176,6 +190,14 @@ func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop) } +// RemoteLinkAddress returns the link-layer (MAC) address of the next hop in +// the route. +func (r *Route) RemoteLinkAddress() tcpip.LinkAddress { + r.mu.RLock() + defer r.mu.RUnlock() + return r.mu.remoteLinkAddress +} + // NICID returns the id of the NIC from which this route originates. func (r *Route) NICID() tcpip.NICID { return r.outgoingNIC.ID() @@ -237,7 +259,9 @@ func (r *Route) GSOMaxSize() uint32 { // ResolveWith immediately resolves a route with the specified remote link // address. func (r *Route) ResolveWith(addr tcpip.LinkAddress) { - r.RemoteLinkAddress = addr + r.mu.Lock() + defer r.mu.Unlock() + r.mu.remoteLinkAddress = addr } // Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in @@ -250,7 +274,10 @@ func (r *Route) ResolveWith(addr tcpip.LinkAddress) { // // The NIC r uses must not be locked. func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { - if !r.IsResolutionRequired() { + r.mu.Lock() + defer r.mu.Unlock() + + if !r.isResolutionRequiredRLocked() { // Nothing to do if there is no cache (which does the resolution on cache miss) or // link address is already known. return nil, nil @@ -260,7 +287,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { if nextAddr == "" { // Local link address is already known. if r.RemoteAddress == r.LocalAddress { - r.RemoteLinkAddress = r.LocalLinkAddress + r.mu.remoteLinkAddress = r.LocalLinkAddress return nil, nil } nextAddr = r.RemoteAddress @@ -278,7 +305,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { if err != nil { return ch, err } - r.RemoteLinkAddress = entry.LinkAddr + r.mu.remoteLinkAddress = entry.LinkAddr return nil, nil } @@ -286,7 +313,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { if err != nil { return ch, err } - r.RemoteLinkAddress = linkAddr + r.mu.remoteLinkAddress = linkAddr return nil, nil } @@ -315,7 +342,13 @@ func (r *Route) local() bool { // // The NICs the route is associated with must not be locked. func (r *Route) IsResolutionRequired() bool { - if !r.isValidForOutgoing() || r.RemoteLinkAddress != "" || r.local() { + r.mu.RLock() + defer r.mu.RUnlock() + return r.isResolutionRequiredRLocked() +} + +func (r *Route) isResolutionRequiredRLocked() bool { + if !r.isValidForOutgoingRLocked() || r.mu.remoteLinkAddress != "" || r.local() { return false } @@ -323,11 +356,18 @@ func (r *Route) IsResolutionRequired() bool { } func (r *Route) isValidForOutgoing() bool { + r.mu.RLock() + defer r.mu.RUnlock() + return r.isValidForOutgoingRLocked() +} + +func (r *Route) isValidForOutgoingRLocked() bool { if !r.outgoingNIC.Enabled() { return false } - if !r.localAddressNIC.isValidForOutgoing(r.localAddressEndpoint) { + localAddressEndpoint := r.mu.localAddressEndpoint + if localAddressEndpoint == nil || !r.localAddressNIC.isValidForOutgoing(localAddressEndpoint) { return false } @@ -381,20 +421,44 @@ func (r *Route) MTU() uint32 { // Release frees all resources associated with the route. func (r *Route) Release() { - if r.localAddressEndpoint != nil { - r.localAddressEndpoint.DecRef() - r.localAddressEndpoint = nil + r.mu.Lock() + defer r.mu.Unlock() + + if r.mu.localAddressEndpoint != nil { + r.mu.localAddressEndpoint.DecRef() + r.mu.localAddressEndpoint = nil } } // Clone clones the route. -func (r *Route) Clone() Route { - if r.localAddressEndpoint != nil { - if !r.localAddressEndpoint.IncRef() { - panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress)) +func (r *Route) Clone() *Route { + r.mu.RLock() + defer r.mu.RUnlock() + + newRoute := &Route{ + RemoteAddress: r.RemoteAddress, + LocalAddress: r.LocalAddress, + LocalLinkAddress: r.LocalLinkAddress, + NextHop: r.NextHop, + NetProto: r.NetProto, + Loop: r.Loop, + localAddressNIC: r.localAddressNIC, + outgoingNIC: r.outgoingNIC, + linkCache: r.linkCache, + linkRes: r.linkRes, + } + + newRoute.mu.Lock() + defer newRoute.mu.Unlock() + newRoute.mu.localAddressEndpoint = r.mu.localAddressEndpoint + if newRoute.mu.localAddressEndpoint != nil { + if !newRoute.mu.localAddressEndpoint.IncRef() { + panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", newRoute.LocalAddress)) } } - return *r + newRoute.mu.remoteLinkAddress = r.mu.remoteLinkAddress + + return newRoute } // Stack returns the instance of the Stack that owns this route. @@ -407,7 +471,14 @@ func (r *Route) isV4Broadcast(addr tcpip.Address) bool { return true } - subnet := r.localAddressEndpoint.Subnet() + r.mu.RLock() + localAddressEndpoint := r.mu.localAddressEndpoint + r.mu.RUnlock() + if localAddressEndpoint == nil { + return false + } + + subnet := localAddressEndpoint.Subnet() return subnet.IsBroadcast(addr) } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index a2d234e7d..dc4f5b3e7 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1218,10 +1218,10 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP // from the specified NIC. // // Precondition: s.mu must be read locked. -func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) { +func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route { localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint) if localAddressEndpoint == nil { - return Route{}, false + return nil } var outgoingNIC *NIC @@ -1245,7 +1245,7 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re // route. if outgoingNIC == nil { localAddressEndpoint.DecRef() - return Route{}, false + return nil } r := makeLocalRoute( @@ -1259,10 +1259,10 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re if r.IsOutboundBroadcast() { r.Release() - return Route{}, false + return nil } - return r, true + return r } // findLocalRouteRLocked returns a local route. @@ -1271,26 +1271,26 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re // is, a local route is a route where packets never have to leave the stack. // // Precondition: s.mu must be read locked. -func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) { +func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route { if len(localAddr) == 0 { localAddr = remoteAddr } if localAddressNICID == 0 { for _, localAddressNIC := range s.nics { - if r, ok := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); ok { - return r, true + if r := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); r != nil { + return r } } - return Route{}, false + return nil } if localAddressNIC, ok := s.nics[localAddressNICID]; ok { return s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto) } - return Route{}, false + return nil } // FindRoute creates a route to the given destination address, leaving through @@ -1304,7 +1304,7 @@ func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, // If no local address is provided, the stack will select a local address. If no // remote address is provided, the stack wil use a remote address equal to the // local address. -func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) { +func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (*Route, *tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() @@ -1315,7 +1315,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || isLoopback) if s.handleLocal && !isMulticast && !isLocalBroadcast { - if r, ok := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); ok { + if r := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); r != nil { return r, nil } } @@ -1339,9 +1339,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } if isLoopback { - return Route{}, tcpip.ErrBadLocalAddress + return nil, tcpip.ErrBadLocalAddress } - return Route{}, tcpip.ErrNetworkUnreachable + return nil, tcpip.ErrNetworkUnreachable } canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal @@ -1365,7 +1365,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n gateway = route.Gateway } r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop) - if r == (Route{}) { + if r == nil { panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr)) } return r, nil @@ -1401,13 +1401,13 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if id != 0 { if aNIC, ok := s.nics[id]; ok { if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil { - if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil { return r, nil } } } - return Route{}, tcpip.ErrNoRoute + return nil, tcpip.ErrNoRoute } if id == 0 { @@ -1419,7 +1419,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n continue } - if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil { return r, nil } } @@ -1427,12 +1427,12 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } if needRoute { - return Route{}, tcpip.ErrNoRoute + return nil, tcpip.ErrNoRoute } if header.IsV6LoopbackAddress(remoteAddr) { - return Route{}, tcpip.ErrBadLocalAddress + return nil, tcpip.ErrBadLocalAddress } - return Route{}, tcpip.ErrNetworkUnreachable + return nil, tcpip.ErrNetworkUnreachable } // CheckNetworkProtocol checks if a given network protocol is enabled in the diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 9d2d0aa84..457990945 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -27,7 +27,6 @@ import ( "time" "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -407,7 +406,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro return send(r, payload) } -func send(r stack.Route, payload buffer.View) *tcpip.Error { +func send(r *stack.Route, payload buffer.View) *tcpip.Error { return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: payload.ToVectorisedView(), @@ -425,7 +424,7 @@ func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.En } } -func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) { +func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View) { t.Helper() ep.Drain() if err := send(r, payload); err != nil { @@ -436,7 +435,7 @@ func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer. } } -func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { +func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { t.Helper() if gotErr := send(r, payload); gotErr != wantErr { t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) @@ -1563,15 +1562,15 @@ func TestSpoofingNoAddress(t *testing.T) { // testSendTo(t, s, remoteAddr, ep, nil) } -func verifyRoute(gotRoute, wantRoute stack.Route) error { +func verifyRoute(gotRoute, wantRoute *stack.Route) error { if gotRoute.LocalAddress != wantRoute.LocalAddress { return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress) } if gotRoute.RemoteAddress != wantRoute.RemoteAddress { return fmt.Errorf("bad remote address: got %s, want = %s", gotRoute.RemoteAddress, wantRoute.RemoteAddress) } - if gotRoute.RemoteLinkAddress != wantRoute.RemoteLinkAddress { - return fmt.Errorf("bad remote link address: got %s, want = %s", gotRoute.RemoteLinkAddress, wantRoute.RemoteLinkAddress) + if got, want := gotRoute.RemoteLinkAddress(), wantRoute.RemoteLinkAddress(); got != want { + return fmt.Errorf("bad remote link address: got %s, want = %s", got, want) } if gotRoute.NextHop != wantRoute.NextHop { return fmt.Errorf("bad next-hop address: got %s, want = %s", gotRoute.NextHop, wantRoute.NextHop) @@ -1603,7 +1602,7 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil { + if err := verifyRoute(r, &stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil { t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } @@ -1657,7 +1656,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + if err := verifyRoute(r, &stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } @@ -1667,7 +1666,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + if err := verifyRoute(r, &stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) } @@ -1683,7 +1682,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + if err := verifyRoute(r, &stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) } } @@ -3351,11 +3350,16 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { remNetSubnetBcast := remNetSubnet.Broadcast() tests := []struct { - name string - nicAddr tcpip.ProtocolAddress - routes []tcpip.Route - remoteAddr tcpip.Address - expectedRoute stack.Route + name string + nicAddr tcpip.ProtocolAddress + routes []tcpip.Route + remoteAddr tcpip.Address + expectedLocalAddress tcpip.Address + expectedRemoteAddress tcpip.Address + expectedRemoteLinkAddress tcpip.LinkAddress + expectedNextHop tcpip.Address + expectedNetProto tcpip.NetworkProtocolNumber + expectedLoop stack.PacketLooping }{ // Broadcast to a locally attached subnet populates the broadcast MAC. { @@ -3370,14 +3374,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: ipv4SubnetBcast, - expectedRoute: stack.Route{ - LocalAddress: ipv4Addr.Address, - RemoteAddress: ipv4SubnetBcast, - RemoteLinkAddress: header.EthernetBroadcastAddress, - NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut | stack.PacketLoop, - }, + remoteAddr: ipv4SubnetBcast, + expectedLocalAddress: ipv4Addr.Address, + expectedRemoteAddress: ipv4SubnetBcast, + expectedRemoteLinkAddress: header.EthernetBroadcastAddress, + expectedNetProto: header.IPv4ProtocolNumber, + expectedLoop: stack.PacketOut | stack.PacketLoop, }, // Broadcast to a locally attached /31 subnet does not populate the // broadcast MAC. @@ -3393,13 +3395,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: ipv4Subnet31Bcast, - expectedRoute: stack.Route{ - LocalAddress: ipv4AddrPrefix31.Address, - RemoteAddress: ipv4Subnet31Bcast, - NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut, - }, + remoteAddr: ipv4Subnet31Bcast, + expectedLocalAddress: ipv4AddrPrefix31.Address, + expectedRemoteAddress: ipv4Subnet31Bcast, + expectedNetProto: header.IPv4ProtocolNumber, + expectedLoop: stack.PacketOut, }, // Broadcast to a locally attached /32 subnet does not populate the // broadcast MAC. @@ -3415,13 +3415,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: ipv4Subnet32Bcast, - expectedRoute: stack.Route{ - LocalAddress: ipv4AddrPrefix32.Address, - RemoteAddress: ipv4Subnet32Bcast, - NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut, - }, + remoteAddr: ipv4Subnet32Bcast, + expectedLocalAddress: ipv4AddrPrefix32.Address, + expectedRemoteAddress: ipv4Subnet32Bcast, + expectedNetProto: header.IPv4ProtocolNumber, + expectedLoop: stack.PacketOut, }, // IPv6 has no notion of a broadcast. { @@ -3436,13 +3434,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: ipv6SubnetBcast, - expectedRoute: stack.Route{ - LocalAddress: ipv6Addr.Address, - RemoteAddress: ipv6SubnetBcast, - NetProto: header.IPv6ProtocolNumber, - Loop: stack.PacketOut, - }, + remoteAddr: ipv6SubnetBcast, + expectedLocalAddress: ipv6Addr.Address, + expectedRemoteAddress: ipv6SubnetBcast, + expectedNetProto: header.IPv6ProtocolNumber, + expectedLoop: stack.PacketOut, }, // Broadcast to a remote subnet in the route table is send to the next-hop // gateway. @@ -3459,14 +3455,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: remNetSubnetBcast, - expectedRoute: stack.Route{ - LocalAddress: ipv4Addr.Address, - RemoteAddress: remNetSubnetBcast, - NextHop: ipv4Gateway, - NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut, - }, + remoteAddr: remNetSubnetBcast, + expectedLocalAddress: ipv4Addr.Address, + expectedRemoteAddress: remNetSubnetBcast, + expectedNextHop: ipv4Gateway, + expectedNetProto: header.IPv4ProtocolNumber, + expectedLoop: stack.PacketOut, }, // Broadcast to an unknown subnet follows the default route. Note that this // is essentially just routing an unknown destination IP, because w/o any @@ -3484,14 +3478,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: remNetSubnetBcast, - expectedRoute: stack.Route{ - LocalAddress: ipv4Addr.Address, - RemoteAddress: remNetSubnetBcast, - NextHop: ipv4Gateway, - NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut, - }, + remoteAddr: remNetSubnetBcast, + expectedLocalAddress: ipv4Addr.Address, + expectedRemoteAddress: remNetSubnetBcast, + expectedNextHop: ipv4Gateway, + expectedNetProto: header.IPv4ProtocolNumber, + expectedLoop: stack.PacketOut, }, } @@ -3520,10 +3512,27 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { t.Fatalf("got unexpected address length = %d bytes", l) } - if r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */); err != nil { + r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */) + if err != nil { t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, test.remoteAddr, netProto, err) - } else if diff := cmp.Diff(r, test.expectedRoute, cmpopts.IgnoreUnexported(r)); diff != "" { - t.Errorf("route mismatch (-want +got):\n%s", diff) + } + if r.LocalAddress != test.expectedLocalAddress { + t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, test.expectedLocalAddress) + } + if r.RemoteAddress != test.expectedRemoteAddress { + t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, test.expectedRemoteAddress) + } + if got := r.RemoteLinkAddress(); got != test.expectedRemoteLinkAddress { + t.Errorf("got r.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddress) + } + if r.NextHop != test.expectedNextHop { + t.Errorf("got r.NextHop = %s, want = %s", r.NextHop, test.expectedNextHop) + } + if r.NetProto != test.expectedNetProto { + t.Errorf("got r.NetProto = %d, want = %d", r.NetProto, test.expectedNetProto) + } + if r.Loop != test.expectedLoop { + t.Errorf("got r.Loop = %x, want = %x", r.Loop, test.expectedLoop) } }) } @@ -4091,10 +4100,12 @@ func TestFindRouteWithForwarding(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) + if r != nil { + defer r.Release() + } if err != test.findRouteErr { t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr) } - defer r.Release() if test.findRouteErr != nil { return @@ -4193,7 +4204,7 @@ func TestWritePacketToRemote(t *testing.T) { if got, want := pkt.Proto, test.protocol; got != want { t.Fatalf("pkt.Proto = %d, want %d", got, want) } - if got, want := pkt.Route.RemoteLinkAddress, linkAddr2; got != want { + if got, want := pkt.Route.RemoteLinkAddress(), linkAddr2; got != want { t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", got, want) } if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index fbac66993..d9769e47d 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -42,7 +42,7 @@ type fakeTransportEndpoint struct { proto *fakeTransportProtocol peerAddr tcpip.Address - route stack.Route + route *stack.Route uniqueID uint64 // acceptQueue is non-nil iff bound. @@ -65,6 +65,7 @@ func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {} func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions { return &f.ops } + func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} ep.ops.InitHandler(ep) @@ -117,21 +118,11 @@ func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Erro return tcpip.ErrInvalidEndpointState } -// SetSockOptBool sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOptBool(tcpip.SockOptBool, bool) *tcpip.Error { - return tcpip.ErrInvalidEndpointState -} - // SetSockOptInt sets a socket option. Currently not supported. func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error { return tcpip.ErrInvalidEndpointState } -// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (*fakeTransportEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - return false, tcpip.ErrUnknownProtocolOption -} - // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return -1, tcpip.ErrUnknownProtocolOption |