diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/nic.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 100 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 42 |
3 files changed, 96 insertions, 48 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 43696ba14..5f3757de0 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -267,7 +267,7 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb if ch, err := r.Resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { r := r.Clone() - n.stack.linkResQueue.enqueue(ch, &r, protocol, pkt) + n.stack.linkResQueue.enqueue(ch, r, protocol, pkt) return nil } return err diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index f0b256507..8cb7c5dd8 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -18,11 +18,18 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) // Route represents a route through the networking stack to a given destination. +// +// It is safe to call Route's methods from multiple goroutines. +// +// The exported fields are immutable. +// +// TODO(gvisor.dev/issue/4902): Unexpose immutable fields. type Route struct { // RemoteAddress is the final destination of the route. RemoteAddress tcpip.Address @@ -52,8 +59,12 @@ type Route struct { // address's assigned status without the NIC. localAddressNIC *NIC - // localAddressEndpoint is the local address this route is associated with. - localAddressEndpoint AssignableAddressEndpoint + mu struct { + sync.RWMutex + + // localAddressEndpoint is the local address this route is associated with. + localAddressEndpoint AssignableAddressEndpoint + } // outgoingNIC is the interface this route uses to write packets. outgoingNIC *NIC @@ -71,14 +82,14 @@ type Route struct { // ownership of the provided local address. // // Returns an empty route if validation fails. -func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) Route { +func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) *Route { if len(localAddr) == 0 { localAddr = addressEndpoint.AddressWithPrefix().Address } if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(localAddr) { addressEndpoint.DecRef() - return Route{} + return nil } // If no remote address is provided, use the local address. @@ -110,7 +121,7 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp // makeRoute initializes a new route. It takes ownership of the provided // AssignableAddressEndpoint. -func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route { +func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) *Route { if localAddressNIC.stack != outgoingNIC.stack { panic(fmt.Sprintf("cannot create a route with NICs from different stacks")) } @@ -139,18 +150,21 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop) } -func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) Route { - r := Route{ - NetProto: netProto, - LocalAddress: localAddr, - LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(), - RemoteAddress: remoteAddr, - localAddressNIC: localAddressNIC, - localAddressEndpoint: localAddressEndpoint, - outgoingNIC: outgoingNIC, - Loop: loop, +func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) *Route { + r := &Route{ + NetProto: netProto, + LocalAddress: localAddr, + LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(), + RemoteAddress: remoteAddr, + localAddressNIC: localAddressNIC, + outgoingNIC: outgoingNIC, + Loop: loop, } + r.mu.Lock() + r.mu.localAddressEndpoint = localAddressEndpoint + r.mu.Unlock() + if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok { r.linkRes = linkRes @@ -165,7 +179,7 @@ func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr // provided AssignableAddressEndpoint. // // A local route is a route to a destination that is local to the stack. -func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) Route { +func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) *Route { loop := PacketLoop // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the // link endpoint level. We can remove this check once loopback interfaces @@ -327,7 +341,10 @@ func (r *Route) isValidForOutgoing() bool { return false } - if !r.localAddressNIC.isValidForOutgoing(r.localAddressEndpoint) { + r.mu.RLock() + localAddressEndpoint := r.mu.localAddressEndpoint + r.mu.RUnlock() + if localAddressEndpoint == nil || !r.localAddressNIC.isValidForOutgoing(localAddressEndpoint) { return false } @@ -381,20 +398,44 @@ func (r *Route) MTU() uint32 { // Release frees all resources associated with the route. func (r *Route) Release() { - if r.localAddressEndpoint != nil { - r.localAddressEndpoint.DecRef() - r.localAddressEndpoint = nil + r.mu.Lock() + defer r.mu.Unlock() + + if r.mu.localAddressEndpoint != nil { + r.mu.localAddressEndpoint.DecRef() + r.mu.localAddressEndpoint = nil } } // Clone clones the route. -func (r *Route) Clone() Route { - if r.localAddressEndpoint != nil { - if !r.localAddressEndpoint.IncRef() { - panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress)) +func (r *Route) Clone() *Route { + r.mu.RLock() + defer r.mu.RUnlock() + + newRoute := &Route{ + RemoteAddress: r.RemoteAddress, + RemoteLinkAddress: r.RemoteLinkAddress, + LocalAddress: r.LocalAddress, + LocalLinkAddress: r.LocalLinkAddress, + NextHop: r.NextHop, + NetProto: r.NetProto, + Loop: r.Loop, + localAddressNIC: r.localAddressNIC, + outgoingNIC: r.outgoingNIC, + linkCache: r.linkCache, + linkRes: r.linkRes, + } + + newRoute.mu.Lock() + defer newRoute.mu.Unlock() + newRoute.mu.localAddressEndpoint = r.mu.localAddressEndpoint + if newRoute.mu.localAddressEndpoint != nil { + if !newRoute.mu.localAddressEndpoint.IncRef() { + panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", newRoute.LocalAddress)) } } - return *r + + return newRoute } // Stack returns the instance of the Stack that owns this route. @@ -407,7 +448,14 @@ func (r *Route) isV4Broadcast(addr tcpip.Address) bool { return true } - subnet := r.localAddressEndpoint.Subnet() + r.mu.RLock() + localAddressEndpoint := r.mu.localAddressEndpoint + r.mu.RUnlock() + if localAddressEndpoint == nil { + return false + } + + subnet := localAddressEndpoint.Subnet() return subnet.IsBroadcast(addr) } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index a2d234e7d..dc4f5b3e7 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1218,10 +1218,10 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP // from the specified NIC. // // Precondition: s.mu must be read locked. -func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) { +func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route { localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint) if localAddressEndpoint == nil { - return Route{}, false + return nil } var outgoingNIC *NIC @@ -1245,7 +1245,7 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re // route. if outgoingNIC == nil { localAddressEndpoint.DecRef() - return Route{}, false + return nil } r := makeLocalRoute( @@ -1259,10 +1259,10 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re if r.IsOutboundBroadcast() { r.Release() - return Route{}, false + return nil } - return r, true + return r } // findLocalRouteRLocked returns a local route. @@ -1271,26 +1271,26 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re // is, a local route is a route where packets never have to leave the stack. // // Precondition: s.mu must be read locked. -func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) { +func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route { if len(localAddr) == 0 { localAddr = remoteAddr } if localAddressNICID == 0 { for _, localAddressNIC := range s.nics { - if r, ok := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); ok { - return r, true + if r := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); r != nil { + return r } } - return Route{}, false + return nil } if localAddressNIC, ok := s.nics[localAddressNICID]; ok { return s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto) } - return Route{}, false + return nil } // FindRoute creates a route to the given destination address, leaving through @@ -1304,7 +1304,7 @@ func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, // If no local address is provided, the stack will select a local address. If no // remote address is provided, the stack wil use a remote address equal to the // local address. -func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) { +func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (*Route, *tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() @@ -1315,7 +1315,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || isLoopback) if s.handleLocal && !isMulticast && !isLocalBroadcast { - if r, ok := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); ok { + if r := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); r != nil { return r, nil } } @@ -1339,9 +1339,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } if isLoopback { - return Route{}, tcpip.ErrBadLocalAddress + return nil, tcpip.ErrBadLocalAddress } - return Route{}, tcpip.ErrNetworkUnreachable + return nil, tcpip.ErrNetworkUnreachable } canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal @@ -1365,7 +1365,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n gateway = route.Gateway } r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop) - if r == (Route{}) { + if r == nil { panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr)) } return r, nil @@ -1401,13 +1401,13 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if id != 0 { if aNIC, ok := s.nics[id]; ok { if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil { - if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil { return r, nil } } } - return Route{}, tcpip.ErrNoRoute + return nil, tcpip.ErrNoRoute } if id == 0 { @@ -1419,7 +1419,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n continue } - if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil { return r, nil } } @@ -1427,12 +1427,12 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } if needRoute { - return Route{}, tcpip.ErrNoRoute + return nil, tcpip.ErrNoRoute } if header.IsV6LoopbackAddress(remoteAddr) { - return Route{}, tcpip.ErrBadLocalAddress + return nil, tcpip.ErrBadLocalAddress } - return Route{}, tcpip.ErrNetworkUnreachable + return nil, tcpip.ErrNetworkUnreachable } // CheckNetworkProtocol checks if a given network protocol is enabled in the |