diff options
author | Ghanan Gowripalan <ghanan@google.com> | 2021-10-19 17:22:45 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-10-19 17:25:55 -0700 |
commit | bdf4e41c863ce025c67bfd30b5c52d15bdc54ced (patch) | |
tree | 26d1cf814ed2e03d90000cf6fef4a57f3264b5ae /pkg/tcpip | |
parent | 6dde3d5ae51030fceb0798d671d19ec1df3ae7a3 (diff) |
Always parse Transport headers
..including ICMP headers before delivering them to the
TransportDispatcher.
Updates #3810.
PiperOrigin-RevId: 404404002
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/header/parse/parse.go | 68 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/icmp.go | 29 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 30 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 81 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 38 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 19 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 18 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 7 |
9 files changed, 161 insertions, 135 deletions
diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go index 1c913b5e1..80a9ad6be 100644 --- a/pkg/tcpip/header/parse/parse.go +++ b/pkg/tcpip/header/parse/parse.go @@ -110,6 +110,16 @@ traverseExtensions: switch extHdr := extHdr.(type) { case header.IPv6FragmentExtHdr: + if extHdr.IsAtomic() { + // This fragment extension header indicates that this packet is an + // atomic fragment. An atomic fragment is a fragment that contains + // all the data required to reassemble a full packet. As per RFC 6946, + // atomic fragments must not interfere with "normal" fragmented traffic + // so we skip processing the fragment instead of feeding it through the + // reassembly process below. + continue + } + if fragID == 0 && fragOffset == 0 && !fragMore { fragID = extHdr.ID() fragOffset = extHdr.FragmentOffset() @@ -175,3 +185,61 @@ func TCP(pkt *stack.PacketBuffer) bool { pkt.TransportProtocolNumber = header.TCPProtocolNumber return ok } + +// ICMPv4 populates the packet buffer's transport header with an ICMPv4 header, +// if present. +// +// Returns true if an ICMPv4 header was successfully parsed. +func ICMPv4(pkt *stack.PacketBuffer) bool { + if _, ok := pkt.TransportHeader().Consume(header.ICMPv4MinimumSize); ok { + pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber + return true + } + return false +} + +// ICMPv6 populates the packet buffer's transport header with an ICMPv4 header, +// if present. +// +// Returns true if an ICMPv6 header was successfully parsed. +func ICMPv6(pkt *stack.PacketBuffer) bool { + hdr, ok := pkt.Data().PullUp(header.ICMPv6MinimumSize) + if !ok { + return false + } + + h := header.ICMPv6(hdr) + switch h.Type() { + case header.ICMPv6RouterSolicit, + header.ICMPv6RouterAdvert, + header.ICMPv6NeighborSolicit, + header.ICMPv6NeighborAdvert, + header.ICMPv6RedirectMsg: + size := pkt.Data().Size() + if _, ok := pkt.TransportHeader().Consume(size); !ok { + panic(fmt.Sprintf("expected to consume the full data of size = %d bytes into transport header", size)) + } + case header.ICMPv6MulticastListenerQuery, + header.ICMPv6MulticastListenerReport, + header.ICMPv6MulticastListenerDone: + size := header.ICMPv6HeaderSize + header.MLDMinimumSize + if _, ok := pkt.TransportHeader().Consume(size); !ok { + return false + } + case header.ICMPv6DstUnreachable, + header.ICMPv6PacketTooBig, + header.ICMPv6TimeExceeded, + header.ICMPv6ParamProblem, + header.ICMPv6EchoRequest, + header.ICMPv6EchoReply: + fallthrough + default: + if _, ok := pkt.TransportHeader().Consume(header.ICMPv6MinimumSize); !ok { + // Checked above if the packet buffer holds at least the minimum size for + // an ICMPv6 packet. + panic(fmt.Sprintf("expected to consume %d bytes", header.ICMPv6MinimumSize)) + } + } + pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber + return true +} diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 1c3b0887f..3eff0bbd8 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -175,18 +175,14 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { received := e.stats.icmp.packetsReceived - // ICMP packets don't have their TransportHeader fields set. See - // icmp/protocol.go:protocol.Parse for a full explanation. Not all ICMP types - // require consuming the header, so we only call PullUp. - v, ok := pkt.Data().PullUp(header.ICMPv4MinimumSize) - if !ok { + h := header.ICMPv4(pkt.TransportHeader().View()) + if len(h) < header.ICMPv4MinimumSize { received.invalid.Increment() return } - h := header.ICMPv4(v) // Only do in-stack processing if the checksum is correct. - if pkt.Data().AsRange().Checksum() != 0xffff { + if header.Checksum(h, pkt.Data().AsRange().Checksum()) != 0xffff { received.invalid.Increment() // It's possible that a raw socket expects to receive this regardless // of checksum errors. If it's an echo request we know it's safe because @@ -251,7 +247,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { // TODO(gvisor.dev/issue/4399): The copy may not be needed if there are no // waiting endpoints. Consider moving responsibility for doing the copy to // DeliverTransportPacket so that is is only done when needed. - replyData := pkt.Data().AsRange().ToOwnedView() + replyData := stack.PayloadSince(pkt.TransportHeader()) ipHdr := header.IPv4(pkt.NetworkHeader().View()) localAddressBroadcast := pkt.NetworkPacketInfo.LocalAddressBroadcast @@ -344,9 +340,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { mtu := h.MTU() code := h.Code() - if _, ok := pkt.Data().Consume(header.ICMPv4MinimumSize); !ok { - panic("could not consume ICMPv4MinimumSize bytes") - } switch code { case header.ICMPv4HostUnreachable: e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt) @@ -574,20 +567,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip // Don't respond to icmp error packets. if origIPHdr.Protocol() == uint8(header.ICMPv4ProtocolNumber) { - // TODO(gvisor.dev/issue/3810): - // Unfortunately the current stack pretty much always has ICMPv4 headers - // in the Data section of the packet but there is no guarantee that is the - // case. If this is the case grab the header to make it like all other - // packet types. When this is cleaned up the Consume should be removed. - if transportHeader.IsEmpty() { - var ok bool - transportHeader, ok = pkt.TransportHeader().Consume(header.ICMPv4MinimumSize) - if !ok { - return nil - } - } else if transportHeader.Size() < header.ICMPv4MinimumSize { - return nil - } // We need to decide to explicitly name the packets we can respond to or // the ones we can not respond to. The decision is somewhat arbitrary and // if problems arise this could be reversed. It was judged less of a breach diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 6e52cc9bb..d1d509702 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -984,7 +984,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, } proto := h.Protocol() - resPkt, _, ready, err := e.protocol.fragmentation.Process( + resPkt, transProtoNum, ready, err := e.protocol.fragmentation.Process( // As per RFC 791 section 2.3, the identification value is unique // for a source-destination pair and protocol. fragmentation.FragmentID{ @@ -1015,6 +1015,8 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, h.SetTotalLength(uint16(pkt.Data().Size() + len(h))) h.SetFlagsFragmentOffset(0, 0) + e.protocol.parseTransport(pkt, tcpip.TransportProtocolNumber(transProtoNum)) + // Now that the packet is reassembled, it can be sent to raw sockets. e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt) } @@ -1310,19 +1312,29 @@ func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv4, bool) } if hasTransportHdr { - switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err { - case stack.ParsedOK: - case stack.UnknownTransportProtocol, stack.TransportLayerParseError: - // The transport layer will handle unknown protocols and transport layer - // parsing errors. - default: - panic(fmt.Sprintf("unexpected error parsing transport header = %d", err)) - } + p.parseTransport(pkt, transProtoNum) } return h, true } +func (p *protocol) parseTransport(pkt *stack.PacketBuffer, transProtoNum tcpip.TransportProtocolNumber) { + if transProtoNum == header.ICMPv4ProtocolNumber { + // The transport layer will handle transport layer parsing errors. + _ = parse.ICMPv4(pkt) + return + } + + switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err { + case stack.ParsedOK: + case stack.UnknownTransportProtocol, stack.TransportLayerParseError: + // The transport layer will handle unknown protocols and transport layer + // parsing errors. + default: + panic(fmt.Sprintf("unexpected error parsing transport header = %d", err)) + } +} + // Parse implements stack.NetworkProtocol. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { if ok := parse.IPv4(pkt); !ok { diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index ff23d48e7..adfc8d8da 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -274,7 +274,7 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP if routerAlert == nil || routerAlert.Value != header.IPv6RouterAlertMLD { return false } - if pkt.Data().Size() < header.ICMPv6HeaderSize+header.MLDMinimumSize { + if pkt.TransportHeader().View().Size() < header.ICMPv6HeaderSize+header.MLDMinimumSize { return false } if iph.HopLimit() != header.MLDHopLimit { @@ -289,20 +289,17 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, routerAlert *header.IPv6RouterAlertOption) { sent := e.stats.icmp.packetsSent received := e.stats.icmp.packetsReceived - // ICMP packets don't have their TransportHeader fields set. See - // icmp/protocol.go:protocol.Parse for a full explanation. - v, ok := pkt.Data().PullUp(header.ICMPv6HeaderSize) - if !ok { + h := header.ICMPv6(pkt.TransportHeader().View()) + if len(h) < header.ICMPv6MinimumSize { received.invalid.Increment() return } - h := header.ICMPv6(v) iph := header.IPv6(pkt.NetworkHeader().View()) srcAddr := iph.SourceAddress() dstAddr := iph.DestinationAddress() // Validate ICMPv6 checksum before processing the packet. - payload := pkt.Data().AsRange().SubRange(len(h)) + payload := pkt.Data().AsRange() if got, want := h.Checksum(), header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: h, Src: srcAddr, @@ -329,12 +326,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r switch icmpType := h.Type(); icmpType { case header.ICMPv6PacketTooBig: received.packetTooBig.Increment() - hdr, ok := pkt.Data().Consume(header.ICMPv6PacketTooBigMinimumSize) - if !ok { - received.invalid.Increment() - return - } - networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize) + networkMTU, err := calculateNetworkMTU(h.MTU(), header.IPv6MinimumSize) if err != nil { networkMTU = 0 } @@ -342,13 +334,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r case header.ICMPv6DstUnreachable: received.dstUnreachable.Increment() - hdr, ok := pkt.Data().Consume(header.ICMPv6DstUnreachableMinimumSize) - if !ok { - received.invalid.Increment() - return - } - code := header.ICMPv6(hdr).Code() - switch code { + switch h.Code() { case header.ICMPv6NetworkUnreachable: e.handleControl(&icmpv6DestinationNetworkUnreachableSockError{}, pkt) case header.ICMPv6PortUnreachable: @@ -356,16 +342,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r } case header.ICMPv6NeighborSolicit: received.neighborSolicit.Increment() - if !isNDPValid() || pkt.Data().Size() < header.ICMPv6NeighborSolicitMinimumSize { + if !isNDPValid() || len(h) < header.ICMPv6NeighborSolicitMinimumSize { received.invalid.Increment() return } - // The remainder of payload must be only the neighbor solicitation, so - // payload.AsView() always returns the solicitation. Per RFC 6980 section 5, - // NDP messages cannot be fragmented. Also note that in the common case NDP - // datagrams are very small and AsView() will not incur allocations. - ns := header.NDPNeighborSolicit(payload.AsView()) + ns := header.NDPNeighborSolicit(h.MessageBody()) targetAddr := ns.TargetAddress() // As per RFC 4861 section 4.3, the Target Address MUST NOT be a multicast @@ -578,16 +560,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r case header.ICMPv6NeighborAdvert: received.neighborAdvert.Increment() - if !isNDPValid() || pkt.Data().Size() < header.ICMPv6NeighborAdvertMinimumSize { + if !isNDPValid() || len(h) < header.ICMPv6NeighborAdvertMinimumSize { received.invalid.Increment() return } - // The remainder of payload must be only the neighbor advertisement, so - // payload.AsView() always returns the advertisement. Per RFC 6980 section - // 5, NDP messages cannot be fragmented. Also note that in the common case - // NDP datagrams are very small and AsView() will not incur allocations. - na := header.NDPNeighborAdvert(payload.AsView()) + na := header.NDPNeighborAdvert(h.MessageBody()) it, err := na.Options().Iter(false /* check */) if err != nil { @@ -674,12 +652,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r case header.ICMPv6EchoRequest: received.echoRequest.Increment() - icmpHdr, ok := pkt.TransportHeader().Consume(header.ICMPv6EchoMinimumSize) - if !ok { - received.invalid.Increment() - return - } - // As per RFC 4291 section 2.7, multicast addresses must not be used as // source addresses in IPv6 packets. localAddr := dstAddr @@ -705,7 +677,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r }) icmp := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize)) pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber - copy(icmp, icmpHdr) + copy(icmp, h) icmp.SetType(header.ICMPv6EchoReply) dataRange := replyPkt.Data().AsRange() icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ @@ -727,7 +699,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r case header.ICMPv6EchoReply: received.echoReply.Increment() - if pkt.Data().Size() < header.ICMPv6EchoMinimumSize { + if len(h) < header.ICMPv6EchoMinimumSize { received.invalid.Increment() return } @@ -747,7 +719,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // // Is the NDP payload of sufficient size to hold a Router Solictation? - if !isNDPValid() || pkt.Data().Size()-header.ICMPv6HeaderSize < header.NDPRSMinimumSize { + if !isNDPValid() || len(h)-header.ICMPv6HeaderSize < header.NDPRSMinimumSize { received.invalid.Increment() return } @@ -757,9 +729,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r return } - // Note that in the common case NDP datagrams are very small and AsView() - // will not incur allocations. - rs := header.NDPRouterSolicit(payload.AsView()) + rs := header.NDPRouterSolicit(h.MessageBody()) it, err := rs.Options().Iter(false /* check */) if err != nil { // Options are not valid as per the wire format, silently drop the packet. @@ -803,7 +773,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // // Is the NDP payload of sufficient size to hold a Router Advertisement? - if !isNDPValid() || pkt.Data().Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize { + if !isNDPValid() || len(h)-header.ICMPv6HeaderSize < header.NDPRAMinimumSize { received.invalid.Increment() return } @@ -817,9 +787,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r return } - // Note that in the common case NDP datagrams are very small and AsView() - // will not incur allocations. - ra := header.NDPRouterAdvert(payload.AsView()) + ra := header.NDPRouterAdvert(h.MessageBody()) it, err := ra.Options().Iter(false /* check */) if err != nil { // Options are not valid as per the wire format, silently drop the packet. @@ -897,11 +865,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r switch icmpType { case header.ICMPv6MulticastListenerQuery: e.mu.Lock() - e.mu.mld.handleMulticastListenerQuery(header.MLD(payload.AsView())) + e.mu.mld.handleMulticastListenerQuery(header.MLD(h.MessageBody())) e.mu.Unlock() case header.ICMPv6MulticastListenerReport: e.mu.Lock() - e.mu.mld.handleMulticastListenerReport(header.MLD(payload.AsView())) + e.mu.mld.handleMulticastListenerReport(header.MLD(h.MessageBody())) e.mu.Unlock() case header.ICMPv6MulticastListenerDone: default: @@ -1182,18 +1150,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip } if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber { - // TODO(gvisor.dev/issues/3810): Sort this out when ICMP headers are stored. - // Unfortunately at this time ICMP Packets do not have a transport - // header separated out. It is in the Data part so we need to - // separate it out now. We will just pretend it is a minimal length - // ICMP packet as we don't really care if any later bits of a - // larger ICMP packet are in the header view or in the Data view. - transport, ok := pkt.TransportHeader().Consume(header.ICMPv6MinimumSize) - if !ok { - return nil - } - typ := header.ICMPv6(transport).Type() - if typ.IsErrorType() || typ == header.ICMPv6RedirectMsg { + if typ := header.ICMPv6(pkt.TransportHeader().View()).Type(); typ.IsErrorType() || typ == header.ICMPv6RedirectMsg { return nil } } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 0406a2e6e..7d3e1fd53 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -1554,13 +1554,19 @@ func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffe return fmt.Errorf("could not consume %d bytes", trim) } + proto := tcpip.TransportProtocolNumber(extHdr.Identifier) + // If the packet was reassembled from a fragment, it will not have a + // transport header set yet. + if pkt.TransportHeader().View().IsEmpty() { + e.protocol.parseTransport(pkt, proto) + } + stats.PacketsDelivered.Increment() - if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { - pkt.TransportProtocolNumber = p + if proto == header.ICMPv6ProtocolNumber { e.handleICMP(pkt, hasFragmentHeader, routerAlert) } else { stats.PacketsDelivered.Increment() - switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res { + switch res := e.dispatcher.DeliverTransportPacket(proto, pkt); res { case stack.TransportPacketHandled: case stack.TransportPacketDestinationPortUnreachable: // As per RFC 4443 section 3.1: @@ -2161,19 +2167,29 @@ func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv6, bool) } if hasTransportHdr { - switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err { - case stack.ParsedOK: - case stack.UnknownTransportProtocol, stack.TransportLayerParseError: - // The transport layer will handle unknown protocols and transport layer - // parsing errors. - default: - panic(fmt.Sprintf("unexpected error parsing transport header = %d", err)) - } + p.parseTransport(pkt, transProtoNum) } return h, true } +func (p *protocol) parseTransport(pkt *stack.PacketBuffer, transProtoNum tcpip.TransportProtocolNumber) { + if transProtoNum == header.ICMPv6ProtocolNumber { + // The transport layer will handle transport layer parsing errors. + _ = parse.ICMPv6(pkt) + return + } + + switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err { + case stack.ParsedOK: + case stack.UnknownTransportProtocol, stack.TransportLayerParseError: + // The transport layer will handle unknown protocols and transport layer + // parsing errors. + default: + panic(fmt.Sprintf("unexpected error parsing transport header = %d", err)) + } +} + // Parse implements stack.NetworkProtocol. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { proto, _, fragOffset, fragMore, ok := parse.IPv6(pkt) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 29d580e76..e251e3b24 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -833,24 +833,9 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt transProto := state.proto - // TransportHeader is empty only when pkt is an ICMP packet or was reassembled - // from fragments. if pkt.TransportHeader().View().IsEmpty() { - // ICMP packets don't have their TransportHeader fields set yet, parse it - // here. See icmp/protocol.go:protocol.Parse for a full explanation. - if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { - // ICMP packets may be longer, but until icmp.Parse is implemented, here - // we parse it using the minimum size. - if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok { - n.stats.malformedL4RcvdPackets.Increment() - // We consider a malformed transport packet handled because there is - // nothing the caller can do. - return TransportPacketHandled - } - } else if !transProto.Parse(pkt) { - n.stats.malformedL4RcvdPackets.Increment() - return TransportPacketHandled - } + n.stats.malformedL4RcvdPackets.Increment() + return TransportPacketHandled } srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View()) diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index ee6767654..a05fd7036 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1865,12 +1865,6 @@ const ( // ParsePacketBufferTransport parses the provided packet buffer's transport // header. func (s *Stack) ParsePacketBufferTransport(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) ParseResult { - // ICMP packets don't have their TransportHeader fields set yet, parse it - // here. See icmp/protocol.go:protocol.Parse for a full explanation. - if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { - return ParsedOK - } - pkt.TransportProtocolNumber = protocol // Parse the transport header if present. state, ok := s.transportProtocols[protocol] diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index c23e91702..f5a35eac4 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -155,8 +155,18 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { return } + transProtoNum := tcpip.TransportProtocolNumber(netHdr[protocolNumberOffset]) + switch err := f.proto.stack.ParsePacketBufferTransport(transProtoNum, pkt); err { + case stack.ParsedOK: + case stack.UnknownTransportProtocol, stack.TransportLayerParseError: + // The transport layer will handle unknown protocols and transport layer + // parsing errors. + default: + panic(fmt.Sprintf("unexpected error parsing transport header = %d", err)) + } + // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(transProtoNum, pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -218,6 +228,8 @@ func (*fakeNetworkEndpointStats) IsNetworkEndpointStats() {} // number of packets sent and received via endpoints of this protocol. The index // where packets are added is given by the packet's destination address MOD 10. type fakeNetworkProtocol struct { + stack *stack.Stack + packetCount [10]int sendPacketCount [10]int defaultTTL uint8 @@ -299,8 +311,8 @@ func (f *fakeNetworkEndpoint) SetForwarding(v bool) { f.mu.forwarding = v } -func fakeNetFactory(*stack.Stack) stack.NetworkProtocol { - return &fakeNetworkProtocol{} +func fakeNetFactory(s *stack.Stack) stack.NetworkProtocol { + return &fakeNetworkProtocol{stack: s} } // linkEPWithMockedAttach is a stack.LinkEndpoint that tests can use to verify diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 655931715..51870d03f 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -331,8 +331,11 @@ func (*fakeTransportProtocol) Wait() {} // Parse implements TransportProtocol.Parse. func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool { - _, ok := pkt.TransportHeader().Consume(fakeTransHeaderLen) - return ok + if _, ok := pkt.TransportHeader().Consume(fakeTransHeaderLen); ok { + pkt.TransportProtocolNumber = fakeTransNumber + return true + } + return false } func fakeTransFactory(s *stack.Stack) stack.TransportProtocol { |