summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack/stack.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack/stack.go')
-rw-r--r--pkg/tcpip/stack/stack.go303
1 files changed, 269 insertions, 34 deletions
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index ba0e1a7ec..a23fb97ff 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -22,6 +22,7 @@ package stack
import (
"bytes"
"encoding/binary"
+ "fmt"
mathrand "math/rand"
"sync/atomic"
"time"
@@ -52,7 +53,7 @@ const (
type transportProtocolState struct {
proto TransportProtocol
- defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool
+ defaultHandler func(id TransportEndpointID, pkt *PacketBuffer) bool
}
// TCPProbeFunc is the expected function type for a TCP probe function to be
@@ -759,7 +760,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber,
//
// It must be called only during initialization of the stack. Changing it as the
// stack is operating is not supported.
-func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *PacketBuffer) bool) {
+func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(TransportEndpointID, *PacketBuffer) bool) {
state := s.transportProtocols[p]
if state != nil {
state.defaultHandler = h
@@ -1202,59 +1203,225 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP
return nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint)
}
+// findLocalRouteFromNICRLocked is like findLocalRouteRLocked but finds a route
+// from the specified NIC.
+//
+// Precondition: s.mu must be read locked.
+func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) {
+ localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint)
+ if localAddressEndpoint == nil {
+ return Route{}, false
+ }
+
+ var outgoingNIC *NIC
+ // Prefer a local route to the same interface as the local address.
+ if localAddressNIC.hasAddress(netProto, remoteAddr) {
+ outgoingNIC = localAddressNIC
+ }
+
+ // If the remote address isn't owned by the local address's NIC, check all
+ // NICs.
+ if outgoingNIC == nil {
+ for _, nic := range s.nics {
+ if nic.hasAddress(netProto, remoteAddr) {
+ outgoingNIC = nic
+ break
+ }
+ }
+ }
+
+ // If the remote address is not owned by the stack, we can't return a local
+ // route.
+ if outgoingNIC == nil {
+ localAddressEndpoint.DecRef()
+ return Route{}, false
+ }
+
+ r := makeLocalRoute(
+ netProto,
+ localAddressEndpoint.AddressWithPrefix().Address,
+ remoteAddr,
+ outgoingNIC,
+ localAddressNIC,
+ localAddressEndpoint,
+ )
+
+ if r.IsOutboundBroadcast() {
+ r.Release()
+ return Route{}, false
+ }
+
+ return r, true
+}
+
+// findLocalRouteRLocked returns a local route.
+//
+// A local route is a route to some remote address which the stack owns. That
+// is, a local route is a route where packets never have to leave the stack.
+//
+// Precondition: s.mu must be read locked.
+func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) {
+ if len(localAddr) == 0 {
+ localAddr = remoteAddr
+ }
+
+ if localAddressNICID == 0 {
+ for _, localAddressNIC := range s.nics {
+ if r, ok := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); ok {
+ return r, true
+ }
+ }
+
+ return Route{}, false
+ }
+
+ if localAddressNIC, ok := s.nics[localAddressNICID]; ok {
+ return s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto)
+ }
+
+ return Route{}, false
+}
+
// FindRoute creates a route to the given destination address, leaving through
-// the given nic and local address (if provided).
+// the given NIC and local address (if provided).
+//
+// If a NIC is not specified, the returned route will leave through the same
+// NIC as the NIC that has the local address assigned when forwarding is
+// disabled. If forwarding is enabled and the NIC is unspecified, the route may
+// leave through any interface unless the route is link-local.
+//
+// If no local address is provided, the stack will select a local address. If no
+// remote address is provided, the stack wil use a remote address equal to the
+// local address.
func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
+ isLinkLocal := header.IsV6LinkLocalAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr)
isLocalBroadcast := remoteAddr == header.IPv4Broadcast
isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)
- isLinkLocal := header.IsV6LinkLocalAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr)
- IsLoopback := header.IsV4LoopbackAddress(remoteAddr) || header.IsV6LoopbackAddress(remoteAddr)
- needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || IsLoopback)
+ isLoopback := header.IsV4LoopbackAddress(remoteAddr) || header.IsV6LoopbackAddress(remoteAddr)
+ needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || isLoopback)
+
+ if s.handleLocal && !isMulticast && !isLocalBroadcast {
+ if r, ok := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); ok {
+ return r, nil
+ }
+ }
+
+ // If the interface is specified and we do not need a route, return a route
+ // through the interface if the interface is valid and enabled.
if id != 0 && !needRoute {
if nic, ok := s.nics[id]; ok && nic.Enabled() {
if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
- return makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()), nil
+ return makeRoute(
+ netProto,
+ addressEndpoint.AddressWithPrefix().Address,
+ remoteAddr,
+ nic, /* outboundNIC */
+ nic, /* localAddressNIC*/
+ addressEndpoint,
+ s.handleLocal,
+ multicastLoop,
+ ), nil
}
}
- } else {
- for _, route := range s.routeTable {
- if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) {
- continue
+
+ if isLoopback {
+ return Route{}, tcpip.ErrBadLocalAddress
+ }
+ return Route{}, tcpip.ErrNetworkUnreachable
+ }
+
+ canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal
+
+ // Find a route to the remote with the route table.
+ var chosenRoute tcpip.Route
+ for _, route := range s.routeTable {
+ if len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr) {
+ continue
+ }
+
+ nic, ok := s.nics[route.NIC]
+ if !ok || !nic.Enabled() {
+ continue
+ }
+
+ if id == 0 || id == route.NIC {
+ if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
+ var gateway tcpip.Address
+ if needRoute {
+ gateway = route.Gateway
+ }
+ r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop)
+ if r == (Route{}) {
+ panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr))
+ }
+ return r, nil
}
- if nic, ok := s.nics[route.NIC]; ok && nic.Enabled() {
- if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
- if len(remoteAddr) == 0 {
- // If no remote address was provided, then the route
- // provided will refer to the link local address.
- remoteAddr = addressEndpoint.AddressWithPrefix().Address
- }
+ }
+
+ // If the stack has forwarding enabled and we haven't found a valid route to
+ // the remote address yet, keep track of the first valid route. We keep
+ // iterating because we prefer routes that let us use a local address that
+ // is assigned to the outgoing interface. There is no requirement to do this
+ // from any RFC but simply a choice made to better follow a strong host
+ // model which the netstack follows at the time of writing.
+ if canForward && chosenRoute == (tcpip.Route{}) {
+ chosenRoute = route
+ }
+ }
+
+ if chosenRoute != (tcpip.Route{}) {
+ // At this point we know the stack has forwarding enabled since chosenRoute is
+ // only set when forwarding is enabled.
+ nic, ok := s.nics[chosenRoute.NIC]
+ if !ok {
+ // If the route's NIC was invalid, we should not have chosen the route.
+ panic(fmt.Sprintf("chosen route must have a valid NIC with ID = %d", chosenRoute.NIC))
+ }
+
+ var gateway tcpip.Address
+ if needRoute {
+ gateway = chosenRoute.Gateway
+ }
- r := makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback())
- if len(route.Gateway) > 0 {
- if needRoute {
- r.NextHop = route.Gateway
- }
- } else if subnet := addressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) {
- r.RemoteLinkAddress = header.EthernetBroadcastAddress
+ // Use the specified NIC to get the local address endpoint.
+ if id != 0 {
+ if aNIC, ok := s.nics[id]; ok {
+ if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil {
+ if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) {
+ return r, nil
}
+ }
+ }
+
+ return Route{}, tcpip.ErrNoRoute
+ }
+ if id == 0 {
+ // If an interface is not specified, try to find a NIC that holds the local
+ // address endpoint to construct a route.
+ for _, aNIC := range s.nics {
+ addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto)
+ if addressEndpoint == nil {
+ continue
+ }
+
+ if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) {
return r, nil
}
}
}
}
- if !needRoute {
- if IsLoopback {
- return Route{}, tcpip.ErrBadLocalAddress
- }
- return Route{}, tcpip.ErrNetworkUnreachable
+ if needRoute {
+ return Route{}, tcpip.ErrNoRoute
}
-
- return Route{}, tcpip.ErrNoRoute
+ if isLoopback {
+ return Route{}, tcpip.ErrBadLocalAddress
+ }
+ return Route{}, tcpip.ErrNetworkUnreachable
}
// CheckNetworkProtocol checks if a given network protocol is enabled in the
@@ -1470,8 +1637,8 @@ func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) {
// FindTransportEndpoint finds an endpoint that most closely matches the provided
// id. If no endpoint is found it returns nil.
-func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint {
- return s.demux.findTransportEndpoint(netProto, transProto, id, r)
+func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint {
+ return s.demux.findTransportEndpoint(netProto, transProto, id, nicID)
}
// RegisterRawTransportEndpoint registers the given endpoint with the stack
@@ -1923,3 +2090,71 @@ func (s *Stack) FindNICNameFromID(id tcpip.NICID) string {
func (s *Stack) NewJob(l sync.Locker, f func()) *tcpip.Job {
return tcpip.NewJob(s.clock, l, f)
}
+
+// ParseResult indicates the result of a parsing attempt.
+type ParseResult int
+
+const (
+ // ParsedOK indicates that a packet was successfully parsed.
+ ParsedOK ParseResult = iota
+
+ // UnknownNetworkProtocol indicates that the network protocol is unknown.
+ UnknownNetworkProtocol
+
+ // NetworkLayerParseError indicates that the network packet was not
+ // successfully parsed.
+ NetworkLayerParseError
+
+ // UnknownTransportProtocol indicates that the transport protocol is unknown.
+ UnknownTransportProtocol
+
+ // TransportLayerParseError indicates that the transport packet was not
+ // successfully parsed.
+ TransportLayerParseError
+)
+
+// ParsePacketBuffer parses the provided packet buffer.
+func (s *Stack) ParsePacketBuffer(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) ParseResult {
+ netProto, ok := s.networkProtocols[protocol]
+ if !ok {
+ return UnknownNetworkProtocol
+ }
+
+ transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt)
+ if !ok {
+ return NetworkLayerParseError
+ }
+ if !hasTransportHdr {
+ return ParsedOK
+ }
+
+ // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
+ // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a
+ // full explanation.
+ if transProtoNum == header.ICMPv4ProtocolNumber || transProtoNum == header.ICMPv6ProtocolNumber {
+ return ParsedOK
+ }
+
+ pkt.TransportProtocolNumber = transProtoNum
+ // Parse the transport header if present.
+ state, ok := s.transportProtocols[transProtoNum]
+ if !ok {
+ return UnknownTransportProtocol
+ }
+
+ if !state.proto.Parse(pkt) {
+ return TransportLayerParseError
+ }
+
+ return ParsedOK
+}
+
+// networkProtocolNumbers returns the network protocol numbers the stack is
+// configured with.
+func (s *Stack) networkProtocolNumbers() []tcpip.NetworkProtocolNumber {
+ protos := make([]tcpip.NetworkProtocolNumber, 0, len(s.networkProtocols))
+ for p := range s.networkProtocols {
+ protos = append(protos, p)
+ }
+ return protos
+}