From 24416032ab848cff7696b3f37e4c18220aeee2c0 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 5 Feb 2021 16:44:49 -0800 Subject: Refactor locally delivered packets Make it clear that failing to parse a looped back is not a packet sending error but a malformed received packet error. FindNetworkEndpoint returns nil when no network endpoint is found instead of an error. PiperOrigin-RevId: 355954946 --- pkg/tcpip/network/BUILD | 1 - pkg/tcpip/network/arp/arp.go | 5 ++ pkg/tcpip/network/ip_test.go | 14 ----- pkg/tcpip/network/ipv4/ipv4.go | 103 +++++++++++++++++++++++++----------- pkg/tcpip/network/ipv6/icmp_test.go | 71 ++++++++----------------- pkg/tcpip/network/ipv6/ipv6.go | 103 +++++++++++++++++++++++++----------- pkg/tcpip/network/ipv6/ndp_test.go | 14 ++--- pkg/tcpip/stack/forwarding_test.go | 4 ++ pkg/tcpip/stack/nic.go | 30 ----------- pkg/tcpip/stack/stack.go | 44 ++++++--------- pkg/tcpip/stack/stack_test.go | 4 ++ 11 files changed, 201 insertions(+), 192 deletions(-) diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index 0caa65251..fa8814bac 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -16,7 +16,6 @@ go_test( "//pkg/tcpip/checker", "//pkg/tcpip/faketime", "//pkg/tcpip/header", - "//pkg/tcpip/header/parse", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 0d7fadc31..bd9b9c020 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -129,6 +129,11 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } + if _, _, ok := e.protocol.Parse(pkt); !ok { + stats.malformedPacketsReceived.Increment() + return + } + h := header.ARP(pkt.NetworkHeader().View()) if !h.IsValid() { stats.malformedPacketsReceived.Increment() diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 6a1f11a36..a176ef2b9 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -24,7 +24,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -626,9 +625,6 @@ func TestReceive(t *testing.T) { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: view.ToVectorisedView(), }) - if ok := parse.IPv4(pkt); !ok { - t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) - } ep.HandlePacket(pkt) }, }, @@ -664,9 +660,6 @@ func TestReceive(t *testing.T) { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: view.ToVectorisedView(), }) - if _, _, _, _, ok := parse.IPv6(pkt); !ok { - t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) - } ep.HandlePacket(pkt) }, }, @@ -943,9 +936,6 @@ func TestIPv4FragmentationReceive(t *testing.T) { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: frag1.ToVectorisedView(), }) - if _, _, ok := proto.Parse(pkt); !ok { - t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) - } addressableEndpoint, ok := ep.(stack.AddressableEndpoint) if !ok { @@ -967,9 +957,6 @@ func TestIPv4FragmentationReceive(t *testing.T) { pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: frag2.ToVectorisedView(), }) - if _, _, ok := proto.Parse(pkt); !ok { - t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) - } ep.HandlePacket(pkt) if nic.testObject.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) @@ -1234,7 +1221,6 @@ func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: v.ToVectorisedView(), }) - _, _ = pkt.NetworkHeader().Consume(netHdrLen) return pkt } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index b2d626107..e1e05e39c 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -347,15 +347,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { netHeader := header.IPv4(pkt.NetworkHeader().View()) - ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()) - if err == nil { - pkt := pkt.CloneToInbound() - if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - // Since we rewrote the packet but it is being routed back to us, we can - // safely assume the checksum is valid. - pkt.RXTransportChecksumValidated = true - ep.(*endpoint).handlePacket(pkt) - } + if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil { + // Since we rewrote the packet but it is being routed back to us, we + // can safely assume the checksum is valid. + ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */) return nil } } @@ -365,14 +360,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error { if r.Loop&stack.PacketLoop != 0 { - pkt := pkt.CloneToInbound() - if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - // If the packet was generated by the stack (not a raw/packet endpoint - // where a packet may be written with the header included), then we can - // safely assume the checksum is valid. - pkt.RXTransportChecksumValidated = !headerIncluded - e.handlePacket(pkt) - } + // If the packet was generated by the stack (not a raw/packet endpoint + // where a packet may be written with the header included), then we can + // safely assume the checksum is valid. + e.handleLocalPacket(pkt, !headerIncluded /* canSkipRXChecksum */) } if r.Loop&stack.PacketOut == 0 { return nil @@ -471,14 +462,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } if _, ok := natPkts[pkt]; ok { netHeader := header.IPv4(pkt.NetworkHeader().View()) - if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - pkt := pkt.CloneToInbound() - if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - // Since we rewrote the packet but it is being routed back to us, we - // can safely assume the checksum is valid. - pkt.RXTransportChecksumValidated = true - ep.(*endpoint).handlePacket(pkt) - } + if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil { + // Since we rewrote the packet but it is being routed back to us, we + // can safely assume the checksum is valid. + ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */) n++ continue } @@ -573,14 +560,10 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { dstAddr := h.DestinationAddress() // Check if the destination is owned by the stack. - networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr) - if err == nil { - networkEndpoint.(*endpoint).handlePacket(pkt) + if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr); ep != nil { + ep.(*endpoint).handlePacket(pkt) return nil } - if _, ok := err.(*tcpip.ErrBadAddress); !ok { - return err - } r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) if err != nil { @@ -619,8 +602,26 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - // Loopback traffic skips the prerouting chain. + if !e.protocol.parse(pkt) { + stats.MalformedPacketsReceived.Increment() + return + } + if !e.nic.IsLoopback() { + if e.protocol.stack.HandleLocal() { + addressEndpoint := e.AcquireAssignedAddress(header.IPv4(pkt.NetworkHeader().View()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) + if addressEndpoint != nil { + addressEndpoint.DecRef() + + // The source address is one of our own, so we never should have gotten + // a packet like this unless HandleLocal is false or our NIC is the + // loopback interface. + stats.InvalidSourceAddressesReceived.Increment() + return + } + } + + // Loopback traffic skips the prerouting chain. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. @@ -632,6 +633,21 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { e.handlePacket(pkt) } +func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum bool) { + stats := e.stats.ip + + stats.PacketsReceived.Increment() + + pkt = pkt.CloneToInbound() + if e.protocol.parse(pkt) { + pkt.RXTransportChecksumValidated = canSkipRXChecksum + e.handlePacket(pkt) + return + } + + stats.MalformedPacketsReceived.Increment() +} + // handlePacket is like HandlePacket except it does not perform the prerouting // iptables hook. func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { @@ -1043,6 +1059,29 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// parse is like Parse but also attempts to parse the transport layer. +// +// Returns true if the network header was successfully parsed. +func (p *protocol) parse(pkt *stack.PacketBuffer) bool { + transProtoNum, hasTransportHdr, ok := p.Parse(pkt) + if !ok { + return false + } + + if hasTransportHdr { + switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err { + case stack.ParsedOK: + case stack.UnknownTransportProtocol, stack.TransportLayerParseError: + // The transport layer will handle unknown protocols and transport layer + // parsing errors. + default: + panic(fmt.Sprintf("unexpected error parsing transport header = %d", err)) + } + } + + return true +} + // Parse implements stack.NetworkProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { if ok := parse.IPv4(pkt); !ok { diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 92f9ee2c2..ca46ec61f 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -149,6 +149,23 @@ func (t *testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, return nil } +func handleICMPInIPv6(ep stack.NetworkEndpoint, src, dst tcpip.Address, icmp header.ICMPv6) { + ip := buffer.NewView(header.IPv6MinimumSize) + header.IPv6(ip).Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(icmp)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: src, + DstAddr: dst, + }) + vv := ip.ToVectorisedView() + vv.AppendView(buffer.View(icmp)) + ep.HandlePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize, + Data: vv, + })) +} + func TestICMPCounts(t *testing.T) { tests := []struct { name string @@ -282,33 +299,17 @@ func TestICMPCounts(t *testing.T) { }, } - handleIPv6Payload := func(icmp header.ICMPv6) { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.IPv6MinimumSize, - Data: buffer.View(icmp).ToVectorisedView(), - }) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - ep.HandlePacket(pkt) - } - for _, typ := range types { icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) - handleIPv6Payload(icmp) + handleICMPInIPv6(ep, lladdr1, lladdr0, icmp) } // Construct an empty ICMP packet so that // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. - handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) + handleICMPInIPv6(ep, lladdr1, lladdr0, header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { @@ -440,33 +441,17 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { }, } - handleIPv6Payload := func(icmp header.ICMPv6) { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.IPv6MinimumSize, - Data: buffer.View(icmp).ToVectorisedView(), - }) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - ep.HandlePacket(pkt) - } - for _, typ := range types { icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) - handleIPv6Payload(icmp) + handleICMPInIPv6(ep, lladdr1, lladdr0, icmp) } // Construct an empty ICMP packet so that // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. - handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) + handleICMPInIPv6(ep, lladdr1, lladdr0, header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { @@ -1818,19 +1803,7 @@ func TestCallsToNeighborCache(t *testing.T) { icmp := test.createPacket() icmp.SetChecksum(header.ICMPv6Checksum(icmp, test.source, test.destination, buffer.VectorisedView{})) - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.IPv6MinimumSize, - Data: buffer.View(icmp).ToVectorisedView(), - }) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: test.source, - DstAddr: test.destination, - }) - ep.HandlePacket(pkt) + handleICMPInIPv6(ep, test.source, test.destination, icmp) // Confirm the endpoint calls the correct NUDHandler method. if testInterface.probeCount != test.wantProbeCount { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index c2e8c3ea7..5cad546b8 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -648,14 +648,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { netHeader := header.IPv6(pkt.NetworkHeader().View()) - if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - pkt := pkt.CloneToInbound() - if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - // Since we rewrote the packet but it is being routed back to us, we can - // safely assume the checksum is valid. - pkt.RXTransportChecksumValidated = true - ep.(*endpoint).handlePacket(pkt) - } + if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil { + // Since we rewrote the packet but it is being routed back to us, we + // can safely assume the checksum is valid. + ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */) return nil } } @@ -665,14 +661,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) tcpip.Error { if r.Loop&stack.PacketLoop != 0 { - pkt := pkt.CloneToInbound() - if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - // If the packet was generated by the stack (not a raw/packet endpoint - // where a packet may be written with the header included), then we can - // safely assume the checksum is valid. - pkt.RXTransportChecksumValidated = !headerIncluded - e.handlePacket(pkt) - } + // If the packet was generated by the stack (not a raw/packet endpoint + // where a packet may be written with the header included), then we can + // safely assume the checksum is valid. + e.handleLocalPacket(pkt, !headerIncluded /* canSkipRXChecksum */) } if r.Loop&stack.PacketOut == 0 { return nil @@ -771,14 +763,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } if _, ok := natPkts[pkt]; ok { netHeader := header.IPv6(pkt.NetworkHeader().View()) - if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - pkt := pkt.CloneToInbound() - if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - // Since we rewrote the packet but it is being routed back to us, we - // can safely assume the checksum is valid. - pkt.RXTransportChecksumValidated = true - ep.(*endpoint).handlePacket(pkt) - } + if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil { + // Since we rewrote the packet but it is being routed back to us, we + // can safely assume the checksum is valid. + ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */) n++ continue } @@ -852,14 +840,11 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { dstAddr := h.DestinationAddress() // Check if the destination is owned by the stack. - networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr) - if err == nil { - networkEndpoint.(*endpoint).handlePacket(pkt) + + if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr); ep != nil { + ep.(*endpoint).handlePacket(pkt) return nil } - if _, ok := err.(*tcpip.ErrBadAddress); !ok { - return err - } r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) if err != nil { @@ -896,8 +881,26 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - // Loopback traffic skips the prerouting chain. + if !e.protocol.parse(pkt) { + stats.MalformedPacketsReceived.Increment() + return + } + if !e.nic.IsLoopback() { + if e.protocol.stack.HandleLocal() { + addressEndpoint := e.AcquireAssignedAddress(header.IPv6(pkt.NetworkHeader().View()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) + if addressEndpoint != nil { + addressEndpoint.DecRef() + + // The source address is one of our own, so we never should have gotten + // a packet like this unless HandleLocal is false or our NIC is the + // loopback interface. + stats.InvalidSourceAddressesReceived.Increment() + return + } + } + + // Loopback traffic skips the prerouting chain. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. @@ -909,6 +912,21 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { e.handlePacket(pkt) } +func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum bool) { + stats := e.stats.ip + + stats.PacketsReceived.Increment() + + pkt = pkt.CloneToInbound() + if e.protocol.parse(pkt) { + pkt.RXTransportChecksumValidated = canSkipRXChecksum + e.handlePacket(pkt) + return + } + + stats.MalformedPacketsReceived.Increment() +} + // handlePacket is like HandlePacket except it does not perform the prerouting // iptables hook. func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { @@ -1798,6 +1816,29 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// parse is like Parse but also attempts to parse the transport layer. +// +// Returns true if the network header was successfully parsed. +func (p *protocol) parse(pkt *stack.PacketBuffer) bool { + transProtoNum, hasTransportHdr, ok := p.Parse(pkt) + if !ok { + return false + } + + if hasTransportHdr { + switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err { + case stack.ParsedOK: + case stack.UnknownTransportProtocol, stack.TransportLayerParseError: + // The transport layer will handle unknown protocols and transport layer + // parsing errors. + default: + panic(fmt.Sprintf("unexpected error parsing transport header = %d", err)) + } + } + + return true +} + // Parse implements stack.NetworkProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { proto, _, fragOffset, fragMore, ok := parse.IPv6(pkt) diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 8edaa9508..104fe2139 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -977,12 +977,8 @@ func TestNDPValidation(t *testing.T) { } extHdrsLen := extHdrs.Length() - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.IPv6MinimumSize + extHdrsLen, - Data: payload.ToVectorisedView(), - }) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)) - ip.Encode(&header.IPv6Fields{ + ip := buffer.NewView(header.IPv6MinimumSize + extHdrsLen) + header.IPv6(ip).Encode(&header.IPv6Fields{ PayloadLength: uint16(len(payload) + extHdrsLen), TransportProtocol: header.ICMPv6ProtocolNumber, HopLimit: hopLimit, @@ -990,7 +986,11 @@ func TestNDPValidation(t *testing.T) { DstAddr: lladdr0, ExtensionHeaders: extHdrs, }) - ep.HandlePacket(pkt) + vv := ip.ToVectorisedView() + vv.AppendView(payload) + ep.HandlePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + })) } var tllData [header.NDPLinkLayerAddressSize]byte diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index c24f56ece..0cb9ec3a3 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -75,6 +75,10 @@ func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { } func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) { + if _, _, ok := f.proto.Parse(pkt); !ok { + return + } + netHdr := pkt.NetworkHeader().View() _, dst := f.proto.ParseAddresses(netHdr) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 41a489047..6f2a0e487 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -777,36 +777,6 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp anyEPs.forEach(deliverPacketEPs) } - // Parse headers. - netProto := n.stack.NetworkProtocolInstance(protocol) - transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) - if !ok { - // The packet is too small to contain a network header. - n.stack.stats.MalformedRcvdPackets.Increment() - return - } - if hasTransportHdr { - pkt.TransportProtocolNumber = transProtoNum - // Parse the transport header if present. - if state, ok := n.stack.transportProtocols[transProtoNum]; ok { - state.proto.Parse(pkt) - } - } - - if n.stack.handleLocal && !n.IsLoopback() { - src, _ := netProto.ParseAddresses(pkt.NetworkHeader().View()) - if r := n.getAddress(protocol, src); r != nil { - r.DecRef() - - // The source address is one of our own, so we never should have gotten a - // packet like this unless handleLocal is false. Loopback also calls this - // function even though the packets didn't come from the physical interface - // so don't drop those. - n.stack.stats.IP.InvalidSourceAddressesReceived.Increment() - return - } - } - networkEndpoint.HandlePacket(pkt) } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index a51d758d0..035ab33ca 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1319,6 +1319,11 @@ func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, return nil } +// HandleLocal returns true if non-loopback interfaces are allowed to loop packets. +func (s *Stack) HandleLocal() bool { + return s.handleLocal +} + // FindRoute creates a route to the given destination address, leaving through // the given NIC and local address (if provided). // @@ -2063,7 +2068,9 @@ func generateRandInt64() int64 { } // FindNetworkEndpoint returns the network endpoint for the given address. -func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, tcpip.Error) { +// +// Returns nil if the address is not associated with any network endpoint. +func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) NetworkEndpoint { s.mu.RLock() defer s.mu.RUnlock() @@ -2073,9 +2080,9 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres continue } addressEndpoint.DecRef() - return nic.getNetworkEndpoint(netProto), nil + return nic.getNetworkEndpoint(netProto) } - return nil, &tcpip.ErrBadAddress{} + return nil } // FindNICNameFromID returns the name of the NIC for the given NICID. @@ -2103,13 +2110,6 @@ const ( // ParsedOK indicates that a packet was successfully parsed. ParsedOK ParseResult = iota - // UnknownNetworkProtocol indicates that the network protocol is unknown. - UnknownNetworkProtocol - - // NetworkLayerParseError indicates that the network packet was not - // successfully parsed. - NetworkLayerParseError - // UnknownTransportProtocol indicates that the transport protocol is unknown. UnknownTransportProtocol @@ -2118,31 +2118,19 @@ const ( TransportLayerParseError ) -// ParsePacketBuffer parses the provided packet buffer. -func (s *Stack) ParsePacketBuffer(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) ParseResult { - netProto, ok := s.networkProtocols[protocol] - if !ok { - return UnknownNetworkProtocol - } - - transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) - if !ok { - return NetworkLayerParseError - } - if !hasTransportHdr { - return ParsedOK - } - +// ParsePacketBufferTransport parses the provided packet buffer's transport +// header. +func (s *Stack) ParsePacketBufferTransport(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) ParseResult { // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a // full explanation. - if transProtoNum == header.ICMPv4ProtocolNumber || transProtoNum == header.ICMPv6ProtocolNumber { + if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { return ParsedOK } - pkt.TransportProtocolNumber = transProtoNum + pkt.TransportProtocolNumber = protocol // Parse the transport header if present. - state, ok := s.transportProtocols[transProtoNum] + state, ok := s.transportProtocols[protocol] if !ok { return UnknownTransportProtocol } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index b641a4aaa..b3386f705 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -119,6 +119,10 @@ func (*fakeNetworkEndpoint) DefaultTTL() uint8 { } func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { + if _, _, ok := f.proto.Parse(pkt); !ok { + return + } + // Increment the received packet count in the protocol descriptor. netHdr := pkt.NetworkHeader().View() -- cgit v1.2.3