diff options
author | Ghanan Gowripalan <ghanan@google.com> | 2020-10-08 16:14:01 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2020-10-08 16:15:59 -0700 |
commit | 6768e6c59ec252854a1633e184b69dc5723ac3f3 (patch) | |
tree | 9e9915094781b9b31b201434c7a1e1f0940c2869 /pkg/tcpip/network/ipv4 | |
parent | 40269d0c24d1ea9b040a8326c9fa01b03477410a (diff) |
Do not resolve routes immediately
When a response needs to be sent to an incoming packet, the stack should
consult its neighbour table to determine the remote address's link
address.
When an entry does not exist in the stack's neighbor table, the stack
should queue the packet while link resolution completes. See comments.
PiperOrigin-RevId: 336185457
Diffstat (limited to 'pkg/tcpip/network/ipv4')
-rw-r--r-- | pkg/tcpip/network/ipv4/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/icmp.go | 40 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 26 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4_test.go | 203 |
4 files changed, 240 insertions, 30 deletions
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index ee2c23e91..7fc12e229 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -32,6 +32,7 @@ go_test( "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/testutil", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index eab9a530c..3407755ed 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -102,8 +102,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) - remoteLinkAddr := r.RemoteLinkAddress - // As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP // source address MUST be one of its own IP addresses (but not a broadcast // or multicast address). @@ -119,9 +117,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { } defer r.Release() - // Use the remote link address from the incoming packet. - r.ResolveWith(remoteLinkAddr) - // TODO(gvisor.dev/issue/3810:) When adding protocol numbers into the // header information, we may have to change this code to handle the // ICMP header no longer being in the data buffer. @@ -244,13 +239,7 @@ func (*icmpReasonProtoUnreachable) isICMPReason() {} // the problematic packet. It incorporates as much of that packet as // possible as well as any error metadata as is available. returnError // expects pkt to hold a valid IPv4 packet as per the wire format. -func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { - sent := r.Stats().ICMP.V4PacketsSent - if !r.Stack().AllowICMPMessage() { - sent.RateLimited.Increment() - return nil - } - +func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { // We check we are responding only when we are allowed to. // See RFC 1812 section 4.3.2.7 (shown below). // @@ -279,6 +268,25 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc return nil } + // Even if we were able to receive a packet from some remote, we may not have + // a route to it - the remote may be blocked via routing rules. We must always + // consult our routing table and find a route to the remote before sending any + // packet. + route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err + } + defer route.Release() + // From this point on, the incoming route should no longer be used; route + // must be used to send the ICMP error. + r = nil + + sent := p.stack.Stats().ICMP.V4PacketsSent + if !p.stack.AllowICMPMessage() { + sent.RateLimited.Increment() + return nil + } + networkHeader := pkt.NetworkHeader().View() transportHeader := pkt.TransportHeader().View() @@ -329,11 +337,11 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc // least 8 bytes of the payload must be included. Today linux and other // systems implement the RFC 1812 definition and not the original // requirement. We treat 8 bytes as the minimum but will try send more. - mtu := int(r.MTU()) + mtu := int(route.MTU()) if mtu > header.IPv4MinimumProcessableDatagramSize { mtu = header.IPv4MinimumProcessableDatagramSize } - headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize + headerLen := int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize available := int(mtu) - headerLen if available < header.IPv4MinimumSize+header.ICMPv4MinimumErrorPayloadSize { @@ -378,11 +386,11 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data)) counter := sent.DstUnreachable - if err := r.WritePacket( + if err := route.WritePacket( nil, /* gso */ stack.NetworkHeaderParams{ Protocol: header.ICMPv4ProtocolNumber, - TTL: r.DefaultTTL(), + TTL: route.DefaultTTL(), TOS: stack.DefaultTOS, }, icmpPkt, diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 99274dd45..115fb1ab0 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -66,7 +66,6 @@ var _ stack.NetworkEndpoint = (*endpoint)(nil) type endpoint struct { nic stack.NetworkInterface - linkEP stack.LinkEndpoint dispatcher stack.TransportDispatcher protocol *protocol @@ -87,7 +86,6 @@ type endpoint struct { func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ nic: nic, - linkEP: nic.LinkEndpoint(), dispatcher: dispatcher, protocol: p, } @@ -178,18 +176,18 @@ func (e *endpoint) DefaultTTL() uint8 { // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus // the network layer max header length. func (e *endpoint) MTU() uint32 { - return calculateMTU(e.linkEP.MTU()) + return calculateMTU(e.nic.MTU()) } // MaxHeaderLength returns the maximum length needed by ipv4 headers (and // underlying protocols). func (e *endpoint) MaxHeaderLength() uint16 { - return e.linkEP.MaxHeaderLength() + header.IPv4MaximumHeaderSize + return e.nic.MaxHeaderLength() + header.IPv4MaximumHeaderSize } // GSOMaxSize returns the maximum GSO packet size. func (e *endpoint) GSOMaxSize() uint32 { - if gso, ok := e.linkEP.(stack.GSOEndpoint); ok { + if gso, ok := e.nic.(stack.GSOEndpoint); ok { return gso.GSOMaxSize() } return 0 @@ -210,7 +208,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu uint for { fragPkt, more := buildNextFragment(&pf, networkHeader) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil { + if err := e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil { r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pf.RemainingFragmentCount() + 1)) return err } @@ -283,10 +281,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw if r.Loop&stack.PacketOut == 0 { return nil } - if pkt.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { - return e.writePacketFragments(r, gso, e.linkEP.MTU(), pkt) + if pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone) { + return e.writePacketFragments(r, gso, e.nic.MTU(), pkt) } - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err } @@ -316,7 +314,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. - n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) if err != nil { r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) @@ -343,7 +341,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe continue } } - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped))) // Dropped packets aren't errors, so include them in @@ -404,7 +402,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return nil } - if err := e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err } @@ -512,13 +510,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // 3 (Port Unreachable), when the designated transport protocol // (e.g., UDP) is unable to demultiplex the datagram but has no // protocol mechanism to inform the sender. - _ = returnError(r, &icmpReasonPortUnreachable{}, pkt) + _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt) case stack.TransportPacketProtocolUnreachable: // As per RFC: 1122 Section 3.2.2.1 // A host SHOULD generate Destination Unreachable messages with code: // 2 (Protocol Unreachable), when the designated transport protocol // is not supported - _ = returnError(r, &icmpReasonProtoUnreachable{}, pkt) + _ = e.protocol.returnError(r, &icmpReasonProtoUnreachable{}, pkt) default: panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index f250a3cde..9916d783f 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -16,6 +16,7 @@ package ipv4_test import ( "bytes" + "context" "encoding/hex" "math" "net" @@ -28,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -1492,3 +1494,204 @@ func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, lm.limit-- return false, false } + +func TestPacketQueing(t *testing.T) { + const nicID = 1 + + var ( + host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") + host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") + + host1IPv4Addr = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), + PrefixLen: 24, + }, + } + host2IPv4Addr = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), + PrefixLen: 8, + }, + } + ) + + tests := []struct { + name string + rxPkt func(*channel.Endpoint) + checkResp func(*testing.T, *channel.Endpoint) + }{ + { + name: "ICMP Error", + rxPkt: func(e *channel.Endpoint) { + hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.UDPMinimumSize) + u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + u.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: 80, + Length: header.UDPMinimumSize, + }) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv4Addr.AddressWithPrefix.Address, host1IPv4Addr.AddressWithPrefix.Address, header.UDPMinimumSize) + sum = header.Checksum(header.UDP([]byte{}), sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: header.IPv4MinimumSize + header.UDPMinimumSize, + TTL: ipv4.DefaultTTL, + Protocol: uint8(udp.ProtocolNumber), + SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, + DstAddr: host1IPv4Addr.AddressWithPrefix.Address, + }) + e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + }, + checkResp: func(t *testing.T, e *channel.Endpoint) { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != header.IPv4ProtocolNumber { + t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) + } + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + } + checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + checker.ICMPv4Code(header.ICMPv4PortUnreachable))) + }, + }, + + { + name: "Ping", + rxPkt: func(e *channel.Endpoint) { + totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + pkt.SetType(header.ICMPv4Echo) + pkt.SetCode(0) + pkt.SetChecksum(0) + pkt.SetChecksum(^header.Checksum(pkt, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(totalLen), + Protocol: uint8(icmp.ProtocolNumber4), + TTL: ipv4.DefaultTTL, + SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, + DstAddr: host1IPv4Addr.AddressWithPrefix.Address, + }) + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + }, + checkResp: func(t *testing.T, e *channel.Endpoint) { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != header.IPv4ProtocolNumber { + t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) + } + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + } + checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4EchoReply), + checker.ICMPv4Code(header.ICMPv4UnusedCode))) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, arp.ProtocolNumber, arp.ProtocolAddress, err) + } + if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err) + } + + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: nicID, + }, + }) + + // Receive a packet to trigger link resolution before a response is sent. + test.rxPkt(e) + + // Wait for a ARP request since link address resolution should be + // performed. + { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != arp.ProtocolNumber { + t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber) + } + if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) + } + rep := header.ARP(p.Pkt.NetworkHeader().View()) + if got := rep.Op(); got != header.ARPRequest { + t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest) + } + if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != host1NICLinkAddr { + t.Errorf("got HardwareAddressSender = %s, want = %s", got, host1NICLinkAddr) + } + if got := tcpip.Address(rep.ProtocolAddressSender()); got != host1IPv4Addr.AddressWithPrefix.Address { + t.Errorf("got ProtocolAddressSender = %s, want = %s", got, host1IPv4Addr.AddressWithPrefix.Address) + } + if got := tcpip.Address(rep.ProtocolAddressTarget()); got != host2IPv4Addr.AddressWithPrefix.Address { + t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, host2IPv4Addr.AddressWithPrefix.Address) + } + } + + // Send an ARP reply to complete link address resolution. + { + hdr := buffer.View(make([]byte, header.ARPSize)) + packet := header.ARP(hdr) + packet.SetIPv4OverEthernet() + packet.SetOp(header.ARPReply) + copy(packet.HardwareAddressSender(), host2NICLinkAddr) + copy(packet.ProtocolAddressSender(), host2IPv4Addr.AddressWithPrefix.Address) + copy(packet.HardwareAddressTarget(), host1NICLinkAddr) + copy(packet.ProtocolAddressTarget(), host1IPv4Addr.AddressWithPrefix.Address) + e.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.ToVectorisedView(), + })) + } + + // Expect the response now that the link address has resolved. + test.checkResp(t, e) + + // Since link resolution was already performed, it shouldn't be performed + // again. + test.rxPkt(e) + test.checkResp(t, e) + }) + } +} |