diff options
author | Ghanan Gowripalan <ghanan@google.com> | 2021-02-05 16:44:49 -0800 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-02-05 16:47:11 -0800 |
commit | 24416032ab848cff7696b3f37e4c18220aeee2c0 (patch) | |
tree | dd50fbdc84304102a6f9dcb7f6bfa594299bc4d7 /pkg/tcpip/network/ipv6 | |
parent | 3514c289a9c9da232bf3054c971c3e0434d8cfa3 (diff) |
Refactor locally delivered packets
Make it clear that failing to parse a looped back is not a packet
sending error but a malformed received packet error.
FindNetworkEndpoint returns nil when no network endpoint is found
instead of an error.
PiperOrigin-RevId: 355954946
Diffstat (limited to 'pkg/tcpip/network/ipv6')
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp_test.go | 71 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 103 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp_test.go | 14 |
3 files changed, 101 insertions, 87 deletions
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 92f9ee2c2..ca46ec61f 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -149,6 +149,23 @@ func (t *testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, return nil } +func handleICMPInIPv6(ep stack.NetworkEndpoint, src, dst tcpip.Address, icmp header.ICMPv6) { + ip := buffer.NewView(header.IPv6MinimumSize) + header.IPv6(ip).Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(icmp)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: src, + DstAddr: dst, + }) + vv := ip.ToVectorisedView() + vv.AppendView(buffer.View(icmp)) + ep.HandlePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize, + Data: vv, + })) +} + func TestICMPCounts(t *testing.T) { tests := []struct { name string @@ -282,33 +299,17 @@ func TestICMPCounts(t *testing.T) { }, } - handleIPv6Payload := func(icmp header.ICMPv6) { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.IPv6MinimumSize, - Data: buffer.View(icmp).ToVectorisedView(), - }) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - ep.HandlePacket(pkt) - } - for _, typ := range types { icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) - handleIPv6Payload(icmp) + handleICMPInIPv6(ep, lladdr1, lladdr0, icmp) } // Construct an empty ICMP packet so that // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. - handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) + handleICMPInIPv6(ep, lladdr1, lladdr0, header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { @@ -440,33 +441,17 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { }, } - handleIPv6Payload := func(icmp header.ICMPv6) { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.IPv6MinimumSize, - Data: buffer.View(icmp).ToVectorisedView(), - }) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - ep.HandlePacket(pkt) - } - for _, typ := range types { icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) - handleIPv6Payload(icmp) + handleICMPInIPv6(ep, lladdr1, lladdr0, icmp) } // Construct an empty ICMP packet so that // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. - handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) + handleICMPInIPv6(ep, lladdr1, lladdr0, header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { @@ -1818,19 +1803,7 @@ func TestCallsToNeighborCache(t *testing.T) { icmp := test.createPacket() icmp.SetChecksum(header.ICMPv6Checksum(icmp, test.source, test.destination, buffer.VectorisedView{})) - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.IPv6MinimumSize, - Data: buffer.View(icmp).ToVectorisedView(), - }) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: test.source, - DstAddr: test.destination, - }) - ep.HandlePacket(pkt) + handleICMPInIPv6(ep, test.source, test.destination, icmp) // Confirm the endpoint calls the correct NUDHandler method. if testInterface.probeCount != test.wantProbeCount { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index c2e8c3ea7..5cad546b8 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -648,14 +648,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { netHeader := header.IPv6(pkt.NetworkHeader().View()) - if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - pkt := pkt.CloneToInbound() - if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - // Since we rewrote the packet but it is being routed back to us, we can - // safely assume the checksum is valid. - pkt.RXTransportChecksumValidated = true - ep.(*endpoint).handlePacket(pkt) - } + if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil { + // Since we rewrote the packet but it is being routed back to us, we + // can safely assume the checksum is valid. + ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */) return nil } } @@ -665,14 +661,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) tcpip.Error { if r.Loop&stack.PacketLoop != 0 { - pkt := pkt.CloneToInbound() - if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - // If the packet was generated by the stack (not a raw/packet endpoint - // where a packet may be written with the header included), then we can - // safely assume the checksum is valid. - pkt.RXTransportChecksumValidated = !headerIncluded - e.handlePacket(pkt) - } + // If the packet was generated by the stack (not a raw/packet endpoint + // where a packet may be written with the header included), then we can + // safely assume the checksum is valid. + e.handleLocalPacket(pkt, !headerIncluded /* canSkipRXChecksum */) } if r.Loop&stack.PacketOut == 0 { return nil @@ -771,14 +763,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } if _, ok := natPkts[pkt]; ok { netHeader := header.IPv6(pkt.NetworkHeader().View()) - if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - pkt := pkt.CloneToInbound() - if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - // Since we rewrote the packet but it is being routed back to us, we - // can safely assume the checksum is valid. - pkt.RXTransportChecksumValidated = true - ep.(*endpoint).handlePacket(pkt) - } + if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil { + // Since we rewrote the packet but it is being routed back to us, we + // can safely assume the checksum is valid. + ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */) n++ continue } @@ -852,14 +840,11 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { dstAddr := h.DestinationAddress() // Check if the destination is owned by the stack. - networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr) - if err == nil { - networkEndpoint.(*endpoint).handlePacket(pkt) + + if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr); ep != nil { + ep.(*endpoint).handlePacket(pkt) return nil } - if _, ok := err.(*tcpip.ErrBadAddress); !ok { - return err - } r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) if err != nil { @@ -896,8 +881,26 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - // Loopback traffic skips the prerouting chain. + if !e.protocol.parse(pkt) { + stats.MalformedPacketsReceived.Increment() + return + } + if !e.nic.IsLoopback() { + if e.protocol.stack.HandleLocal() { + addressEndpoint := e.AcquireAssignedAddress(header.IPv6(pkt.NetworkHeader().View()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) + if addressEndpoint != nil { + addressEndpoint.DecRef() + + // The source address is one of our own, so we never should have gotten + // a packet like this unless HandleLocal is false or our NIC is the + // loopback interface. + stats.InvalidSourceAddressesReceived.Increment() + return + } + } + + // Loopback traffic skips the prerouting chain. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. @@ -909,6 +912,21 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { e.handlePacket(pkt) } +func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum bool) { + stats := e.stats.ip + + stats.PacketsReceived.Increment() + + pkt = pkt.CloneToInbound() + if e.protocol.parse(pkt) { + pkt.RXTransportChecksumValidated = canSkipRXChecksum + e.handlePacket(pkt) + return + } + + stats.MalformedPacketsReceived.Increment() +} + // handlePacket is like HandlePacket except it does not perform the prerouting // iptables hook. func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { @@ -1798,6 +1816,29 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// parse is like Parse but also attempts to parse the transport layer. +// +// Returns true if the network header was successfully parsed. +func (p *protocol) parse(pkt *stack.PacketBuffer) bool { + transProtoNum, hasTransportHdr, ok := p.Parse(pkt) + if !ok { + return false + } + + if hasTransportHdr { + switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err { + case stack.ParsedOK: + case stack.UnknownTransportProtocol, stack.TransportLayerParseError: + // The transport layer will handle unknown protocols and transport layer + // parsing errors. + default: + panic(fmt.Sprintf("unexpected error parsing transport header = %d", err)) + } + } + + return true +} + // Parse implements stack.NetworkProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { proto, _, fragOffset, fragMore, ok := parse.IPv6(pkt) diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 8edaa9508..104fe2139 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -977,12 +977,8 @@ func TestNDPValidation(t *testing.T) { } extHdrsLen := extHdrs.Length() - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.IPv6MinimumSize + extHdrsLen, - Data: payload.ToVectorisedView(), - }) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)) - ip.Encode(&header.IPv6Fields{ + ip := buffer.NewView(header.IPv6MinimumSize + extHdrsLen) + header.IPv6(ip).Encode(&header.IPv6Fields{ PayloadLength: uint16(len(payload) + extHdrsLen), TransportProtocol: header.ICMPv6ProtocolNumber, HopLimit: hopLimit, @@ -990,7 +986,11 @@ func TestNDPValidation(t *testing.T) { DstAddr: lladdr0, ExtensionHeaders: extHdrs, }) - ep.HandlePacket(pkt) + vv := ip.ToVectorisedView() + vv.AppendView(payload) + ep.HandlePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + })) } var tllData [header.NDPLinkLayerAddressSize]byte |