diff options
-rw-r--r-- | pkg/tcpip/network/arp/arp.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 103 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 103 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 30 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 44 |
5 files changed, 164 insertions, 121 deletions
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 0d7fadc31..bd9b9c020 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -129,6 +129,11 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } + if _, _, ok := e.protocol.Parse(pkt); !ok { + stats.malformedPacketsReceived.Increment() + return + } + h := header.ARP(pkt.NetworkHeader().View()) if !h.IsValid() { stats.malformedPacketsReceived.Increment() diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index b2d626107..e1e05e39c 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -347,15 +347,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { netHeader := header.IPv4(pkt.NetworkHeader().View()) - ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()) - if err == nil { - pkt := pkt.CloneToInbound() - if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - // Since we rewrote the packet but it is being routed back to us, we can - // safely assume the checksum is valid. - pkt.RXTransportChecksumValidated = true - ep.(*endpoint).handlePacket(pkt) - } + if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil { + // Since we rewrote the packet but it is being routed back to us, we + // can safely assume the checksum is valid. + ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */) return nil } } @@ -365,14 +360,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error { if r.Loop&stack.PacketLoop != 0 { - pkt := pkt.CloneToInbound() - if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - // If the packet was generated by the stack (not a raw/packet endpoint - // where a packet may be written with the header included), then we can - // safely assume the checksum is valid. - pkt.RXTransportChecksumValidated = !headerIncluded - e.handlePacket(pkt) - } + // If the packet was generated by the stack (not a raw/packet endpoint + // where a packet may be written with the header included), then we can + // safely assume the checksum is valid. + e.handleLocalPacket(pkt, !headerIncluded /* canSkipRXChecksum */) } if r.Loop&stack.PacketOut == 0 { return nil @@ -471,14 +462,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } if _, ok := natPkts[pkt]; ok { netHeader := header.IPv4(pkt.NetworkHeader().View()) - if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - pkt := pkt.CloneToInbound() - if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - // Since we rewrote the packet but it is being routed back to us, we - // can safely assume the checksum is valid. - pkt.RXTransportChecksumValidated = true - ep.(*endpoint).handlePacket(pkt) - } + if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil { + // Since we rewrote the packet but it is being routed back to us, we + // can safely assume the checksum is valid. + ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */) n++ continue } @@ -573,14 +560,10 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { dstAddr := h.DestinationAddress() // Check if the destination is owned by the stack. - networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr) - if err == nil { - networkEndpoint.(*endpoint).handlePacket(pkt) + if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr); ep != nil { + ep.(*endpoint).handlePacket(pkt) return nil } - if _, ok := err.(*tcpip.ErrBadAddress); !ok { - return err - } r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) if err != nil { @@ -619,8 +602,26 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - // Loopback traffic skips the prerouting chain. + if !e.protocol.parse(pkt) { + stats.MalformedPacketsReceived.Increment() + return + } + if !e.nic.IsLoopback() { + if e.protocol.stack.HandleLocal() { + addressEndpoint := e.AcquireAssignedAddress(header.IPv4(pkt.NetworkHeader().View()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) + if addressEndpoint != nil { + addressEndpoint.DecRef() + + // The source address is one of our own, so we never should have gotten + // a packet like this unless HandleLocal is false or our NIC is the + // loopback interface. + stats.InvalidSourceAddressesReceived.Increment() + return + } + } + + // Loopback traffic skips the prerouting chain. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. @@ -632,6 +633,21 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { e.handlePacket(pkt) } +func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum bool) { + stats := e.stats.ip + + stats.PacketsReceived.Increment() + + pkt = pkt.CloneToInbound() + if e.protocol.parse(pkt) { + pkt.RXTransportChecksumValidated = canSkipRXChecksum + e.handlePacket(pkt) + return + } + + stats.MalformedPacketsReceived.Increment() +} + // handlePacket is like HandlePacket except it does not perform the prerouting // iptables hook. func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { @@ -1043,6 +1059,29 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// parse is like Parse but also attempts to parse the transport layer. +// +// Returns true if the network header was successfully parsed. +func (p *protocol) parse(pkt *stack.PacketBuffer) bool { + transProtoNum, hasTransportHdr, ok := p.Parse(pkt) + if !ok { + return false + } + + if hasTransportHdr { + switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err { + case stack.ParsedOK: + case stack.UnknownTransportProtocol, stack.TransportLayerParseError: + // The transport layer will handle unknown protocols and transport layer + // parsing errors. + default: + panic(fmt.Sprintf("unexpected error parsing transport header = %d", err)) + } + } + + return true +} + // Parse implements stack.NetworkProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { if ok := parse.IPv4(pkt); !ok { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index c2e8c3ea7..5cad546b8 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -648,14 +648,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { netHeader := header.IPv6(pkt.NetworkHeader().View()) - if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - pkt := pkt.CloneToInbound() - if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - // Since we rewrote the packet but it is being routed back to us, we can - // safely assume the checksum is valid. - pkt.RXTransportChecksumValidated = true - ep.(*endpoint).handlePacket(pkt) - } + if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil { + // Since we rewrote the packet but it is being routed back to us, we + // can safely assume the checksum is valid. + ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */) return nil } } @@ -665,14 +661,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) tcpip.Error { if r.Loop&stack.PacketLoop != 0 { - pkt := pkt.CloneToInbound() - if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - // If the packet was generated by the stack (not a raw/packet endpoint - // where a packet may be written with the header included), then we can - // safely assume the checksum is valid. - pkt.RXTransportChecksumValidated = !headerIncluded - e.handlePacket(pkt) - } + // If the packet was generated by the stack (not a raw/packet endpoint + // where a packet may be written with the header included), then we can + // safely assume the checksum is valid. + e.handleLocalPacket(pkt, !headerIncluded /* canSkipRXChecksum */) } if r.Loop&stack.PacketOut == 0 { return nil @@ -771,14 +763,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } if _, ok := natPkts[pkt]; ok { netHeader := header.IPv6(pkt.NetworkHeader().View()) - if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - pkt := pkt.CloneToInbound() - if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - // Since we rewrote the packet but it is being routed back to us, we - // can safely assume the checksum is valid. - pkt.RXTransportChecksumValidated = true - ep.(*endpoint).handlePacket(pkt) - } + if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil { + // Since we rewrote the packet but it is being routed back to us, we + // can safely assume the checksum is valid. + ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */) n++ continue } @@ -852,14 +840,11 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { dstAddr := h.DestinationAddress() // Check if the destination is owned by the stack. - networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr) - if err == nil { - networkEndpoint.(*endpoint).handlePacket(pkt) + + if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr); ep != nil { + ep.(*endpoint).handlePacket(pkt) return nil } - if _, ok := err.(*tcpip.ErrBadAddress); !ok { - return err - } r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) if err != nil { @@ -896,8 +881,26 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - // Loopback traffic skips the prerouting chain. + if !e.protocol.parse(pkt) { + stats.MalformedPacketsReceived.Increment() + return + } + if !e.nic.IsLoopback() { + if e.protocol.stack.HandleLocal() { + addressEndpoint := e.AcquireAssignedAddress(header.IPv6(pkt.NetworkHeader().View()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) + if addressEndpoint != nil { + addressEndpoint.DecRef() + + // The source address is one of our own, so we never should have gotten + // a packet like this unless HandleLocal is false or our NIC is the + // loopback interface. + stats.InvalidSourceAddressesReceived.Increment() + return + } + } + + // Loopback traffic skips the prerouting chain. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. @@ -909,6 +912,21 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { e.handlePacket(pkt) } +func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum bool) { + stats := e.stats.ip + + stats.PacketsReceived.Increment() + + pkt = pkt.CloneToInbound() + if e.protocol.parse(pkt) { + pkt.RXTransportChecksumValidated = canSkipRXChecksum + e.handlePacket(pkt) + return + } + + stats.MalformedPacketsReceived.Increment() +} + // handlePacket is like HandlePacket except it does not perform the prerouting // iptables hook. func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { @@ -1798,6 +1816,29 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// parse is like Parse but also attempts to parse the transport layer. +// +// Returns true if the network header was successfully parsed. +func (p *protocol) parse(pkt *stack.PacketBuffer) bool { + transProtoNum, hasTransportHdr, ok := p.Parse(pkt) + if !ok { + return false + } + + if hasTransportHdr { + switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err { + case stack.ParsedOK: + case stack.UnknownTransportProtocol, stack.TransportLayerParseError: + // The transport layer will handle unknown protocols and transport layer + // parsing errors. + default: + panic(fmt.Sprintf("unexpected error parsing transport header = %d", err)) + } + } + + return true +} + // Parse implements stack.NetworkProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { proto, _, fragOffset, fragMore, ok := parse.IPv6(pkt) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 41a489047..6f2a0e487 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -777,36 +777,6 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp anyEPs.forEach(deliverPacketEPs) } - // Parse headers. - netProto := n.stack.NetworkProtocolInstance(protocol) - transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) - if !ok { - // The packet is too small to contain a network header. - n.stack.stats.MalformedRcvdPackets.Increment() - return - } - if hasTransportHdr { - pkt.TransportProtocolNumber = transProtoNum - // Parse the transport header if present. - if state, ok := n.stack.transportProtocols[transProtoNum]; ok { - state.proto.Parse(pkt) - } - } - - if n.stack.handleLocal && !n.IsLoopback() { - src, _ := netProto.ParseAddresses(pkt.NetworkHeader().View()) - if r := n.getAddress(protocol, src); r != nil { - r.DecRef() - - // The source address is one of our own, so we never should have gotten a - // packet like this unless handleLocal is false. Loopback also calls this - // function even though the packets didn't come from the physical interface - // so don't drop those. - n.stack.stats.IP.InvalidSourceAddressesReceived.Increment() - return - } - } - networkEndpoint.HandlePacket(pkt) } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index a51d758d0..035ab33ca 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1319,6 +1319,11 @@ func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, return nil } +// HandleLocal returns true if non-loopback interfaces are allowed to loop packets. +func (s *Stack) HandleLocal() bool { + return s.handleLocal +} + // FindRoute creates a route to the given destination address, leaving through // the given NIC and local address (if provided). // @@ -2063,7 +2068,9 @@ func generateRandInt64() int64 { } // FindNetworkEndpoint returns the network endpoint for the given address. -func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, tcpip.Error) { +// +// Returns nil if the address is not associated with any network endpoint. +func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) NetworkEndpoint { s.mu.RLock() defer s.mu.RUnlock() @@ -2073,9 +2080,9 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres continue } addressEndpoint.DecRef() - return nic.getNetworkEndpoint(netProto), nil + return nic.getNetworkEndpoint(netProto) } - return nil, &tcpip.ErrBadAddress{} + return nil } // FindNICNameFromID returns the name of the NIC for the given NICID. @@ -2103,13 +2110,6 @@ const ( // ParsedOK indicates that a packet was successfully parsed. ParsedOK ParseResult = iota - // UnknownNetworkProtocol indicates that the network protocol is unknown. - UnknownNetworkProtocol - - // NetworkLayerParseError indicates that the network packet was not - // successfully parsed. - NetworkLayerParseError - // UnknownTransportProtocol indicates that the transport protocol is unknown. UnknownTransportProtocol @@ -2118,31 +2118,19 @@ const ( TransportLayerParseError ) -// ParsePacketBuffer parses the provided packet buffer. -func (s *Stack) ParsePacketBuffer(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) ParseResult { - netProto, ok := s.networkProtocols[protocol] - if !ok { - return UnknownNetworkProtocol - } - - transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) - if !ok { - return NetworkLayerParseError - } - if !hasTransportHdr { - return ParsedOK - } - +// ParsePacketBufferTransport parses the provided packet buffer's transport +// header. +func (s *Stack) ParsePacketBufferTransport(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) ParseResult { // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a // full explanation. - if transProtoNum == header.ICMPv4ProtocolNumber || transProtoNum == header.ICMPv6ProtocolNumber { + if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { return ParsedOK } - pkt.TransportProtocolNumber = transProtoNum + pkt.TransportProtocolNumber = protocol // Parse the transport header if present. - state, ok := s.transportProtocols[transProtoNum] + state, ok := s.transportProtocols[protocol] if !ok { return UnknownTransportProtocol } |