diff options
Diffstat (limited to 'pkg/tcpip/tests/integration')
-rw-r--r-- | pkg/tcpip/tests/integration/BUILD | 167 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/forward_test.go | 698 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/iptables_test.go | 2271 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/istio_test.go | 365 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/link_resolution_test.go | 1640 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/loopback_test.go | 782 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/multicast_broadcast_test.go | 723 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/route_test.go | 441 |
8 files changed, 0 insertions, 7087 deletions
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD deleted file mode 100644 index 99f4d4d0e..000000000 --- a/pkg/tcpip/tests/integration/BUILD +++ /dev/null @@ -1,167 +0,0 @@ -load("//tools:defs.bzl", "go_test") - -package(licenses = ["notice"]) - -go_test( - name = "forward_test", - size = "small", - srcs = ["forward_test.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/network/arp", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/tests/utils", - "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) - -go_test( - name = "iptables_test", - size = "small", - srcs = ["iptables_test.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/network/arp", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/tests/utils", - "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) - -go_test( - name = "link_resolution_test", - size = "small", - srcs = ["link_resolution_test.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/faketime", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/pipe", - "//pkg/tcpip/network/arp", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/tests/utils", - "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", - ], -) - -go_test( - name = "loopback_test", - size = "small", - srcs = ["loopback_test.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/tests/utils", - "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/tcp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) - -go_test( - name = "multicast_broadcast_test", - size = "small", - srcs = ["multicast_broadcast_test.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/tests/utils", - "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) - -go_test( - name = "route_test", - size = "small", - srcs = ["route_test.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/stack", - "//pkg/tcpip/tests/utils", - "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) - -go_test( - name = "istio_test", - size = "small", - srcs = ["istio_test.go"], - deps = [ - "//pkg/context", - "//pkg/rand", - "//pkg/sync", - "//pkg/tcpip", - "//pkg/tcpip/adapters/gonet", - "//pkg/tcpip/header", - "//pkg/tcpip/link/loopback", - "//pkg/tcpip/link/pipe", - "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/stack", - "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/tcp", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go deleted file mode 100644 index 6e1d4720d..000000000 --- a/pkg/tcpip/tests/integration/forward_test.go +++ /dev/null @@ -1,698 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package forward_test - -import ( - "bytes" - "fmt" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/tests/utils" - "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ttl = 64 - -var ( - ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10") - ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a") -) - -func rxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { - utils.RxICMPv4EchoRequest(e, src, dst, ttl) -} - -func rxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { - utils.RxICMPv6EchoRequest(e, src, dst, ttl) -} - -func forwardedICMPv4EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) { - checker.IPv4(t, b, - checker.SrcAddr(src), - checker.DstAddr(dst), - checker.TTL(ttl-1), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4Echo))) -} - -func forwardedICMPv6EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) { - checker.IPv6(t, b, - checker.SrcAddr(src), - checker.DstAddr(dst), - checker.TTL(ttl-1), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6EchoRequest))) -} - -func TestForwarding(t *testing.T) { - const listenPort = 8080 - - type endpointAndAddresses struct { - serverEP tcpip.Endpoint - serverAddr tcpip.Address - serverReadableCH chan struct{} - - clientEP tcpip.Endpoint - clientAddr tcpip.Address - clientReadableCH chan struct{} - } - - newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) { - t.Helper() - var wq waiter.Queue - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - ep, err := s.NewEndpoint(transProto, netProto, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err) - } - - t.Cleanup(func() { - wq.EventUnregister(&we) - }) - - return ep, ch - } - - tests := []struct { - name string - epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses - }{ - { - name: "IPv4 host1 server with host2 client", - epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { - ep1, ep1WECH := newEP(t, host1Stack, proto, ipv4.ProtocolNumber) - ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) - return endpointAndAddresses{ - serverEP: ep1, - serverAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address, - serverReadableCH: ep1WECH, - - clientEP: ep2, - clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, - clientReadableCH: ep2WECH, - } - }, - }, - { - name: "IPv6 host2 server with host1 client", - epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { - ep1, ep1WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber) - ep2, ep2WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber) - return endpointAndAddresses{ - serverEP: ep1, - serverAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, - serverReadableCH: ep1WECH, - - clientEP: ep2, - clientAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address, - clientReadableCH: ep2WECH, - } - }, - }, - { - name: "IPv4 host2 server with routerNIC1 client", - epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { - ep1, ep1WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) - ep2, ep2WECH := newEP(t, routerStack, proto, ipv4.ProtocolNumber) - return endpointAndAddresses{ - serverEP: ep1, - serverAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, - serverReadableCH: ep1WECH, - - clientEP: ep2, - clientAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - clientReadableCH: ep2WECH, - } - }, - }, - { - name: "IPv6 routerNIC2 server with host1 client", - epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { - ep1, ep1WECH := newEP(t, routerStack, proto, ipv6.ProtocolNumber) - ep2, ep2WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber) - return endpointAndAddresses{ - serverEP: ep1, - serverAddr: utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, - serverReadableCH: ep1WECH, - - clientEP: ep2, - clientAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address, - clientReadableCH: ep2WECH, - } - }, - }, - } - - subTests := []struct { - name string - proto tcpip.TransportProtocolNumber - expectedConnectErr tcpip.Error - setupServer func(t *testing.T, ep tcpip.Endpoint) - setupServerConn func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) - needRemoteAddr bool - }{ - { - name: "UDP", - proto: udp.ProtocolNumber, - expectedConnectErr: nil, - setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { - t.Helper() - - if err := ep.Connect(clientAddr); err != nil { - t.Fatalf("ep.Connect(%#v): %s", clientAddr, err) - } - return nil, nil - }, - needRemoteAddr: true, - }, - { - name: "TCP", - proto: tcp.ProtocolNumber, - expectedConnectErr: &tcpip.ErrConnectStarted{}, - setupServer: func(t *testing.T, ep tcpip.Endpoint) { - t.Helper() - - if err := ep.Listen(1); err != nil { - t.Fatalf("ep.Listen(1): %s", err) - } - }, - setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { - t.Helper() - - var addr tcpip.FullAddress - for { - newEP, wq, err := ep.Accept(&addr) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - <-ch - continue - } - if err != nil { - t.Fatalf("ep.Accept(_): %s", err) - } - if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath( - "NIC", - )); diff != "" { - t.Errorf("accepted address mismatch (-want +got):\n%s", diff) - } - - we, newCH := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - return newEP, newCH - } - }, - needRemoteAddr: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, subTest := range subTests { - t.Run(subTest.name, func(t *testing.T) { - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, - } - - host1Stack := stack.New(stackOpts) - routerStack := stack.New(stackOpts) - host2Stack := stack.New(stackOpts) - utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack) - - epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto) - defer epsAndAddrs.serverEP.Close() - defer epsAndAddrs.clientEP.Close() - - serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort} - if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil { - t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err) - } - clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr} - if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil { - t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err) - } - - if subTest.setupServer != nil { - subTest.setupServer(t, epsAndAddrs.serverEP) - } - { - err := epsAndAddrs.clientEP.Connect(serverAddr) - if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" { - t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff) - } - } - if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil { - t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err) - } else { - clientAddr = addr - clientAddr.NIC = 0 - } - - serverEP := epsAndAddrs.serverEP - serverCH := epsAndAddrs.serverReadableCH - if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, clientAddr); ep != nil { - defer ep.Close() - serverEP = ep - serverCH = ch - } - - write := func(ep tcpip.Endpoint, data []byte) { - t.Helper() - - var r bytes.Reader - r.Reset(data) - var wOpts tcpip.WriteOptions - n, err := ep.Write(&r, wOpts) - if err != nil { - t.Fatalf("ep.Write(_, %#v): %s", wOpts, err) - } - if want := int64(len(data)); n != want { - t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want) - } - } - - data := []byte{1, 2, 3, 4} - write(epsAndAddrs.clientEP, data) - - read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) { - t.Helper() - - var buf bytes.Buffer - var res tcpip.ReadResult - for { - var err tcpip.Error - opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr} - res, err = ep.Read(&buf, opts) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - <-ch - continue - } - if err != nil { - t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) - } - break - } - - readResult := tcpip.ReadResult{ - Count: len(data), - Total: len(data), - } - if subTest.needRemoteAddr { - readResult.RemoteAddr = expectedFrom - } - if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath( - "ControlMessages", - "RemoteAddr.NIC", - )); diff != "" { - t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(buf.Bytes(), data); diff != "" { - t.Errorf("received data mismatch (-want +got):\n%s", diff) - } - - if t.Failed() { - t.FailNow() - } - } - - read(serverCH, serverEP, data, clientAddr) - - data = []byte{5, 6, 7, 8, 9, 10, 11, 12} - write(serverEP, data) - read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr) - }) - } - }) - } -} - -func TestMulticastForwarding(t *testing.T) { - const ( - nicID1 = 1 - nicID2 = 2 - ) - - var ( - ipv4LinkLocalUnicastAddr = testutil.MustParse4("169.254.0.10") - ipv4LinkLocalMulticastAddr = testutil.MustParse4("224.0.0.10") - - ipv6LinkLocalUnicastAddr = testutil.MustParse6("fe80::a") - ipv6LinkLocalMulticastAddr = testutil.MustParse6("ff02::a") - ) - - tests := []struct { - name string - srcAddr, dstAddr tcpip.Address - rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) - expectForward bool - checker func(*testing.T, []byte) - }{ - { - name: "IPv4 link-local multicast destination", - srcAddr: utils.RemoteIPv4Addr, - dstAddr: ipv4LinkLocalMulticastAddr, - rx: rxICMPv4EchoRequest, - expectForward: false, - }, - { - name: "IPv4 link-local source", - srcAddr: ipv4LinkLocalUnicastAddr, - dstAddr: utils.RemoteIPv4Addr, - rx: rxICMPv4EchoRequest, - expectForward: false, - }, - { - name: "IPv4 link-local destination", - srcAddr: utils.RemoteIPv4Addr, - dstAddr: ipv4LinkLocalUnicastAddr, - rx: rxICMPv4EchoRequest, - expectForward: false, - }, - { - name: "IPv4 non-link-local unicast", - srcAddr: utils.RemoteIPv4Addr, - dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - rx: rxICMPv4EchoRequest, - expectForward: true, - checker: func(t *testing.T, b []byte) { - forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address) - }, - }, - { - name: "IPv4 non-link-local multicast", - srcAddr: utils.RemoteIPv4Addr, - dstAddr: ipv4GlobalMulticastAddr, - rx: rxICMPv4EchoRequest, - expectForward: true, - checker: func(t *testing.T, b []byte) { - forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr) - }, - }, - - { - name: "IPv6 link-local multicast destination", - srcAddr: utils.RemoteIPv6Addr, - dstAddr: ipv6LinkLocalMulticastAddr, - rx: rxICMPv6EchoRequest, - expectForward: false, - }, - { - name: "IPv6 link-local source", - srcAddr: ipv6LinkLocalUnicastAddr, - dstAddr: utils.RemoteIPv6Addr, - rx: rxICMPv6EchoRequest, - expectForward: false, - }, - { - name: "IPv6 link-local destination", - srcAddr: utils.RemoteIPv6Addr, - dstAddr: ipv6LinkLocalUnicastAddr, - rx: rxICMPv6EchoRequest, - expectForward: false, - }, - { - name: "IPv6 non-link-local unicast", - srcAddr: utils.RemoteIPv6Addr, - dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - rx: rxICMPv6EchoRequest, - expectForward: true, - checker: func(t *testing.T, b []byte) { - forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address) - }, - }, - { - name: "IPv6 non-link-local multicast", - srcAddr: utils.RemoteIPv6Addr, - dstAddr: ipv6GlobalMulticastAddr, - rx: rxICMPv6EchoRequest, - expectForward: true, - checker: func(t *testing.T, b []byte) { - forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr) - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - - e1 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID1, e1); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err) - } - - e2 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID2, e2); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err) - } - - protocolAddrV4 := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: utils.Ipv4Addr, - } - if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err) - } - protocolAddrV6 := tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: utils.Ipv6Addr, - } - if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err) - } - - if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) - } - if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: nicID2, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: nicID2, - }, - }) - - test.rx(e1, test.srcAddr, test.dstAddr) - - p, ok := e2.Read() - if ok != test.expectForward { - t.Fatalf("got e2.Read() = (%#v, %t), want = (_, %t)", p, ok, test.expectForward) - } - - if test.expectForward { - test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader())) - } - }) - } -} - -func TestPerInterfaceForwarding(t *testing.T) { - const ( - nicID1 = 1 - nicID2 = 2 - ) - - tests := []struct { - name string - srcAddr, dstAddr tcpip.Address - rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) - checker func(*testing.T, []byte) - }{ - { - name: "IPv4 unicast", - srcAddr: utils.RemoteIPv4Addr, - dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - rx: rxICMPv4EchoRequest, - checker: func(t *testing.T, b []byte) { - forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address) - }, - }, - { - name: "IPv4 multicast", - srcAddr: utils.RemoteIPv4Addr, - dstAddr: ipv4GlobalMulticastAddr, - rx: rxICMPv4EchoRequest, - checker: func(t *testing.T, b []byte) { - forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr) - }, - }, - - { - name: "IPv6 unicast", - srcAddr: utils.RemoteIPv6Addr, - dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - rx: rxICMPv6EchoRequest, - checker: func(t *testing.T, b []byte) { - forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address) - }, - }, - { - name: "IPv6 multicast", - srcAddr: utils.RemoteIPv6Addr, - dstAddr: ipv6GlobalMulticastAddr, - rx: rxICMPv6EchoRequest, - checker: func(t *testing.T, b []byte) { - forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr) - }, - }, - } - - netProtos := [...]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber, ipv6.ProtocolNumber} - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - // ARP is not used in this test but it is a network protocol that does - // not support forwarding. We install the protocol to make sure that - // forwarding information for a NIC is only reported for network - // protocols that support forwarding. - arp.NewProtocol, - - ipv4.NewProtocol, - ipv6.NewProtocol, - }, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - - e1 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID1, e1); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err) - } - - e2 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID2, e2); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err) - } - - for _, add := range [...]struct { - nicID tcpip.NICID - addr tcpip.ProtocolAddress - }{ - { - nicID: nicID1, - addr: utils.RouterNIC1IPv4Addr, - }, - { - nicID: nicID1, - addr: utils.RouterNIC1IPv6Addr, - }, - { - nicID: nicID2, - addr: utils.RouterNIC2IPv4Addr, - }, - { - nicID: nicID2, - addr: utils.RouterNIC2IPv6Addr, - }, - } { - if err := s.AddProtocolAddress(add.nicID, add.addr, stack.AddressProperties{}); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", add.nicID, add.addr, err) - } - } - - // Only enable forwarding on NIC1 and make sure that only packets arriving - // on NIC1 are forwarded. - for _, netProto := range netProtos { - if err := s.SetNICForwarding(nicID1, netProto, true); err != nil { - t.Fatalf("s.SetNICForwarding(%d, %d, true): %s", nicID1, netProtos, err) - } - } - - nicsInfo := s.NICInfo() - for _, subTest := range [...]struct { - nicID tcpip.NICID - nicEP *channel.Endpoint - otherNICID tcpip.NICID - otherNICEP *channel.Endpoint - expectForwarding bool - }{ - { - nicID: nicID1, - nicEP: e1, - otherNICID: nicID2, - otherNICEP: e2, - expectForwarding: true, - }, - { - nicID: nicID2, - nicEP: e2, - otherNICID: nicID2, - otherNICEP: e1, - expectForwarding: false, - }, - } { - t.Run(fmt.Sprintf("Packet arriving at NIC%d", subTest.nicID), func(t *testing.T) { - nicInfo, ok := nicsInfo[subTest.nicID] - if !ok { - t.Errorf("expected NIC info for NIC %d; got = %#v", subTest.nicID, nicsInfo) - } else { - forwarding := make(map[tcpip.NetworkProtocolNumber]bool) - for _, netProto := range netProtos { - forwarding[netProto] = subTest.expectForwarding - } - - if diff := cmp.Diff(forwarding, nicInfo.Forwarding); diff != "" { - t.Errorf("nicsInfo[%d].Forwarding mismatch (-want +got):\n%s", subTest.nicID, diff) - } - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: subTest.otherNICID, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: subTest.otherNICID, - }, - }) - - test.rx(subTest.nicEP, test.srcAddr, test.dstAddr) - if p, ok := subTest.nicEP.Read(); ok { - t.Errorf("unexpectedly got a response from the interface the packet arrived on: %#v", p) - } - if p, ok := subTest.otherNICEP.Read(); ok != subTest.expectForwarding { - t.Errorf("got otherNICEP.Read() = (%#v, %t), want = (_, %t)", p, ok, subTest.expectForwarding) - } else if subTest.expectForwarding { - test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader())) - } - }) - } - }) - } -} diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go deleted file mode 100644 index b2383576c..000000000 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ /dev/null @@ -1,2271 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package iptables_test - -import ( - "bytes" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/tests/utils" - "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -type inputIfNameMatcher struct { - name string -} - -var _ stack.Matcher = (*inputIfNameMatcher)(nil) - -func (*inputIfNameMatcher) Name() string { - return "inputIfNameMatcher" -} - -func (im *inputIfNameMatcher) Match(hook stack.Hook, _ *stack.PacketBuffer, inNicName, _ string) (bool, bool) { - return (hook == stack.Input && im.name != "" && im.name == inNicName), false -} - -const ( - nicID = 1 - nicName = "nic1" - anotherNicName = "nic2" - linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") - srcAddrV4 = tcpip.Address("\x0a\x00\x00\x01") - dstAddrV4 = tcpip.Address("\x0a\x00\x00\x02") - srcAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - dstAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - payloadSize = 20 -) - -func genStackV6(t *testing.T) (*stack.Stack, *channel.Endpoint) { - t.Helper() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - }) - e := channel.New(0, header.IPv6MinimumMTU, linkAddr) - nicOpts := stack.NICOptions{Name: nicName} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err) - } - protocolAddr := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: dstAddrV6.WithPrefix(), - } - if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) - } - return s, e -} - -func genStackV4(t *testing.T) (*stack.Stack, *channel.Endpoint) { - t.Helper() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - }) - e := channel.New(0, header.IPv4MinimumMTU, linkAddr) - nicOpts := stack.NICOptions{Name: nicName} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err) - } - protocolAddr := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: dstAddrV4.WithPrefix(), - } - if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) - } - return s, e -} - -func genPacketV6() *stack.PacketBuffer { - pktSize := header.IPv6MinimumSize + payloadSize - hdr := buffer.NewPrependable(pktSize) - ip := header.IPv6(hdr.Prepend(pktSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: payloadSize, - TransportProtocol: 99, - HopLimit: 255, - SrcAddr: srcAddrV6, - DstAddr: dstAddrV6, - }) - vv := hdr.View().ToVectorisedView() - return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv}) -} - -func genPacketV4() *stack.PacketBuffer { - pktSize := header.IPv4MinimumSize + payloadSize - hdr := buffer.NewPrependable(pktSize) - ip := header.IPv4(hdr.Prepend(pktSize)) - ip.Encode(&header.IPv4Fields{ - TOS: 0, - TotalLength: uint16(pktSize), - ID: 1, - Flags: 0, - FragmentOffset: 16, - TTL: 48, - Protocol: 99, - SrcAddr: srcAddrV4, - DstAddr: dstAddrV4, - }) - ip.SetChecksum(0) - ip.SetChecksum(^ip.CalculateChecksum()) - vv := hdr.View().ToVectorisedView() - return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv}) -} - -func TestIPTablesStatsForInput(t *testing.T) { - tests := []struct { - name string - setupStack func(*testing.T) (*stack.Stack, *channel.Endpoint) - setupFilter func(*testing.T, *stack.Stack) - genPacket func() *stack.PacketBuffer - proto tcpip.NetworkProtocolNumber - expectReceived int - expectInputDropped int - }{ - { - name: "IPv6 Accept", - setupStack: genStackV6, - setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, - genPacket: genPacketV6, - proto: header.IPv6ProtocolNumber, - expectReceived: 1, - expectInputDropped: 0, - }, - { - name: "IPv4 Accept", - setupStack: genStackV4, - setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, - genPacket: genPacketV4, - proto: header.IPv4ProtocolNumber, - expectReceived: 1, - expectInputDropped: 0, - }, - { - name: "IPv6 Drop (input interface matches)", - setupStack: genStackV6, - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Input] - filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName} - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}} - // Make sure the packet is not dropped by the next rule. - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) - } - }, - genPacket: genPacketV6, - proto: header.IPv6ProtocolNumber, - expectReceived: 1, - expectInputDropped: 1, - }, - { - name: "IPv4 Drop (input interface matches)", - setupStack: genStackV4, - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Input] - filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName} - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}} - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) - } - }, - genPacket: genPacketV4, - proto: header.IPv4ProtocolNumber, - expectReceived: 1, - expectInputDropped: 1, - }, - { - name: "IPv6 Accept (input interface does not match)", - setupStack: genStackV6, - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Input] - filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName} - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) - } - }, - genPacket: genPacketV6, - proto: header.IPv6ProtocolNumber, - expectReceived: 1, - expectInputDropped: 0, - }, - { - name: "IPv4 Accept (input interface does not match)", - setupStack: genStackV4, - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Input] - filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName} - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) - } - }, - genPacket: genPacketV4, - proto: header.IPv4ProtocolNumber, - expectReceived: 1, - expectInputDropped: 0, - }, - { - name: "IPv6 Drop (input interface does not match but invert is true)", - setupStack: genStackV6, - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Input] - filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{ - InputInterface: anotherNicName, - InputInterfaceInvert: true, - } - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) - } - }, - genPacket: genPacketV6, - proto: header.IPv6ProtocolNumber, - expectReceived: 1, - expectInputDropped: 1, - }, - { - name: "IPv4 Drop (input interface does not match but invert is true)", - setupStack: genStackV4, - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Input] - filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{ - InputInterface: anotherNicName, - InputInterfaceInvert: true, - } - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) - } - }, - genPacket: genPacketV4, - proto: header.IPv4ProtocolNumber, - expectReceived: 1, - expectInputDropped: 1, - }, - { - name: "IPv6 Accept (input interface does not match using a matcher)", - setupStack: genStackV6, - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Input] - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) - } - }, - genPacket: genPacketV6, - proto: header.IPv6ProtocolNumber, - expectReceived: 1, - expectInputDropped: 0, - }, - { - name: "IPv4 Accept (input interface does not match using a matcher)", - setupStack: genStackV4, - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) - ruleIdx := filter.BuiltinChains[stack.Input] - filter.Rules[ruleIdx].Target = &stack.DropTarget{} - filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) - } - }, - genPacket: genPacketV4, - proto: header.IPv4ProtocolNumber, - expectReceived: 1, - expectInputDropped: 0, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s, e := test.setupStack(t) - test.setupFilter(t, s) - e.InjectInbound(test.proto, test.genPacket()) - - if got := int(s.Stats().IP.PacketsReceived.Value()); got != test.expectReceived { - t.Errorf("got PacketReceived = %d, want = %d", got, test.expectReceived) - } - if got := int(s.Stats().IP.IPTablesInputDropped.Value()); got != test.expectInputDropped { - t.Errorf("got IPTablesInputDropped = %d, want = %d", got, test.expectInputDropped) - } - }) - } -} - -var _ stack.LinkEndpoint = (*channelEndpointWithoutWritePacket)(nil) - -// channelEndpointWithoutWritePacket is a channel endpoint that does not support -// stack.LinkEndpoint.WritePacket. -type channelEndpointWithoutWritePacket struct { - *channel.Endpoint - - t *testing.T -} - -func (c *channelEndpointWithoutWritePacket) WritePacket(stack.RouteInfo, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { - c.t.Error("unexpectedly called WritePacket; all writes should go through WritePackets") - return &tcpip.ErrNotSupported{} -} - -var _ stack.Matcher = (*udpSourcePortMatcher)(nil) - -type udpSourcePortMatcher struct { - port uint16 -} - -func (*udpSourcePortMatcher) Name() string { - return "udpSourcePortMatcher" -} - -func (m *udpSourcePortMatcher) Match(_ stack.Hook, pkt *stack.PacketBuffer, _, _ string) (matches, hotdrop bool) { - udp := header.UDP(pkt.TransportHeader().View()) - if len(udp) < header.UDPMinimumSize { - // Drop immediately as the packet is invalid. - return false, true - } - - return udp.SourcePort() == m.port, false -} - -func TestIPTableWritePackets(t *testing.T) { - const ( - nicID = 1 - - dropLocalPort = utils.LocalPort - 1 - acceptPackets = 2 - dropPackets = 3 - ) - - udpHdr := func(hdr buffer.View, srcAddr, dstAddr tcpip.Address, srcPort, dstPort uint16) { - u := header.UDP(hdr) - u.Encode(&header.UDPFields{ - SrcPort: srcPort, - DstPort: dstPort, - Length: header.UDPMinimumSize, - }) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, srcAddr, dstAddr, header.UDPMinimumSize) - sum = header.Checksum(hdr, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - } - - tests := []struct { - name string - setupFilter func(*testing.T, *stack.Stack) - genPacket func(*stack.Route) stack.PacketBufferList - proto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - expectSent uint64 - expectOutputDropped uint64 - }{ - { - name: "IPv4 Accept", - setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, - genPacket: func(r *stack.Route) stack.PacketBufferList { - var pkts stack.PacketBufferList - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), - }) - hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) - udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort) - pkts.PushFront(pkt) - - return pkts - }, - proto: header.IPv4ProtocolNumber, - remoteAddr: dstAddrV4, - expectSent: 1, - expectOutputDropped: 0, - }, - { - name: "IPv4 Drop Other Port", - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - - table := stack.Table{ - Rules: []stack.Rule{ - { - Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}, - }, - { - Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}, - }, - { - Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}}, - Target: &stack.DropTarget{NetworkProtocol: header.IPv4ProtocolNumber}, - }, - { - Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}, - }, - { - Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}, - }, - }, - BuiltinChains: [stack.NumHooks]int{ - stack.Prerouting: stack.HookUnset, - stack.Input: 0, - stack.Forward: 1, - stack.Output: 2, - stack.Postrouting: stack.HookUnset, - }, - Underflows: [stack.NumHooks]int{ - stack.Prerouting: stack.HookUnset, - stack.Input: 0, - stack.Forward: 1, - stack.Output: 2, - stack.Postrouting: stack.HookUnset, - }, - } - - if err := s.IPTables().ReplaceTable(stack.FilterID, table, false /* ipv4 */); err != nil { - t.Fatalf("ReplaceTable(%d, _, false): %s", stack.FilterID, err) - } - }, - genPacket: func(r *stack.Route) stack.PacketBufferList { - var pkts stack.PacketBufferList - - for i := 0; i < acceptPackets; i++ { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), - }) - hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) - udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort) - pkts.PushFront(pkt) - } - for i := 0; i < dropPackets; i++ { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), - }) - hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) - udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), dropLocalPort, utils.RemotePort) - pkts.PushFront(pkt) - } - - return pkts - }, - proto: header.IPv4ProtocolNumber, - remoteAddr: dstAddrV4, - expectSent: acceptPackets, - expectOutputDropped: dropPackets, - }, - { - name: "IPv6 Accept", - setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, - genPacket: func(r *stack.Route) stack.PacketBufferList { - var pkts stack.PacketBufferList - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), - }) - hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) - udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort) - pkts.PushFront(pkt) - - return pkts - }, - proto: header.IPv6ProtocolNumber, - remoteAddr: dstAddrV6, - expectSent: 1, - expectOutputDropped: 0, - }, - { - name: "IPv6 Drop Other Port", - setupFilter: func(t *testing.T, s *stack.Stack) { - t.Helper() - - table := stack.Table{ - Rules: []stack.Rule{ - { - Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}, - }, - { - Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}, - }, - { - Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}}, - Target: &stack.DropTarget{NetworkProtocol: header.IPv6ProtocolNumber}, - }, - { - Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}, - }, - { - Target: &stack.ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}, - }, - }, - BuiltinChains: [stack.NumHooks]int{ - stack.Prerouting: stack.HookUnset, - stack.Input: 0, - stack.Forward: 1, - stack.Output: 2, - stack.Postrouting: stack.HookUnset, - }, - Underflows: [stack.NumHooks]int{ - stack.Prerouting: stack.HookUnset, - stack.Input: 0, - stack.Forward: 1, - stack.Output: 2, - stack.Postrouting: stack.HookUnset, - }, - } - - if err := s.IPTables().ReplaceTable(stack.FilterID, table, true /* ipv6 */); err != nil { - t.Fatalf("ReplaceTable(%d, _, true): %s", stack.FilterID, err) - } - }, - genPacket: func(r *stack.Route) stack.PacketBufferList { - var pkts stack.PacketBufferList - - for i := 0; i < acceptPackets; i++ { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), - }) - hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) - udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort) - pkts.PushFront(pkt) - } - for i := 0; i < dropPackets; i++ { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), - }) - hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) - udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), dropLocalPort, utils.RemotePort) - pkts.PushFront(pkt) - } - - return pkts - }, - proto: header.IPv6ProtocolNumber, - remoteAddr: dstAddrV6, - expectSent: acceptPackets, - expectOutputDropped: dropPackets, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - e := channelEndpointWithoutWritePacket{ - Endpoint: channel.New(4, header.IPv6MinimumMTU, linkAddr), - t: t, - } - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - protocolAddrV6 := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: srcAddrV6.WithPrefix(), - } - if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err) - } - protocolAddrV4 := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: srcAddrV4.WithPrefix(), - } - if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }, - }) - - test.setupFilter(t, s) - - r, err := s.FindRoute(nicID, "", test.remoteAddr, test.proto, false) - if err != nil { - t.Fatalf("FindRoute(%d, '', %s, %d, false): %s", nicID, test.remoteAddr, test.proto, err) - } - defer r.Release() - - pkts := test.genPacket(r) - pktsLen := pkts.Len() - if n, err := r.WritePackets(pkts, stack.NetworkHeaderParams{ - Protocol: header.UDPProtocolNumber, - TTL: 64, - }); err != nil { - t.Fatalf("WritePackets(...): %s", err) - } else if n != pktsLen { - t.Fatalf("got WritePackets(...) = %d, want = %d", n, pktsLen) - } - - if got := s.Stats().IP.PacketsSent.Value(); got != test.expectSent { - t.Errorf("got PacketSent = %d, want = %d", got, test.expectSent) - } - if got := s.Stats().IP.IPTablesOutputDropped.Value(); got != test.expectOutputDropped { - t.Errorf("got IPTablesOutputDropped = %d, want = %d", got, test.expectOutputDropped) - } - }) - } -} - -const ttl = 64 - -var ( - ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10") - ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a") -) - -func rxICMPv4EchoReply(e *channel.Endpoint, src, dst tcpip.Address) { - utils.RxICMPv4EchoReply(e, src, dst, ttl) -} - -func rxICMPv6EchoReply(e *channel.Endpoint, src, dst tcpip.Address) { - utils.RxICMPv6EchoReply(e, src, dst, ttl) -} - -func forwardedICMPv4EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Address) { - checker.IPv4(t, b, - checker.SrcAddr(src), - checker.DstAddr(dst), - checker.TTL(ttl-1), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4EchoReply))) -} - -func forwardedICMPv6EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Address) { - checker.IPv6(t, b, - checker.SrcAddr(src), - checker.DstAddr(dst), - checker.TTL(ttl-1), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6EchoReply))) -} - -func boolToInt(v bool) uint64 { - if v { - return 1 - } - return 0 -} - -func setupDropFilter(hook stack.Hook, f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { - return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) { - t.Helper() - - ipv6 := netProto == ipv6.ProtocolNumber - - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, ipv6) - ruleIdx := filter.BuiltinChains[hook] - filter.Rules[ruleIdx].Filter = f - filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto} - // Make sure the packet is not dropped by the next rule. - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto} - if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err) - } - } -} - -func TestForwardingHook(t *testing.T) { - const ( - nicID1 = 1 - nicID2 = 2 - - nic1Name = "nic1" - nic2Name = "nic2" - - otherNICName = "otherNIC" - ) - - tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - local bool - srcAddr, dstAddr tcpip.Address - rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) - checker func(*testing.T, []byte) - }{ - { - name: "IPv4 remote", - netProto: ipv4.ProtocolNumber, - local: false, - srcAddr: utils.RemoteIPv4Addr, - dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - rx: rxICMPv4EchoReply, - checker: func(t *testing.T, b []byte) { - forwardedICMPv4EchoReplyChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address) - }, - }, - { - name: "IPv4 local", - netProto: ipv4.ProtocolNumber, - local: true, - srcAddr: utils.RemoteIPv4Addr, - dstAddr: utils.Ipv4Addr.Address, - rx: rxICMPv4EchoReply, - }, - { - name: "IPv6 remote", - netProto: ipv6.ProtocolNumber, - local: false, - srcAddr: utils.RemoteIPv6Addr, - dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - rx: rxICMPv6EchoReply, - checker: func(t *testing.T, b []byte) { - forwardedICMPv6EchoReplyChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address) - }, - }, - { - name: "IPv6 local", - netProto: ipv6.ProtocolNumber, - local: true, - srcAddr: utils.RemoteIPv6Addr, - dstAddr: utils.Ipv6Addr.Address, - rx: rxICMPv6EchoReply, - }, - } - - subTests := []struct { - name string - setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) - expectForward bool - }{ - { - name: "Accept", - setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ }, - expectForward: true, - }, - - { - name: "Drop", - setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{}), - expectForward: false, - }, - { - name: "Drop with input NIC filtering", - setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name}), - expectForward: false, - }, - { - name: "Drop with output NIC filtering", - setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name}), - expectForward: false, - }, - { - name: "Drop with input and output NIC filtering", - setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}), - expectForward: false, - }, - - { - name: "Drop with other input NIC filtering", - setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName}), - expectForward: true, - }, - { - name: "Drop with other output NIC filtering", - setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: otherNICName}), - expectForward: true, - }, - { - name: "Drop with other input and output NIC filtering", - setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}), - expectForward: true, - }, - { - name: "Drop with input and other output NIC filtering", - setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}), - expectForward: true, - }, - { - name: "Drop with other input and other output NIC filtering", - setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}), - expectForward: true, - }, - - { - name: "Drop with inverted input NIC filtering", - setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}), - expectForward: true, - }, - { - name: "Drop with inverted output NIC filtering", - setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}), - expectForward: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, subTest := range subTests { - t.Run(subTest.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - }) - - subTest.setupFilter(t, s, test.netProto) - - e1 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil { - t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err) - } - - e2 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil { - t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err) - } - - protocolAddrV4 := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: utils.Ipv4Addr.Address.WithPrefix(), - } - if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err) - } - protocolAddrV6 := tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: utils.Ipv6Addr.Address.WithPrefix(), - } - if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err) - } - - if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) - } - if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: nicID2, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: nicID2, - }, - }) - - test.rx(e1, test.srcAddr, test.dstAddr) - - expectTransmitPacket := subTest.expectForward && !test.local - - ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err) - } - ep1Stats := ep1.Stats() - ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats) - if !ok { - t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats) - } - ip1Stats := ipEP1Stats.IPStats() - - if got := ip1Stats.PacketsReceived.Value(); got != 1 { - t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got) - } - if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 { - t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got) - } - if got, want := ip1Stats.IPTablesForwardDropped.Value(), boolToInt(!subTest.expectForward); got != want { - t.Errorf("got ip1Stats.IPTablesForwardDropped.Value() = %d, want = %d", got, want) - } - if got := ip1Stats.PacketsSent.Value(); got != 0 { - t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = 0", got) - } - - ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err) - } - ep2Stats := ep2.Stats() - ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats) - if !ok { - t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats) - } - ip2Stats := ipEP2Stats.IPStats() - if got := ip2Stats.PacketsReceived.Value(); got != 0 { - t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got) - } - if got, want := ip2Stats.ValidPacketsReceived.Value(), boolToInt(subTest.expectForward && test.local); got != want { - t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = %d", got, want) - } - if got, want := ip2Stats.PacketsSent.Value(), boolToInt(expectTransmitPacket); got != want { - t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = %d", got, want) - } - - p, ok := e2.Read() - if ok != expectTransmitPacket { - t.Fatalf("got e2.Read() = (%#v, %t), want = (_, %t)", p, ok, expectTransmitPacket) - } - if expectTransmitPacket { - test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader())) - } - }) - } - }) - } -} - -func TestInputHookWithLocalForwarding(t *testing.T) { - const ( - nicID1 = 1 - nicID2 = 2 - - nic1Name = "nic1" - nic2Name = "nic2" - - otherNICName = "otherNIC" - ) - - tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - rx func(*channel.Endpoint) - checker func(*testing.T, []byte) - }{ - { - name: "IPv4", - netProto: ipv4.ProtocolNumber, - rx: func(e *channel.Endpoint) { - utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address, ttl) - }, - checker: func(t *testing.T, b []byte) { - checker.IPv4(t, b, - checker.SrcAddr(utils.Ipv4Addr2.AddressWithPrefix.Address), - checker.DstAddr(utils.RemoteIPv4Addr), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4EchoReply))) - }, - }, - { - name: "IPv6", - netProto: ipv6.ProtocolNumber, - rx: func(e *channel.Endpoint) { - utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address, ttl) - }, - checker: func(t *testing.T, b []byte) { - checker.IPv6(t, b, - checker.SrcAddr(utils.Ipv6Addr2.AddressWithPrefix.Address), - checker.DstAddr(utils.RemoteIPv6Addr), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6EchoReply))) - }, - }, - } - - subTests := []struct { - name string - setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) - expectDrop bool - }{ - { - name: "Accept", - setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ }, - expectDrop: false, - }, - - { - name: "Drop", - setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{}), - expectDrop: true, - }, - { - name: "Drop with input NIC filtering on arrival NIC", - setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic1Name}), - expectDrop: true, - }, - { - name: "Drop with input NIC filtering on delivered NIC", - setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic2Name}), - expectDrop: false, - }, - - { - name: "Drop with input NIC filtering on other NIC", - setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: otherNICName}), - expectDrop: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, subTest := range subTests { - t.Run(subTest.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - }) - - subTest.setupFilter(t, s, test.netProto) - - e1 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil { - t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err) - } - if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv4Addr1, err) - } - if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv6Addr1, err) - } - - e2 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil { - t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err) - } - if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv4Addr2, err) - } - if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv6Addr2, err) - } - - if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) - } - if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: nicID1, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: nicID1, - }, - }) - - test.rx(e1) - - ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err) - } - ep1Stats := ep1.Stats() - ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats) - if !ok { - t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats) - } - ip1Stats := ipEP1Stats.IPStats() - - if got := ip1Stats.PacketsReceived.Value(); got != 1 { - t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got) - } - if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 { - t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got) - } - if got, want := ip1Stats.PacketsSent.Value(), boolToInt(!subTest.expectDrop); got != want { - t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = %d", got, want) - } - - ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto) - if err != nil { - t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err) - } - ep2Stats := ep2.Stats() - ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats) - if !ok { - t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats) - } - ip2Stats := ipEP2Stats.IPStats() - if got := ip2Stats.PacketsReceived.Value(); got != 0 { - t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got) - } - if got := ip2Stats.ValidPacketsReceived.Value(); got != 1 { - t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = 1", got) - } - if got, want := ip2Stats.IPTablesInputDropped.Value(), boolToInt(subTest.expectDrop); got != want { - t.Errorf("got ip2Stats.IPTablesInputDropped.Value() = %d, want = %d", got, want) - } - if got := ip2Stats.PacketsSent.Value(); got != 0 { - t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = 0", got) - } - - if p, ok := e1.Read(); ok == subTest.expectDrop { - t.Errorf("got e1.Read() = (%#v, %t), want = (_, %t)", p, ok, !subTest.expectDrop) - } else if !subTest.expectDrop { - test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader())) - } - if p, ok := e2.Read(); ok { - t.Errorf("got e1.Read() = (%#v, true), want = (_, false)", p) - } - }) - } - }) - } -} - -func TestNAT(t *testing.T) { - const listenPort uint16 = 8080 - - type endpointAndAddresses struct { - serverEP tcpip.Endpoint - serverAddr tcpip.FullAddress - serverReadableCH chan struct{} - serverConnectAddr tcpip.Address - - clientEP tcpip.Endpoint - clientAddr tcpip.Address - clientReadableCH chan struct{} - clientConnectAddr tcpip.FullAddress - } - - newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) { - t.Helper() - var wq waiter.Queue - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - t.Cleanup(func() { - wq.EventUnregister(&we) - }) - - ep, err := s.NewEndpoint(transProto, netProto, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err) - } - t.Cleanup(ep.Close) - - return ep, ch - } - - setupNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, hook stack.Hook, filter stack.IPHeaderFilter, target stack.Target) { - t.Helper() - - ipv6 := netProto == ipv6.ProtocolNumber - ipt := s.IPTables() - table := ipt.GetTable(stack.NATID, ipv6) - ruleIdx := table.BuiltinChains[hook] - table.Rules[ruleIdx].Filter = filter - table.Rules[ruleIdx].Target = target - // Make sure the packet is not dropped by the next rule. - table.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err) - } - } - - setupDNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, target stack.Target) { - t.Helper() - - setupNAT( - t, - s, - netProto, - stack.Prerouting, - stack.IPHeaderFilter{ - Protocol: transProto, - CheckProtocol: true, - InputInterface: utils.RouterNIC2Name, - }, - target) - } - - setupSNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, target stack.Target) { - t.Helper() - - setupNAT( - t, - s, - netProto, - stack.Postrouting, - stack.IPHeaderFilter{ - Protocol: transProto, - CheckProtocol: true, - OutputInterface: utils.RouterNIC1Name, - }, - target) - } - - type natType struct { - name string - setupNAT func(_ *testing.T, _ *stack.Stack, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address) - } - - snatTypes := []natType{ - { - name: "SNAT", - setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, _ tcpip.Address) { - t.Helper() - - setupSNAT(t, s, netProto, transProto, &stack.SNATTarget{NetworkProtocol: netProto, Addr: snatAddr}) - }, - }, - { - name: "Masquerade", - setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, _ tcpip.Address) { - t.Helper() - - setupSNAT(t, s, netProto, transProto, &stack.MasqueradeTarget{NetworkProtocol: netProto}) - }, - }, - } - dnatTypes := []natType{ - { - name: "Redirect", - setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, _ tcpip.Address) { - t.Helper() - - setupDNAT(t, s, netProto, transProto, &stack.RedirectTarget{NetworkProtocol: netProto, Port: listenPort}) - }, - }, - { - name: "DNAT", - setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, dnatAddr tcpip.Address) { - t.Helper() - - setupDNAT(t, s, netProto, transProto, &stack.DNATTarget{NetworkProtocol: netProto, Addr: dnatAddr, Port: listenPort}) - }, - }, - } - - setupTwiceNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, dnatAddr tcpip.Address, snatTarget stack.Target) { - t.Helper() - - ipv6 := netProto == ipv6.ProtocolNumber - ipt := s.IPTables() - - table := stack.Table{ - Rules: []stack.Rule{ - // Prerouting - { - Filter: stack.IPHeaderFilter{ - Protocol: transProto, - CheckProtocol: true, - InputInterface: utils.RouterNIC2Name, - }, - Target: &stack.DNATTarget{NetworkProtocol: netProto, Addr: dnatAddr, Port: listenPort}, - }, - { - Target: &stack.AcceptTarget{}, - }, - - // Input - { - Target: &stack.AcceptTarget{}, - }, - - // Forward - { - Target: &stack.AcceptTarget{}, - }, - - // Output - { - Target: &stack.AcceptTarget{}, - }, - - // Postrouting - { - Filter: stack.IPHeaderFilter{ - Protocol: transProto, - CheckProtocol: true, - OutputInterface: utils.RouterNIC1Name, - }, - Target: snatTarget, - }, - { - Target: &stack.AcceptTarget{}, - }, - }, - BuiltinChains: [stack.NumHooks]int{ - stack.Prerouting: 0, - stack.Input: 2, - stack.Forward: 3, - stack.Output: 4, - stack.Postrouting: 5, - }, - } - - if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err) - } - } - twiceNATTypes := []natType{ - { - name: "DNAT-Masquerade", - setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address) { - t.Helper() - - setupTwiceNAT(t, s, netProto, transProto, dnatAddr, &stack.MasqueradeTarget{NetworkProtocol: netProto}) - }, - }, - { - name: "DNAT-SNAT", - setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address) { - t.Helper() - - setupTwiceNAT(t, s, netProto, transProto, dnatAddr, &stack.SNATTarget{NetworkProtocol: netProto, Addr: snatAddr}) - }, - }, - } - - tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - // Setups up the stacks in such a way that: - // - // - Host2 is the client for all tests. - // - When performing SNAT only: - // + Host1 is the server. - // + NAT will transform client-originating packets' source addresses to - // the router's NIC1's address before reaching Host1. - // - When performing DNAT only: - // + Router is the server. - // + Client will send packets directed to Host1. - // + NAT will transform client-originating packets' destination addresses - // to the router's NIC2's address. - // - When performing Twice-NAT: - // + Host1 is the server. - // + Client will send packets directed to router's NIC2. - // + NAT will transform client originating packets' destination addresses - // to Host1's address. - // + NAT will transform client-originating packets' source addresses to - // the router's NIC1's address before reaching Host1. - epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses - natTypes []natType - }{ - { - name: "IPv4 SNAT", - netProto: ipv4.ProtocolNumber, - epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { - t.Helper() - - listenerStack := host1Stack - serverAddr := tcpip.FullAddress{ - Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address, - Port: listenPort, - } - serverConnectAddr := utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address - clientConnectPort := serverAddr.Port - ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber) - ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) - return endpointAndAddresses{ - serverEP: ep1, - serverAddr: serverAddr, - serverReadableCH: ep1WECH, - serverConnectAddr: serverConnectAddr, - - clientEP: ep2, - clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, - clientReadableCH: ep2WECH, - clientConnectAddr: tcpip.FullAddress{ - Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address, - Port: clientConnectPort, - }, - } - }, - natTypes: snatTypes, - }, - { - name: "IPv4 DNAT", - netProto: ipv4.ProtocolNumber, - epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { - t.Helper() - - // If we are performing DNAT, then the packet will be redirected - // to the router. - listenerStack := routerStack - serverAddr := tcpip.FullAddress{ - Addr: utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, - Port: listenPort, - } - serverConnectAddr := utils.Host2IPv4Addr.AddressWithPrefix.Address - // DNAT will update the destination port to what the server is - // bound to. - clientConnectPort := serverAddr.Port + 1 - ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber) - ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) - return endpointAndAddresses{ - serverEP: ep1, - serverAddr: serverAddr, - serverReadableCH: ep1WECH, - serverConnectAddr: serverConnectAddr, - - clientEP: ep2, - clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, - clientReadableCH: ep2WECH, - clientConnectAddr: tcpip.FullAddress{ - Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address, - Port: clientConnectPort, - }, - } - }, - natTypes: dnatTypes, - }, - { - name: "IPv4 Twice-NAT", - netProto: ipv4.ProtocolNumber, - epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { - t.Helper() - - listenerStack := host1Stack - serverAddr := tcpip.FullAddress{ - Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address, - Port: listenPort, - } - serverConnectAddr := utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address - clientConnectPort := serverAddr.Port - ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber) - ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) - return endpointAndAddresses{ - serverEP: ep1, - serverAddr: serverAddr, - serverReadableCH: ep1WECH, - serverConnectAddr: serverConnectAddr, - - clientEP: ep2, - clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, - clientReadableCH: ep2WECH, - clientConnectAddr: tcpip.FullAddress{ - Addr: utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, - Port: clientConnectPort, - }, - } - }, - natTypes: twiceNATTypes, - }, - { - name: "IPv6 SNAT", - netProto: ipv6.ProtocolNumber, - epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { - t.Helper() - - listenerStack := host1Stack - serverAddr := tcpip.FullAddress{ - Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address, - Port: listenPort, - } - serverConnectAddr := utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address - clientConnectPort := serverAddr.Port - ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber) - ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber) - return endpointAndAddresses{ - serverEP: ep1, - serverAddr: serverAddr, - serverReadableCH: ep1WECH, - serverConnectAddr: serverConnectAddr, - - clientEP: ep2, - clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, - clientReadableCH: ep2WECH, - clientConnectAddr: tcpip.FullAddress{ - Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address, - Port: clientConnectPort, - }, - } - }, - natTypes: snatTypes, - }, - { - name: "IPv6 DNAT", - netProto: ipv6.ProtocolNumber, - epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { - t.Helper() - - // If we are performing DNAT, then the packet will be redirected - // to the router. - listenerStack := routerStack - serverAddr := tcpip.FullAddress{ - Addr: utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, - Port: listenPort, - } - serverConnectAddr := utils.Host2IPv6Addr.AddressWithPrefix.Address - // DNAT will update the destination port to what the server is - // bound to. - clientConnectPort := serverAddr.Port + 1 - ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber) - ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber) - return endpointAndAddresses{ - serverEP: ep1, - serverAddr: serverAddr, - serverReadableCH: ep1WECH, - serverConnectAddr: serverConnectAddr, - - clientEP: ep2, - clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, - clientReadableCH: ep2WECH, - clientConnectAddr: tcpip.FullAddress{ - Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address, - Port: clientConnectPort, - }, - } - }, - natTypes: dnatTypes, - }, - { - name: "IPv6 Twice-NAT", - netProto: ipv6.ProtocolNumber, - epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { - t.Helper() - - listenerStack := host1Stack - serverAddr := tcpip.FullAddress{ - Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address, - Port: listenPort, - } - serverConnectAddr := utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address - clientConnectPort := serverAddr.Port - ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber) - ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber) - return endpointAndAddresses{ - serverEP: ep1, - serverAddr: serverAddr, - serverReadableCH: ep1WECH, - serverConnectAddr: serverConnectAddr, - - clientEP: ep2, - clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, - clientReadableCH: ep2WECH, - clientConnectAddr: tcpip.FullAddress{ - Addr: utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, - Port: clientConnectPort, - }, - } - }, - natTypes: twiceNATTypes, - }, - } - - subTests := []struct { - name string - proto tcpip.TransportProtocolNumber - expectedConnectErr tcpip.Error - setupServer func(t *testing.T, ep tcpip.Endpoint) - setupServerConn func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) - needRemoteAddr bool - }{ - { - name: "UDP", - proto: udp.ProtocolNumber, - expectedConnectErr: nil, - setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { - t.Helper() - - if err := ep.Connect(clientAddr); err != nil { - t.Fatalf("ep.Connect(%#v): %s", clientAddr, err) - } - return nil, nil - }, - needRemoteAddr: true, - }, - { - name: "TCP", - proto: tcp.ProtocolNumber, - expectedConnectErr: &tcpip.ErrConnectStarted{}, - setupServer: func(t *testing.T, ep tcpip.Endpoint) { - t.Helper() - - if err := ep.Listen(1); err != nil { - t.Fatalf("ep.Listen(1): %s", err) - } - }, - setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { - t.Helper() - - var addr tcpip.FullAddress - for { - newEP, wq, err := ep.Accept(&addr) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - <-ch - continue - } - if err != nil { - t.Fatalf("ep.Accept(_): %s", err) - } - if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath( - "NIC", - )); diff != "" { - t.Errorf("accepted address mismatch (-want +got):\n%s", diff) - } - - we, newCH := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - return newEP, newCH - } - }, - needRemoteAddr: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, subTest := range subTests { - t.Run(subTest.name, func(t *testing.T) { - for _, natType := range test.natTypes { - t.Run(natType.name, func(t *testing.T) { - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, - } - - host1Stack := stack.New(stackOpts) - routerStack := stack.New(stackOpts) - host2Stack := stack.New(stackOpts) - utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack) - - epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto) - natType.setupNAT(t, routerStack, test.netProto, subTest.proto, epsAndAddrs.serverConnectAddr, epsAndAddrs.serverAddr.Addr) - - if err := epsAndAddrs.serverEP.Bind(epsAndAddrs.serverAddr); err != nil { - t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", epsAndAddrs.serverAddr, err) - } - clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr} - if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil { - t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err) - } - - if subTest.setupServer != nil { - subTest.setupServer(t, epsAndAddrs.serverEP) - } - { - err := epsAndAddrs.clientEP.Connect(epsAndAddrs.clientConnectAddr) - if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" { - t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", epsAndAddrs.clientConnectAddr, diff) - } - } - serverConnectAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverConnectAddr} - if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil { - t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err) - } else { - serverConnectAddr.Port = addr.Port - } - - serverEP := epsAndAddrs.serverEP - serverCH := epsAndAddrs.serverReadableCH - if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, serverConnectAddr); ep != nil { - defer ep.Close() - serverEP = ep - serverCH = ch - } - - write := func(ep tcpip.Endpoint, data []byte) { - t.Helper() - - var r bytes.Reader - r.Reset(data) - var wOpts tcpip.WriteOptions - n, err := ep.Write(&r, wOpts) - if err != nil { - t.Fatalf("ep.Write(_, %#v): %s", wOpts, err) - } - if want := int64(len(data)); n != want { - t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want) - } - } - - read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) { - t.Helper() - - var buf bytes.Buffer - var res tcpip.ReadResult - for { - var err tcpip.Error - opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr} - res, err = ep.Read(&buf, opts) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - <-ch - continue - } - if err != nil { - t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) - } - break - } - - readResult := tcpip.ReadResult{ - Count: len(data), - Total: len(data), - } - if subTest.needRemoteAddr { - readResult.RemoteAddr = expectedFrom - } - if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath( - "ControlMessages", - "RemoteAddr.NIC", - )); diff != "" { - t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(buf.Bytes(), data); diff != "" { - t.Errorf("received data mismatch (-want +got):\n%s", diff) - } - - if t.Failed() { - t.FailNow() - } - } - - { - data := []byte{1, 2, 3, 4} - write(epsAndAddrs.clientEP, data) - read(serverCH, serverEP, data, serverConnectAddr) - } - - { - data := []byte{5, 6, 7, 8, 9, 10, 11, 12} - write(serverEP, data) - read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, epsAndAddrs.clientConnectAddr) - } - }) - } - }) - } - }) - } -} - -func TestNATICMPError(t *testing.T) { - const ( - srcPort = 1234 - dstPort = 5432 - dataSize = 4 - ) - - type icmpTypeTest struct { - name string - val uint8 - expectResponse bool - } - - type transportTypeTest struct { - name string - proto tcpip.TransportProtocolNumber - buf buffer.View - checkNATed func(*testing.T, buffer.View) - } - - ipHdr := func(v buffer.View, totalLen int, transProto tcpip.TransportProtocolNumber, srcAddr, dstAddr tcpip.Address) { - ip := header.IPv4(v) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - Protocol: uint8(transProto), - TTL: 64, - SrcAddr: srcAddr, - DstAddr: dstAddr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - } - - ip6Hdr := func(v buffer.View, payloadLen int, transProto tcpip.TransportProtocolNumber, srcAddr, dstAddr tcpip.Address) { - ip := header.IPv6(v) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLen), - TransportProtocol: transProto, - HopLimit: 64, - SrcAddr: srcAddr, - DstAddr: dstAddr, - }) - } - - tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - host1Addr tcpip.Address - icmpError func(*testing.T, buffer.View, uint8) buffer.View - decrementTTL func(buffer.View) - checkNATedError func(*testing.T, buffer.View, buffer.View, uint8) - - transportTypes []transportTypeTest - icmpTypes []icmpTypeTest - }{ - { - name: "IPv4", - netProto: ipv4.ProtocolNumber, - host1Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address, - icmpError: func(t *testing.T, original buffer.View, icmpType uint8) buffer.View { - hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(original)) - if n := copy(hdr.Prepend(len(original)), original); n != len(original) { - t.Fatalf("got copy(...) = %d, want = %d", n, len(original)) - } - icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - icmp.SetType(header.ICMPv4Type(icmpType)) - icmp.SetChecksum(0) - icmp.SetChecksum(header.ICMPv4Checksum(icmp, 0)) - ipHdr( - hdr.Prepend(header.IPv4MinimumSize), - hdr.UsedLength(), - header.ICMPv4ProtocolNumber, - utils.Host1IPv4Addr.AddressWithPrefix.Address, - utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - ) - return hdr.View() - }, - decrementTTL: func(v buffer.View) { - ip := header.IPv4(v) - ip.SetTTL(ip.TTL() - 1) - ip.SetChecksum(0) - ip.SetChecksum(^ip.CalculateChecksum()) - }, - checkNATedError: func(t *testing.T, v buffer.View, original buffer.View, icmpType uint8) { - checker.IPv4(t, v, - checker.SrcAddr(utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address), - checker.DstAddr(utils.Host2IPv4Addr.AddressWithPrefix.Address), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4Type(icmpType)), - checker.ICMPv4Checksum(), - checker.ICMPv4Payload(original), - ), - ) - }, - transportTypes: []transportTypeTest{ - { - name: "UDP", - proto: header.UDPProtocolNumber, - buf: func() buffer.View { - udpSize := header.UDPMinimumSize + dataSize - hdr := buffer.NewPrependable(header.IPv4MinimumSize + udpSize) - udp := header.UDP(hdr.Prepend(udpSize)) - udp.SetSourcePort(srcPort) - udp.SetDestinationPort(dstPort) - udp.SetChecksum(0) - udp.SetChecksum(^udp.CalculateChecksum(header.PseudoHeaderChecksum( - header.UDPProtocolNumber, - utils.Host2IPv4Addr.AddressWithPrefix.Address, - utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, - uint16(len(udp)), - ))) - ipHdr( - hdr.Prepend(header.IPv4MinimumSize), - hdr.UsedLength(), - header.UDPProtocolNumber, - utils.Host2IPv4Addr.AddressWithPrefix.Address, - utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, - ) - return hdr.View() - }(), - checkNATed: func(t *testing.T, v buffer.View) { - checker.IPv4(t, v, - checker.SrcAddr(utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address), - checker.DstAddr(utils.Host1IPv4Addr.AddressWithPrefix.Address), - checker.UDP( - checker.SrcPort(srcPort), - checker.DstPort(dstPort), - ), - ) - }, - }, - { - name: "TCP", - proto: header.TCPProtocolNumber, - buf: func() buffer.View { - tcpSize := header.TCPMinimumSize + dataSize - hdr := buffer.NewPrependable(header.IPv4MinimumSize + tcpSize) - tcp := header.TCP(hdr.Prepend(tcpSize)) - tcp.SetSourcePort(srcPort) - tcp.SetDestinationPort(dstPort) - tcp.SetDataOffset(header.TCPMinimumSize) - tcp.SetChecksum(0) - tcp.SetChecksum(^tcp.CalculateChecksum(header.PseudoHeaderChecksum( - header.TCPProtocolNumber, - utils.Host2IPv4Addr.AddressWithPrefix.Address, - utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, - uint16(len(tcp)), - ))) - ipHdr( - hdr.Prepend(header.IPv4MinimumSize), - hdr.UsedLength(), - header.TCPProtocolNumber, - utils.Host2IPv4Addr.AddressWithPrefix.Address, - utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, - ) - return hdr.View() - }(), - checkNATed: func(t *testing.T, v buffer.View) { - checker.IPv4(t, v, - checker.SrcAddr(utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address), - checker.DstAddr(utils.Host1IPv4Addr.AddressWithPrefix.Address), - checker.TCP( - checker.SrcPort(srcPort), - checker.DstPort(dstPort), - ), - ) - }, - }, - }, - icmpTypes: []icmpTypeTest{ - { - name: "Destination Unreachable", - val: uint8(header.ICMPv4DstUnreachable), - expectResponse: true, - }, - { - name: "Time Exceeded", - val: uint8(header.ICMPv4TimeExceeded), - expectResponse: true, - }, - { - name: "Parameter Problem", - val: uint8(header.ICMPv4ParamProblem), - expectResponse: true, - }, - { - name: "Echo Request", - val: uint8(header.ICMPv4Echo), - expectResponse: false, - }, - { - name: "Echo Reply", - val: uint8(header.ICMPv4EchoReply), - expectResponse: false, - }, - }, - }, - { - name: "IPv6", - netProto: ipv6.ProtocolNumber, - host1Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address, - icmpError: func(t *testing.T, original buffer.View, icmpType uint8) buffer.View { - payloadLen := header.ICMPv6MinimumSize + len(original) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen) - icmp := header.ICMPv6(hdr.Prepend(payloadLen)) - icmp.SetType(header.ICMPv6Type(icmpType)) - if n := copy(icmp.Payload(), original); n != len(original) { - t.Fatalf("got copy(...) = %d, want = %d", n, len(original)) - } - icmp.SetChecksum(0) - icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmp, - Src: utils.Host1IPv6Addr.AddressWithPrefix.Address, - Dst: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - })) - ip6Hdr( - hdr.Prepend(header.IPv6MinimumSize), - payloadLen, - header.ICMPv6ProtocolNumber, - utils.Host1IPv6Addr.AddressWithPrefix.Address, - utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - ) - return hdr.View() - }, - decrementTTL: func(v buffer.View) { - ip := header.IPv6(v) - ip.SetHopLimit(ip.HopLimit() - 1) - }, - checkNATedError: func(t *testing.T, v buffer.View, original buffer.View, icmpType uint8) { - checker.IPv6(t, v, - checker.SrcAddr(utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address), - checker.DstAddr(utils.Host2IPv6Addr.AddressWithPrefix.Address), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6Type(icmpType)), - checker.ICMPv6Payload(original), - ), - ) - }, - transportTypes: []transportTypeTest{ - { - name: "UDP", - proto: header.UDPProtocolNumber, - buf: func() buffer.View { - udpSize := header.UDPMinimumSize + dataSize - hdr := buffer.NewPrependable(header.IPv6MinimumSize + udpSize) - udp := header.UDP(hdr.Prepend(udpSize)) - udp.SetSourcePort(srcPort) - udp.SetDestinationPort(dstPort) - udp.SetChecksum(0) - udp.SetChecksum(^udp.CalculateChecksum(header.PseudoHeaderChecksum( - header.UDPProtocolNumber, - utils.Host2IPv6Addr.AddressWithPrefix.Address, - utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, - uint16(len(udp)), - ))) - ip6Hdr( - hdr.Prepend(header.IPv6MinimumSize), - len(udp), - header.UDPProtocolNumber, - utils.Host2IPv6Addr.AddressWithPrefix.Address, - utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, - ) - return hdr.View() - }(), - checkNATed: func(t *testing.T, v buffer.View) { - checker.IPv6(t, v, - checker.SrcAddr(utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address), - checker.DstAddr(utils.Host1IPv6Addr.AddressWithPrefix.Address), - checker.UDP( - checker.SrcPort(srcPort), - checker.DstPort(dstPort), - ), - ) - }, - }, - { - name: "TCP", - proto: header.TCPProtocolNumber, - buf: func() buffer.View { - tcpSize := header.TCPMinimumSize + dataSize - hdr := buffer.NewPrependable(header.IPv6MinimumSize + tcpSize) - tcp := header.TCP(hdr.Prepend(tcpSize)) - tcp.SetSourcePort(srcPort) - tcp.SetDestinationPort(dstPort) - tcp.SetDataOffset(header.TCPMinimumSize) - tcp.SetChecksum(0) - tcp.SetChecksum(^tcp.CalculateChecksum(header.PseudoHeaderChecksum( - header.TCPProtocolNumber, - utils.Host2IPv6Addr.AddressWithPrefix.Address, - utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, - uint16(len(tcp)), - ))) - ip6Hdr( - hdr.Prepend(header.IPv6MinimumSize), - len(tcp), - header.TCPProtocolNumber, - utils.Host2IPv6Addr.AddressWithPrefix.Address, - utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, - ) - return hdr.View() - }(), - checkNATed: func(t *testing.T, v buffer.View) { - checker.IPv6(t, v, - checker.SrcAddr(utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address), - checker.DstAddr(utils.Host1IPv6Addr.AddressWithPrefix.Address), - checker.TCP( - checker.SrcPort(srcPort), - checker.DstPort(dstPort), - ), - ) - }, - }, - }, - icmpTypes: []icmpTypeTest{ - { - name: "Destination Unreachable", - val: uint8(header.ICMPv6DstUnreachable), - expectResponse: true, - }, - { - name: "Packet Too Big", - val: uint8(header.ICMPv6PacketTooBig), - expectResponse: true, - }, - { - name: "Time Exceeded", - val: uint8(header.ICMPv6TimeExceeded), - expectResponse: true, - }, - { - name: "Parameter Problem", - val: uint8(header.ICMPv6ParamProblem), - expectResponse: true, - }, - { - name: "Echo Request", - val: uint8(header.ICMPv6EchoRequest), - expectResponse: false, - }, - { - name: "Echo Reply", - val: uint8(header.ICMPv6EchoReply), - expectResponse: false, - }, - }, - }, - } - - trimTests := []struct { - name string - trimLen int - expectNATedICMP bool - }{ - { - name: "Trim nothing", - trimLen: 0, - expectNATedICMP: true, - }, - { - name: "Trim data", - trimLen: dataSize, - expectNATedICMP: true, - }, - { - name: "Trim data and transport header", - trimLen: dataSize + 1, - expectNATedICMP: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, transportType := range test.transportTypes { - t.Run(transportType.name, func(t *testing.T) { - for _, icmpType := range test.icmpTypes { - t.Run(icmpType.name, func(t *testing.T) { - for _, trimTest := range trimTests { - t.Run(trimTest.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, - }) - - ep1 := channel.New(1, header.IPv6MinimumMTU, "") - ep2 := channel.New(1, header.IPv6MinimumMTU, "") - utils.SetupRouterStack(t, s, ep1, ep2) - - ipv6 := test.netProto == ipv6.ProtocolNumber - ipt := s.IPTables() - - table := stack.Table{ - Rules: []stack.Rule{ - // Prerouting - { - Filter: stack.IPHeaderFilter{ - Protocol: transportType.proto, - CheckProtocol: true, - InputInterface: utils.RouterNIC2Name, - }, - Target: &stack.DNATTarget{NetworkProtocol: test.netProto, Addr: test.host1Addr, Port: dstPort}, - }, - { - Target: &stack.AcceptTarget{}, - }, - - // Input - { - Target: &stack.AcceptTarget{}, - }, - - // Forward - { - Target: &stack.AcceptTarget{}, - }, - - // Output - { - Target: &stack.AcceptTarget{}, - }, - - // Postrouting - { - Filter: stack.IPHeaderFilter{ - Protocol: transportType.proto, - CheckProtocol: true, - OutputInterface: utils.RouterNIC1Name, - }, - Target: &stack.MasqueradeTarget{NetworkProtocol: test.netProto}, - }, - { - Target: &stack.AcceptTarget{}, - }, - }, - BuiltinChains: [stack.NumHooks]int{ - stack.Prerouting: 0, - stack.Input: 2, - stack.Forward: 3, - stack.Output: 4, - stack.Postrouting: 5, - }, - } - - if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err) - } - - buf := transportType.buf - - ep2.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: append(buffer.View(nil), buf...).ToVectorisedView(), - })) - - { - pkt, ok := ep1.Read() - if !ok { - t.Fatal("expected to read a packet on ep1") - } - pktView := stack.PayloadSince(pkt.Pkt.NetworkHeader()) - transportType.checkNATed(t, pktView) - if t.Failed() { - t.FailNow() - } - - pktView = pktView[:len(pktView)-trimTest.trimLen] - buf = buf[:len(buf)-trimTest.trimLen] - - ep1.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: test.icmpError(t, pktView, icmpType.val).ToVectorisedView(), - })) - } - - pkt, ok := ep2.Read() - expectResponse := icmpType.expectResponse && trimTest.expectNATedICMP - if ok != expectResponse { - t.Fatalf("got ep2.Read() = (%#v, %t), want = (_, %t)", pkt, ok, expectResponse) - } - if !expectResponse { - return - } - test.decrementTTL(buf) - test.checkNATedError(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), buf, icmpType.val) - }) - } - }) - } - }) - } - }) - } -} diff --git a/pkg/tcpip/tests/integration/istio_test.go b/pkg/tcpip/tests/integration/istio_test.go deleted file mode 100644 index 95d994ef8..000000000 --- a/pkg/tcpip/tests/integration/istio_test.go +++ /dev/null @@ -1,365 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package istio_test - -import ( - "fmt" - "io" - "net" - "net/http" - "strconv" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/rand" - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/link/pipe" - "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" -) - -// testContext encapsulates the state required to run tests that simulate -// an istio like environment. -// -// A diagram depicting the setup is shown below. -// +-----------------------------------------------------------------------+ -// | +-------------------------------------------------+ | -// | + ----------+ | + -----------------+ PROXY +----------+ | | -// | | clientEP | | | serverListeningEP|--accepted-> | serverEP |-+ | | -// | + ----------+ | + -----------------+ +----------+ | | | -// | | -------|-------------+ +----------+ | | | -// | | | | | proxyEP |-+ | | -// | +-----redirect | +----------+ | | -// | + ------------+---|------+---+ | -// | | | -// | Local Stack. | | -// +-------------------------------------------------------|---------------+ -// | -// +-----------------------------------------------------------------------+ -// | remoteStack | | -// | +-------------SYN ---------------| | -// | | | | -// | +-------------------|--------------------------------|-_---+ | -// | | + -----------------+ + ----------+ | | | -// | | | remoteListeningEP|--accepted--->| remoteEP |<++ | | -// | | + -----------------+ + ----------+ | | -// | | Remote HTTP Server | | -// | +----------------------------------------------------------+ | -// +-----------------------------------------------------------------------+ -// -type testContext struct { - // localServerListener is the listening port for the server which will proxy - // all traffic to the remote EP. - localServerListener *gonet.TCPListener - - // remoteListenListener is the remote listening endpoint that will receive - // connections from server. - remoteServerListener *gonet.TCPListener - - // localStack is the stack used to create client/server endpoints and - // also the stack on which we install NAT redirect rules. - localStack *stack.Stack - - // remoteStack is the stack that represents a *remote* server. - remoteStack *stack.Stack - - // defaultResponse is the response served by the HTTP server for all GET - defaultResponse []byte - - // requests. wg is used to wait for HTTP server and Proxy to terminate before - // returning from cleanup. - wg sync.WaitGroup -} - -func (ctx *testContext) cleanup() { - ctx.localServerListener.Close() - ctx.localStack.Close() - ctx.remoteServerListener.Close() - ctx.remoteStack.Close() - ctx.wg.Wait() -} - -const ( - localServerPort = 8080 - remoteServerPort = 9090 -) - -var ( - localIPv4Addr1 = testutil.MustParse4("10.0.0.1") - localIPv4Addr2 = testutil.MustParse4("10.0.0.2") - loopbackIPv4Addr = testutil.MustParse4("127.0.0.1") - remoteIPv4Addr1 = testutil.MustParse4("10.0.0.3") -) - -func newTestContext(t *testing.T) *testContext { - t.Helper() - localNIC, remoteNIC := pipe.New("" /* linkAddr1 */, "" /* linkAddr2 */) - - localStack := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - HandleLocal: true, - }) - - remoteStack := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - HandleLocal: true, - }) - - // Add loopback NIC. We need a loopback NIC as NAT redirect rule redirect to - // loopback address + specified port. - loopbackNIC := loopback.New() - const loopbackNICID = tcpip.NICID(1) - if err := localStack.CreateNIC(loopbackNICID, sniffer.New(loopbackNIC)); err != nil { - t.Fatalf("localStack.CreateNIC(%d, _): %s", loopbackNICID, err) - } - loopbackAddr := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: loopbackIPv4Addr.WithPrefix(), - } - if err := localStack.AddProtocolAddress(loopbackNICID, loopbackAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("localStack.AddProtocolAddress(%d, %+v, {}): %s", loopbackNICID, loopbackAddr, err) - } - - // Create linked NICs that connects the local and remote stack. - const localNICID = tcpip.NICID(2) - const remoteNICID = tcpip.NICID(3) - if err := localStack.CreateNIC(localNICID, sniffer.New(localNIC)); err != nil { - t.Fatalf("localStack.CreateNIC(%d, _): %s", localNICID, err) - } - if err := remoteStack.CreateNIC(remoteNICID, sniffer.New(remoteNIC)); err != nil { - t.Fatalf("remoteStack.CreateNIC(%d, _): %s", remoteNICID, err) - } - - for _, addr := range []tcpip.Address{localIPv4Addr1, localIPv4Addr2} { - localProtocolAddr := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: addr.WithPrefix(), - } - if err := localStack.AddProtocolAddress(localNICID, localProtocolAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("localStack.AddProtocolAddress(%d, %+v, {}): %s", localNICID, localProtocolAddr, err) - } - } - - remoteProtocolAddr := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: remoteIPv4Addr1.WithPrefix(), - } - if err := remoteStack.AddProtocolAddress(remoteNICID, remoteProtocolAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("remoteStack.AddProtocolAddress(%d, %+v, {}): %s", remoteNICID, remoteProtocolAddr, err) - } - - // Setup route table for local and remote stacks. - localStack.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4LoopbackSubnet, - NIC: loopbackNICID, - }, - { - Destination: header.IPv4EmptySubnet, - NIC: localNICID, - }, - }) - remoteStack.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: remoteNICID, - }, - }) - - const netProto = ipv4.ProtocolNumber - localServerAddress := tcpip.FullAddress{ - Port: localServerPort, - } - - localServerListener, err := gonet.ListenTCP(localStack, localServerAddress, netProto) - if err != nil { - t.Fatalf("gonet.ListenTCP(_, %+v, %d) = %s", localServerAddress, netProto, err) - } - - remoteServerAddress := tcpip.FullAddress{ - Port: remoteServerPort, - } - remoteServerListener, err := gonet.ListenTCP(remoteStack, remoteServerAddress, netProto) - if err != nil { - t.Fatalf("gonet.ListenTCP(_, %+v, %d) = %s", remoteServerAddress, netProto, err) - } - - // Initialize a random default response served by the HTTP server. - defaultResponse := make([]byte, 512<<10) - if _, err := rand.Read(defaultResponse); err != nil { - t.Fatalf("rand.Read(buf) failed: %s", err) - } - - tc := &testContext{ - localServerListener: localServerListener, - remoteServerListener: remoteServerListener, - localStack: localStack, - remoteStack: remoteStack, - defaultResponse: defaultResponse, - } - - tc.startServers(t) - return tc -} - -func (ctx *testContext) startServers(t *testing.T) { - ctx.wg.Add(1) - go func() { - defer ctx.wg.Done() - ctx.startHTTPServer() - }() - ctx.wg.Add(1) - go func() { - defer ctx.wg.Done() - ctx.startTCPProxyServer(t) - }() -} - -func (ctx *testContext) startTCPProxyServer(t *testing.T) { - t.Helper() - for { - conn, err := ctx.localServerListener.Accept() - if err != nil { - t.Logf("terminating local proxy server: %s", err) - return - } - // Start a goroutine to handle this inbound connection. - go func() { - remoteServerAddr := tcpip.FullAddress{ - Addr: remoteIPv4Addr1, - Port: remoteServerPort, - } - localServerAddr := tcpip.FullAddress{ - Addr: localIPv4Addr2, - } - serverConn, err := gonet.DialTCPWithBind(context.Background(), ctx.localStack, localServerAddr, remoteServerAddr, ipv4.ProtocolNumber) - if err != nil { - t.Logf("gonet.DialTCP(_, %+v, %d) = %s", remoteServerAddr, ipv4.ProtocolNumber, err) - return - } - proxy(conn, serverConn) - t.Logf("proxying completed") - }() - } -} - -// proxy transparently proxies the TCP payload from conn1 to conn2 -// and vice versa. -func proxy(conn1, conn2 net.Conn) { - var wg sync.WaitGroup - wg.Add(1) - go func() { - io.Copy(conn2, conn1) - conn1.Close() - conn2.Close() - }() - wg.Add(1) - go func() { - io.Copy(conn1, conn2) - conn1.Close() - conn2.Close() - }() - wg.Wait() -} - -func (ctx *testContext) startHTTPServer() { - handlerFunc := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(ctx.defaultResponse)) - }) - s := &http.Server{ - Handler: handlerFunc, - } - s.Serve(ctx.remoteServerListener) -} - -func TestOutboundNATRedirect(t *testing.T) { - ctx := newTestContext(t) - defer ctx.cleanup() - - // Install an IPTable rule to redirect all TCP traffic with the sourceIP of - // localIPv4Addr1 to the tcp proxy port. - ipt := ctx.localStack.IPTables() - tbl := ipt.GetTable(stack.NATID, false /* ipv6 */) - ruleIdx := tbl.BuiltinChains[stack.Output] - tbl.Rules[ruleIdx].Filter = stack.IPHeaderFilter{ - Protocol: tcp.ProtocolNumber, - CheckProtocol: true, - Src: localIPv4Addr1, - SrcMask: tcpip.Address("\xff\xff\xff\xff"), - } - tbl.Rules[ruleIdx].Target = &stack.RedirectTarget{ - Port: localServerPort, - NetworkProtocol: ipv4.ProtocolNumber, - } - tbl.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.NATID, tbl, false /* ipv6 */); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, false): %s", stack.NATID, err) - } - - dialFunc := func(protocol, address string) (net.Conn, error) { - host, port, err := net.SplitHostPort(address) - if err != nil { - return nil, fmt.Errorf("unable to parse address: %s, err: %s", address, err) - } - - remoteServerIP := net.ParseIP(host) - remoteServerPort, err := strconv.Atoi(port) - if err != nil { - return nil, fmt.Errorf("unable to parse port from string %s, err: %s", port, err) - } - remoteAddress := tcpip.FullAddress{ - Addr: tcpip.Address(remoteServerIP.To4()), - Port: uint16(remoteServerPort), - } - - // Dial with an explicit source address bound so that the redirect rule will - // be able to correctly redirect these packets. - localAddr := tcpip.FullAddress{Addr: localIPv4Addr1} - return gonet.DialTCPWithBind(context.Background(), ctx.localStack, localAddr, remoteAddress, ipv4.ProtocolNumber) - } - - httpClient := &http.Client{ - Transport: &http.Transport{ - Dial: dialFunc, - }, - } - - serverURL := fmt.Sprintf("http://[%s]:%d/", net.IP(remoteIPv4Addr1), remoteServerPort) - response, err := httpClient.Get(serverURL) - if err != nil { - t.Fatalf("httpClient.Get(\"/\") failed: %s", err) - } - if got, want := response.StatusCode, http.StatusOK; got != want { - t.Fatalf("unexpected status code got: %d, want: %d", got, want) - } - body, err := io.ReadAll(response.Body) - if err != nil { - t.Fatalf("io.ReadAll(response.Body) failed: %s", err) - } - response.Body.Close() - if diff := cmp.Diff(body, ctx.defaultResponse); diff != "" { - t.Fatalf("unexpected response (-want +got): \n %s", diff) - } -} diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go deleted file mode 100644 index 95ddd8ec3..000000000 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ /dev/null @@ -1,1640 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package link_resolution_test - -import ( - "bytes" - "fmt" - "net" - "runtime" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/pipe" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/tests/utils" - tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -func setupStack(t *testing.T, stackOpts stack.Options, host1NICID, host2NICID tcpip.NICID) (*stack.Stack, *stack.Stack) { - host1Stack := stack.New(stackOpts) - host2Stack := stack.New(stackOpts) - - host1NIC, host2NIC := pipe.New(utils.LinkAddr1, utils.LinkAddr2) - - if err := host1Stack.CreateNIC(host1NICID, utils.NewEthernetEndpoint(host1NIC)); err != nil { - t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) - } - if err := host2Stack.CreateNIC(host2NICID, utils.NewEthernetEndpoint(host2NIC)); err != nil { - t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) - } - - if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", host1NICID, utils.Ipv4Addr1, err) - } - if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", host2NICID, utils.Ipv4Addr2, err) - } - if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", host1NICID, utils.Ipv6Addr1, err) - } - if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", host2NICID, utils.Ipv6Addr2, err) - } - - host1Stack.SetRouteTable([]tcpip.Route{ - { - Destination: utils.Ipv4Addr1.AddressWithPrefix.Subnet(), - NIC: host1NICID, - }, - { - Destination: utils.Ipv6Addr1.AddressWithPrefix.Subnet(), - NIC: host1NICID, - }, - }) - host2Stack.SetRouteTable([]tcpip.Route{ - { - Destination: utils.Ipv4Addr2.AddressWithPrefix.Subnet(), - NIC: host2NICID, - }, - { - Destination: utils.Ipv6Addr2.AddressWithPrefix.Subnet(), - NIC: host2NICID, - }, - }) - - return host1Stack, host2Stack -} - -// TestPing tests that two hosts can ping eachother when link resolution is -// enabled. -func TestPing(t *testing.T) { - const ( - host1NICID = 1 - host2NICID = 4 - - // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo - // request/reply packets. - icmpDataOffset = 8 - ) - - tests := []struct { - name string - transProto tcpip.TransportProtocolNumber - netProto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - icmpBuf func(*testing.T) []byte - }{ - { - name: "IPv4 Ping", - transProto: icmp.ProtocolNumber4, - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - icmpBuf: func(t *testing.T) []byte { - data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} - hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data))) - hdr.SetType(header.ICMPv4Echo) - if n := copy(hdr.Payload(), data[:]); n != len(data) { - t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) - } - return hdr - }, - }, - { - name: "IPv6 Ping", - transProto: icmp.ProtocolNumber6, - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - icmpBuf: func(t *testing.T) []byte { - data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} - hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data))) - hdr.SetType(header.ICMPv6EchoRequest) - if n := copy(hdr.Payload(), data[:]); n != len(data) { - t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) - } - return hdr - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, - } - - host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) - - var wq waiter.Queue - we, waiterCH := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - ep, err := host1Stack.NewEndpoint(test.transProto, test.netProto, &wq) - if err != nil { - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err) - } - defer ep.Close() - - icmpBuf := test.icmpBuf(t) - var r bytes.Reader - r.Reset(icmpBuf) - wOpts := tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: test.remoteAddr}} - if n, err := ep.Write(&r, wOpts); err != nil { - t.Fatalf("ep.Write(_, _): %s", err) - } else if want := int64(len(icmpBuf)); n != want { - t.Fatalf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want) - } - - // Wait for the endpoint to be readable. - <-waiterCH - - var buf bytes.Buffer - opts := tcpip.ReadOptions{NeedRemoteAddr: true} - res, err := ep.Read(&buf, opts) - if err != nil { - t.Fatalf("ep.Read(_, %d, %#v): %s", len(icmpBuf), opts, err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - RemoteAddr: tcpip.FullAddress{Addr: test.remoteAddr}, - }, res, checker.IgnoreCmpPath( - "ControlMessages", - "RemoteAddr.NIC", - "RemoteAddr.Port", - )); diff != "" { - t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(buf.Bytes()[icmpDataOffset:], icmpBuf[icmpDataOffset:]); diff != "" { - t.Errorf("received data mismatch (-want +got):\n%s", diff) - } - }) - } -} - -type transportError struct { - origin tcpip.SockErrOrigin - typ uint8 - code uint8 - info uint32 - kind stack.TransportErrorKind -} - -func TestTCPLinkResolutionFailure(t *testing.T) { - const ( - host1NICID = 1 - host2NICID = 4 - ) - - tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - expectedWriteErr tcpip.Error - sockError tcpip.SockError - transErr transportError - }{ - { - name: "IPv4 with resolvable remote", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - expectedWriteErr: nil, - }, - { - name: "IPv6 with resolvable remote", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - expectedWriteErr: nil, - }, - { - name: "IPv4 without resolvable remote", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address, - expectedWriteErr: &tcpip.ErrNoRoute{}, - sockError: tcpip.SockError{ - Err: &tcpip.ErrNoRoute{}, - Dst: tcpip.FullAddress{ - NIC: host1NICID, - Addr: utils.Ipv4Addr3.AddressWithPrefix.Address, - Port: 1234, - }, - Offender: tcpip.FullAddress{ - NIC: host1NICID, - Addr: utils.Ipv4Addr1.AddressWithPrefix.Address, - }, - NetProto: ipv4.ProtocolNumber, - }, - transErr: transportError{ - origin: tcpip.SockExtErrorOriginICMP, - typ: uint8(header.ICMPv4DstUnreachable), - code: uint8(header.ICMPv4HostUnreachable), - kind: stack.DestinationHostUnreachableTransportError, - }, - }, - { - name: "IPv6 without resolvable remote", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address, - expectedWriteErr: &tcpip.ErrNoRoute{}, - sockError: tcpip.SockError{ - Err: &tcpip.ErrNoRoute{}, - Dst: tcpip.FullAddress{ - NIC: host1NICID, - Addr: utils.Ipv6Addr3.AddressWithPrefix.Address, - Port: 1234, - }, - Offender: tcpip.FullAddress{ - NIC: host1NICID, - Addr: utils.Ipv6Addr1.AddressWithPrefix.Address, - }, - NetProto: ipv6.ProtocolNumber, - }, - transErr: transportError{ - origin: tcpip.SockExtErrorOriginICMP6, - typ: uint8(header.ICMPv6DstUnreachable), - code: uint8(header.ICMPv6AddressUnreachable), - kind: stack.DestinationHostUnreachableTransportError, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - Clock: clock, - } - - host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID) - - var listenerWQ waiter.Queue - listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, test.netProto, &listenerWQ) - if err != nil { - t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, test.netProto, err) - } - defer listenerEP.Close() - - listenerAddr := tcpip.FullAddress{Port: 1234} - if err := listenerEP.Bind(listenerAddr); err != nil { - t.Fatalf("listenerEP.Bind(%#v): %s", listenerAddr, err) - } - - if err := listenerEP.Listen(1); err != nil { - t.Fatalf("listenerEP.Listen(1): %s", err) - } - - var clientWQ waiter.Queue - we, ch := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&we, waiter.WritableEvents|waiter.EventErr) - clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, test.netProto, &clientWQ) - if err != nil { - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, test.netProto, err) - } - defer clientEP.Close() - - sockOpts := clientEP.SocketOptions() - sockOpts.SetRecvError(true) - - remoteAddr := listenerAddr - remoteAddr.Addr = test.remoteAddr - { - err := clientEP.Connect(remoteAddr) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", remoteAddr, err, &tcpip.ErrConnectStarted{}) - } - } - - // Wait for an error due to link resolution failing, or the endpoint to be - // writable. - if test.expectedWriteErr != nil { - nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto) - if err != nil { - t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err) - } - clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer) - } else { - clock.RunImmediatelyScheduledJobs() - } - <-ch - - { - var r bytes.Reader - r.Reset([]byte{0}) - var wOpts tcpip.WriteOptions - _, err := clientEP.Write(&r, wOpts) - if diff := cmp.Diff(test.expectedWriteErr, err); diff != "" { - t.Errorf("unexpected error from clientEP.Write(_, %#v), (-want, +got):\n%s", wOpts, diff) - } - } - - if test.expectedWriteErr == nil { - return - } - - sockErr := sockOpts.DequeueErr() - if sockErr == nil { - t.Fatalf("got sockOpts.DequeueErr() = nil, want = non-nil") - } - - sockErrCmpOpts := []cmp.Option{ - cmpopts.IgnoreUnexported(tcpip.SockError{}), - cmp.Comparer(func(a, b tcpip.Error) bool { - // tcpip.Error holds an unexported field but the errors netstack uses - // are pre defined so we can simply compare pointers. - return a == b - }), - checker.IgnoreCmpPath( - // Ignore the payload since we do not know the TCP seq/ack numbers. - "Payload", - // Ignore the cause since we will compare its properties separately - // since the concrete type of the cause is unknown. - "Cause", - ), - } - - if addr, err := clientEP.GetLocalAddress(); err != nil { - t.Fatalf("clientEP.GetLocalAddress(): %s", err) - } else { - test.sockError.Offender.Port = addr.Port - } - if diff := cmp.Diff(&test.sockError, sockErr, sockErrCmpOpts...); diff != "" { - t.Errorf("socket error mismatch (-want +got):\n%s", diff) - } - - transErr, ok := sockErr.Cause.(stack.TransportError) - if !ok { - t.Fatalf("socket error cause is not a transport error; cause = %#v", sockErr.Cause) - } - if diff := cmp.Diff( - test.transErr, - transportError{ - origin: transErr.Origin(), - typ: transErr.Type(), - code: transErr.Code(), - info: transErr.Info(), - kind: transErr.Kind(), - }, - cmp.AllowUnexported(transportError{}), - ); diff != "" { - t.Errorf("socket error mismatch (-want +got):\n%s", diff) - } - }) - } -} - -func TestForwardingWithLinkResolutionFailure(t *testing.T) { - const ( - incomingNICID = 1 - outgoingNICID = 2 - ttl = 2 - expectedHostUnreachableErrorCount = 1 - ) - outgoingLinkAddr := tcptestutil.MustParseLink("02:03:03:04:05:06") - - rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { - utils.RxICMPv4EchoRequest(e, src, dst, ttl) - } - - rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { - utils.RxICMPv6EchoRequest(e, src, dst, ttl) - } - - arpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) { - if request.Proto != arp.ProtocolNumber { - t.Errorf("got request.Proto = %d, want = %d", request.Proto, arp.ProtocolNumber) - } - if request.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { - t.Errorf("got request.Route.RemoteLinkAddress = %s, want = %s", request.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) - } - rep := header.ARP(request.Pkt.NetworkHeader().View()) - if got := rep.Op(); got != header.ARPRequest { - t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest) - } - if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != outgoingLinkAddr { - t.Errorf("got HardwareAddressSender = %s, want = %s", got, outgoingLinkAddr) - } - if got := tcpip.Address(rep.ProtocolAddressSender()); got != src { - t.Errorf("got ProtocolAddressSender = %s, want = %s", got, src) - } - if got := tcpip.Address(rep.ProtocolAddressTarget()); got != dst { - t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, dst) - } - } - - ndpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) { - if request.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", request.Proto, header.IPv6ProtocolNumber) - } - - snmc := header.SolicitedNodeAddr(dst) - if want := header.EthernetAddressFromMulticastIPv6Address(snmc); request.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", request.Route.RemoteLinkAddress, want) - } - - checker.IPv6(t, stack.PayloadSince(request.Pkt.NetworkHeader()), - checker.SrcAddr(src), - checker.DstAddr(snmc), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(dst), - )) - } - - icmpv4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { - checker.IPv4(t, b, - checker.SrcAddr(src), - checker.DstAddr(dst), - checker.TTL(ipv4.DefaultTTL), - checker.ICMPv4( - checker.ICMPv4Checksum(), - checker.ICMPv4Type(header.ICMPv4DstUnreachable), - checker.ICMPv4Code(header.ICMPv4HostUnreachable), - ), - ) - } - - icmpv6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { - checker.IPv6(t, b, - checker.SrcAddr(src), - checker.DstAddr(dst), - checker.TTL(ipv6.DefaultTTL), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6DstUnreachable), - checker.ICMPv6Code(header.ICMPv6AddressUnreachable), - ), - ) - } - - tests := []struct { - name string - networkProtocolFactory []stack.NetworkProtocolFactory - networkProtocolNumber tcpip.NetworkProtocolNumber - sourceAddr tcpip.Address - destAddr tcpip.Address - incomingAddr tcpip.AddressWithPrefix - outgoingAddr tcpip.AddressWithPrefix - transportProtocol func(*stack.Stack) stack.TransportProtocol - rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) - linkResolutionRequestChecker func(*testing.T, channel.PacketInfo, tcpip.Address, tcpip.Address) - icmpReplyChecker func(*testing.T, []byte, tcpip.Address, tcpip.Address) - mtu uint32 - }{ - { - name: "IPv4 Host unreachable", - networkProtocolFactory: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, - networkProtocolNumber: header.IPv4ProtocolNumber, - sourceAddr: tcptestutil.MustParse4("10.0.0.2"), - destAddr: tcptestutil.MustParse4("11.0.0.2"), - incomingAddr: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), - PrefixLen: 8, - }, - outgoingAddr: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()), - PrefixLen: 8, - }, - transportProtocol: icmp.NewProtocol4, - linkResolutionRequestChecker: arpChecker, - icmpReplyChecker: icmpv4Checker, - rx: rxICMPv4EchoRequest, - mtu: ipv4.MaxTotalSize, - }, - { - name: "IPv6 Host unreachable", - networkProtocolFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - networkProtocolNumber: header.IPv6ProtocolNumber, - sourceAddr: tcptestutil.MustParse6("10::2"), - destAddr: tcptestutil.MustParse6("11::2"), - incomingAddr: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("10::1").To16()), - PrefixLen: 64, - }, - outgoingAddr: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("11::1").To16()), - PrefixLen: 64, - }, - transportProtocol: icmp.NewProtocol6, - linkResolutionRequestChecker: ndpChecker, - icmpReplyChecker: icmpv6Checker, - rx: rxICMPv6EchoRequest, - mtu: header.IPv6MinimumMTU, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - - s := stack.New(stack.Options{ - NetworkProtocols: test.networkProtocolFactory, - TransportProtocols: []stack.TransportProtocolFactory{test.transportProtocol}, - Clock: clock, - }) - - // Set up endpoint through which we will receive packets. - incomingEndpoint := channel.New(1, test.mtu, "") - if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) - } - incomingProtoAddr := tcpip.ProtocolAddress{ - Protocol: test.networkProtocolNumber, - AddressWithPrefix: test.incomingAddr, - } - if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingProtoAddr, err) - } - - // Set up endpoint through which we will attempt to forward packets. - outgoingEndpoint := channel.New(1, test.mtu, outgoingLinkAddr) - outgoingEndpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) - } - outgoingProtoAddr := tcpip.ProtocolAddress{ - Protocol: test.networkProtocolNumber, - AddressWithPrefix: test.outgoingAddr, - } - if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingProtoAddr, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: test.incomingAddr.Subnet(), - NIC: incomingNICID, - }, - { - Destination: test.outgoingAddr.Subnet(), - NIC: outgoingNICID, - }, - }) - - if err := s.SetForwardingDefaultAndAllNICs(test.networkProtocolNumber, true); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", test.networkProtocolNumber, err) - } - - test.rx(incomingEndpoint, test.sourceAddr, test.destAddr) - - nudConfigs, err := s.NUDConfigurations(outgoingNICID, test.networkProtocolNumber) - if err != nil { - t.Fatalf("s.NUDConfigurations(%d, %d): %s", outgoingNICID, test.networkProtocolNumber, err) - } - // Trigger the first packet on the endpoint. - clock.RunImmediatelyScheduledJobs() - - for i := 0; i < int(nudConfigs.MaxMulticastProbes); i++ { - request, ok := outgoingEndpoint.Read() - if !ok { - t.Fatal("expected ARP packet through outgoing NIC") - } - - test.linkResolutionRequestChecker(t, request, test.outgoingAddr.Address, test.destAddr) - - // Advance the clock the span of one request timeout. - clock.Advance(nudConfigs.RetransmitTimer) - } - - // Next, we make a blocking read to retrieve the error packet. This is - // necessary because outgoing packets are dequeued asynchronously when - // link resolution fails, and this dequeue is what triggers the ICMP - // error. - reply, ok := incomingEndpoint.Read() - if !ok { - t.Fatal("expected ICMP packet through incoming NIC") - } - - test.icmpReplyChecker(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), test.incomingAddr.Address, test.sourceAddr) - - // Since link resolution failed, we don't expect the packet to be - // forwarded. - forwardedPacket, ok := outgoingEndpoint.Read() - if ok { - t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", forwardedPacket) - } - - if got, want := s.Stats().IP.Forwarding.HostUnreachable.Value(), expectedHostUnreachableErrorCount; int(got) != want { - t.Errorf("got rt.Stats().IP.Forwarding.HostUnreachable.Value() = %d, want = %d", got, want) - } - }) - } -} - -func TestGetLinkAddress(t *testing.T) { - const ( - host1NICID = 1 - host2NICID = 4 - ) - - tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - remoteAddr, localAddr tcpip.Address - expectedErr tcpip.Error - }{ - { - name: "IPv4 resolvable", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - expectedErr: nil, - }, - { - name: "IPv6 resolvable", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - expectedErr: nil, - }, - { - name: "IPv4 not resolvable", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address, - expectedErr: &tcpip.ErrTimeout{}, - }, - { - name: "IPv6 not resolvable", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address, - expectedErr: &tcpip.ErrTimeout{}, - }, - { - name: "IPv4 bad local address", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - localAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - expectedErr: &tcpip.ErrBadLocalAddress{}, - }, - { - name: "IPv6 bad local address", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - localAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - expectedErr: &tcpip.ErrBadLocalAddress{}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - Clock: clock, - } - - host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) - - ch := make(chan stack.LinkResolutionResult, 1) - err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, test.localAddr, test.netProto, func(r stack.LinkResolutionResult) { - ch <- r - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, &tcpip.ErrWouldBlock{}) - } - wantRes := stack.LinkResolutionResult{Err: test.expectedErr} - if test.expectedErr == nil { - wantRes.LinkAddress = utils.LinkAddr2 - } - - nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto) - if err != nil { - t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err) - } - - clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer) - select { - case got := <-ch: - if diff := cmp.Diff(wantRes, got); diff != "" { - t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("event didn't arrive") - } - }) - } -} - -func TestRouteResolvedFields(t *testing.T) { - const ( - host1NICID = 1 - host2NICID = 4 - ) - - tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - localAddr tcpip.Address - remoteAddr tcpip.Address - immediatelyResolvable bool - expectedErr tcpip.Error - expectedLinkAddr tcpip.LinkAddress - }{ - { - name: "IPv4 immediately resolvable", - netProto: ipv4.ProtocolNumber, - localAddr: utils.Ipv4Addr1.AddressWithPrefix.Address, - remoteAddr: header.IPv4AllSystems, - immediatelyResolvable: true, - expectedErr: nil, - expectedLinkAddr: header.EthernetAddressFromMulticastIPv4Address(header.IPv4AllSystems), - }, - { - name: "IPv6 immediately resolvable", - netProto: ipv6.ProtocolNumber, - localAddr: utils.Ipv6Addr1.AddressWithPrefix.Address, - remoteAddr: header.IPv6AllNodesMulticastAddress, - immediatelyResolvable: true, - expectedErr: nil, - expectedLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress), - }, - { - name: "IPv4 resolvable", - netProto: ipv4.ProtocolNumber, - localAddr: utils.Ipv4Addr1.AddressWithPrefix.Address, - remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - immediatelyResolvable: false, - expectedErr: nil, - expectedLinkAddr: utils.LinkAddr2, - }, - { - name: "IPv6 resolvable", - netProto: ipv6.ProtocolNumber, - localAddr: utils.Ipv6Addr1.AddressWithPrefix.Address, - remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - immediatelyResolvable: false, - expectedErr: nil, - expectedLinkAddr: utils.LinkAddr2, - }, - { - name: "IPv4 not resolvable", - netProto: ipv4.ProtocolNumber, - localAddr: utils.Ipv4Addr1.AddressWithPrefix.Address, - remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address, - immediatelyResolvable: false, - expectedErr: &tcpip.ErrTimeout{}, - }, - { - name: "IPv6 not resolvable", - netProto: ipv6.ProtocolNumber, - localAddr: utils.Ipv6Addr1.AddressWithPrefix.Address, - remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address, - immediatelyResolvable: false, - expectedErr: &tcpip.ErrTimeout{}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - Clock: clock, - } - - host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) - r, err := host1Stack.FindRoute(host1NICID, test.localAddr, test.remoteAddr, test.netProto, false /* multicastLoop */) - if err != nil { - t.Fatalf("host1Stack.FindRoute(%d, %s, %s, %d, false): %s", host1NICID, test.localAddr, test.remoteAddr, test.netProto, err) - } - defer r.Release() - - var wantRouteInfo stack.RouteInfo - wantRouteInfo.LocalLinkAddress = utils.LinkAddr1 - wantRouteInfo.LocalAddress = test.localAddr - wantRouteInfo.RemoteAddress = test.remoteAddr - wantRouteInfo.NetProto = test.netProto - wantRouteInfo.Loop = stack.PacketOut - wantRouteInfo.RemoteLinkAddress = test.expectedLinkAddr - - ch := make(chan stack.ResolvedFieldsResult, 1) - - if !test.immediatelyResolvable { - wantUnresolvedRouteInfo := wantRouteInfo - wantUnresolvedRouteInfo.RemoteLinkAddress = "" - - err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) { - ch <- r - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{}) - } - - nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto) - if err != nil { - t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err) - } - clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer) - - select { - case got := <-ch: - if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Err: test.expectedErr}, got, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { - t.Errorf("route resolve result mismatch (-want +got):\n%s", diff) - } - default: - t.Fatalf("event didn't arrive") - } - - if test.expectedErr != nil { - return - } - - // At this point the neighbor table should be populated so the route - // should be immediately resolvable. - } - - if err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) { - ch <- r - }); err != nil { - t.Errorf("r.ResolvedFields(_): %s", err) - } - select { - case routeResolveRes := <-ch: - if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Err: nil}, routeResolveRes, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { - t.Errorf("route resolve result from resolved route mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected route to be immediately resolvable") - } - }) - } -} - -func TestWritePacketsLinkResolution(t *testing.T) { - const ( - host1NICID = 1 - host2NICID = 4 - ) - - tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - expectedWriteErr tcpip.Error - }{ - { - name: "IPv4", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - expectedWriteErr: nil, - }, - { - name: "IPv6", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - expectedWriteErr: nil, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - } - - host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID) - - var serverWQ waiter.Queue - serverWE, serverCH := waiter.NewChannelEntry(nil) - serverWQ.EventRegister(&serverWE, waiter.ReadableEvents) - serverEP, err := host2Stack.NewEndpoint(udp.ProtocolNumber, test.netProto, &serverWQ) - if err != nil { - t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.netProto, err) - } - defer serverEP.Close() - - serverAddr := tcpip.FullAddress{Port: 1234} - if err := serverEP.Bind(serverAddr); err != nil { - t.Fatalf("serverEP.Bind(%#v): %s", serverAddr, err) - } - - r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */) - if err != nil { - t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err) - } - defer r.Release() - - data := []byte{1, 2} - var pkts stack.PacketBufferList - for _, d := range data { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()), - Data: buffer.View([]byte{d}).ToVectorisedView(), - }) - pkt.TransportProtocolNumber = udp.ProtocolNumber - length := uint16(pkt.Size()) - udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) - udpHdr.Encode(&header.UDPFields{ - SrcPort: 5555, - DstPort: serverAddr.Port, - Length: length, - }) - xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length) - xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) - udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum)) - - pkts.PushBack(pkt) - } - - params := stack.NetworkHeaderParams{ - Protocol: udp.ProtocolNumber, - TTL: 64, - TOS: stack.DefaultTOS, - } - - if n, err := r.WritePackets(pkts, params); err != nil { - t.Fatalf("r.WritePackets(_, %#v): %s", params, err) - } else if want := pkts.Len(); want != n { - t.Fatalf("got r.WritePackets(_, %#v) = %d, want = %d", params, n, want) - } - - var writer bytes.Buffer - count := 0 - for { - var rOpts tcpip.ReadOptions - res, err := serverEP.Read(&writer, rOpts) - if err != nil { - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - // Should not have anymore bytes to read after we read the sent - // number of bytes. - if count == len(data) { - break - } - - <-serverCH - continue - } - - t.Fatalf("serverEP.Read(_, %#v): %s", rOpts, err) - } - count += res.Count - } - - if got, want := host2Stack.Stats().UDP.PacketsReceived.Value(), uint64(len(data)); got != want { - t.Errorf("got host2Stack.Stats().UDP.PacketsReceived.Value() = %d, want = %d", got, want) - } - if diff := cmp.Diff(data, writer.Bytes()); diff != "" { - t.Errorf("read bytes mismatch (-want +got):\n%s", diff) - } - }) - } -} - -type eventType int - -const ( - entryAdded eventType = iota - entryChanged - entryRemoved -) - -func (t eventType) String() string { - switch t { - case entryAdded: - return "add" - case entryChanged: - return "change" - case entryRemoved: - return "remove" - default: - return fmt.Sprintf("unknown (%d)", t) - } -} - -type eventInfo struct { - eventType eventType - nicID tcpip.NICID - entry stack.NeighborEntry -} - -func (e eventInfo) String() string { - return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry) -} - -var _ stack.NUDDispatcher = (*nudDispatcher)(nil) - -type nudDispatcher struct { - c chan eventInfo -} - -func (d *nudDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) { - e := eventInfo{ - eventType: entryAdded, - nicID: nicID, - entry: entry, - } - d.c <- e -} - -func (d *nudDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) { - e := eventInfo{ - eventType: entryChanged, - nicID: nicID, - entry: entry, - } - d.c <- e -} - -func (d *nudDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) { - e := eventInfo{ - eventType: entryRemoved, - nicID: nicID, - entry: entry, - } - d.c <- e -} - -func (d *nudDispatcher) expectEvent(want eventInfo) error { - select { - case got := <-d.c: - if diff := cmp.Diff(want, got, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAt")); diff != "" { - return fmt.Errorf("got invalid event (-want +got):\n%s", diff) - } - return nil - default: - return fmt.Errorf("event didn't arrive") - } -} - -// TestTCPConfirmNeighborReachability tests that TCP informs layers beneath it -// that the neighbor used for a route is reachable. -func TestTCPConfirmNeighborReachability(t *testing.T) { - tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - neighborAddr tcpip.Address - getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) - isHost1Listener bool - }{ - { - name: "IPv4 active connection through neighbor", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, - neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { - var listenerWQ waiter.Queue - listenerWE, listenerCH := waiter.NewChannelEntry(nil) - listenerWQ.EventRegister(&listenerWE, waiter.EventIn) - listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) - if err != nil { - t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - t.Cleanup(listenerEP.Close) - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) - clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) - if err != nil { - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - - return listenerEP, listenerCH, clientEP, clientCH - }, - }, - { - name: "IPv6 active connection through neighbor", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, - neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { - var listenerWQ waiter.Queue - listenerWE, listenerCH := waiter.NewChannelEntry(nil) - listenerWQ.EventRegister(&listenerWE, waiter.EventIn) - listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) - if err != nil { - t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) - } - t.Cleanup(listenerEP.Close) - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) - clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) - if err != nil { - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) - } - - return listenerEP, listenerCH, clientEP, clientCH - }, - }, - { - name: "IPv4 active connection to neighbor", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { - var listenerWQ waiter.Queue - listenerWE, listenerCH := waiter.NewChannelEntry(nil) - listenerWQ.EventRegister(&listenerWE, waiter.EventIn) - listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) - if err != nil { - t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - t.Cleanup(listenerEP.Close) - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) - clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) - if err != nil { - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - - return listenerEP, listenerCH, clientEP, clientCH - }, - }, - { - name: "IPv6 active connection to neighbor", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { - var listenerWQ waiter.Queue - listenerWE, listenerCH := waiter.NewChannelEntry(nil) - listenerWQ.EventRegister(&listenerWE, waiter.EventIn) - listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) - if err != nil { - t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) - } - t.Cleanup(listenerEP.Close) - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) - clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) - if err != nil { - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) - } - - return listenerEP, listenerCH, clientEP, clientCH - }, - }, - { - name: "IPv4 passive connection to neighbor", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address, - neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { - var listenerWQ waiter.Queue - listenerWE, listenerCH := waiter.NewChannelEntry(nil) - listenerWQ.EventRegister(&listenerWE, waiter.EventIn) - listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) - if err != nil { - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - t.Cleanup(listenerEP.Close) - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) - clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) - if err != nil { - t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - - return listenerEP, listenerCH, clientEP, clientCH - }, - isHost1Listener: true, - }, - { - name: "IPv6 passive connection to neighbor", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address, - neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { - var listenerWQ waiter.Queue - listenerWE, listenerCH := waiter.NewChannelEntry(nil) - listenerWQ.EventRegister(&listenerWE, waiter.EventIn) - listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) - if err != nil { - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) - } - t.Cleanup(listenerEP.Close) - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) - clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) - if err != nil { - t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) - } - - return listenerEP, listenerCH, clientEP, clientCH - }, - isHost1Listener: true, - }, - { - name: "IPv4 passive connection through neighbor", - netProto: ipv4.ProtocolNumber, - remoteAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address, - neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { - var listenerWQ waiter.Queue - listenerWE, listenerCH := waiter.NewChannelEntry(nil) - listenerWQ.EventRegister(&listenerWE, waiter.EventIn) - listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) - if err != nil { - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - t.Cleanup(listenerEP.Close) - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) - clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) - if err != nil { - t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - - return listenerEP, listenerCH, clientEP, clientCH - }, - isHost1Listener: true, - }, - { - name: "IPv6 passive connection through neighbor", - netProto: ipv6.ProtocolNumber, - remoteAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address, - neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { - var listenerWQ waiter.Queue - listenerWE, listenerCH := waiter.NewChannelEntry(nil) - listenerWQ.EventRegister(&listenerWE, waiter.EventIn) - listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) - if err != nil { - t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) - } - t.Cleanup(listenerEP.Close) - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) - clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) - if err != nil { - t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) - } - - return listenerEP, listenerCH, clientEP, clientCH - }, - isHost1Listener: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - nudDisp := nudDispatcher{ - c: make(chan eventInfo, 3), - } - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - Clock: clock, - } - host1StackOpts := stackOpts - host1StackOpts.NUDDisp = &nudDisp - - host1Stack := stack.New(host1StackOpts) - routerStack := stack.New(stackOpts) - host2Stack := stack.New(stackOpts) - utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack) - - // Add a reachable dynamic entry to our neighbor table for the remote. - { - ch := make(chan stack.LinkResolutionResult, 1) - err := host1Stack.GetLinkAddress(utils.Host1NICID, test.neighborAddr, "", test.netProto, func(r stack.LinkResolutionResult) { - ch <- r - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", utils.Host1NICID, test.neighborAddr, test.netProto, err, &tcpip.ErrWouldBlock{}) - } - if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: utils.LinkAddr2, Err: nil}, <-ch); diff != "" { - t.Fatalf("link resolution mismatch (-want +got):\n%s", diff) - } - } - if err := nudDisp.expectEvent(eventInfo{ - eventType: entryAdded, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Incomplete, Addr: test.neighborAddr}, - }); err != nil { - t.Fatalf("error waiting for initial NUD event: %s", err) - } - if err := nudDisp.expectEvent(eventInfo{ - eventType: entryChanged, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, - }); err != nil { - t.Fatalf("error waiting for reachable NUD event: %s", err) - } - - // Wait for the remote's neighbor entry to be stale before creating a - // TCP connection from host1 to some remote. - nudConfigs, err := host1Stack.NUDConfigurations(utils.Host1NICID, test.netProto) - if err != nil { - t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", utils.Host1NICID, test.netProto, err) - } - // The maximum reachable time for a neighbor is some maximum random factor - // applied to the base reachable time. - // - // See NUDConfigurations.BaseReachableTime for more information. - maxReachableTime := time.Duration(float32(nudConfigs.BaseReachableTime) * nudConfigs.MaxRandomFactor) - clock.Advance(maxReachableTime) - if err := nudDisp.expectEvent(eventInfo{ - eventType: entryChanged, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, - }); err != nil { - t.Fatalf("error waiting for stale NUD event: %s", err) - } - - listenerEP, listenerCH, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack) - defer clientEP.Close() - listenerAddr := tcpip.FullAddress{Addr: test.remoteAddr, Port: 1234} - if err := listenerEP.Bind(listenerAddr); err != nil { - t.Fatalf("listenerEP.Bind(%#v): %s", listenerAddr, err) - } - if err := listenerEP.Listen(1); err != nil { - t.Fatalf("listenerEP.Listen(1): %s", err) - } - { - err := clientEP.Connect(listenerAddr) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", listenerAddr, err, &tcpip.ErrConnectStarted{}) - } - } - - // Wait for the TCP handshake to complete then make sure the neighbor is - // reachable without entering the probe state as TCP should provide NUD - // with confirmation that the neighbor is reachable (indicated by a - // successful 3-way handshake). - <-clientCH - if err := nudDisp.expectEvent(eventInfo{ - eventType: entryChanged, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, - }); err != nil { - t.Fatalf("error waiting for delay NUD event: %s", err) - } - <-listenerCH - if err := nudDisp.expectEvent(eventInfo{ - eventType: entryChanged, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, - }); err != nil { - t.Fatalf("error waiting for reachable NUD event: %s", err) - } - - peerEP, peerWQ, err := listenerEP.Accept(nil) - if err != nil { - t.Fatalf("listenerEP.Accept(): %s", err) - } - defer peerEP.Close() - peerWE, peerCH := waiter.NewChannelEntry(nil) - peerWQ.EventRegister(&peerWE, waiter.ReadableEvents) - - // Wait for the neighbor to be stale again then send data to the remote. - // - // On successful transmission, the neighbor should become reachable - // without probing the neighbor as a TCP ACK would be received which is an - // indication of the neighbor being reachable. - clock.Advance(maxReachableTime) - if err := nudDisp.expectEvent(eventInfo{ - eventType: entryChanged, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, - }); err != nil { - t.Fatalf("error waiting for stale NUD event: %s", err) - } - { - var r bytes.Reader - r.Reset([]byte{0}) - var wOpts tcpip.WriteOptions - if _, err := clientEP.Write(&r, wOpts); err != nil { - t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err) - } - } - // Heads up, there is a race here. - // - // Incoming TCP segments are handled in - // tcp.(*endpoint).handleSegmentLocked: - // - // - tcp.(*endpoint).rcv.handleRcvdSegment puts the segment on the - // segment queue and notifies waiting readers (such as this channel) - // - // - tcp.(*endpoint).snd.handleRcvdSegment sends an ACK for the segment - // and notifies the NUD machinery that the peer is reachable - // - // Thus we must permit a delay between the readable signal and the - // expected NUD event. - // - // At the time of writing, this race is reliably hit with gotsan. - <-peerCH - for len(nudDisp.c) == 0 { - runtime.Gosched() - } - if err := nudDisp.expectEvent(eventInfo{ - eventType: entryChanged, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, - }); err != nil { - t.Fatalf("error waiting for delay NUD event: %s", err) - } - if test.isHost1Listener { - // If host1 is not the client, host1 does not send any data so TCP - // has no way to know it is making forward progress. Because of this, - // TCP should not mark the route reachable and NUD should go through the - // probe state. - clock.Advance(nudConfigs.DelayFirstProbeTime) - if err := nudDisp.expectEvent(eventInfo{ - eventType: entryChanged, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Probe, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, - }); err != nil { - t.Fatalf("error waiting for probe NUD event: %s", err) - } - } - { - var r bytes.Reader - r.Reset([]byte{0}) - var wOpts tcpip.WriteOptions - if _, err := peerEP.Write(&r, wOpts); err != nil { - t.Errorf("peerEP.Write(_, %#v): %s", wOpts, err) - } - } - <-clientCH - if err := nudDisp.expectEvent(eventInfo{ - eventType: entryChanged, - nicID: utils.Host1NICID, - entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, - }); err != nil { - t.Fatalf("error waiting for reachable NUD event: %s", err) - } - }) - } -} - -func TestDAD(t *testing.T) { - dadConfigs := stack.DADConfigurations{ - DupAddrDetectTransmits: 1, - RetransmitTimer: time.Second, - } - - tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - dadNetProto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - expectedResult stack.DADResult - }{ - { - name: "IPv4 own address", - netProto: ipv4.ProtocolNumber, - dadNetProto: arp.ProtocolNumber, - remoteAddr: utils.Ipv4Addr1.AddressWithPrefix.Address, - expectedResult: &stack.DADSucceeded{}, - }, - { - name: "IPv6 own address", - netProto: ipv6.ProtocolNumber, - dadNetProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr1.AddressWithPrefix.Address, - expectedResult: &stack.DADSucceeded{}, - }, - { - name: "IPv4 duplicate address", - netProto: ipv4.ProtocolNumber, - dadNetProto: arp.ProtocolNumber, - remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - expectedResult: &stack.DADDupAddrDetected{HolderLinkAddress: utils.LinkAddr2}, - }, - { - name: "IPv6 duplicate address", - netProto: ipv6.ProtocolNumber, - dadNetProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - expectedResult: &stack.DADDupAddrDetected{HolderLinkAddress: utils.LinkAddr2}, - }, - { - name: "IPv4 no duplicate address", - netProto: ipv4.ProtocolNumber, - dadNetProto: arp.ProtocolNumber, - remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address, - expectedResult: &stack.DADSucceeded{}, - }, - { - name: "IPv6 no duplicate address", - netProto: ipv6.ProtocolNumber, - dadNetProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address, - expectedResult: &stack.DADSucceeded{}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - stackOpts := stack.Options{ - Clock: clock, - NetworkProtocols: []stack.NetworkProtocolFactory{ - arp.NewProtocol, - ipv4.NewProtocol, - ipv6.NewProtocol, - }, - } - - host1Stack, _ := setupStack(t, stackOpts, utils.Host1NICID, utils.Host2NICID) - - // DAD should be disabled by default. - if res, err := host1Stack.CheckDuplicateAddress(utils.Host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) { - t.Errorf("unexpectedly called DAD completion handler when DAD was supposed to be disabled") - }); err != nil { - t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", utils.Host1NICID, test.netProto, test.remoteAddr, err) - } else if res != stack.DADDisabled { - t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", utils.Host1NICID, test.netProto, test.remoteAddr, res, stack.DADDisabled) - } - - // Enable DAD then attempt to check if an address is duplicated. - netEP, err := host1Stack.GetNetworkEndpoint(utils.Host1NICID, test.dadNetProto) - if err != nil { - t.Fatalf("host1Stack.GetNetworkEndpoint(%d, %d): %s", utils.Host1NICID, test.dadNetProto, err) - } - dad, ok := netEP.(stack.DuplicateAddressDetector) - if !ok { - t.Fatalf("expected %T to implement stack.DuplicateAddressDetector", netEP) - } - dad.SetDADConfigurations(dadConfigs) - ch := make(chan stack.DADResult, 3) - if res, err := host1Stack.CheckDuplicateAddress(utils.Host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) { - ch <- r - }); err != nil { - t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", utils.Host1NICID, test.netProto, test.remoteAddr, err) - } else if res != stack.DADStarting { - t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", utils.Host1NICID, test.netProto, test.remoteAddr, res, stack.DADStarting) - } - - expectResults := 1 - if _, ok := test.expectedResult.(*stack.DADSucceeded); ok { - const delta = time.Nanosecond - clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer - delta) - select { - case r := <-ch: - t.Fatalf("unexpectedly got DAD result before the DAD timeout; r = %#v", r) - default: - } - - // If we expect the resolve to succeed try requesting DAD again on the - // same address. The handler for the new request should be called once - // the original DAD request completes. - expectResults = 2 - if res, err := host1Stack.CheckDuplicateAddress(utils.Host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) { - ch <- r - }); err != nil { - t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", utils.Host1NICID, test.netProto, test.remoteAddr, err) - } else if res != stack.DADAlreadyRunning { - t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", utils.Host1NICID, test.netProto, test.remoteAddr, res, stack.DADAlreadyRunning) - } - - clock.Advance(delta) - } - - for i := 0; i < expectResults; i++ { - if diff := cmp.Diff(test.expectedResult, <-ch); diff != "" { - t.Errorf("(i=%d) DAD result mismatch (-want +got):\n%s", i, diff) - } - } - - // Should have no more results. - select { - case r := <-ch: - t.Errorf("unexpectedly got an extra DAD result; r = %#v", r) - default: - } - }) - } -} diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go deleted file mode 100644 index f33223e79..000000000 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ /dev/null @@ -1,782 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package loopback_test - -import ( - "bytes" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/tests/utils" - "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil) - -type ndpDispatcher struct{} - -func (*ndpDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) { -} - -func (*ndpDispatcher) OnOffLinkRouteUpdated(tcpip.NICID, tcpip.Subnet, tcpip.Address, header.NDPRoutePreference) { -} - -func (*ndpDispatcher) OnOffLinkRouteInvalidated(tcpip.NICID, tcpip.Subnet, tcpip.Address) {} - -func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) { -} - -func (*ndpDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {} - -func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) { -} - -func (*ndpDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {} - -func (*ndpDispatcher) OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) {} - -func (*ndpDispatcher) OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) {} - -func (*ndpDispatcher) OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) {} - -func (*ndpDispatcher) OnDHCPv6Configuration(tcpip.NICID, ipv6.DHCPv6ConfigurationFromNDPRA) {} - -// TestInitialLoopbackAddresses tests that the loopback interface does not -// auto-generate a link-local address when it is brought up. -func TestInitialLoopbackAddresses(t *testing.T) { - const nicID = 1 - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPDisp: &ndpDispatcher{}, - AutoGenLinkLocal: true, - OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(nicID tcpip.NICID, nicName string) string { - t.Fatalf("should not attempt to get name for NIC with ID = %d; nicName = %s", nicID, nicName) - return "" - }, - }, - })}, - }) - - if err := s.CreateNIC(nicID, loopback.New()); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - - nicsInfo := s.NICInfo() - if nicInfo, ok := nicsInfo[nicID]; !ok { - t.Fatalf("did not find NIC with ID = %d in s.NICInfo() = %#v", nicID, nicsInfo) - } else if got := len(nicInfo.ProtocolAddresses); got != 0 { - t.Fatalf("got len(nicInfo.ProtocolAddresses) = %d, want = 0; nicInfo.ProtocolAddresses = %#v", got, nicInfo.ProtocolAddresses) - } -} - -// TestLoopbackAcceptAllInSubnetUDP tests that a loopback interface considers -// itself bound to all addresses in the subnet of an assigned address and UDP -// traffic is sent/received correctly. -func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) { - const ( - nicID = 1 - localPort = 80 - ) - - data := []byte{1, 2, 3, 4} - - ipv4ProtocolAddress := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: utils.Ipv4Addr, - } - ipv4Bytes := []byte(ipv4ProtocolAddress.AddressWithPrefix.Address) - ipv4Bytes[len(ipv4Bytes)-1]++ - otherIPv4Address := tcpip.Address(ipv4Bytes) - - ipv6ProtocolAddress := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: utils.Ipv6Addr, - } - ipv6Bytes := []byte(utils.Ipv6Addr.Address) - ipv6Bytes[len(ipv6Bytes)-1]++ - otherIPv6Address := tcpip.Address(ipv6Bytes) - - tests := []struct { - name string - addAddress tcpip.ProtocolAddress - bindAddr tcpip.Address - dstAddr tcpip.Address - expectRx bool - }{ - { - name: "IPv4 bind to wildcard and send to assigned address", - addAddress: ipv4ProtocolAddress, - dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, - expectRx: true, - }, - { - name: "IPv4 bind to wildcard and send to other subnet-local address", - addAddress: ipv4ProtocolAddress, - dstAddr: otherIPv4Address, - expectRx: true, - }, - { - name: "IPv4 bind to wildcard send to other address", - addAddress: ipv4ProtocolAddress, - dstAddr: utils.RemoteIPv4Addr, - expectRx: false, - }, - { - name: "IPv4 bind to other subnet-local address and send to assigned address", - addAddress: ipv4ProtocolAddress, - bindAddr: otherIPv4Address, - dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, - expectRx: false, - }, - { - name: "IPv4 bind and send to other subnet-local address", - addAddress: ipv4ProtocolAddress, - bindAddr: otherIPv4Address, - dstAddr: otherIPv4Address, - expectRx: true, - }, - { - name: "IPv4 bind to assigned address and send to other subnet-local address", - addAddress: ipv4ProtocolAddress, - bindAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, - dstAddr: otherIPv4Address, - expectRx: false, - }, - - { - name: "IPv6 bind and send to assigned address", - addAddress: ipv6ProtocolAddress, - bindAddr: utils.Ipv6Addr.Address, - dstAddr: utils.Ipv6Addr.Address, - expectRx: true, - }, - { - name: "IPv6 bind to wildcard and send to other subnet-local address", - addAddress: ipv6ProtocolAddress, - dstAddr: otherIPv6Address, - expectRx: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - if err := s.CreateNIC(nicID, loopback.New()); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddProtocolAddress(nicID, test.addAddress, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, test.addAddress, err) - } - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }, - }) - - var wq waiter.Queue - rep, err := s.NewEndpoint(udp.ProtocolNumber, test.addAddress.Protocol, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) - } - defer rep.Close() - - bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort} - if err := rep.Bind(bindAddr); err != nil { - t.Fatalf("rep.Bind(%+v): %s", bindAddr, err) - } - - sep, err := s.NewEndpoint(udp.ProtocolNumber, test.addAddress.Protocol, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) - } - defer sep.Close() - - wopts := tcpip.WriteOptions{ - To: &tcpip.FullAddress{ - Addr: test.dstAddr, - Port: localPort, - }, - } - var r bytes.Reader - r.Reset(data) - n, err := sep.Write(&r, wopts) - if err != nil { - t.Fatalf("sep.Write(_, _): %s", err) - } - if want := int64(len(data)); n != want { - t.Fatalf("got sep.Write(_, _) = (%d, nil), want = (%d, nil)", n, want) - } - - var buf bytes.Buffer - opts := tcpip.ReadOptions{NeedRemoteAddr: true} - if res, err := rep.Read(&buf, opts); test.expectRx { - if err != nil { - t.Fatalf("rep.Read(_, %#v): %s", opts, err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - RemoteAddr: tcpip.FullAddress{ - Addr: test.addAddress.AddressWithPrefix.Address, - }, - }, res, - checker.IgnoreCmpPath("ControlMessages", "RemoteAddr.NIC", "RemoteAddr.Port"), - ); diff != "" { - t.Errorf("rep.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(data, buf.Bytes()); diff != "" { - t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) - } - } else if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got rep.Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), &tcpip.ErrWouldBlock{}) - } - }) - } -} - -// TestLoopbackSubnetLifetimeBoundToAddr tests that the lifetime of an address -// in a loopback interface's associated subnet is bound to the permanently bound -// address. -func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { - const nicID = 1 - - protoAddr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: utils.Ipv4Addr, - } - addrBytes := []byte(utils.Ipv4Addr.Address) - addrBytes[len(addrBytes)-1]++ - otherAddr := tcpip.Address(addrBytes) - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - }) - if err := s.CreateNIC(nicID, loopback.New()); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err) - } - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - }) - - r, err := s.FindRoute(nicID, otherAddr, utils.RemoteIPv4Addr, ipv4.ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, otherAddr, utils.RemoteIPv4Addr, ipv4.ProtocolNumber, err) - } - defer r.Release() - - params := stack.NetworkHeaderParams{ - Protocol: 111, - TTL: 64, - TOS: stack.DefaultTOS, - } - data := buffer.View([]byte{1, 2, 3, 4}) - if err := r.WritePacket(params, stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: data.ToVectorisedView(), - })); err != nil { - t.Fatalf("r.WritePacket(%#v, _): %s", params, err) - } - - // Removing the address should make the endpoint invalid. - if err := s.RemoveAddress(nicID, protoAddr.AddressWithPrefix.Address); err != nil { - t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, protoAddr.AddressWithPrefix.Address, err) - } - { - err := r.WritePacket(params, stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: data.ToVectorisedView(), - })) - if _, ok := err.(*tcpip.ErrInvalidEndpointState); !ok { - t.Fatalf("got r.WritePacket(%#v, _) = %s, want = %s", params, err, &tcpip.ErrInvalidEndpointState{}) - } - } -} - -// TestLoopbackAcceptAllInSubnetTCP tests that a loopback interface considers -// itself bound to all addresses in the subnet of an assigned address and TCP -// traffic is sent/received correctly. -func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) { - const ( - nicID = 1 - localPort = 80 - ) - - ipv4ProtocolAddress := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: utils.Ipv4Addr, - } - ipv4ProtocolAddress.AddressWithPrefix.PrefixLen = 8 - ipv4Bytes := []byte(ipv4ProtocolAddress.AddressWithPrefix.Address) - ipv4Bytes[len(ipv4Bytes)-1]++ - otherIPv4Address := tcpip.Address(ipv4Bytes) - - ipv6ProtocolAddress := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: utils.Ipv6Addr, - } - ipv6Bytes := []byte(utils.Ipv6Addr.Address) - ipv6Bytes[len(ipv6Bytes)-1]++ - otherIPv6Address := tcpip.Address(ipv6Bytes) - - tests := []struct { - name string - addAddress tcpip.ProtocolAddress - bindAddr tcpip.Address - dstAddr tcpip.Address - expectAccept bool - }{ - { - name: "IPv4 bind to wildcard and send to assigned address", - addAddress: ipv4ProtocolAddress, - dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, - expectAccept: true, - }, - { - name: "IPv4 bind to wildcard and send to other subnet-local address", - addAddress: ipv4ProtocolAddress, - dstAddr: otherIPv4Address, - expectAccept: true, - }, - { - name: "IPv4 bind to wildcard send to other address", - addAddress: ipv4ProtocolAddress, - dstAddr: utils.RemoteIPv4Addr, - expectAccept: false, - }, - { - name: "IPv4 bind to other subnet-local address and send to assigned address", - addAddress: ipv4ProtocolAddress, - bindAddr: otherIPv4Address, - dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, - expectAccept: false, - }, - { - name: "IPv4 bind and send to other subnet-local address", - addAddress: ipv4ProtocolAddress, - bindAddr: otherIPv4Address, - dstAddr: otherIPv4Address, - expectAccept: true, - }, - { - name: "IPv4 bind to assigned address and send to other subnet-local address", - addAddress: ipv4ProtocolAddress, - bindAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, - dstAddr: otherIPv4Address, - expectAccept: false, - }, - - { - name: "IPv6 bind and send to assigned address", - addAddress: ipv6ProtocolAddress, - bindAddr: utils.Ipv6Addr.Address, - dstAddr: utils.Ipv6Addr.Address, - expectAccept: true, - }, - { - name: "IPv6 bind to wildcard and send to other subnet-local address", - addAddress: ipv6ProtocolAddress, - dstAddr: otherIPv6Address, - expectAccept: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - }) - if err := s.CreateNIC(nicID, loopback.New()); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - if err := s.AddProtocolAddress(nicID, test.addAddress, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, test.addAddress, err) - } - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }, - }) - - var wq waiter.Queue - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - listeningEndpoint, err := s.NewEndpoint(tcp.ProtocolNumber, test.addAddress.Protocol, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) - } - defer listeningEndpoint.Close() - - bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort} - if err := listeningEndpoint.Bind(bindAddr); err != nil { - t.Fatalf("listeningEndpoint.Bind(%#v): %s", bindAddr, err) - } - - if err := listeningEndpoint.Listen(1); err != nil { - t.Fatalf("listeningEndpoint.Listen(1): %s", err) - } - - connectingEndpoint, err := s.NewEndpoint(tcp.ProtocolNumber, test.addAddress.Protocol, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) - } - defer connectingEndpoint.Close() - - connectAddr := tcpip.FullAddress{ - Addr: test.dstAddr, - Port: localPort, - } - { - err := connectingEndpoint.Connect(connectAddr) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("connectingEndpoint.Connect(%#v): %s", connectAddr, err) - } - } - - if !test.expectAccept { - _, _, err := listeningEndpoint.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got listeningEndpoint.Accept(nil) = %s, want = %s", err, &tcpip.ErrWouldBlock{}) - } - return - } - - // Wait for the listening endpoint to be "readable". That is, wait for a - // new connection. - <-ch - var addr tcpip.FullAddress - if _, _, err := listeningEndpoint.Accept(&addr); err != nil { - t.Fatalf("listeningEndpoint.Accept(nil): %s", err) - } - if addr.Addr != test.addAddress.AddressWithPrefix.Address { - t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.addAddress.AddressWithPrefix.Address) - } - }) - } -} - -func TestExternalLoopbackTraffic(t *testing.T) { - const ( - nicID1 = 1 - nicID2 = 2 - - numPackets = 1 - ttl = 64 - ) - ipv4Loopback := testutil.MustParse4("127.0.0.1") - - loopbackSourcedICMPv4 := func(e *channel.Endpoint) { - utils.RxICMPv4EchoRequest(e, ipv4Loopback, utils.Ipv4Addr.Address, ttl) - } - - loopbackSourcedICMPv6 := func(e *channel.Endpoint) { - utils.RxICMPv6EchoRequest(e, header.IPv6Loopback, utils.Ipv6Addr.Address, ttl) - } - - loopbackDestinedICMPv4 := func(e *channel.Endpoint) { - utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, ipv4Loopback, ttl) - } - - loopbackDestinedICMPv6 := func(e *channel.Endpoint) { - utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, header.IPv6Loopback, ttl) - } - - invalidSrcAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter { - return s.InvalidSourceAddressesReceived - } - - invalidDestAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter { - return s.InvalidDestinationAddressesReceived - } - - tests := []struct { - name string - allowExternalLoopback bool - forwarding bool - rxICMP func(*channel.Endpoint) - invalidAddressStat func(tcpip.IPStats) *tcpip.StatCounter - shouldAccept bool - }{ - { - name: "IPv4 external loopback sourced traffic without forwarding and drop external loopback disabled", - allowExternalLoopback: true, - forwarding: false, - rxICMP: loopbackSourcedICMPv4, - invalidAddressStat: invalidSrcAddrStat, - shouldAccept: true, - }, - { - name: "IPv4 external loopback sourced traffic without forwarding and drop external loopback enabled", - allowExternalLoopback: false, - forwarding: false, - rxICMP: loopbackSourcedICMPv4, - invalidAddressStat: invalidSrcAddrStat, - shouldAccept: false, - }, - { - name: "IPv4 external loopback sourced traffic with forwarding and drop external loopback disabled", - allowExternalLoopback: true, - forwarding: true, - rxICMP: loopbackSourcedICMPv4, - invalidAddressStat: invalidSrcAddrStat, - shouldAccept: true, - }, - { - name: "IPv4 external loopback sourced traffic with forwarding and drop external loopback enabled", - allowExternalLoopback: false, - forwarding: true, - rxICMP: loopbackSourcedICMPv4, - invalidAddressStat: invalidSrcAddrStat, - shouldAccept: false, - }, - { - name: "IPv4 external loopback destined traffic without forwarding and drop external loopback disabled", - allowExternalLoopback: true, - forwarding: false, - rxICMP: loopbackDestinedICMPv4, - invalidAddressStat: invalidDestAddrStat, - shouldAccept: false, - }, - { - name: "IPv4 external loopback destined traffic without forwarding and drop external loopback enabled", - allowExternalLoopback: false, - forwarding: false, - rxICMP: loopbackDestinedICMPv4, - invalidAddressStat: invalidDestAddrStat, - shouldAccept: false, - }, - { - name: "IPv4 external loopback destined traffic with forwarding and drop external loopback disabled", - allowExternalLoopback: true, - forwarding: true, - rxICMP: loopbackDestinedICMPv4, - invalidAddressStat: invalidDestAddrStat, - shouldAccept: true, - }, - { - name: "IPv4 external loopback destined traffic with forwarding and drop external loopback enabled", - allowExternalLoopback: false, - forwarding: true, - rxICMP: loopbackDestinedICMPv4, - invalidAddressStat: invalidDestAddrStat, - shouldAccept: false, - }, - - { - name: "IPv6 external loopback sourced traffic without forwarding and drop external loopback disabled", - allowExternalLoopback: true, - forwarding: false, - rxICMP: loopbackSourcedICMPv6, - invalidAddressStat: invalidSrcAddrStat, - shouldAccept: true, - }, - { - name: "IPv6 external loopback sourced traffic without forwarding and drop external loopback enabled", - allowExternalLoopback: false, - forwarding: false, - rxICMP: loopbackSourcedICMPv6, - invalidAddressStat: invalidSrcAddrStat, - shouldAccept: false, - }, - { - name: "IPv6 external loopback sourced traffic with forwarding and drop external loopback disabled", - allowExternalLoopback: true, - forwarding: true, - rxICMP: loopbackSourcedICMPv6, - invalidAddressStat: invalidSrcAddrStat, - shouldAccept: true, - }, - { - name: "IPv6 external loopback sourced traffic with forwarding and drop external loopback enabled", - allowExternalLoopback: false, - forwarding: true, - rxICMP: loopbackSourcedICMPv6, - invalidAddressStat: invalidSrcAddrStat, - shouldAccept: false, - }, - { - name: "IPv6 external loopback destined traffic without forwarding and drop external loopback disabled", - allowExternalLoopback: true, - forwarding: false, - rxICMP: loopbackDestinedICMPv6, - invalidAddressStat: invalidDestAddrStat, - shouldAccept: false, - }, - { - name: "IPv6 external loopback destined traffic without forwarding and drop external loopback enabled", - allowExternalLoopback: false, - forwarding: false, - rxICMP: loopbackDestinedICMPv6, - invalidAddressStat: invalidDestAddrStat, - shouldAccept: false, - }, - { - name: "IPv6 external loopback destined traffic with forwarding and drop external loopback disabled", - allowExternalLoopback: true, - forwarding: true, - rxICMP: loopbackDestinedICMPv6, - invalidAddressStat: invalidDestAddrStat, - shouldAccept: true, - }, - { - name: "IPv6 external loopback destined traffic with forwarding and drop external loopback enabled", - allowExternalLoopback: false, - forwarding: true, - rxICMP: loopbackDestinedICMPv6, - invalidAddressStat: invalidDestAddrStat, - shouldAccept: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - ipv4.NewProtocolWithOptions(ipv4.Options{ - AllowExternalLoopbackTraffic: test.allowExternalLoopback, - }), - ipv6.NewProtocolWithOptions(ipv6.Options{ - AllowExternalLoopbackTraffic: test.allowExternalLoopback, - }), - }, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, - }) - e := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID1, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - v4Addr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: utils.Ipv4Addr, - } - if err := s.AddProtocolAddress(nicID1, v4Addr, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, v4Addr, err) - } - v6Addr := tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: utils.Ipv6Addr, - } - if err := s.AddProtocolAddress(nicID1, v6Addr, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, v6Addr, err) - } - - if err := s.CreateNIC(nicID2, loopback.New()); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) - } - protocolAddrV4 := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: ipv4Loopback, - PrefixLen: 8, - }, - } - if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err) - } - protocolAddrV6 := tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: header.IPv6Loopback.WithPrefix(), - } - if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err) - } - - if test.forwarding { - if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) - } - if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) - } - } - - s.SetRouteTable([]tcpip.Route{ - tcpip.Route{ - Destination: header.IPv4EmptySubnet, - NIC: nicID1, - }, - tcpip.Route{ - Destination: header.IPv6EmptySubnet, - NIC: nicID1, - }, - tcpip.Route{ - Destination: ipv4Loopback.WithPrefix().Subnet(), - NIC: nicID2, - }, - tcpip.Route{ - Destination: header.IPv6Loopback.WithPrefix().Subnet(), - NIC: nicID2, - }, - }) - - stats := s.Stats().IP - invalidAddressStat := test.invalidAddressStat(stats) - deliveredPacketsStat := stats.PacketsDelivered - if got := invalidAddressStat.Value(); got != 0 { - t.Fatalf("got invalidAddressStat.Value() = %d, want = 0", got) - } - if got := deliveredPacketsStat.Value(); got != 0 { - t.Fatalf("got deliveredPacketsStat.Value() = %d, want = 0", got) - } - test.rxICMP(e) - var expectedInvalidPackets uint64 - if !test.shouldAccept { - expectedInvalidPackets = numPackets - } - if got := invalidAddressStat.Value(); got != expectedInvalidPackets { - t.Fatalf("got invalidAddressStat.Value() = %d, want = %d", got, expectedInvalidPackets) - } - if got, want := deliveredPacketsStat.Value(), numPackets-expectedInvalidPackets; got != want { - t.Fatalf("got deliveredPacketsStat.Value() = %d, want = %d", got, want) - } - }) - } -} diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go deleted file mode 100644 index 7753e7d6e..000000000 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ /dev/null @@ -1,723 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package multicast_broadcast_test - -import ( - "bytes" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/tests/utils" - "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -const ( - defaultMTU = 1280 - ttl = 255 -) - -// TestPingMulticastBroadcast tests that responding to an Echo Request destined -// to a multicast or broadcast address uses a unicast source address for the -// reply. -func TestPingMulticastBroadcast(t *testing.T) { - const ( - nicID = 1 - ttl = 64 - ) - - tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - rxICMP func(*channel.Endpoint, tcpip.Address, tcpip.Address, uint8) - srcAddr tcpip.Address - dstAddr tcpip.Address - expectedSrc tcpip.Address - }{ - { - name: "IPv4 unicast", - protoNum: header.IPv4ProtocolNumber, - dstAddr: utils.Ipv4Addr.Address, - srcAddr: utils.RemoteIPv4Addr, - rxICMP: utils.RxICMPv4EchoRequest, - expectedSrc: utils.Ipv4Addr.Address, - }, - { - name: "IPv4 directed broadcast", - protoNum: header.IPv4ProtocolNumber, - rxICMP: utils.RxICMPv4EchoRequest, - srcAddr: utils.RemoteIPv4Addr, - dstAddr: utils.Ipv4SubnetBcast, - expectedSrc: utils.Ipv4Addr.Address, - }, - { - name: "IPv4 broadcast", - protoNum: header.IPv4ProtocolNumber, - rxICMP: utils.RxICMPv4EchoRequest, - srcAddr: utils.RemoteIPv4Addr, - dstAddr: header.IPv4Broadcast, - expectedSrc: utils.Ipv4Addr.Address, - }, - { - name: "IPv4 all-systems multicast", - protoNum: header.IPv4ProtocolNumber, - rxICMP: utils.RxICMPv4EchoRequest, - srcAddr: utils.RemoteIPv4Addr, - dstAddr: header.IPv4AllSystems, - expectedSrc: utils.Ipv4Addr.Address, - }, - { - name: "IPv6 unicast", - protoNum: header.IPv6ProtocolNumber, - rxICMP: utils.RxICMPv6EchoRequest, - srcAddr: utils.RemoteIPv6Addr, - dstAddr: utils.Ipv6Addr.Address, - expectedSrc: utils.Ipv6Addr.Address, - }, - { - name: "IPv6 all-nodes multicast", - protoNum: header.IPv6ProtocolNumber, - rxICMP: utils.RxICMPv6EchoRequest, - srcAddr: utils.RemoteIPv6Addr, - dstAddr: header.IPv6AllNodesMulticastAddress, - expectedSrc: utils.Ipv6Addr.Address, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, - }) - // We only expect a single packet in response to our ICMP Echo Request. - e := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: utils.Ipv4Addr} - if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtoAddr, err) - } - ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: utils.Ipv6Addr} - if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtoAddr, err) - } - - // Default routes for IPv4 and IPv6 so ICMP can find a route to the remote - // node when attempting to send the ICMP Echo Reply. - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }, - { - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - }) - - test.rxICMP(e, test.srcAddr, test.dstAddr, ttl) - pkt, ok := e.Read() - if !ok { - t.Fatal("expected ICMP response") - } - - if pkt.Route.LocalAddress != test.expectedSrc { - t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.expectedSrc) - } - // The destination of the response packet should be the source of the - // original packet. - if pkt.Route.RemoteAddress != test.srcAddr { - t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.srcAddr) - } - - src, dst := s.NetworkProtocolInstance(test.protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader())) - if src != test.expectedSrc { - t.Errorf("got pkt source = %s, want = %s", src, test.expectedSrc) - } - // The destination of the response packet should be the source of the - // original packet. - if dst != test.srcAddr { - t.Errorf("got pkt destination = %s, want = %s", dst, test.srcAddr) - } - }) - } - -} - -func rxIPv4UDP(e *channel.Endpoint, src, dst tcpip.Address, data []byte) { - payloadLen := header.UDPMinimumSize + len(data) - totalLen := header.IPv4MinimumSize + payloadLen - hdr := buffer.NewPrependable(totalLen) - u := header.UDP(hdr.Prepend(payloadLen)) - u.Encode(&header.UDPFields{ - SrcPort: utils.RemotePort, - DstPort: utils.LocalPort, - Length: uint16(payloadLen), - }) - copy(u.Payload(), data) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(payloadLen)) - sum = header.Checksum(data, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - Protocol: uint8(udp.ProtocolNumber), - TTL: ttl, - SrcAddr: src, - DstAddr: dst, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) -} - -func rxIPv6UDP(e *channel.Endpoint, src, dst tcpip.Address, data []byte) { - payloadLen := header.UDPMinimumSize + len(data) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen) - u := header.UDP(hdr.Prepend(payloadLen)) - u.Encode(&header.UDPFields{ - SrcPort: utils.RemotePort, - DstPort: utils.LocalPort, - Length: uint16(payloadLen), - }) - copy(u.Payload(), data) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(payloadLen)) - sum = header.Checksum(data, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLen), - TransportProtocol: udp.ProtocolNumber, - HopLimit: ttl, - SrcAddr: src, - DstAddr: dst, - }) - - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) -} - -// TestIncomingMulticastAndBroadcast tests receiving a packet destined to some -// multicast or broadcast address. -func TestIncomingMulticastAndBroadcast(t *testing.T) { - const nicID = 1 - - data := []byte{1, 2, 3, 4} - - tests := []struct { - name string - proto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - localAddr tcpip.AddressWithPrefix - rxUDP func(*channel.Endpoint, tcpip.Address, tcpip.Address, []byte) - bindAddr tcpip.Address - dstAddr tcpip.Address - expectRx bool - }{ - { - name: "IPv4 unicast binding to unicast", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - bindAddr: utils.Ipv4Addr.Address, - dstAddr: utils.Ipv4Addr.Address, - expectRx: true, - }, - { - name: "IPv4 unicast binding to broadcast", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - bindAddr: header.IPv4Broadcast, - dstAddr: utils.Ipv4Addr.Address, - expectRx: false, - }, - { - name: "IPv4 unicast binding to wildcard", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - dstAddr: utils.Ipv4Addr.Address, - expectRx: true, - }, - - { - name: "IPv4 directed broadcast binding to subnet broadcast", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - bindAddr: utils.Ipv4SubnetBcast, - dstAddr: utils.Ipv4SubnetBcast, - expectRx: true, - }, - { - name: "IPv4 directed broadcast binding to broadcast", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - bindAddr: header.IPv4Broadcast, - dstAddr: utils.Ipv4SubnetBcast, - expectRx: false, - }, - { - name: "IPv4 directed broadcast binding to wildcard", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - dstAddr: utils.Ipv4SubnetBcast, - expectRx: true, - }, - - { - name: "IPv4 broadcast binding to broadcast", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - bindAddr: header.IPv4Broadcast, - dstAddr: header.IPv4Broadcast, - expectRx: true, - }, - { - name: "IPv4 broadcast binding to subnet broadcast", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - bindAddr: utils.Ipv4SubnetBcast, - dstAddr: header.IPv4Broadcast, - expectRx: false, - }, - { - name: "IPv4 broadcast binding to wildcard", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - dstAddr: utils.Ipv4SubnetBcast, - expectRx: true, - }, - - { - name: "IPv4 all-systems multicast binding to all-systems multicast", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - bindAddr: header.IPv4AllSystems, - dstAddr: header.IPv4AllSystems, - expectRx: true, - }, - { - name: "IPv4 all-systems multicast binding to wildcard", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - dstAddr: header.IPv4AllSystems, - expectRx: true, - }, - { - name: "IPv4 all-systems multicast binding to unicast", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - bindAddr: utils.Ipv4Addr.Address, - dstAddr: header.IPv4AllSystems, - expectRx: false, - }, - - // IPv6 has no notion of a broadcast. - { - name: "IPv6 unicast binding to wildcard", - dstAddr: utils.Ipv6Addr.Address, - proto: header.IPv6ProtocolNumber, - remoteAddr: utils.RemoteIPv6Addr, - localAddr: utils.Ipv6Addr, - rxUDP: rxIPv6UDP, - expectRx: true, - }, - { - name: "IPv6 broadcast-like address binding to wildcard", - dstAddr: utils.Ipv6SubnetBcast, - proto: header.IPv6ProtocolNumber, - remoteAddr: utils.RemoteIPv6Addr, - localAddr: utils.Ipv6Addr, - rxUDP: rxIPv6UDP, - expectRx: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - e := channel.New(0, defaultMTU, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr} - if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err) - } - - var wq waiter.Queue - ep, err := s.NewEndpoint(udp.ProtocolNumber, test.proto, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err) - } - defer ep.Close() - - bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: utils.LocalPort} - if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("ep.Bind(%#v): %s", bindAddr, err) - } - - test.rxUDP(e, test.remoteAddr, test.dstAddr, data) - var buf bytes.Buffer - var opts tcpip.ReadOptions - if res, err := ep.Read(&buf, opts); test.expectRx { - if err != nil { - t.Fatalf("ep.Read(_, %#v): %s", opts, err) - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - }, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(data, buf.Bytes()); diff != "" { - t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) - } - } else if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), &tcpip.ErrWouldBlock{}) - } - }) - } -} - -// TestReuseAddrAndBroadcast makes sure broadcast packets are received by all -// interested endpoints. -func TestReuseAddrAndBroadcast(t *testing.T) { - const ( - nicID = 1 - localPort = 9000 - ) - loopbackBroadcast := testutil.MustParse4("127.255.255.255") - - tests := []struct { - name string - broadcastAddr tcpip.Address - }{ - { - name: "Subnet directed broadcast", - broadcastAddr: loopbackBroadcast, - }, - { - name: "IPv4 broadcast", - broadcastAddr: header.IPv4Broadcast, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - if err := s.CreateNIC(nicID, loopback.New()); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - protoAddr := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: "\x7f\x00\x00\x01", - PrefixLen: 8, - }, - } - if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - // We use the empty subnet instead of just the loopback subnet so we - // also have a route to the IPv4 Broadcast address. - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - }) - - type endpointAndWaiter struct { - ep tcpip.Endpoint - ch chan struct{} - } - var eps []endpointAndWaiter - // We create endpoints that bind to both the wildcard address and the - // broadcast address to make sure both of these types of "broadcast - // interested" endpoints receive broadcast packets. - for _, bindWildcard := range []bool{false, true} { - // Create multiple endpoints for each type of "broadcast interested" - // endpoint so we can test that all endpoints receive the broadcast - // packet. - for i := 0; i < 2; i++ { - var wq waiter.Queue - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { - t.Fatalf("(eps[%d]) NewEndpoint(%d, %d, _): %s", len(eps), udp.ProtocolNumber, ipv4.ProtocolNumber, err) - } - defer ep.Close() - - ep.SocketOptions().SetReuseAddress(true) - ep.SocketOptions().SetBroadcast(true) - - bindAddr := tcpip.FullAddress{Port: localPort} - if bindWildcard { - if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("eps[%d].Bind(%#v): %s", len(eps), bindAddr, err) - } - } else { - bindAddr.Addr = test.broadcastAddr - if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("eps[%d].Bind(%#v): %s", len(eps), bindAddr, err) - } - } - - eps = append(eps, endpointAndWaiter{ep: ep, ch: ch}) - } - } - - for i, wep := range eps { - writeOpts := tcpip.WriteOptions{ - To: &tcpip.FullAddress{ - Addr: test.broadcastAddr, - Port: localPort, - }, - } - data := []byte{byte(i), 2, 3, 4} - var r bytes.Reader - r.Reset(data) - if n, err := wep.ep.Write(&r, writeOpts); err != nil { - t.Fatalf("eps[%d].Write(_, _): %s", i, err) - } else if want := int64(len(data)); n != want { - t.Fatalf("got eps[%d].Write(_, _) = (%d, nil), want = (%d, nil)", i, n, want) - } - - for j, rep := range eps { - // Wait for the endpoint to become readable. - <-rep.ch - - var buf bytes.Buffer - result, err := rep.ep.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Errorf("(eps[%d] write) eps[%d].Read: %s", i, j, err) - continue - } - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("(eps[%d] write) eps[%d].Read: unexpected result (-want +got):\n%s", i, j, diff) - } - if diff := cmp.Diff([]byte(data), buf.Bytes()); diff != "" { - t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff) - } - } - } - }) - } -} - -func TestUDPAddRemoveMembershipSocketOption(t *testing.T) { - const ( - nicID = 1 - ) - - data := []byte{1, 2, 3, 4} - - tests := []struct { - name string - proto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - localAddr tcpip.AddressWithPrefix - rxUDP func(*channel.Endpoint, tcpip.Address, tcpip.Address, []byte) - multicastAddr tcpip.Address - }{ - { - name: "IPv4 unicast binding to unicast", - multicastAddr: "\xe0\x01\x02\x03", - proto: header.IPv4ProtocolNumber, - remoteAddr: utils.RemoteIPv4Addr, - localAddr: utils.Ipv4Addr, - rxUDP: rxIPv4UDP, - }, - { - name: "IPv6 broadcast-like address binding to wildcard", - multicastAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04", - proto: header.IPv6ProtocolNumber, - remoteAddr: utils.RemoteIPv6Addr, - localAddr: utils.Ipv6Addr, - rxUDP: rxIPv6UDP, - }, - } - - subTests := []struct { - name string - specifyNICID bool - specifyNICAddr bool - }{ - { - name: "Specify NIC ID and NIC address", - specifyNICID: true, - specifyNICAddr: true, - }, - { - name: "Don't specify NIC ID or NIC address", - specifyNICID: false, - specifyNICAddr: false, - }, - { - name: "Specify NIC ID but don't specify NIC address", - specifyNICID: true, - specifyNICAddr: false, - }, - { - name: "Don't specify NIC ID but specify NIC address", - specifyNICID: false, - specifyNICAddr: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, subTest := range subTests { - t.Run(subTest.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - }) - e := channel.New(0, defaultMTU, "") - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr} - if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err) - } - - // Set the route table so that UDP can find a NIC that is - // routable to the multicast address when the NIC isn't specified. - if !subTest.specifyNICID && !subTest.specifyNICAddr { - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv6EmptySubnet, - NIC: nicID, - }, - { - Destination: header.IPv4EmptySubnet, - NIC: nicID, - }, - }) - } - - var wq waiter.Queue - ep, err := s.NewEndpoint(udp.ProtocolNumber, test.proto, &wq) - if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err) - } - defer ep.Close() - - bindAddr := tcpip.FullAddress{Port: utils.LocalPort} - if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("ep.Bind(%#v): %s", bindAddr, err) - } - - memOpt := tcpip.MembershipOption{MulticastAddr: test.multicastAddr} - if subTest.specifyNICID { - memOpt.NIC = nicID - } - if subTest.specifyNICAddr { - memOpt.InterfaceAddr = test.localAddr.Address - } - - // We should receive UDP packets to the group once we join the - // multicast group. - addOpt := tcpip.AddMembershipOption(memOpt) - if err := ep.SetSockOpt(&addOpt); err != nil { - t.Fatalf("ep.SetSockOpt(&%#v): %s", addOpt, err) - } - test.rxUDP(e, test.remoteAddr, test.multicastAddr, data) - var buf bytes.Buffer - result, err := ep.Read(&buf, tcpip.ReadOptions{}) - if err != nil { - t.Fatalf("ep.Read: %s", err) - } else { - if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), - }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { - t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(data, buf.Bytes()); diff != "" { - t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) - } - } - - // We should not receive UDP packets to the group once we leave - // the multicast group. - removeOpt := tcpip.RemoveMembershipOption(memOpt) - if err := ep.SetSockOpt(&removeOpt); err != nil { - t.Fatalf("ep.SetSockOpt(&%#v): %s", removeOpt, err) - } - { - _, err := ep.Read(&buf, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, &tcpip.ErrWouldBlock{}) - } - } - }) - } - }) - } -} diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go deleted file mode 100644 index 422eb8408..000000000 --- a/pkg/tcpip/tests/integration/route_test.go +++ /dev/null @@ -1,441 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package route_test - -import ( - "bytes" - "fmt" - "testing" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/checker" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/tests/utils" - "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" - "gvisor.dev/gvisor/pkg/waiter" -) - -// TestLocalPing tests pinging a remote that is local the stack. -// -// This tests that a local route is created and packets do not leave the stack. -func TestLocalPing(t *testing.T) { - const ( - nicID = 1 - - // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo - // request/reply packets. - icmpDataOffset = 8 - ) - ipv4Loopback := tcpip.AddressWithPrefix{ - Address: testutil.MustParse4("127.0.0.1"), - PrefixLen: 8, - } - - channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") } - channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) { - channelEP := e.(*channel.Endpoint) - if n := channelEP.Drain(); n != 0 { - t.Fatalf("got channelEP.Drain() = %d, want = 0", n) - } - } - - ipv4ICMPBuf := func(t *testing.T) buffer.View { - data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} - hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data))) - hdr.SetType(header.ICMPv4Echo) - if n := copy(hdr.Payload(), data[:]); n != len(data) { - t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) - } - return buffer.View(hdr) - } - - ipv6ICMPBuf := func(t *testing.T) buffer.View { - data := [8]byte{1, 2, 3, 4, 5, 6, 7, 9} - hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data))) - hdr.SetType(header.ICMPv6EchoRequest) - if n := copy(hdr.Payload(), data[:]); n != len(data) { - t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) - } - return buffer.View(hdr) - } - - tests := []struct { - name string - transProto tcpip.TransportProtocolNumber - netProto tcpip.NetworkProtocolNumber - linkEndpoint func() stack.LinkEndpoint - localAddr tcpip.AddressWithPrefix - icmpBuf func(*testing.T) buffer.View - expectedConnectErr tcpip.Error - checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint) - }{ - { - name: "IPv4 loopback", - transProto: icmp.ProtocolNumber4, - netProto: ipv4.ProtocolNumber, - linkEndpoint: loopback.New, - localAddr: ipv4Loopback, - icmpBuf: ipv4ICMPBuf, - checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, - }, - { - name: "IPv6 loopback", - transProto: icmp.ProtocolNumber6, - netProto: ipv6.ProtocolNumber, - linkEndpoint: loopback.New, - localAddr: header.IPv6Loopback.WithPrefix(), - icmpBuf: ipv6ICMPBuf, - checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, - }, - { - name: "IPv4 non-loopback", - transProto: icmp.ProtocolNumber4, - netProto: ipv4.ProtocolNumber, - linkEndpoint: channelEP, - localAddr: utils.Ipv4Addr, - icmpBuf: ipv4ICMPBuf, - checkLinkEndpoint: channelEPCheck, - }, - { - name: "IPv6 non-loopback", - transProto: icmp.ProtocolNumber6, - netProto: ipv6.ProtocolNumber, - linkEndpoint: channelEP, - localAddr: utils.Ipv6Addr, - icmpBuf: ipv6ICMPBuf, - checkLinkEndpoint: channelEPCheck, - }, - { - name: "IPv4 loopback without local address", - transProto: icmp.ProtocolNumber4, - netProto: ipv4.ProtocolNumber, - linkEndpoint: loopback.New, - icmpBuf: ipv4ICMPBuf, - expectedConnectErr: &tcpip.ErrNoRoute{}, - checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, - }, - { - name: "IPv6 loopback without local address", - transProto: icmp.ProtocolNumber6, - netProto: ipv6.ProtocolNumber, - linkEndpoint: loopback.New, - icmpBuf: ipv6ICMPBuf, - expectedConnectErr: &tcpip.ErrNoRoute{}, - checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, - }, - { - name: "IPv4 non-loopback without local address", - transProto: icmp.ProtocolNumber4, - netProto: ipv4.ProtocolNumber, - linkEndpoint: channelEP, - icmpBuf: ipv4ICMPBuf, - expectedConnectErr: &tcpip.ErrNoRoute{}, - checkLinkEndpoint: channelEPCheck, - }, - { - name: "IPv6 non-loopback without local address", - transProto: icmp.ProtocolNumber6, - netProto: ipv6.ProtocolNumber, - linkEndpoint: channelEP, - icmpBuf: ipv6ICMPBuf, - expectedConnectErr: &tcpip.ErrNoRoute{}, - checkLinkEndpoint: channelEPCheck, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, allowExternalLoopback := range []bool{true, false} { - t.Run(fmt.Sprintf("AllowExternalLoopback=%t", allowExternalLoopback), func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ - ipv4.NewProtocolWithOptions(ipv4.Options{ - AllowExternalLoopbackTraffic: allowExternalLoopback, - }), - ipv6.NewProtocolWithOptions(ipv6.Options{ - AllowExternalLoopbackTraffic: allowExternalLoopback, - }), - }, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, - HandleLocal: true, - }) - e := test.linkEndpoint() - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - - if len(test.localAddr.Address) != 0 { - protocolAddr := tcpip.ProtocolAddress{ - Protocol: test.netProto, - AddressWithPrefix: test.localAddr, - } - if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) - } - } - - var wq waiter.Queue - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.ReadableEvents) - ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err) - } - defer ep.Close() - - connAddr := tcpip.FullAddress{Addr: test.localAddr.Address} - if err := ep.Connect(connAddr); err != test.expectedConnectErr { - t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr) - } - - if test.expectedConnectErr != nil { - return - } - - var r bytes.Reader - payload := test.icmpBuf(t) - r.Reset(payload) - var wOpts tcpip.WriteOptions - if n, err := ep.Write(&r, wOpts); err != nil { - t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err) - } else if n != int64(len(payload)) { - t.Fatalf("got ep.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", payload, wOpts, n, len(payload)) - } - - // Wait for the endpoint to become readable. - <-ch - - var w bytes.Buffer - rr, err := ep.Read(&w, tcpip.ReadOptions{ - NeedRemoteAddr: true, - }) - if err != nil { - t.Fatalf("ep.Read(...): %s", err) - } - if diff := cmp.Diff(buffer.View(w.Bytes()[icmpDataOffset:]), payload[icmpDataOffset:]); diff != "" { - t.Errorf("received data mismatch (-want +got):\n%s", diff) - } - if rr.RemoteAddr.Addr != test.localAddr.Address { - t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr.Address) - } - - test.checkLinkEndpoint(t, e) - }) - } - }) - } -} - -// TestLocalUDP tests sending UDP packets between two endpoints that are local -// to the stack. -// -// This tests that that packets never leave the stack and the addresses -// used when sending a packet. -func TestLocalUDP(t *testing.T) { - const ( - nicID = 1 - ) - - tests := []struct { - name string - canBePrimaryAddr tcpip.ProtocolAddress - firstPrimaryAddr tcpip.ProtocolAddress - }{ - { - name: "IPv4", - canBePrimaryAddr: utils.Ipv4Addr1, - firstPrimaryAddr: utils.Ipv4Addr2, - }, - { - name: "IPv6", - canBePrimaryAddr: utils.Ipv6Addr1, - firstPrimaryAddr: utils.Ipv6Addr2, - }, - } - - subTests := []struct { - name string - addAddress bool - expectedWriteErr tcpip.Error - }{ - { - name: "Unassigned local address", - addAddress: false, - expectedWriteErr: &tcpip.ErrNoRoute{}, - }, - { - name: "Assigned local address", - addAddress: true, - expectedWriteErr: nil, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, subTest := range subTests { - t.Run(subTest.name, func(t *testing.T) { - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - HandleLocal: true, - } - - s := stack.New(stackOpts) - ep := channel.New(1, header.IPv6MinimumMTU, "") - - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) - } - - if subTest.addAddress { - if err := s.AddProtocolAddress(nicID, test.canBePrimaryAddr, stack.AddressProperties{}); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, test.canBePrimaryAddr, err) - } - properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint} - if err := s.AddProtocolAddress(nicID, test.firstPrimaryAddr, properties); err != nil { - t.Fatalf("s.AddProtocolAddress(%d, %+v, %+v): %s", nicID, test.firstPrimaryAddr, properties, err) - } - } - - var serverWQ waiter.Queue - serverWE, serverCH := waiter.NewChannelEntry(nil) - serverWQ.EventRegister(&serverWE, waiter.ReadableEvents) - server, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &serverWQ) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err) - } - defer server.Close() - - bindAddr := tcpip.FullAddress{Port: 80} - if err := server.Bind(bindAddr); err != nil { - t.Fatalf("server.Bind(%#v): %s", bindAddr, err) - } - - var clientWQ waiter.Queue - clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.ReadableEvents) - client, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &clientWQ) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err) - } - defer client.Close() - - serverAddr := tcpip.FullAddress{ - Addr: test.canBePrimaryAddr.AddressWithPrefix.Address, - Port: 80, - } - - clientPayload := []byte{1, 2, 3, 4} - { - var r bytes.Reader - r.Reset(clientPayload) - wOpts := tcpip.WriteOptions{ - To: &serverAddr, - } - if n, err := client.Write(&r, wOpts); err != subTest.expectedWriteErr { - t.Fatalf("got client.Write(%#v, %#v) = (%d, %s), want = (_, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr) - } else if subTest.expectedWriteErr != nil { - // Nothing else to test if we expected not to be able to send the - // UDP packet. - return - } else if n != int64(len(clientPayload)) { - t.Fatalf("got client.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", clientPayload, wOpts, n, len(clientPayload)) - } - } - - // Wait for the server endpoint to become readable. - <-serverCH - - var clientAddr tcpip.FullAddress - var readBuf bytes.Buffer - if read, err := server.Read(&readBuf, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil { - t.Fatalf("server.Read(_): %s", err) - } else { - clientAddr = read.RemoteAddr - - if diff := cmp.Diff(tcpip.ReadResult{ - Count: readBuf.Len(), - Total: readBuf.Len(), - RemoteAddr: tcpip.FullAddress{ - Addr: test.canBePrimaryAddr.AddressWithPrefix.Address, - }, - }, read, checker.IgnoreCmpPath( - "ControlMessages", - "RemoteAddr.NIC", - "RemoteAddr.Port", - )); diff != "" { - t.Errorf("server.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(buffer.View(clientPayload), buffer.View(readBuf.Bytes())); diff != "" { - t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff) - } - if t.Failed() { - t.FailNow() - } - } - - serverPayload := []byte{1, 2, 3, 4} - { - var r bytes.Reader - r.Reset(serverPayload) - wOpts := tcpip.WriteOptions{ - To: &clientAddr, - } - if n, err := server.Write(&r, wOpts); err != nil { - t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err) - } else if n != int64(len(serverPayload)) { - t.Fatalf("got server.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", serverPayload, wOpts, n, len(serverPayload)) - } - } - - // Wait for the client endpoint to become readable. - <-clientCH - - readBuf.Reset() - if read, err := client.Read(&readBuf, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil { - t.Fatalf("client.Read(_): %s", err) - } else { - if diff := cmp.Diff(tcpip.ReadResult{ - Count: readBuf.Len(), - Total: readBuf.Len(), - RemoteAddr: tcpip.FullAddress{Addr: serverAddr.Addr}, - }, read, checker.IgnoreCmpPath( - "ControlMessages", - "RemoteAddr.NIC", - "RemoteAddr.Port", - )); diff != "" { - t.Errorf("client.Read: unexpected result (-want +got):\n%s", diff) - } - if diff := cmp.Diff(buffer.View(serverPayload), buffer.View(readBuf.Bytes())); diff != "" { - t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff) - } - if t.Failed() { - t.FailNow() - } - } - }) - } - }) - } -} |