diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/stack.go | 116 |
1 files changed, 65 insertions, 51 deletions
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index e720d676f..66bf22823 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -385,6 +385,15 @@ type Stack struct { stats tcpip.Stats + // LOCK ORDERING: mu > route.mu. + route struct { + mu struct { + sync.RWMutex + + table []tcpip.Route + } + } + mu sync.RWMutex nics map[tcpip.NICID]*NIC @@ -392,11 +401,6 @@ type Stack struct { cleanupEndpointsMu sync.Mutex cleanupEndpoints map[TransportEndpoint]struct{} - // route is the route table passed in by the user via SetRouteTable(), - // it is used by FindRoute() to build a route for a specific - // destination. - routeTable []tcpip.Route - *ports.PortManager // If not nil, then any new endpoints will have this probe function @@ -813,38 +817,37 @@ func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool { // // This method takes ownership of the table. func (s *Stack) SetRouteTable(table []tcpip.Route) { - s.mu.Lock() - defer s.mu.Unlock() - - s.routeTable = table + s.route.mu.Lock() + defer s.route.mu.Unlock() + s.route.mu.table = table } // GetRouteTable returns the route table which is currently in use. func (s *Stack) GetRouteTable() []tcpip.Route { - s.mu.Lock() - defer s.mu.Unlock() - return append([]tcpip.Route(nil), s.routeTable...) + s.route.mu.RLock() + defer s.route.mu.RUnlock() + return append([]tcpip.Route(nil), s.route.mu.table...) } // AddRoute appends a route to the route table. func (s *Stack) AddRoute(route tcpip.Route) { - s.mu.Lock() - defer s.mu.Unlock() - s.routeTable = append(s.routeTable, route) + s.route.mu.Lock() + defer s.route.mu.Unlock() + s.route.mu.table = append(s.route.mu.table, route) } // RemoveRoutes removes matching routes from the route table. func (s *Stack) RemoveRoutes(match func(tcpip.Route) bool) { - s.mu.Lock() - defer s.mu.Unlock() + s.route.mu.Lock() + defer s.route.mu.Unlock() var filteredRoutes []tcpip.Route - for _, route := range s.routeTable { + for _, route := range s.route.mu.table { if !match(route) { filteredRoutes = append(filteredRoutes, route) } } - s.routeTable = filteredRoutes + s.route.mu.table = filteredRoutes } // NewEndpoint creates a new transport layer endpoint of the given protocol. @@ -1017,17 +1020,18 @@ func (s *Stack) removeNICLocked(id tcpip.NICID) tcpip.Error { delete(s.nics, id) // Remove routes in-place. n tracks the number of routes written. + s.route.mu.Lock() n := 0 - for i, r := range s.routeTable { - s.routeTable[i] = tcpip.Route{} + for i, r := range s.route.mu.table { + s.route.mu.table[i] = tcpip.Route{} if r.NIC != id { // Keep this route. - s.routeTable[n] = r + s.route.mu.table[n] = r n++ } } - - s.routeTable = s.routeTable[:n] + s.route.mu.table = s.route.mu.table[:n] + s.route.mu.Unlock() return nic.remove() } @@ -1352,39 +1356,49 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n // Find a route to the remote with the route table. var chosenRoute tcpip.Route - for _, route := range s.routeTable { - if len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr) { - continue - } + if r := func() *Route { + s.route.mu.RLock() + defer s.route.mu.RUnlock() - nic, ok := s.nics[route.NIC] - if !ok || !nic.Enabled() { - continue - } + for _, route := range s.route.mu.table { + if len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr) { + continue + } - if id == 0 || id == route.NIC { - if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { - var gateway tcpip.Address - if needRoute { - gateway = route.Gateway - } - r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop) - if r == nil { - panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr)) + nic, ok := s.nics[route.NIC] + if !ok || !nic.Enabled() { + continue + } + + if id == 0 || id == route.NIC { + if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { + var gateway tcpip.Address + if needRoute { + gateway = route.Gateway + } + r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop) + if r == nil { + panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr)) + } + return r } - return r, nil } - } - // If the stack has forwarding enabled and we haven't found a valid route to - // the remote address yet, keep track of the first valid route. We keep - // iterating because we prefer routes that let us use a local address that - // is assigned to the outgoing interface. There is no requirement to do this - // from any RFC but simply a choice made to better follow a strong host - // model which the netstack follows at the time of writing. - if canForward && chosenRoute == (tcpip.Route{}) { - chosenRoute = route + // If the stack has forwarding enabled and we haven't found a valid route + // to the remote address yet, keep track of the first valid route. We + // keep iterating because we prefer routes that let us use a local + // address that is assigned to the outgoing interface. There is no + // requirement to do this from any RFC but simply a choice made to better + // follow a strong host model which the netstack follows at the time of + // writing. + if canForward && chosenRoute == (tcpip.Route{}) { + chosenRoute = route + } } + + return nil + }(); r != nil { + return r, nil } if chosenRoute != (tcpip.Route{}) { |