diff options
author | Nick Brown <nickbrow@google.com> | 2021-05-19 12:09:54 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-05-19 12:13:09 -0700 |
commit | b8b43f70c898861a6fd642ded069dd33d35f17f9 (patch) | |
tree | cc78171174e81454a2d761cec23f0812f95d56ee /pkg/tcpip/tests | |
parent | 2f3eda37a4dd64f1202fac1c69ab819fb0bd7a5e (diff) |
Send ICMP errors when link address resolution fails
Before this change, we would silently drop packets when link resolution
failed. This change brings us into line with RFC 792 (IPv4) and RFC 4443 (IPv6),
both of which specify that gateways should return an ICMP error to the sender
when link resolution fails.
PiperOrigin-RevId: 374699789
Diffstat (limited to 'pkg/tcpip/tests')
-rw-r--r-- | pkg/tcpip/tests/integration/BUILD | 4 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/link_resolution_test.go | 248 |
2 files changed, 252 insertions, 0 deletions
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index ab2dab60c..e446c8cf3 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -48,17 +48,21 @@ go_test( size = "small", srcs = ["link_resolution_test.go"], deps = [ + "//pkg/context", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", "//pkg/tcpip/faketime", "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", "//pkg/tcpip/link/pipe", "//pkg/tcpip/network/arp", + "//pkg/tcpip/network/internal/testutil", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index c657714ba..e94f31ad7 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -17,22 +17,27 @@ package link_resolution_test import ( "bytes" "fmt" + "net" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/arp" + iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" + tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -395,6 +400,249 @@ func TestTCPLinkResolutionFailure(t *testing.T) { } } +func TestForwardingWithLinkResolutionFailure(t *testing.T) { + const ( + incomingNICID = 1 + outgoingNICID = 2 + randomSequence = 123 + randomIdent = 42 + randomTimeOffset = 0x10203040 + ttl = 2 + expectedHostUnreachableErrorCount = 1 + ) + outgoingLinkAddr := tcptestutil.MustParseLink("02:03:03:04:05:06") + + rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv4EchoRequest(e, src, dst, ttl) + } + + rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv6EchoRequest(e, src, dst, ttl) + } + + arpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) { + if request.Proto != arp.ProtocolNumber { + t.Errorf("got request.Proto = %d, want = %d", request.Proto, arp.ProtocolNumber) + } + if request.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { + t.Errorf("got request.Route.RemoteLinkAddress = %s, want = %s", request.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) + } + rep := header.ARP(request.Pkt.NetworkHeader().View()) + if got := rep.Op(); got != header.ARPRequest { + t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest) + } + if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != outgoingLinkAddr { + t.Errorf("got HardwareAddressSender = %s, want = %s", got, outgoingLinkAddr) + } + if got := tcpip.Address(rep.ProtocolAddressSender()); got != src { + t.Errorf("got ProtocolAddressSender = %s, want = %s", got, src) + } + if got := tcpip.Address(rep.ProtocolAddressTarget()); got != dst { + t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, dst) + } + } + + ndpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) { + if request.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", request.Proto, header.IPv6ProtocolNumber) + } + + snmc := header.SolicitedNodeAddr(dst) + if want := header.EthernetAddressFromMulticastIPv6Address(snmc); request.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", request.Route.RemoteLinkAddress, want) + } + + checker.IPv6(t, stack.PayloadSince(request.Pkt.NetworkHeader()), + checker.SrcAddr(src), + checker.DstAddr(snmc), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(dst), + )) + } + + icmpv4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv4(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ipv4.DefaultTTL), + checker.ICMPv4( + checker.ICMPv4Checksum(), + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + checker.ICMPv4Code(header.ICMPv4HostUnreachable), + ), + ) + } + + icmpv6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv6(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ipv6.DefaultTTL), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6DstUnreachable), + checker.ICMPv6Code(header.ICMPv6AddressUnreachable), + ), + ) + } + + tests := []struct { + name string + networkProtocolFactory []stack.NetworkProtocolFactory + networkProtocolNumber tcpip.NetworkProtocolNumber + sourceAddr tcpip.Address + destAddr tcpip.Address + incomingAddr tcpip.AddressWithPrefix + outgoingAddr tcpip.AddressWithPrefix + transportProtocol func(*stack.Stack) stack.TransportProtocol + rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) + linkResolutionRequestChecker func(*testing.T, channel.PacketInfo, tcpip.Address, tcpip.Address) + icmpReplyChecker func(*testing.T, []byte, tcpip.Address, tcpip.Address) + mtu uint32 + }{ + { + name: "IPv4 Host unreachable", + networkProtocolFactory: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + networkProtocolNumber: header.IPv4ProtocolNumber, + sourceAddr: tcptestutil.MustParse4("10.0.0.2"), + destAddr: tcptestutil.MustParse4("11.0.0.2"), + incomingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), + PrefixLen: 8, + }, + outgoingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()), + PrefixLen: 8, + }, + transportProtocol: icmp.NewProtocol4, + linkResolutionRequestChecker: arpChecker, + icmpReplyChecker: icmpv4Checker, + rx: rxICMPv4EchoRequest, + mtu: ipv4.MaxTotalSize, + }, + { + name: "IPv6 Host unreachable", + networkProtocolFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + networkProtocolNumber: header.IPv6ProtocolNumber, + sourceAddr: tcptestutil.MustParse6("10::2"), + destAddr: tcptestutil.MustParse6("11::2"), + incomingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10::1").To16()), + PrefixLen: 64, + }, + outgoingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("11::1").To16()), + PrefixLen: 64, + }, + transportProtocol: icmp.NewProtocol6, + linkResolutionRequestChecker: ndpChecker, + icmpReplyChecker: icmpv6Checker, + rx: rxICMPv6EchoRequest, + mtu: header.IPv6MinimumMTU, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + + s := stack.New(stack.Options{ + NetworkProtocols: test.networkProtocolFactory, + TransportProtocols: []stack.TransportProtocolFactory{test.transportProtocol}, + Clock: clock, + }) + + // Set up endpoint through which we will receive packets. + incomingEndpoint := channel.New(1, test.mtu, "") + if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) + } + incomingProtoAddr := tcpip.ProtocolAddress{ + Protocol: test.networkProtocolNumber, + AddressWithPrefix: test.incomingAddr, + } + if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingProtoAddr, err) + } + + // Set up endpoint through which we will attempt to forward packets. + outgoingEndpoint := channel.New(1, test.mtu, outgoingLinkAddr) + outgoingEndpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) + } + outgoingProtoAddr := tcpip.ProtocolAddress{ + Protocol: test.networkProtocolNumber, + AddressWithPrefix: test.outgoingAddr, + } + if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingProtoAddr, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: test.incomingAddr.Subnet(), + NIC: incomingNICID, + }, + { + Destination: test.outgoingAddr.Subnet(), + NIC: outgoingNICID, + }, + }) + + if err := s.SetForwardingDefaultAndAllNICs(test.networkProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", test.networkProtocolNumber, err) + } + + test.rx(incomingEndpoint, test.sourceAddr, test.destAddr) + + var request channel.PacketInfo + var ok bool + nudConfigs, err := s.NUDConfigurations(outgoingNICID, test.networkProtocolNumber) + if err != nil { + t.Fatalf("s.NUDConfigurations(%d, %d): %s", outgoingNICID, test.networkProtocolNumber, err) + } + // Trigger the first packet on the endpoint. + iptestutil.RunImmediatelyScheduledJobs(clock) + + for i := 0; i < int(nudConfigs.MaxMulticastProbes); i++ { + if request, ok = outgoingEndpoint.Read(); !ok { + t.Fatal("expected ARP packet through outgoing NIC") + } + + test.linkResolutionRequestChecker(t, request, test.outgoingAddr.Address, test.destAddr) + + // Advance the clock the span of one request timeout. + clock.Advance(nudConfigs.RetransmitTimer) + } + + // Next, we make a blocking read to retrieve the error packet. This is + // necessary because outgoing packets are dequeued asynchronously when + // link resolution fails, and this dequeue is what triggers the ICMP + // error. + // + // TODO(gvisor.dev/issue/6012): Replace with asynchronous read after we + // have integrated the stack clock with the dequeuing code. + reply, ok := incomingEndpoint.ReadContext(context.Background()) + if !ok { + t.Fatal("expected ICMP packet through incoming NIC") + } + + test.icmpReplyChecker(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), test.incomingAddr.Address, test.sourceAddr) + + // Since link resolution failed, we don't expect the packet to be + // forwarded. + forwardedPacket, ok := outgoingEndpoint.Read() + if ok { + t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", forwardedPacket) + } + + if got, want := s.Stats().IP.Forwarding.HostUnreachable.Value(), expectedHostUnreachableErrorCount; int(got) != want { + t.Errorf("got rt.Stats().IP.Forwarding.HostUnreachable.Value() = %d, want = %d", got, want) + } + }) + } +} + func TestGetLinkAddress(t *testing.T) { const ( host1NICID = 1 |