diff options
Diffstat (limited to 'pkg/tcpip/stack/stack.go')
-rw-r--r-- | pkg/tcpip/stack/stack.go | 303 |
1 files changed, 269 insertions, 34 deletions
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index ba0e1a7ec..a23fb97ff 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -22,6 +22,7 @@ package stack import ( "bytes" "encoding/binary" + "fmt" mathrand "math/rand" "sync/atomic" "time" @@ -52,7 +53,7 @@ const ( type transportProtocolState struct { proto TransportProtocol - defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool + defaultHandler func(id TransportEndpointID, pkt *PacketBuffer) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -759,7 +760,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *PacketBuffer) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(TransportEndpointID, *PacketBuffer) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h @@ -1202,59 +1203,225 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP return nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint) } +// findLocalRouteFromNICRLocked is like findLocalRouteRLocked but finds a route +// from the specified NIC. +// +// Precondition: s.mu must be read locked. +func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) { + localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint) + if localAddressEndpoint == nil { + return Route{}, false + } + + var outgoingNIC *NIC + // Prefer a local route to the same interface as the local address. + if localAddressNIC.hasAddress(netProto, remoteAddr) { + outgoingNIC = localAddressNIC + } + + // If the remote address isn't owned by the local address's NIC, check all + // NICs. + if outgoingNIC == nil { + for _, nic := range s.nics { + if nic.hasAddress(netProto, remoteAddr) { + outgoingNIC = nic + break + } + } + } + + // If the remote address is not owned by the stack, we can't return a local + // route. + if outgoingNIC == nil { + localAddressEndpoint.DecRef() + return Route{}, false + } + + r := makeLocalRoute( + netProto, + localAddressEndpoint.AddressWithPrefix().Address, + remoteAddr, + outgoingNIC, + localAddressNIC, + localAddressEndpoint, + ) + + if r.IsOutboundBroadcast() { + r.Release() + return Route{}, false + } + + return r, true +} + +// findLocalRouteRLocked returns a local route. +// +// A local route is a route to some remote address which the stack owns. That +// is, a local route is a route where packets never have to leave the stack. +// +// Precondition: s.mu must be read locked. +func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) { + if len(localAddr) == 0 { + localAddr = remoteAddr + } + + if localAddressNICID == 0 { + for _, localAddressNIC := range s.nics { + if r, ok := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); ok { + return r, true + } + } + + return Route{}, false + } + + if localAddressNIC, ok := s.nics[localAddressNICID]; ok { + return s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto) + } + + return Route{}, false +} + // FindRoute creates a route to the given destination address, leaving through -// the given nic and local address (if provided). +// the given NIC and local address (if provided). +// +// If a NIC is not specified, the returned route will leave through the same +// NIC as the NIC that has the local address assigned when forwarding is +// disabled. If forwarding is enabled and the NIC is unspecified, the route may +// leave through any interface unless the route is link-local. +// +// If no local address is provided, the stack will select a local address. If no +// remote address is provided, the stack wil use a remote address equal to the +// local address. func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() + isLinkLocal := header.IsV6LinkLocalAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr) isLocalBroadcast := remoteAddr == header.IPv4Broadcast isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr) - isLinkLocal := header.IsV6LinkLocalAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr) - IsLoopback := header.IsV4LoopbackAddress(remoteAddr) || header.IsV6LoopbackAddress(remoteAddr) - needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || IsLoopback) + isLoopback := header.IsV4LoopbackAddress(remoteAddr) || header.IsV6LoopbackAddress(remoteAddr) + needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || isLoopback) + + if s.handleLocal && !isMulticast && !isLocalBroadcast { + if r, ok := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); ok { + return r, nil + } + } + + // If the interface is specified and we do not need a route, return a route + // through the interface if the interface is valid and enabled. if id != 0 && !needRoute { if nic, ok := s.nics[id]; ok && nic.Enabled() { if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { - return makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()), nil + return makeRoute( + netProto, + addressEndpoint.AddressWithPrefix().Address, + remoteAddr, + nic, /* outboundNIC */ + nic, /* localAddressNIC*/ + addressEndpoint, + s.handleLocal, + multicastLoop, + ), nil } } - } else { - for _, route := range s.routeTable { - if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) { - continue + + if isLoopback { + return Route{}, tcpip.ErrBadLocalAddress + } + return Route{}, tcpip.ErrNetworkUnreachable + } + + canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal + + // Find a route to the remote with the route table. + var chosenRoute tcpip.Route + for _, route := range s.routeTable { + if len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr) { + continue + } + + nic, ok := s.nics[route.NIC] + if !ok || !nic.Enabled() { + continue + } + + if id == 0 || id == route.NIC { + if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { + var gateway tcpip.Address + if needRoute { + gateway = route.Gateway + } + r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop) + if r == (Route{}) { + panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr)) + } + return r, nil } - if nic, ok := s.nics[route.NIC]; ok && nic.Enabled() { - if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { - if len(remoteAddr) == 0 { - // If no remote address was provided, then the route - // provided will refer to the link local address. - remoteAddr = addressEndpoint.AddressWithPrefix().Address - } + } + + // If the stack has forwarding enabled and we haven't found a valid route to + // the remote address yet, keep track of the first valid route. We keep + // iterating because we prefer routes that let us use a local address that + // is assigned to the outgoing interface. There is no requirement to do this + // from any RFC but simply a choice made to better follow a strong host + // model which the netstack follows at the time of writing. + if canForward && chosenRoute == (tcpip.Route{}) { + chosenRoute = route + } + } + + if chosenRoute != (tcpip.Route{}) { + // At this point we know the stack has forwarding enabled since chosenRoute is + // only set when forwarding is enabled. + nic, ok := s.nics[chosenRoute.NIC] + if !ok { + // If the route's NIC was invalid, we should not have chosen the route. + panic(fmt.Sprintf("chosen route must have a valid NIC with ID = %d", chosenRoute.NIC)) + } + + var gateway tcpip.Address + if needRoute { + gateway = chosenRoute.Gateway + } - r := makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()) - if len(route.Gateway) > 0 { - if needRoute { - r.NextHop = route.Gateway - } - } else if subnet := addressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) { - r.RemoteLinkAddress = header.EthernetBroadcastAddress + // Use the specified NIC to get the local address endpoint. + if id != 0 { + if aNIC, ok := s.nics[id]; ok { + if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil { + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { + return r, nil } + } + } + + return Route{}, tcpip.ErrNoRoute + } + if id == 0 { + // If an interface is not specified, try to find a NIC that holds the local + // address endpoint to construct a route. + for _, aNIC := range s.nics { + addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto) + if addressEndpoint == nil { + continue + } + + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { return r, nil } } } } - if !needRoute { - if IsLoopback { - return Route{}, tcpip.ErrBadLocalAddress - } - return Route{}, tcpip.ErrNetworkUnreachable + if needRoute { + return Route{}, tcpip.ErrNoRoute } - - return Route{}, tcpip.ErrNoRoute + if isLoopback { + return Route{}, tcpip.ErrBadLocalAddress + } + return Route{}, tcpip.ErrNetworkUnreachable } // CheckNetworkProtocol checks if a given network protocol is enabled in the @@ -1470,8 +1637,8 @@ func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) { // FindTransportEndpoint finds an endpoint that most closely matches the provided // id. If no endpoint is found it returns nil. -func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { - return s.demux.findTransportEndpoint(netProto, transProto, id, r) +func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint { + return s.demux.findTransportEndpoint(netProto, transProto, id, nicID) } // RegisterRawTransportEndpoint registers the given endpoint with the stack @@ -1923,3 +2090,71 @@ func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { func (s *Stack) NewJob(l sync.Locker, f func()) *tcpip.Job { return tcpip.NewJob(s.clock, l, f) } + +// ParseResult indicates the result of a parsing attempt. +type ParseResult int + +const ( + // ParsedOK indicates that a packet was successfully parsed. + ParsedOK ParseResult = iota + + // UnknownNetworkProtocol indicates that the network protocol is unknown. + UnknownNetworkProtocol + + // NetworkLayerParseError indicates that the network packet was not + // successfully parsed. + NetworkLayerParseError + + // UnknownTransportProtocol indicates that the transport protocol is unknown. + UnknownTransportProtocol + + // TransportLayerParseError indicates that the transport packet was not + // successfully parsed. + TransportLayerParseError +) + +// ParsePacketBuffer parses the provided packet buffer. +func (s *Stack) ParsePacketBuffer(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) ParseResult { + netProto, ok := s.networkProtocols[protocol] + if !ok { + return UnknownNetworkProtocol + } + + transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) + if !ok { + return NetworkLayerParseError + } + if !hasTransportHdr { + return ParsedOK + } + + // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader + // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a + // full explanation. + if transProtoNum == header.ICMPv4ProtocolNumber || transProtoNum == header.ICMPv6ProtocolNumber { + return ParsedOK + } + + pkt.TransportProtocolNumber = transProtoNum + // Parse the transport header if present. + state, ok := s.transportProtocols[transProtoNum] + if !ok { + return UnknownTransportProtocol + } + + if !state.proto.Parse(pkt) { + return TransportLayerParseError + } + + return ParsedOK +} + +// networkProtocolNumbers returns the network protocol numbers the stack is +// configured with. +func (s *Stack) networkProtocolNumbers() []tcpip.NetworkProtocolNumber { + protos := make([]tcpip.NetworkProtocolNumber, 0, len(s.networkProtocols)) + for p := range s.networkProtocols { + protos = append(protos, p) + } + return protos +} |