From 24416032ab848cff7696b3f37e4c18220aeee2c0 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 5 Feb 2021 16:44:49 -0800 Subject: Refactor locally delivered packets Make it clear that failing to parse a looped back is not a packet sending error but a malformed received packet error. FindNetworkEndpoint returns nil when no network endpoint is found instead of an error. PiperOrigin-RevId: 355954946 --- pkg/tcpip/network/ipv4/ipv4.go | 103 ++++++++++++++++++++++++++++------------- 1 file changed, 71 insertions(+), 32 deletions(-) (limited to 'pkg/tcpip/network/ipv4') diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index b2d626107..e1e05e39c 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -347,15 +347,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { netHeader := header.IPv4(pkt.NetworkHeader().View()) - ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()) - if err == nil { - pkt := pkt.CloneToInbound() - if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - // Since we rewrote the packet but it is being routed back to us, we can - // safely assume the checksum is valid. - pkt.RXTransportChecksumValidated = true - ep.(*endpoint).handlePacket(pkt) - } + if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil { + // Since we rewrote the packet but it is being routed back to us, we + // can safely assume the checksum is valid. + ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */) return nil } } @@ -365,14 +360,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error { if r.Loop&stack.PacketLoop != 0 { - pkt := pkt.CloneToInbound() - if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - // If the packet was generated by the stack (not a raw/packet endpoint - // where a packet may be written with the header included), then we can - // safely assume the checksum is valid. - pkt.RXTransportChecksumValidated = !headerIncluded - e.handlePacket(pkt) - } + // If the packet was generated by the stack (not a raw/packet endpoint + // where a packet may be written with the header included), then we can + // safely assume the checksum is valid. + e.handleLocalPacket(pkt, !headerIncluded /* canSkipRXChecksum */) } if r.Loop&stack.PacketOut == 0 { return nil @@ -471,14 +462,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } if _, ok := natPkts[pkt]; ok { netHeader := header.IPv4(pkt.NetworkHeader().View()) - if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - pkt := pkt.CloneToInbound() - if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - // Since we rewrote the packet but it is being routed back to us, we - // can safely assume the checksum is valid. - pkt.RXTransportChecksumValidated = true - ep.(*endpoint).handlePacket(pkt) - } + if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil { + // Since we rewrote the packet but it is being routed back to us, we + // can safely assume the checksum is valid. + ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */) n++ continue } @@ -573,14 +560,10 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { dstAddr := h.DestinationAddress() // Check if the destination is owned by the stack. - networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr) - if err == nil { - networkEndpoint.(*endpoint).handlePacket(pkt) + if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr); ep != nil { + ep.(*endpoint).handlePacket(pkt) return nil } - if _, ok := err.(*tcpip.ErrBadAddress); !ok { - return err - } r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) if err != nil { @@ -619,8 +602,26 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - // Loopback traffic skips the prerouting chain. + if !e.protocol.parse(pkt) { + stats.MalformedPacketsReceived.Increment() + return + } + if !e.nic.IsLoopback() { + if e.protocol.stack.HandleLocal() { + addressEndpoint := e.AcquireAssignedAddress(header.IPv4(pkt.NetworkHeader().View()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) + if addressEndpoint != nil { + addressEndpoint.DecRef() + + // The source address is one of our own, so we never should have gotten + // a packet like this unless HandleLocal is false or our NIC is the + // loopback interface. + stats.InvalidSourceAddressesReceived.Increment() + return + } + } + + // Loopback traffic skips the prerouting chain. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. @@ -632,6 +633,21 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { e.handlePacket(pkt) } +func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum bool) { + stats := e.stats.ip + + stats.PacketsReceived.Increment() + + pkt = pkt.CloneToInbound() + if e.protocol.parse(pkt) { + pkt.RXTransportChecksumValidated = canSkipRXChecksum + e.handlePacket(pkt) + return + } + + stats.MalformedPacketsReceived.Increment() +} + // handlePacket is like HandlePacket except it does not perform the prerouting // iptables hook. func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { @@ -1043,6 +1059,29 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// parse is like Parse but also attempts to parse the transport layer. +// +// Returns true if the network header was successfully parsed. +func (p *protocol) parse(pkt *stack.PacketBuffer) bool { + transProtoNum, hasTransportHdr, ok := p.Parse(pkt) + if !ok { + return false + } + + if hasTransportHdr { + switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err { + case stack.ParsedOK: + case stack.UnknownTransportProtocol, stack.TransportLayerParseError: + // The transport layer will handle unknown protocols and transport layer + // parsing errors. + default: + panic(fmt.Sprintf("unexpected error parsing transport header = %d", err)) + } + } + + return true +} + // Parse implements stack.NetworkProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { if ok := parse.IPv4(pkt); !ok { -- cgit v1.2.3