From bdf4e41c863ce025c67bfd30b5c52d15bdc54ced Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Tue, 19 Oct 2021 17:22:45 -0700 Subject: Always parse Transport headers ..including ICMP headers before delivering them to the TransportDispatcher. Updates #3810. PiperOrigin-RevId: 404404002 --- pkg/tcpip/stack/nic.go | 19 ++----------------- pkg/tcpip/stack/stack.go | 6 ------ pkg/tcpip/stack/stack_test.go | 18 +++++++++++++++--- pkg/tcpip/stack/transport_test.go | 7 +++++-- 4 files changed, 22 insertions(+), 28 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 29d580e76..e251e3b24 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -833,24 +833,9 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt transProto := state.proto - // TransportHeader is empty only when pkt is an ICMP packet or was reassembled - // from fragments. if pkt.TransportHeader().View().IsEmpty() { - // ICMP packets don't have their TransportHeader fields set yet, parse it - // here. See icmp/protocol.go:protocol.Parse for a full explanation. - if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { - // ICMP packets may be longer, but until icmp.Parse is implemented, here - // we parse it using the minimum size. - if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok { - n.stats.malformedL4RcvdPackets.Increment() - // We consider a malformed transport packet handled because there is - // nothing the caller can do. - return TransportPacketHandled - } - } else if !transProto.Parse(pkt) { - n.stats.malformedL4RcvdPackets.Increment() - return TransportPacketHandled - } + n.stats.malformedL4RcvdPackets.Increment() + return TransportPacketHandled } srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View()) diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index ee6767654..a05fd7036 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1865,12 +1865,6 @@ const ( // ParsePacketBufferTransport parses the provided packet buffer's transport // header. func (s *Stack) ParsePacketBufferTransport(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) ParseResult { - // ICMP packets don't have their TransportHeader fields set yet, parse it - // here. See icmp/protocol.go:protocol.Parse for a full explanation. - if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { - return ParsedOK - } - pkt.TransportProtocolNumber = protocol // Parse the transport header if present. state, ok := s.transportProtocols[protocol] diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index c23e91702..f5a35eac4 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -155,8 +155,18 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { return } + transProtoNum := tcpip.TransportProtocolNumber(netHdr[protocolNumberOffset]) + switch err := f.proto.stack.ParsePacketBufferTransport(transProtoNum, pkt); err { + case stack.ParsedOK: + case stack.UnknownTransportProtocol, stack.TransportLayerParseError: + // The transport layer will handle unknown protocols and transport layer + // parsing errors. + default: + panic(fmt.Sprintf("unexpected error parsing transport header = %d", err)) + } + // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(transProtoNum, pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -218,6 +228,8 @@ func (*fakeNetworkEndpointStats) IsNetworkEndpointStats() {} // number of packets sent and received via endpoints of this protocol. The index // where packets are added is given by the packet's destination address MOD 10. type fakeNetworkProtocol struct { + stack *stack.Stack + packetCount [10]int sendPacketCount [10]int defaultTTL uint8 @@ -299,8 +311,8 @@ func (f *fakeNetworkEndpoint) SetForwarding(v bool) { f.mu.forwarding = v } -func fakeNetFactory(*stack.Stack) stack.NetworkProtocol { - return &fakeNetworkProtocol{} +func fakeNetFactory(s *stack.Stack) stack.NetworkProtocol { + return &fakeNetworkProtocol{stack: s} } // linkEPWithMockedAttach is a stack.LinkEndpoint that tests can use to verify diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 655931715..51870d03f 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -331,8 +331,11 @@ func (*fakeTransportProtocol) Wait() {} // Parse implements TransportProtocol.Parse. func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool { - _, ok := pkt.TransportHeader().Consume(fakeTransHeaderLen) - return ok + if _, ok := pkt.TransportHeader().Consume(fakeTransHeaderLen); ok { + pkt.TransportProtocolNumber = fakeTransNumber + return true + } + return false } func fakeTransFactory(s *stack.Stack) stack.TransportProtocol { -- cgit v1.2.3