summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2020-11-25 14:51:18 -0800
committergVisor bot <gvisor-bot@google.com>2020-11-25 14:52:59 -0800
commit2485a4e2cb4aaee8f1a5e760541fb02e9090de44 (patch)
treebf20ce486235c1bb9c9261a64f24823c822dafd5 /pkg/tcpip/stack
parent4d59a5a62223b56927b37a00cd5a6dea577fe4c6 (diff)
Make stack.Route safe to access concurrently
Multiple goroutines may use the same stack.Route concurrently so the stack.Route should make sure that any functions called on it are thread-safe. Fixes #4073 PiperOrigin-RevId: 344320491
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/ndp_test.go12
-rw-r--r--pkg/tcpip/stack/nic.go2
-rw-r--r--pkg/tcpip/stack/route.go100
-rw-r--r--pkg/tcpip/stack/stack.go42
-rw-r--r--pkg/tcpip/stack/stack_test.go36
-rw-r--r--pkg/tcpip/stack/transport_test.go2
6 files changed, 125 insertions, 69 deletions
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 7b0cd58f6..5c6bb6b0d 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -465,14 +465,18 @@ func TestDADResolve(t *testing.T) {
if err != tcpip.ErrNoRoute {
t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
}
- r.Release()
+ if r != nil {
+ r.Release()
+ }
}
{
r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false)
if err != tcpip.ErrNoRoute {
t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
}
- r.Release()
+ if r != nil {
+ r.Release()
+ }
}
if t.Failed() {
@@ -510,7 +514,9 @@ func TestDADResolve(t *testing.T) {
} else if r.LocalAddress != addr1 {
t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1)
}
- r.Release()
+ if r != nil {
+ r.Release()
+ }
}
if t.Failed() {
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 43696ba14..5f3757de0 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -267,7 +267,7 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb
if ch, err := r.Resolve(nil); err != nil {
if err == tcpip.ErrWouldBlock {
r := r.Clone()
- n.stack.linkResQueue.enqueue(ch, &r, protocol, pkt)
+ n.stack.linkResQueue.enqueue(ch, r, protocol, pkt)
return nil
}
return err
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index f0b256507..8cb7c5dd8 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -18,11 +18,18 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
// Route represents a route through the networking stack to a given destination.
+//
+// It is safe to call Route's methods from multiple goroutines.
+//
+// The exported fields are immutable.
+//
+// TODO(gvisor.dev/issue/4902): Unexpose immutable fields.
type Route struct {
// RemoteAddress is the final destination of the route.
RemoteAddress tcpip.Address
@@ -52,8 +59,12 @@ type Route struct {
// address's assigned status without the NIC.
localAddressNIC *NIC
- // localAddressEndpoint is the local address this route is associated with.
- localAddressEndpoint AssignableAddressEndpoint
+ mu struct {
+ sync.RWMutex
+
+ // localAddressEndpoint is the local address this route is associated with.
+ localAddressEndpoint AssignableAddressEndpoint
+ }
// outgoingNIC is the interface this route uses to write packets.
outgoingNIC *NIC
@@ -71,14 +82,14 @@ type Route struct {
// ownership of the provided local address.
//
// Returns an empty route if validation fails.
-func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) Route {
+func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) *Route {
if len(localAddr) == 0 {
localAddr = addressEndpoint.AddressWithPrefix().Address
}
if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(localAddr) {
addressEndpoint.DecRef()
- return Route{}
+ return nil
}
// If no remote address is provided, use the local address.
@@ -110,7 +121,7 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp
// makeRoute initializes a new route. It takes ownership of the provided
// AssignableAddressEndpoint.
-func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route {
+func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) *Route {
if localAddressNIC.stack != outgoingNIC.stack {
panic(fmt.Sprintf("cannot create a route with NICs from different stacks"))
}
@@ -139,18 +150,21 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop)
}
-func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) Route {
- r := Route{
- NetProto: netProto,
- LocalAddress: localAddr,
- LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(),
- RemoteAddress: remoteAddr,
- localAddressNIC: localAddressNIC,
- localAddressEndpoint: localAddressEndpoint,
- outgoingNIC: outgoingNIC,
- Loop: loop,
+func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) *Route {
+ r := &Route{
+ NetProto: netProto,
+ LocalAddress: localAddr,
+ LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(),
+ RemoteAddress: remoteAddr,
+ localAddressNIC: localAddressNIC,
+ outgoingNIC: outgoingNIC,
+ Loop: loop,
}
+ r.mu.Lock()
+ r.mu.localAddressEndpoint = localAddressEndpoint
+ r.mu.Unlock()
+
if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 {
if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok {
r.linkRes = linkRes
@@ -165,7 +179,7 @@ func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr
// provided AssignableAddressEndpoint.
//
// A local route is a route to a destination that is local to the stack.
-func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) Route {
+func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) *Route {
loop := PacketLoop
// TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the
// link endpoint level. We can remove this check once loopback interfaces
@@ -327,7 +341,10 @@ func (r *Route) isValidForOutgoing() bool {
return false
}
- if !r.localAddressNIC.isValidForOutgoing(r.localAddressEndpoint) {
+ r.mu.RLock()
+ localAddressEndpoint := r.mu.localAddressEndpoint
+ r.mu.RUnlock()
+ if localAddressEndpoint == nil || !r.localAddressNIC.isValidForOutgoing(localAddressEndpoint) {
return false
}
@@ -381,20 +398,44 @@ func (r *Route) MTU() uint32 {
// Release frees all resources associated with the route.
func (r *Route) Release() {
- if r.localAddressEndpoint != nil {
- r.localAddressEndpoint.DecRef()
- r.localAddressEndpoint = nil
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.mu.localAddressEndpoint != nil {
+ r.mu.localAddressEndpoint.DecRef()
+ r.mu.localAddressEndpoint = nil
}
}
// Clone clones the route.
-func (r *Route) Clone() Route {
- if r.localAddressEndpoint != nil {
- if !r.localAddressEndpoint.IncRef() {
- panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress))
+func (r *Route) Clone() *Route {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+
+ newRoute := &Route{
+ RemoteAddress: r.RemoteAddress,
+ RemoteLinkAddress: r.RemoteLinkAddress,
+ LocalAddress: r.LocalAddress,
+ LocalLinkAddress: r.LocalLinkAddress,
+ NextHop: r.NextHop,
+ NetProto: r.NetProto,
+ Loop: r.Loop,
+ localAddressNIC: r.localAddressNIC,
+ outgoingNIC: r.outgoingNIC,
+ linkCache: r.linkCache,
+ linkRes: r.linkRes,
+ }
+
+ newRoute.mu.Lock()
+ defer newRoute.mu.Unlock()
+ newRoute.mu.localAddressEndpoint = r.mu.localAddressEndpoint
+ if newRoute.mu.localAddressEndpoint != nil {
+ if !newRoute.mu.localAddressEndpoint.IncRef() {
+ panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", newRoute.LocalAddress))
}
}
- return *r
+
+ return newRoute
}
// Stack returns the instance of the Stack that owns this route.
@@ -407,7 +448,14 @@ func (r *Route) isV4Broadcast(addr tcpip.Address) bool {
return true
}
- subnet := r.localAddressEndpoint.Subnet()
+ r.mu.RLock()
+ localAddressEndpoint := r.mu.localAddressEndpoint
+ r.mu.RUnlock()
+ if localAddressEndpoint == nil {
+ return false
+ }
+
+ subnet := localAddressEndpoint.Subnet()
return subnet.IsBroadcast(addr)
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index a2d234e7d..dc4f5b3e7 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -1218,10 +1218,10 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP
// from the specified NIC.
//
// Precondition: s.mu must be read locked.
-func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) {
+func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route {
localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint)
if localAddressEndpoint == nil {
- return Route{}, false
+ return nil
}
var outgoingNIC *NIC
@@ -1245,7 +1245,7 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re
// route.
if outgoingNIC == nil {
localAddressEndpoint.DecRef()
- return Route{}, false
+ return nil
}
r := makeLocalRoute(
@@ -1259,10 +1259,10 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re
if r.IsOutboundBroadcast() {
r.Release()
- return Route{}, false
+ return nil
}
- return r, true
+ return r
}
// findLocalRouteRLocked returns a local route.
@@ -1271,26 +1271,26 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re
// is, a local route is a route where packets never have to leave the stack.
//
// Precondition: s.mu must be read locked.
-func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) {
+func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route {
if len(localAddr) == 0 {
localAddr = remoteAddr
}
if localAddressNICID == 0 {
for _, localAddressNIC := range s.nics {
- if r, ok := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); ok {
- return r, true
+ if r := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); r != nil {
+ return r
}
}
- return Route{}, false
+ return nil
}
if localAddressNIC, ok := s.nics[localAddressNICID]; ok {
return s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto)
}
- return Route{}, false
+ return nil
}
// FindRoute creates a route to the given destination address, leaving through
@@ -1304,7 +1304,7 @@ func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr,
// If no local address is provided, the stack will select a local address. If no
// remote address is provided, the stack wil use a remote address equal to the
// local address.
-func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) {
+func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (*Route, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -1315,7 +1315,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || isLoopback)
if s.handleLocal && !isMulticast && !isLocalBroadcast {
- if r, ok := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); ok {
+ if r := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); r != nil {
return r, nil
}
}
@@ -1339,9 +1339,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
if isLoopback {
- return Route{}, tcpip.ErrBadLocalAddress
+ return nil, tcpip.ErrBadLocalAddress
}
- return Route{}, tcpip.ErrNetworkUnreachable
+ return nil, tcpip.ErrNetworkUnreachable
}
canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal
@@ -1365,7 +1365,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
gateway = route.Gateway
}
r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop)
- if r == (Route{}) {
+ if r == nil {
panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr))
}
return r, nil
@@ -1401,13 +1401,13 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if id != 0 {
if aNIC, ok := s.nics[id]; ok {
if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil {
- if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) {
+ if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil {
return r, nil
}
}
}
- return Route{}, tcpip.ErrNoRoute
+ return nil, tcpip.ErrNoRoute
}
if id == 0 {
@@ -1419,7 +1419,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
continue
}
- if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) {
+ if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil {
return r, nil
}
}
@@ -1427,12 +1427,12 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
if needRoute {
- return Route{}, tcpip.ErrNoRoute
+ return nil, tcpip.ErrNoRoute
}
if header.IsV6LoopbackAddress(remoteAddr) {
- return Route{}, tcpip.ErrBadLocalAddress
+ return nil, tcpip.ErrBadLocalAddress
}
- return Route{}, tcpip.ErrNetworkUnreachable
+ return nil, tcpip.ErrNetworkUnreachable
}
// CheckNetworkProtocol checks if a given network protocol is enabled in the
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 9d2d0aa84..fb0f1d8f7 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -407,7 +407,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro
return send(r, payload)
}
-func send(r stack.Route, payload buffer.View) *tcpip.Error {
+func send(r *stack.Route, payload buffer.View) *tcpip.Error {
return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: payload.ToVectorisedView(),
@@ -425,7 +425,7 @@ func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.En
}
}
-func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) {
+func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View) {
t.Helper()
ep.Drain()
if err := send(r, payload); err != nil {
@@ -436,7 +436,7 @@ func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.
}
}
-func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
t.Helper()
if gotErr := send(r, payload); gotErr != wantErr {
t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr)
@@ -1563,7 +1563,7 @@ func TestSpoofingNoAddress(t *testing.T) {
// testSendTo(t, s, remoteAddr, ep, nil)
}
-func verifyRoute(gotRoute, wantRoute stack.Route) error {
+func verifyRoute(gotRoute, wantRoute *stack.Route) error {
if gotRoute.LocalAddress != wantRoute.LocalAddress {
return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress)
}
@@ -1603,7 +1603,7 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ if err := verifyRoute(r, &stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil {
t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1657,7 +1657,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ if err := verifyRoute(r, &stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1667,7 +1667,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ if err := verifyRoute(r, &stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err)
}
@@ -1683,7 +1683,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
if err != nil {
t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
}
- if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ if err := verifyRoute(r, &stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err)
}
}
@@ -3355,7 +3355,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
nicAddr tcpip.ProtocolAddress
routes []tcpip.Route
remoteAddr tcpip.Address
- expectedRoute stack.Route
+ expectedRoute *stack.Route
}{
// Broadcast to a locally attached subnet populates the broadcast MAC.
{
@@ -3371,7 +3371,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
},
},
remoteAddr: ipv4SubnetBcast,
- expectedRoute: stack.Route{
+ expectedRoute: &stack.Route{
LocalAddress: ipv4Addr.Address,
RemoteAddress: ipv4SubnetBcast,
RemoteLinkAddress: header.EthernetBroadcastAddress,
@@ -3394,7 +3394,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
},
},
remoteAddr: ipv4Subnet31Bcast,
- expectedRoute: stack.Route{
+ expectedRoute: &stack.Route{
LocalAddress: ipv4AddrPrefix31.Address,
RemoteAddress: ipv4Subnet31Bcast,
NetProto: header.IPv4ProtocolNumber,
@@ -3416,7 +3416,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
},
},
remoteAddr: ipv4Subnet32Bcast,
- expectedRoute: stack.Route{
+ expectedRoute: &stack.Route{
LocalAddress: ipv4AddrPrefix32.Address,
RemoteAddress: ipv4Subnet32Bcast,
NetProto: header.IPv4ProtocolNumber,
@@ -3437,7 +3437,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
},
},
remoteAddr: ipv6SubnetBcast,
- expectedRoute: stack.Route{
+ expectedRoute: &stack.Route{
LocalAddress: ipv6Addr.Address,
RemoteAddress: ipv6SubnetBcast,
NetProto: header.IPv6ProtocolNumber,
@@ -3460,7 +3460,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
},
},
remoteAddr: remNetSubnetBcast,
- expectedRoute: stack.Route{
+ expectedRoute: &stack.Route{
LocalAddress: ipv4Addr.Address,
RemoteAddress: remNetSubnetBcast,
NextHop: ipv4Gateway,
@@ -3485,7 +3485,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
},
},
remoteAddr: remNetSubnetBcast,
- expectedRoute: stack.Route{
+ expectedRoute: &stack.Route{
LocalAddress: ipv4Addr.Address,
RemoteAddress: remNetSubnetBcast,
NextHop: ipv4Gateway,
@@ -3522,7 +3522,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
if r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */); err != nil {
t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, test.remoteAddr, netProto, err)
- } else if diff := cmp.Diff(r, test.expectedRoute, cmpopts.IgnoreUnexported(r)); diff != "" {
+ } else if diff := cmp.Diff(r, test.expectedRoute, cmpopts.IgnoreUnexported(stack.Route{})); diff != "" {
t.Errorf("route mismatch (-want +got):\n%s", diff)
}
})
@@ -4091,10 +4091,12 @@ func TestFindRouteWithForwarding(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}})
r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
+ if r != nil {
+ defer r.Release()
+ }
if err != test.findRouteErr {
t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr)
}
- defer r.Release()
if test.findRouteErr != nil {
return
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index fbac66993..3c6ec0c3a 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -42,7 +42,7 @@ type fakeTransportEndpoint struct {
proto *fakeTransportProtocol
peerAddr tcpip.Address
- route stack.Route
+ route *stack.Route
uniqueID uint64
// acceptQueue is non-nil iff bound.