diff options
author | Ghanan Gowripalan <ghanan@google.com> | 2021-02-05 16:44:49 -0800 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-02-05 16:47:11 -0800 |
commit | 24416032ab848cff7696b3f37e4c18220aeee2c0 (patch) | |
tree | dd50fbdc84304102a6f9dcb7f6bfa594299bc4d7 /pkg/tcpip/stack | |
parent | 3514c289a9c9da232bf3054c971c3e0434d8cfa3 (diff) |
Refactor locally delivered packets
Make it clear that failing to parse a looped back is not a packet
sending error but a malformed received packet error.
FindNetworkEndpoint returns nil when no network endpoint is found
instead of an error.
PiperOrigin-RevId: 355954946
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/forwarding_test.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 30 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 44 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 4 |
4 files changed, 24 insertions, 58 deletions
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index c24f56ece..0cb9ec3a3 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -75,6 +75,10 @@ func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { } func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) { + if _, _, ok := f.proto.Parse(pkt); !ok { + return + } + netHdr := pkt.NetworkHeader().View() _, dst := f.proto.ParseAddresses(netHdr) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 41a489047..6f2a0e487 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -777,36 +777,6 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp anyEPs.forEach(deliverPacketEPs) } - // Parse headers. - netProto := n.stack.NetworkProtocolInstance(protocol) - transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) - if !ok { - // The packet is too small to contain a network header. - n.stack.stats.MalformedRcvdPackets.Increment() - return - } - if hasTransportHdr { - pkt.TransportProtocolNumber = transProtoNum - // Parse the transport header if present. - if state, ok := n.stack.transportProtocols[transProtoNum]; ok { - state.proto.Parse(pkt) - } - } - - if n.stack.handleLocal && !n.IsLoopback() { - src, _ := netProto.ParseAddresses(pkt.NetworkHeader().View()) - if r := n.getAddress(protocol, src); r != nil { - r.DecRef() - - // The source address is one of our own, so we never should have gotten a - // packet like this unless handleLocal is false. Loopback also calls this - // function even though the packets didn't come from the physical interface - // so don't drop those. - n.stack.stats.IP.InvalidSourceAddressesReceived.Increment() - return - } - } - networkEndpoint.HandlePacket(pkt) } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index a51d758d0..035ab33ca 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1319,6 +1319,11 @@ func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, return nil } +// HandleLocal returns true if non-loopback interfaces are allowed to loop packets. +func (s *Stack) HandleLocal() bool { + return s.handleLocal +} + // FindRoute creates a route to the given destination address, leaving through // the given NIC and local address (if provided). // @@ -2063,7 +2068,9 @@ func generateRandInt64() int64 { } // FindNetworkEndpoint returns the network endpoint for the given address. -func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, tcpip.Error) { +// +// Returns nil if the address is not associated with any network endpoint. +func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) NetworkEndpoint { s.mu.RLock() defer s.mu.RUnlock() @@ -2073,9 +2080,9 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres continue } addressEndpoint.DecRef() - return nic.getNetworkEndpoint(netProto), nil + return nic.getNetworkEndpoint(netProto) } - return nil, &tcpip.ErrBadAddress{} + return nil } // FindNICNameFromID returns the name of the NIC for the given NICID. @@ -2103,13 +2110,6 @@ const ( // ParsedOK indicates that a packet was successfully parsed. ParsedOK ParseResult = iota - // UnknownNetworkProtocol indicates that the network protocol is unknown. - UnknownNetworkProtocol - - // NetworkLayerParseError indicates that the network packet was not - // successfully parsed. - NetworkLayerParseError - // UnknownTransportProtocol indicates that the transport protocol is unknown. UnknownTransportProtocol @@ -2118,31 +2118,19 @@ const ( TransportLayerParseError ) -// ParsePacketBuffer parses the provided packet buffer. -func (s *Stack) ParsePacketBuffer(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) ParseResult { - netProto, ok := s.networkProtocols[protocol] - if !ok { - return UnknownNetworkProtocol - } - - transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) - if !ok { - return NetworkLayerParseError - } - if !hasTransportHdr { - return ParsedOK - } - +// ParsePacketBufferTransport parses the provided packet buffer's transport +// header. +func (s *Stack) ParsePacketBufferTransport(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) ParseResult { // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a // full explanation. - if transProtoNum == header.ICMPv4ProtocolNumber || transProtoNum == header.ICMPv6ProtocolNumber { + if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { return ParsedOK } - pkt.TransportProtocolNumber = transProtoNum + pkt.TransportProtocolNumber = protocol // Parse the transport header if present. - state, ok := s.transportProtocols[transProtoNum] + state, ok := s.transportProtocols[protocol] if !ok { return UnknownTransportProtocol } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index b641a4aaa..b3386f705 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -119,6 +119,10 @@ func (*fakeNetworkEndpoint) DefaultTTL() uint8 { } func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { + if _, _, ok := f.proto.Parse(pkt); !ok { + return + } + // Increment the received packet count in the protocol descriptor. netHdr := pkt.NetworkHeader().View() |