summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/stack/stack.go116
1 files changed, 65 insertions, 51 deletions
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index e720d676f..66bf22823 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -385,6 +385,15 @@ type Stack struct {
stats tcpip.Stats
+ // LOCK ORDERING: mu > route.mu.
+ route struct {
+ mu struct {
+ sync.RWMutex
+
+ table []tcpip.Route
+ }
+ }
+
mu sync.RWMutex
nics map[tcpip.NICID]*NIC
@@ -392,11 +401,6 @@ type Stack struct {
cleanupEndpointsMu sync.Mutex
cleanupEndpoints map[TransportEndpoint]struct{}
- // route is the route table passed in by the user via SetRouteTable(),
- // it is used by FindRoute() to build a route for a specific
- // destination.
- routeTable []tcpip.Route
-
*ports.PortManager
// If not nil, then any new endpoints will have this probe function
@@ -813,38 +817,37 @@ func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool {
//
// This method takes ownership of the table.
func (s *Stack) SetRouteTable(table []tcpip.Route) {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- s.routeTable = table
+ s.route.mu.Lock()
+ defer s.route.mu.Unlock()
+ s.route.mu.table = table
}
// GetRouteTable returns the route table which is currently in use.
func (s *Stack) GetRouteTable() []tcpip.Route {
- s.mu.Lock()
- defer s.mu.Unlock()
- return append([]tcpip.Route(nil), s.routeTable...)
+ s.route.mu.RLock()
+ defer s.route.mu.RUnlock()
+ return append([]tcpip.Route(nil), s.route.mu.table...)
}
// AddRoute appends a route to the route table.
func (s *Stack) AddRoute(route tcpip.Route) {
- s.mu.Lock()
- defer s.mu.Unlock()
- s.routeTable = append(s.routeTable, route)
+ s.route.mu.Lock()
+ defer s.route.mu.Unlock()
+ s.route.mu.table = append(s.route.mu.table, route)
}
// RemoveRoutes removes matching routes from the route table.
func (s *Stack) RemoveRoutes(match func(tcpip.Route) bool) {
- s.mu.Lock()
- defer s.mu.Unlock()
+ s.route.mu.Lock()
+ defer s.route.mu.Unlock()
var filteredRoutes []tcpip.Route
- for _, route := range s.routeTable {
+ for _, route := range s.route.mu.table {
if !match(route) {
filteredRoutes = append(filteredRoutes, route)
}
}
- s.routeTable = filteredRoutes
+ s.route.mu.table = filteredRoutes
}
// NewEndpoint creates a new transport layer endpoint of the given protocol.
@@ -1017,17 +1020,18 @@ func (s *Stack) removeNICLocked(id tcpip.NICID) tcpip.Error {
delete(s.nics, id)
// Remove routes in-place. n tracks the number of routes written.
+ s.route.mu.Lock()
n := 0
- for i, r := range s.routeTable {
- s.routeTable[i] = tcpip.Route{}
+ for i, r := range s.route.mu.table {
+ s.route.mu.table[i] = tcpip.Route{}
if r.NIC != id {
// Keep this route.
- s.routeTable[n] = r
+ s.route.mu.table[n] = r
n++
}
}
-
- s.routeTable = s.routeTable[:n]
+ s.route.mu.table = s.route.mu.table[:n]
+ s.route.mu.Unlock()
return nic.remove()
}
@@ -1352,39 +1356,49 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
// Find a route to the remote with the route table.
var chosenRoute tcpip.Route
- for _, route := range s.routeTable {
- if len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr) {
- continue
- }
+ if r := func() *Route {
+ s.route.mu.RLock()
+ defer s.route.mu.RUnlock()
- nic, ok := s.nics[route.NIC]
- if !ok || !nic.Enabled() {
- continue
- }
+ for _, route := range s.route.mu.table {
+ if len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr) {
+ continue
+ }
- if id == 0 || id == route.NIC {
- if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
- var gateway tcpip.Address
- if needRoute {
- gateway = route.Gateway
- }
- r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop)
- if r == nil {
- panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr))
+ nic, ok := s.nics[route.NIC]
+ if !ok || !nic.Enabled() {
+ continue
+ }
+
+ if id == 0 || id == route.NIC {
+ if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
+ var gateway tcpip.Address
+ if needRoute {
+ gateway = route.Gateway
+ }
+ r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop)
+ if r == nil {
+ panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr))
+ }
+ return r
}
- return r, nil
}
- }
- // If the stack has forwarding enabled and we haven't found a valid route to
- // the remote address yet, keep track of the first valid route. We keep
- // iterating because we prefer routes that let us use a local address that
- // is assigned to the outgoing interface. There is no requirement to do this
- // from any RFC but simply a choice made to better follow a strong host
- // model which the netstack follows at the time of writing.
- if canForward && chosenRoute == (tcpip.Route{}) {
- chosenRoute = route
+ // If the stack has forwarding enabled and we haven't found a valid route
+ // to the remote address yet, keep track of the first valid route. We
+ // keep iterating because we prefer routes that let us use a local
+ // address that is assigned to the outgoing interface. There is no
+ // requirement to do this from any RFC but simply a choice made to better
+ // follow a strong host model which the netstack follows at the time of
+ // writing.
+ if canForward && chosenRoute == (tcpip.Route{}) {
+ chosenRoute = route
+ }
}
+
+ return nil
+ }(); r != nil {
+ return r, nil
}
if chosenRoute != (tcpip.Route{}) {