summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/tests/integration
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/tests/integration')
-rw-r--r--pkg/tcpip/tests/integration/BUILD167
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go698
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go2271
-rw-r--r--pkg/tcpip/tests/integration/istio_test.go365
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go1640
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go782
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go723
-rw-r--r--pkg/tcpip/tests/integration/route_test.go441
8 files changed, 0 insertions, 7087 deletions
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
deleted file mode 100644
index 99f4d4d0e..000000000
--- a/pkg/tcpip/tests/integration/BUILD
+++ /dev/null
@@ -1,167 +0,0 @@
-load("//tools:defs.bzl", "go_test")
-
-package(licenses = ["notice"])
-
-go_test(
- name = "forward_test",
- size = "small",
- srcs = ["forward_test.go"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/network/arp",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/tests/utils",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/tcp",
- "//pkg/tcpip/transport/udp",
- "//pkg/waiter",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
-
-go_test(
- name = "iptables_test",
- size = "small",
- srcs = ["iptables_test.go"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/network/arp",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/tests/utils",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/tcp",
- "//pkg/tcpip/transport/udp",
- "//pkg/waiter",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
-
-go_test(
- name = "link_resolution_test",
- size = "small",
- srcs = ["link_resolution_test.go"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/faketime",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/pipe",
- "//pkg/tcpip/network/arp",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/tests/utils",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/icmp",
- "//pkg/tcpip/transport/tcp",
- "//pkg/tcpip/transport/udp",
- "//pkg/waiter",
- "@com_github_google_go_cmp//cmp:go_default_library",
- "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
- ],
-)
-
-go_test(
- name = "loopback_test",
- size = "small",
- srcs = ["loopback_test.go"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/loopback",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/tests/utils",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/icmp",
- "//pkg/tcpip/transport/tcp",
- "//pkg/tcpip/transport/udp",
- "//pkg/waiter",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
-
-go_test(
- name = "multicast_broadcast_test",
- size = "small",
- srcs = ["multicast_broadcast_test.go"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/loopback",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/tests/utils",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/icmp",
- "//pkg/tcpip/transport/udp",
- "//pkg/waiter",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
-
-go_test(
- name = "route_test",
- size = "small",
- srcs = ["route_test.go"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/loopback",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/tests/utils",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/icmp",
- "//pkg/tcpip/transport/udp",
- "//pkg/waiter",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
-
-go_test(
- name = "istio_test",
- size = "small",
- srcs = ["istio_test.go"],
- deps = [
- "//pkg/context",
- "//pkg/rand",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/adapters/gonet",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/loopback",
- "//pkg/tcpip/link/pipe",
- "//pkg/tcpip/link/sniffer",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/tcp",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
deleted file mode 100644
index 6e1d4720d..000000000
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ /dev/null
@@ -1,698 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package forward_test
-
-import (
- "bytes"
- "fmt"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/network/arp"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/tests/utils"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const ttl = 64
-
-var (
- ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10")
- ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a")
-)
-
-func rxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
- utils.RxICMPv4EchoRequest(e, src, dst, ttl)
-}
-
-func rxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
- utils.RxICMPv6EchoRequest(e, src, dst, ttl)
-}
-
-func forwardedICMPv4EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
- checker.IPv4(t, b,
- checker.SrcAddr(src),
- checker.DstAddr(dst),
- checker.TTL(ttl-1),
- checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4Echo)))
-}
-
-func forwardedICMPv6EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
- checker.IPv6(t, b,
- checker.SrcAddr(src),
- checker.DstAddr(dst),
- checker.TTL(ttl-1),
- checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6EchoRequest)))
-}
-
-func TestForwarding(t *testing.T) {
- const listenPort = 8080
-
- type endpointAndAddresses struct {
- serverEP tcpip.Endpoint
- serverAddr tcpip.Address
- serverReadableCH chan struct{}
-
- clientEP tcpip.Endpoint
- clientAddr tcpip.Address
- clientReadableCH chan struct{}
- }
-
- newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) {
- t.Helper()
- var wq waiter.Queue
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- ep, err := s.NewEndpoint(transProto, netProto, &wq)
- if err != nil {
- t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err)
- }
-
- t.Cleanup(func() {
- wq.EventUnregister(&we)
- })
-
- return ep, ch
- }
-
- tests := []struct {
- name string
- epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses
- }{
- {
- name: "IPv4 host1 server with host2 client",
- epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
- ep1, ep1WECH := newEP(t, host1Stack, proto, ipv4.ProtocolNumber)
- ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
- return endpointAndAddresses{
- serverEP: ep1,
- serverAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
- serverReadableCH: ep1WECH,
-
- clientEP: ep2,
- clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address,
- clientReadableCH: ep2WECH,
- }
- },
- },
- {
- name: "IPv6 host2 server with host1 client",
- epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
- ep1, ep1WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
- ep2, ep2WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber)
- return endpointAndAddresses{
- serverEP: ep1,
- serverAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address,
- serverReadableCH: ep1WECH,
-
- clientEP: ep2,
- clientAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
- clientReadableCH: ep2WECH,
- }
- },
- },
- {
- name: "IPv4 host2 server with routerNIC1 client",
- epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
- ep1, ep1WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
- ep2, ep2WECH := newEP(t, routerStack, proto, ipv4.ProtocolNumber)
- return endpointAndAddresses{
- serverEP: ep1,
- serverAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address,
- serverReadableCH: ep1WECH,
-
- clientEP: ep2,
- clientAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
- clientReadableCH: ep2WECH,
- }
- },
- },
- {
- name: "IPv6 routerNIC2 server with host1 client",
- epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
- ep1, ep1WECH := newEP(t, routerStack, proto, ipv6.ProtocolNumber)
- ep2, ep2WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber)
- return endpointAndAddresses{
- serverEP: ep1,
- serverAddr: utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
- serverReadableCH: ep1WECH,
-
- clientEP: ep2,
- clientAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
- clientReadableCH: ep2WECH,
- }
- },
- },
- }
-
- subTests := []struct {
- name string
- proto tcpip.TransportProtocolNumber
- expectedConnectErr tcpip.Error
- setupServer func(t *testing.T, ep tcpip.Endpoint)
- setupServerConn func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{})
- needRemoteAddr bool
- }{
- {
- name: "UDP",
- proto: udp.ProtocolNumber,
- expectedConnectErr: nil,
- setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
- t.Helper()
-
- if err := ep.Connect(clientAddr); err != nil {
- t.Fatalf("ep.Connect(%#v): %s", clientAddr, err)
- }
- return nil, nil
- },
- needRemoteAddr: true,
- },
- {
- name: "TCP",
- proto: tcp.ProtocolNumber,
- expectedConnectErr: &tcpip.ErrConnectStarted{},
- setupServer: func(t *testing.T, ep tcpip.Endpoint) {
- t.Helper()
-
- if err := ep.Listen(1); err != nil {
- t.Fatalf("ep.Listen(1): %s", err)
- }
- },
- setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
- t.Helper()
-
- var addr tcpip.FullAddress
- for {
- newEP, wq, err := ep.Accept(&addr)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
- <-ch
- continue
- }
- if err != nil {
- t.Fatalf("ep.Accept(_): %s", err)
- }
- if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath(
- "NIC",
- )); diff != "" {
- t.Errorf("accepted address mismatch (-want +got):\n%s", diff)
- }
-
- we, newCH := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- return newEP, newCH
- }
- },
- needRemoteAddr: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- for _, subTest := range subTests {
- t.Run(subTest.name, func(t *testing.T) {
- stackOpts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
- }
-
- host1Stack := stack.New(stackOpts)
- routerStack := stack.New(stackOpts)
- host2Stack := stack.New(stackOpts)
- utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack)
-
- epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto)
- defer epsAndAddrs.serverEP.Close()
- defer epsAndAddrs.clientEP.Close()
-
- serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort}
- if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil {
- t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err)
- }
- clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr}
- if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil {
- t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
- }
-
- if subTest.setupServer != nil {
- subTest.setupServer(t, epsAndAddrs.serverEP)
- }
- {
- err := epsAndAddrs.clientEP.Connect(serverAddr)
- if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" {
- t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff)
- }
- }
- if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil {
- t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err)
- } else {
- clientAddr = addr
- clientAddr.NIC = 0
- }
-
- serverEP := epsAndAddrs.serverEP
- serverCH := epsAndAddrs.serverReadableCH
- if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, clientAddr); ep != nil {
- defer ep.Close()
- serverEP = ep
- serverCH = ch
- }
-
- write := func(ep tcpip.Endpoint, data []byte) {
- t.Helper()
-
- var r bytes.Reader
- r.Reset(data)
- var wOpts tcpip.WriteOptions
- n, err := ep.Write(&r, wOpts)
- if err != nil {
- t.Fatalf("ep.Write(_, %#v): %s", wOpts, err)
- }
- if want := int64(len(data)); n != want {
- t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want)
- }
- }
-
- data := []byte{1, 2, 3, 4}
- write(epsAndAddrs.clientEP, data)
-
- read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) {
- t.Helper()
-
- var buf bytes.Buffer
- var res tcpip.ReadResult
- for {
- var err tcpip.Error
- opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr}
- res, err = ep.Read(&buf, opts)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
- <-ch
- continue
- }
- if err != nil {
- t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
- }
- break
- }
-
- readResult := tcpip.ReadResult{
- Count: len(data),
- Total: len(data),
- }
- if subTest.needRemoteAddr {
- readResult.RemoteAddr = expectedFrom
- }
- if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath(
- "ControlMessages",
- "RemoteAddr.NIC",
- )); diff != "" {
- t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(buf.Bytes(), data); diff != "" {
- t.Errorf("received data mismatch (-want +got):\n%s", diff)
- }
-
- if t.Failed() {
- t.FailNow()
- }
- }
-
- read(serverCH, serverEP, data, clientAddr)
-
- data = []byte{5, 6, 7, 8, 9, 10, 11, 12}
- write(serverEP, data)
- read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr)
- })
- }
- })
- }
-}
-
-func TestMulticastForwarding(t *testing.T) {
- const (
- nicID1 = 1
- nicID2 = 2
- )
-
- var (
- ipv4LinkLocalUnicastAddr = testutil.MustParse4("169.254.0.10")
- ipv4LinkLocalMulticastAddr = testutil.MustParse4("224.0.0.10")
-
- ipv6LinkLocalUnicastAddr = testutil.MustParse6("fe80::a")
- ipv6LinkLocalMulticastAddr = testutil.MustParse6("ff02::a")
- )
-
- tests := []struct {
- name string
- srcAddr, dstAddr tcpip.Address
- rx func(*channel.Endpoint, tcpip.Address, tcpip.Address)
- expectForward bool
- checker func(*testing.T, []byte)
- }{
- {
- name: "IPv4 link-local multicast destination",
- srcAddr: utils.RemoteIPv4Addr,
- dstAddr: ipv4LinkLocalMulticastAddr,
- rx: rxICMPv4EchoRequest,
- expectForward: false,
- },
- {
- name: "IPv4 link-local source",
- srcAddr: ipv4LinkLocalUnicastAddr,
- dstAddr: utils.RemoteIPv4Addr,
- rx: rxICMPv4EchoRequest,
- expectForward: false,
- },
- {
- name: "IPv4 link-local destination",
- srcAddr: utils.RemoteIPv4Addr,
- dstAddr: ipv4LinkLocalUnicastAddr,
- rx: rxICMPv4EchoRequest,
- expectForward: false,
- },
- {
- name: "IPv4 non-link-local unicast",
- srcAddr: utils.RemoteIPv4Addr,
- dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
- rx: rxICMPv4EchoRequest,
- expectForward: true,
- checker: func(t *testing.T, b []byte) {
- forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
- },
- },
- {
- name: "IPv4 non-link-local multicast",
- srcAddr: utils.RemoteIPv4Addr,
- dstAddr: ipv4GlobalMulticastAddr,
- rx: rxICMPv4EchoRequest,
- expectForward: true,
- checker: func(t *testing.T, b []byte) {
- forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr)
- },
- },
-
- {
- name: "IPv6 link-local multicast destination",
- srcAddr: utils.RemoteIPv6Addr,
- dstAddr: ipv6LinkLocalMulticastAddr,
- rx: rxICMPv6EchoRequest,
- expectForward: false,
- },
- {
- name: "IPv6 link-local source",
- srcAddr: ipv6LinkLocalUnicastAddr,
- dstAddr: utils.RemoteIPv6Addr,
- rx: rxICMPv6EchoRequest,
- expectForward: false,
- },
- {
- name: "IPv6 link-local destination",
- srcAddr: utils.RemoteIPv6Addr,
- dstAddr: ipv6LinkLocalUnicastAddr,
- rx: rxICMPv6EchoRequest,
- expectForward: false,
- },
- {
- name: "IPv6 non-link-local unicast",
- srcAddr: utils.RemoteIPv6Addr,
- dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
- rx: rxICMPv6EchoRequest,
- expectForward: true,
- checker: func(t *testing.T, b []byte) {
- forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
- },
- },
- {
- name: "IPv6 non-link-local multicast",
- srcAddr: utils.RemoteIPv6Addr,
- dstAddr: ipv6GlobalMulticastAddr,
- rx: rxICMPv6EchoRequest,
- expectForward: true,
- checker: func(t *testing.T, b []byte) {
- forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr)
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
-
- e1 := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID1, e1); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err)
- }
-
- e2 := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID2, e2); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err)
- }
-
- protocolAddrV4 := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: utils.Ipv4Addr,
- }
- if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
- }
- protocolAddrV6 := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: utils.Ipv6Addr,
- }
- if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
- }
-
- if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
- t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err)
- }
- if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
- t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: nicID2,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: nicID2,
- },
- })
-
- test.rx(e1, test.srcAddr, test.dstAddr)
-
- p, ok := e2.Read()
- if ok != test.expectForward {
- t.Fatalf("got e2.Read() = (%#v, %t), want = (_, %t)", p, ok, test.expectForward)
- }
-
- if test.expectForward {
- test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
- }
- })
- }
-}
-
-func TestPerInterfaceForwarding(t *testing.T) {
- const (
- nicID1 = 1
- nicID2 = 2
- )
-
- tests := []struct {
- name string
- srcAddr, dstAddr tcpip.Address
- rx func(*channel.Endpoint, tcpip.Address, tcpip.Address)
- checker func(*testing.T, []byte)
- }{
- {
- name: "IPv4 unicast",
- srcAddr: utils.RemoteIPv4Addr,
- dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
- rx: rxICMPv4EchoRequest,
- checker: func(t *testing.T, b []byte) {
- forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
- },
- },
- {
- name: "IPv4 multicast",
- srcAddr: utils.RemoteIPv4Addr,
- dstAddr: ipv4GlobalMulticastAddr,
- rx: rxICMPv4EchoRequest,
- checker: func(t *testing.T, b []byte) {
- forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr)
- },
- },
-
- {
- name: "IPv6 unicast",
- srcAddr: utils.RemoteIPv6Addr,
- dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
- rx: rxICMPv6EchoRequest,
- checker: func(t *testing.T, b []byte) {
- forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
- },
- },
- {
- name: "IPv6 multicast",
- srcAddr: utils.RemoteIPv6Addr,
- dstAddr: ipv6GlobalMulticastAddr,
- rx: rxICMPv6EchoRequest,
- checker: func(t *testing.T, b []byte) {
- forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr)
- },
- },
- }
-
- netProtos := [...]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber, ipv6.ProtocolNumber}
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{
- // ARP is not used in this test but it is a network protocol that does
- // not support forwarding. We install the protocol to make sure that
- // forwarding information for a NIC is only reported for network
- // protocols that support forwarding.
- arp.NewProtocol,
-
- ipv4.NewProtocol,
- ipv6.NewProtocol,
- },
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
-
- e1 := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID1, e1); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err)
- }
-
- e2 := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID2, e2); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err)
- }
-
- for _, add := range [...]struct {
- nicID tcpip.NICID
- addr tcpip.ProtocolAddress
- }{
- {
- nicID: nicID1,
- addr: utils.RouterNIC1IPv4Addr,
- },
- {
- nicID: nicID1,
- addr: utils.RouterNIC1IPv6Addr,
- },
- {
- nicID: nicID2,
- addr: utils.RouterNIC2IPv4Addr,
- },
- {
- nicID: nicID2,
- addr: utils.RouterNIC2IPv6Addr,
- },
- } {
- if err := s.AddProtocolAddress(add.nicID, add.addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", add.nicID, add.addr, err)
- }
- }
-
- // Only enable forwarding on NIC1 and make sure that only packets arriving
- // on NIC1 are forwarded.
- for _, netProto := range netProtos {
- if err := s.SetNICForwarding(nicID1, netProto, true); err != nil {
- t.Fatalf("s.SetNICForwarding(%d, %d, true): %s", nicID1, netProtos, err)
- }
- }
-
- nicsInfo := s.NICInfo()
- for _, subTest := range [...]struct {
- nicID tcpip.NICID
- nicEP *channel.Endpoint
- otherNICID tcpip.NICID
- otherNICEP *channel.Endpoint
- expectForwarding bool
- }{
- {
- nicID: nicID1,
- nicEP: e1,
- otherNICID: nicID2,
- otherNICEP: e2,
- expectForwarding: true,
- },
- {
- nicID: nicID2,
- nicEP: e2,
- otherNICID: nicID2,
- otherNICEP: e1,
- expectForwarding: false,
- },
- } {
- t.Run(fmt.Sprintf("Packet arriving at NIC%d", subTest.nicID), func(t *testing.T) {
- nicInfo, ok := nicsInfo[subTest.nicID]
- if !ok {
- t.Errorf("expected NIC info for NIC %d; got = %#v", subTest.nicID, nicsInfo)
- } else {
- forwarding := make(map[tcpip.NetworkProtocolNumber]bool)
- for _, netProto := range netProtos {
- forwarding[netProto] = subTest.expectForwarding
- }
-
- if diff := cmp.Diff(forwarding, nicInfo.Forwarding); diff != "" {
- t.Errorf("nicsInfo[%d].Forwarding mismatch (-want +got):\n%s", subTest.nicID, diff)
- }
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: subTest.otherNICID,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: subTest.otherNICID,
- },
- })
-
- test.rx(subTest.nicEP, test.srcAddr, test.dstAddr)
- if p, ok := subTest.nicEP.Read(); ok {
- t.Errorf("unexpectedly got a response from the interface the packet arrived on: %#v", p)
- }
- if p, ok := subTest.otherNICEP.Read(); ok != subTest.expectForwarding {
- t.Errorf("got otherNICEP.Read() = (%#v, %t), want = (_, %t)", p, ok, subTest.expectForwarding)
- } else if subTest.expectForwarding {
- test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
- }
- })
- }
- })
- }
-}
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
deleted file mode 100644
index b2383576c..000000000
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ /dev/null
@@ -1,2271 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package iptables_test
-
-import (
- "bytes"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/network/arp"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/tests/utils"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-type inputIfNameMatcher struct {
- name string
-}
-
-var _ stack.Matcher = (*inputIfNameMatcher)(nil)
-
-func (*inputIfNameMatcher) Name() string {
- return "inputIfNameMatcher"
-}
-
-func (im *inputIfNameMatcher) Match(hook stack.Hook, _ *stack.PacketBuffer, inNicName, _ string) (bool, bool) {
- return (hook == stack.Input && im.name != "" && im.name == inNicName), false
-}
-
-const (
- nicID = 1
- nicName = "nic1"
- anotherNicName = "nic2"
- linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- srcAddrV4 = tcpip.Address("\x0a\x00\x00\x01")
- dstAddrV4 = tcpip.Address("\x0a\x00\x00\x02")
- srcAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- dstAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- payloadSize = 20
-)
-
-func genStackV6(t *testing.T) (*stack.Stack, *channel.Endpoint) {
- t.Helper()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- })
- e := channel.New(0, header.IPv6MinimumMTU, linkAddr)
- nicOpts := stack.NICOptions{Name: nicName}
- if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
- t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err)
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: dstAddrV6.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- return s, e
-}
-
-func genStackV4(t *testing.T) (*stack.Stack, *channel.Endpoint) {
- t.Helper()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- })
- e := channel.New(0, header.IPv4MinimumMTU, linkAddr)
- nicOpts := stack.NICOptions{Name: nicName}
- if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
- t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err)
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: dstAddrV4.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- return s, e
-}
-
-func genPacketV6() *stack.PacketBuffer {
- pktSize := header.IPv6MinimumSize + payloadSize
- hdr := buffer.NewPrependable(pktSize)
- ip := header.IPv6(hdr.Prepend(pktSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: payloadSize,
- TransportProtocol: 99,
- HopLimit: 255,
- SrcAddr: srcAddrV6,
- DstAddr: dstAddrV6,
- })
- vv := hdr.View().ToVectorisedView()
- return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv})
-}
-
-func genPacketV4() *stack.PacketBuffer {
- pktSize := header.IPv4MinimumSize + payloadSize
- hdr := buffer.NewPrependable(pktSize)
- ip := header.IPv4(hdr.Prepend(pktSize))
- ip.Encode(&header.IPv4Fields{
- TOS: 0,
- TotalLength: uint16(pktSize),
- ID: 1,
- Flags: 0,
- FragmentOffset: 16,
- TTL: 48,
- Protocol: 99,
- SrcAddr: srcAddrV4,
- DstAddr: dstAddrV4,
- })
- ip.SetChecksum(0)
- ip.SetChecksum(^ip.CalculateChecksum())
- vv := hdr.View().ToVectorisedView()
- return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv})
-}
-
-func TestIPTablesStatsForInput(t *testing.T) {
- tests := []struct {
- name string
- setupStack func(*testing.T) (*stack.Stack, *channel.Endpoint)
- setupFilter func(*testing.T, *stack.Stack)
- genPacket func() *stack.PacketBuffer
- proto tcpip.NetworkProtocolNumber
- expectReceived int
- expectInputDropped int
- }{
- {
- name: "IPv6 Accept",
- setupStack: genStackV6,
- setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
- genPacket: genPacketV6,
- proto: header.IPv6ProtocolNumber,
- expectReceived: 1,
- expectInputDropped: 0,
- },
- {
- name: "IPv4 Accept",
- setupStack: genStackV4,
- setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
- genPacket: genPacketV4,
- proto: header.IPv4ProtocolNumber,
- expectReceived: 1,
- expectInputDropped: 0,
- },
- {
- name: "IPv6 Drop (input interface matches)",
- setupStack: genStackV6,
- setupFilter: func(t *testing.T, s *stack.Stack) {
- t.Helper()
- ipt := s.IPTables()
- filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Input]
- filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName}
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}}
- // Make sure the packet is not dropped by the next rule.
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
- }
- },
- genPacket: genPacketV6,
- proto: header.IPv6ProtocolNumber,
- expectReceived: 1,
- expectInputDropped: 1,
- },
- {
- name: "IPv4 Drop (input interface matches)",
- setupStack: genStackV4,
- setupFilter: func(t *testing.T, s *stack.Stack) {
- t.Helper()
- ipt := s.IPTables()
- filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Input]
- filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName}
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}}
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
- }
- },
- genPacket: genPacketV4,
- proto: header.IPv4ProtocolNumber,
- expectReceived: 1,
- expectInputDropped: 1,
- },
- {
- name: "IPv6 Accept (input interface does not match)",
- setupStack: genStackV6,
- setupFilter: func(t *testing.T, s *stack.Stack) {
- t.Helper()
- ipt := s.IPTables()
- filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Input]
- filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName}
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
- }
- },
- genPacket: genPacketV6,
- proto: header.IPv6ProtocolNumber,
- expectReceived: 1,
- expectInputDropped: 0,
- },
- {
- name: "IPv4 Accept (input interface does not match)",
- setupStack: genStackV4,
- setupFilter: func(t *testing.T, s *stack.Stack) {
- t.Helper()
- ipt := s.IPTables()
- filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Input]
- filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName}
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
- }
- },
- genPacket: genPacketV4,
- proto: header.IPv4ProtocolNumber,
- expectReceived: 1,
- expectInputDropped: 0,
- },
- {
- name: "IPv6 Drop (input interface does not match but invert is true)",
- setupStack: genStackV6,
- setupFilter: func(t *testing.T, s *stack.Stack) {
- t.Helper()
- ipt := s.IPTables()
- filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Input]
- filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{
- InputInterface: anotherNicName,
- InputInterfaceInvert: true,
- }
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
- }
- },
- genPacket: genPacketV6,
- proto: header.IPv6ProtocolNumber,
- expectReceived: 1,
- expectInputDropped: 1,
- },
- {
- name: "IPv4 Drop (input interface does not match but invert is true)",
- setupStack: genStackV4,
- setupFilter: func(t *testing.T, s *stack.Stack) {
- t.Helper()
- ipt := s.IPTables()
- filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Input]
- filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{
- InputInterface: anotherNicName,
- InputInterfaceInvert: true,
- }
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
- }
- },
- genPacket: genPacketV4,
- proto: header.IPv4ProtocolNumber,
- expectReceived: 1,
- expectInputDropped: 1,
- },
- {
- name: "IPv6 Accept (input interface does not match using a matcher)",
- setupStack: genStackV6,
- setupFilter: func(t *testing.T, s *stack.Stack) {
- t.Helper()
- ipt := s.IPTables()
- filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Input]
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}}
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
- }
- },
- genPacket: genPacketV6,
- proto: header.IPv6ProtocolNumber,
- expectReceived: 1,
- expectInputDropped: 0,
- },
- {
- name: "IPv4 Accept (input interface does not match using a matcher)",
- setupStack: genStackV4,
- setupFilter: func(t *testing.T, s *stack.Stack) {
- t.Helper()
- ipt := s.IPTables()
- filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Input]
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}}
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
- }
- },
- genPacket: genPacketV4,
- proto: header.IPv4ProtocolNumber,
- expectReceived: 1,
- expectInputDropped: 0,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s, e := test.setupStack(t)
- test.setupFilter(t, s)
- e.InjectInbound(test.proto, test.genPacket())
-
- if got := int(s.Stats().IP.PacketsReceived.Value()); got != test.expectReceived {
- t.Errorf("got PacketReceived = %d, want = %d", got, test.expectReceived)
- }
- if got := int(s.Stats().IP.IPTablesInputDropped.Value()); got != test.expectInputDropped {
- t.Errorf("got IPTablesInputDropped = %d, want = %d", got, test.expectInputDropped)
- }
- })
- }
-}
-
-var _ stack.LinkEndpoint = (*channelEndpointWithoutWritePacket)(nil)
-
-// channelEndpointWithoutWritePacket is a channel endpoint that does not support
-// stack.LinkEndpoint.WritePacket.
-type channelEndpointWithoutWritePacket struct {
- *channel.Endpoint
-
- t *testing.T
-}
-
-func (c *channelEndpointWithoutWritePacket) WritePacket(stack.RouteInfo, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
- c.t.Error("unexpectedly called WritePacket; all writes should go through WritePackets")
- return &tcpip.ErrNotSupported{}
-}
-
-var _ stack.Matcher = (*udpSourcePortMatcher)(nil)
-
-type udpSourcePortMatcher struct {
- port uint16
-}
-
-func (*udpSourcePortMatcher) Name() string {
- return "udpSourcePortMatcher"
-}
-
-func (m *udpSourcePortMatcher) Match(_ stack.Hook, pkt *stack.PacketBuffer, _, _ string) (matches, hotdrop bool) {
- udp := header.UDP(pkt.TransportHeader().View())
- if len(udp) < header.UDPMinimumSize {
- // Drop immediately as the packet is invalid.
- return false, true
- }
-
- return udp.SourcePort() == m.port, false
-}
-
-func TestIPTableWritePackets(t *testing.T) {
- const (
- nicID = 1
-
- dropLocalPort = utils.LocalPort - 1
- acceptPackets = 2
- dropPackets = 3
- )
-
- udpHdr := func(hdr buffer.View, srcAddr, dstAddr tcpip.Address, srcPort, dstPort uint16) {
- u := header.UDP(hdr)
- u.Encode(&header.UDPFields{
- SrcPort: srcPort,
- DstPort: dstPort,
- Length: header.UDPMinimumSize,
- })
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, srcAddr, dstAddr, header.UDPMinimumSize)
- sum = header.Checksum(hdr, sum)
- u.SetChecksum(^u.CalculateChecksum(sum))
- }
-
- tests := []struct {
- name string
- setupFilter func(*testing.T, *stack.Stack)
- genPacket func(*stack.Route) stack.PacketBufferList
- proto tcpip.NetworkProtocolNumber
- remoteAddr tcpip.Address
- expectSent uint64
- expectOutputDropped uint64
- }{
- {
- name: "IPv4 Accept",
- setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
- genPacket: func(r *stack.Route) stack.PacketBufferList {
- var pkts stack.PacketBufferList
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
- })
- hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
- udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort)
- pkts.PushFront(pkt)
-
- return pkts
- },
- proto: header.IPv4ProtocolNumber,
- remoteAddr: dstAddrV4,
- expectSent: 1,
- expectOutputDropped: 0,
- },
- {
- name: "IPv4 Drop Other Port",
- setupFilter: func(t *testing.T, s *stack.Stack) {
- t.Helper()
-
- table := stack.Table{
- Rules: []stack.Rule{
- {
- Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
- },
- {
- Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
- },
- {
- Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}},
- Target: &stack.DropTarget{NetworkProtocol: header.IPv4ProtocolNumber},
- },
- {
- Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
- },
- {
- Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber},
- },
- },
- BuiltinChains: [stack.NumHooks]int{
- stack.Prerouting: stack.HookUnset,
- stack.Input: 0,
- stack.Forward: 1,
- stack.Output: 2,
- stack.Postrouting: stack.HookUnset,
- },
- Underflows: [stack.NumHooks]int{
- stack.Prerouting: stack.HookUnset,
- stack.Input: 0,
- stack.Forward: 1,
- stack.Output: 2,
- stack.Postrouting: stack.HookUnset,
- },
- }
-
- if err := s.IPTables().ReplaceTable(stack.FilterID, table, false /* ipv4 */); err != nil {
- t.Fatalf("ReplaceTable(%d, _, false): %s", stack.FilterID, err)
- }
- },
- genPacket: func(r *stack.Route) stack.PacketBufferList {
- var pkts stack.PacketBufferList
-
- for i := 0; i < acceptPackets; i++ {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
- })
- hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
- udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort)
- pkts.PushFront(pkt)
- }
- for i := 0; i < dropPackets; i++ {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
- })
- hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
- udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), dropLocalPort, utils.RemotePort)
- pkts.PushFront(pkt)
- }
-
- return pkts
- },
- proto: header.IPv4ProtocolNumber,
- remoteAddr: dstAddrV4,
- expectSent: acceptPackets,
- expectOutputDropped: dropPackets,
- },
- {
- name: "IPv6 Accept",
- setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
- genPacket: func(r *stack.Route) stack.PacketBufferList {
- var pkts stack.PacketBufferList
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
- })
- hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
- udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort)
- pkts.PushFront(pkt)
-
- return pkts
- },
- proto: header.IPv6ProtocolNumber,
- remoteAddr: dstAddrV6,
- expectSent: 1,
- expectOutputDropped: 0,
- },
- {
- name: "IPv6 Drop Other Port",
- setupFilter: func(t *testing.T, s *stack.Stack) {
- t.Helper()
-
- table := stack.Table{
- Rules: []stack.Rule{
- {
- Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
- },
- {
- Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
- },
- {
- Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}},
- Target: &stack.DropTarget{NetworkProtocol: header.IPv6ProtocolNumber},
- },
- {
- Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
- },
- {
- Target: &stack.ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber},
- },
- },
- BuiltinChains: [stack.NumHooks]int{
- stack.Prerouting: stack.HookUnset,
- stack.Input: 0,
- stack.Forward: 1,
- stack.Output: 2,
- stack.Postrouting: stack.HookUnset,
- },
- Underflows: [stack.NumHooks]int{
- stack.Prerouting: stack.HookUnset,
- stack.Input: 0,
- stack.Forward: 1,
- stack.Output: 2,
- stack.Postrouting: stack.HookUnset,
- },
- }
-
- if err := s.IPTables().ReplaceTable(stack.FilterID, table, true /* ipv6 */); err != nil {
- t.Fatalf("ReplaceTable(%d, _, true): %s", stack.FilterID, err)
- }
- },
- genPacket: func(r *stack.Route) stack.PacketBufferList {
- var pkts stack.PacketBufferList
-
- for i := 0; i < acceptPackets; i++ {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
- })
- hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
- udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort)
- pkts.PushFront(pkt)
- }
- for i := 0; i < dropPackets; i++ {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
- })
- hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
- udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), dropLocalPort, utils.RemotePort)
- pkts.PushFront(pkt)
- }
-
- return pkts
- },
- proto: header.IPv6ProtocolNumber,
- remoteAddr: dstAddrV6,
- expectSent: acceptPackets,
- expectOutputDropped: dropPackets,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
- e := channelEndpointWithoutWritePacket{
- Endpoint: channel.New(4, header.IPv6MinimumMTU, linkAddr),
- t: t,
- }
- if err := s.CreateNIC(nicID, &e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- protocolAddrV6 := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: srcAddrV6.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err)
- }
- protocolAddrV4 := tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: srcAddrV4.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: nicID,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: nicID,
- },
- })
-
- test.setupFilter(t, s)
-
- r, err := s.FindRoute(nicID, "", test.remoteAddr, test.proto, false)
- if err != nil {
- t.Fatalf("FindRoute(%d, '', %s, %d, false): %s", nicID, test.remoteAddr, test.proto, err)
- }
- defer r.Release()
-
- pkts := test.genPacket(r)
- pktsLen := pkts.Len()
- if n, err := r.WritePackets(pkts, stack.NetworkHeaderParams{
- Protocol: header.UDPProtocolNumber,
- TTL: 64,
- }); err != nil {
- t.Fatalf("WritePackets(...): %s", err)
- } else if n != pktsLen {
- t.Fatalf("got WritePackets(...) = %d, want = %d", n, pktsLen)
- }
-
- if got := s.Stats().IP.PacketsSent.Value(); got != test.expectSent {
- t.Errorf("got PacketSent = %d, want = %d", got, test.expectSent)
- }
- if got := s.Stats().IP.IPTablesOutputDropped.Value(); got != test.expectOutputDropped {
- t.Errorf("got IPTablesOutputDropped = %d, want = %d", got, test.expectOutputDropped)
- }
- })
- }
-}
-
-const ttl = 64
-
-var (
- ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10")
- ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a")
-)
-
-func rxICMPv4EchoReply(e *channel.Endpoint, src, dst tcpip.Address) {
- utils.RxICMPv4EchoReply(e, src, dst, ttl)
-}
-
-func rxICMPv6EchoReply(e *channel.Endpoint, src, dst tcpip.Address) {
- utils.RxICMPv6EchoReply(e, src, dst, ttl)
-}
-
-func forwardedICMPv4EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
- checker.IPv4(t, b,
- checker.SrcAddr(src),
- checker.DstAddr(dst),
- checker.TTL(ttl-1),
- checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4EchoReply)))
-}
-
-func forwardedICMPv6EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
- checker.IPv6(t, b,
- checker.SrcAddr(src),
- checker.DstAddr(dst),
- checker.TTL(ttl-1),
- checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6EchoReply)))
-}
-
-func boolToInt(v bool) uint64 {
- if v {
- return 1
- }
- return 0
-}
-
-func setupDropFilter(hook stack.Hook, f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) {
- return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) {
- t.Helper()
-
- ipv6 := netProto == ipv6.ProtocolNumber
-
- ipt := s.IPTables()
- filter := ipt.GetTable(stack.FilterID, ipv6)
- ruleIdx := filter.BuiltinChains[hook]
- filter.Rules[ruleIdx].Filter = f
- filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto}
- // Make sure the packet is not dropped by the next rule.
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto}
- if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err)
- }
- }
-}
-
-func TestForwardingHook(t *testing.T) {
- const (
- nicID1 = 1
- nicID2 = 2
-
- nic1Name = "nic1"
- nic2Name = "nic2"
-
- otherNICName = "otherNIC"
- )
-
- tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- local bool
- srcAddr, dstAddr tcpip.Address
- rx func(*channel.Endpoint, tcpip.Address, tcpip.Address)
- checker func(*testing.T, []byte)
- }{
- {
- name: "IPv4 remote",
- netProto: ipv4.ProtocolNumber,
- local: false,
- srcAddr: utils.RemoteIPv4Addr,
- dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
- rx: rxICMPv4EchoReply,
- checker: func(t *testing.T, b []byte) {
- forwardedICMPv4EchoReplyChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
- },
- },
- {
- name: "IPv4 local",
- netProto: ipv4.ProtocolNumber,
- local: true,
- srcAddr: utils.RemoteIPv4Addr,
- dstAddr: utils.Ipv4Addr.Address,
- rx: rxICMPv4EchoReply,
- },
- {
- name: "IPv6 remote",
- netProto: ipv6.ProtocolNumber,
- local: false,
- srcAddr: utils.RemoteIPv6Addr,
- dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
- rx: rxICMPv6EchoReply,
- checker: func(t *testing.T, b []byte) {
- forwardedICMPv6EchoReplyChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
- },
- },
- {
- name: "IPv6 local",
- netProto: ipv6.ProtocolNumber,
- local: true,
- srcAddr: utils.RemoteIPv6Addr,
- dstAddr: utils.Ipv6Addr.Address,
- rx: rxICMPv6EchoReply,
- },
- }
-
- subTests := []struct {
- name string
- setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber)
- expectForward bool
- }{
- {
- name: "Accept",
- setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ },
- expectForward: true,
- },
-
- {
- name: "Drop",
- setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{}),
- expectForward: false,
- },
- {
- name: "Drop with input NIC filtering",
- setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name}),
- expectForward: false,
- },
- {
- name: "Drop with output NIC filtering",
- setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name}),
- expectForward: false,
- },
- {
- name: "Drop with input and output NIC filtering",
- setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}),
- expectForward: false,
- },
-
- {
- name: "Drop with other input NIC filtering",
- setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName}),
- expectForward: true,
- },
- {
- name: "Drop with other output NIC filtering",
- setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: otherNICName}),
- expectForward: true,
- },
- {
- name: "Drop with other input and output NIC filtering",
- setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}),
- expectForward: true,
- },
- {
- name: "Drop with input and other output NIC filtering",
- setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}),
- expectForward: true,
- },
- {
- name: "Drop with other input and other output NIC filtering",
- setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}),
- expectForward: true,
- },
-
- {
- name: "Drop with inverted input NIC filtering",
- setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}),
- expectForward: true,
- },
- {
- name: "Drop with inverted output NIC filtering",
- setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}),
- expectForward: true,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- for _, subTest := range subTests {
- t.Run(subTest.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- })
-
- subTest.setupFilter(t, s, test.netProto)
-
- e1 := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil {
- t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err)
- }
-
- e2 := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil {
- t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
- }
-
- protocolAddrV4 := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: utils.Ipv4Addr.Address.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
- }
- protocolAddrV6 := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: utils.Ipv6Addr.Address.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
- }
-
- if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
- t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err)
- }
- if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
- t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: nicID2,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: nicID2,
- },
- })
-
- test.rx(e1, test.srcAddr, test.dstAddr)
-
- expectTransmitPacket := subTest.expectForward && !test.local
-
- ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto)
- if err != nil {
- t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err)
- }
- ep1Stats := ep1.Stats()
- ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats)
- if !ok {
- t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats)
- }
- ip1Stats := ipEP1Stats.IPStats()
-
- if got := ip1Stats.PacketsReceived.Value(); got != 1 {
- t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got)
- }
- if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 {
- t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got)
- }
- if got, want := ip1Stats.IPTablesForwardDropped.Value(), boolToInt(!subTest.expectForward); got != want {
- t.Errorf("got ip1Stats.IPTablesForwardDropped.Value() = %d, want = %d", got, want)
- }
- if got := ip1Stats.PacketsSent.Value(); got != 0 {
- t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = 0", got)
- }
-
- ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto)
- if err != nil {
- t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err)
- }
- ep2Stats := ep2.Stats()
- ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats)
- if !ok {
- t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats)
- }
- ip2Stats := ipEP2Stats.IPStats()
- if got := ip2Stats.PacketsReceived.Value(); got != 0 {
- t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got)
- }
- if got, want := ip2Stats.ValidPacketsReceived.Value(), boolToInt(subTest.expectForward && test.local); got != want {
- t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = %d", got, want)
- }
- if got, want := ip2Stats.PacketsSent.Value(), boolToInt(expectTransmitPacket); got != want {
- t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = %d", got, want)
- }
-
- p, ok := e2.Read()
- if ok != expectTransmitPacket {
- t.Fatalf("got e2.Read() = (%#v, %t), want = (_, %t)", p, ok, expectTransmitPacket)
- }
- if expectTransmitPacket {
- test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
- }
- })
- }
- })
- }
-}
-
-func TestInputHookWithLocalForwarding(t *testing.T) {
- const (
- nicID1 = 1
- nicID2 = 2
-
- nic1Name = "nic1"
- nic2Name = "nic2"
-
- otherNICName = "otherNIC"
- )
-
- tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- rx func(*channel.Endpoint)
- checker func(*testing.T, []byte)
- }{
- {
- name: "IPv4",
- netProto: ipv4.ProtocolNumber,
- rx: func(e *channel.Endpoint) {
- utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address, ttl)
- },
- checker: func(t *testing.T, b []byte) {
- checker.IPv4(t, b,
- checker.SrcAddr(utils.Ipv4Addr2.AddressWithPrefix.Address),
- checker.DstAddr(utils.RemoteIPv4Addr),
- checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4EchoReply)))
- },
- },
- {
- name: "IPv6",
- netProto: ipv6.ProtocolNumber,
- rx: func(e *channel.Endpoint) {
- utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address, ttl)
- },
- checker: func(t *testing.T, b []byte) {
- checker.IPv6(t, b,
- checker.SrcAddr(utils.Ipv6Addr2.AddressWithPrefix.Address),
- checker.DstAddr(utils.RemoteIPv6Addr),
- checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6EchoReply)))
- },
- },
- }
-
- subTests := []struct {
- name string
- setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber)
- expectDrop bool
- }{
- {
- name: "Accept",
- setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ },
- expectDrop: false,
- },
-
- {
- name: "Drop",
- setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{}),
- expectDrop: true,
- },
- {
- name: "Drop with input NIC filtering on arrival NIC",
- setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic1Name}),
- expectDrop: true,
- },
- {
- name: "Drop with input NIC filtering on delivered NIC",
- setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic2Name}),
- expectDrop: false,
- },
-
- {
- name: "Drop with input NIC filtering on other NIC",
- setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: otherNICName}),
- expectDrop: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- for _, subTest := range subTests {
- t.Run(subTest.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- })
-
- subTest.setupFilter(t, s, test.netProto)
-
- e1 := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil {
- t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err)
- }
- if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv4Addr1, err)
- }
- if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv6Addr1, err)
- }
-
- e2 := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil {
- t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
- }
- if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv4Addr2, err)
- }
- if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv6Addr2, err)
- }
-
- if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
- t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err)
- }
- if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
- t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: nicID1,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: nicID1,
- },
- })
-
- test.rx(e1)
-
- ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto)
- if err != nil {
- t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err)
- }
- ep1Stats := ep1.Stats()
- ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats)
- if !ok {
- t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats)
- }
- ip1Stats := ipEP1Stats.IPStats()
-
- if got := ip1Stats.PacketsReceived.Value(); got != 1 {
- t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got)
- }
- if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 {
- t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got)
- }
- if got, want := ip1Stats.PacketsSent.Value(), boolToInt(!subTest.expectDrop); got != want {
- t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = %d", got, want)
- }
-
- ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto)
- if err != nil {
- t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err)
- }
- ep2Stats := ep2.Stats()
- ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats)
- if !ok {
- t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats)
- }
- ip2Stats := ipEP2Stats.IPStats()
- if got := ip2Stats.PacketsReceived.Value(); got != 0 {
- t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got)
- }
- if got := ip2Stats.ValidPacketsReceived.Value(); got != 1 {
- t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = 1", got)
- }
- if got, want := ip2Stats.IPTablesInputDropped.Value(), boolToInt(subTest.expectDrop); got != want {
- t.Errorf("got ip2Stats.IPTablesInputDropped.Value() = %d, want = %d", got, want)
- }
- if got := ip2Stats.PacketsSent.Value(); got != 0 {
- t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = 0", got)
- }
-
- if p, ok := e1.Read(); ok == subTest.expectDrop {
- t.Errorf("got e1.Read() = (%#v, %t), want = (_, %t)", p, ok, !subTest.expectDrop)
- } else if !subTest.expectDrop {
- test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
- }
- if p, ok := e2.Read(); ok {
- t.Errorf("got e1.Read() = (%#v, true), want = (_, false)", p)
- }
- })
- }
- })
- }
-}
-
-func TestNAT(t *testing.T) {
- const listenPort uint16 = 8080
-
- type endpointAndAddresses struct {
- serverEP tcpip.Endpoint
- serverAddr tcpip.FullAddress
- serverReadableCH chan struct{}
- serverConnectAddr tcpip.Address
-
- clientEP tcpip.Endpoint
- clientAddr tcpip.Address
- clientReadableCH chan struct{}
- clientConnectAddr tcpip.FullAddress
- }
-
- newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) {
- t.Helper()
- var wq waiter.Queue
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- t.Cleanup(func() {
- wq.EventUnregister(&we)
- })
-
- ep, err := s.NewEndpoint(transProto, netProto, &wq)
- if err != nil {
- t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err)
- }
- t.Cleanup(ep.Close)
-
- return ep, ch
- }
-
- setupNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, hook stack.Hook, filter stack.IPHeaderFilter, target stack.Target) {
- t.Helper()
-
- ipv6 := netProto == ipv6.ProtocolNumber
- ipt := s.IPTables()
- table := ipt.GetTable(stack.NATID, ipv6)
- ruleIdx := table.BuiltinChains[hook]
- table.Rules[ruleIdx].Filter = filter
- table.Rules[ruleIdx].Target = target
- // Make sure the packet is not dropped by the next rule.
- table.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err)
- }
- }
-
- setupDNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, target stack.Target) {
- t.Helper()
-
- setupNAT(
- t,
- s,
- netProto,
- stack.Prerouting,
- stack.IPHeaderFilter{
- Protocol: transProto,
- CheckProtocol: true,
- InputInterface: utils.RouterNIC2Name,
- },
- target)
- }
-
- setupSNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, target stack.Target) {
- t.Helper()
-
- setupNAT(
- t,
- s,
- netProto,
- stack.Postrouting,
- stack.IPHeaderFilter{
- Protocol: transProto,
- CheckProtocol: true,
- OutputInterface: utils.RouterNIC1Name,
- },
- target)
- }
-
- type natType struct {
- name string
- setupNAT func(_ *testing.T, _ *stack.Stack, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address)
- }
-
- snatTypes := []natType{
- {
- name: "SNAT",
- setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, _ tcpip.Address) {
- t.Helper()
-
- setupSNAT(t, s, netProto, transProto, &stack.SNATTarget{NetworkProtocol: netProto, Addr: snatAddr})
- },
- },
- {
- name: "Masquerade",
- setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, _ tcpip.Address) {
- t.Helper()
-
- setupSNAT(t, s, netProto, transProto, &stack.MasqueradeTarget{NetworkProtocol: netProto})
- },
- },
- }
- dnatTypes := []natType{
- {
- name: "Redirect",
- setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, _ tcpip.Address) {
- t.Helper()
-
- setupDNAT(t, s, netProto, transProto, &stack.RedirectTarget{NetworkProtocol: netProto, Port: listenPort})
- },
- },
- {
- name: "DNAT",
- setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, dnatAddr tcpip.Address) {
- t.Helper()
-
- setupDNAT(t, s, netProto, transProto, &stack.DNATTarget{NetworkProtocol: netProto, Addr: dnatAddr, Port: listenPort})
- },
- },
- }
-
- setupTwiceNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, dnatAddr tcpip.Address, snatTarget stack.Target) {
- t.Helper()
-
- ipv6 := netProto == ipv6.ProtocolNumber
- ipt := s.IPTables()
-
- table := stack.Table{
- Rules: []stack.Rule{
- // Prerouting
- {
- Filter: stack.IPHeaderFilter{
- Protocol: transProto,
- CheckProtocol: true,
- InputInterface: utils.RouterNIC2Name,
- },
- Target: &stack.DNATTarget{NetworkProtocol: netProto, Addr: dnatAddr, Port: listenPort},
- },
- {
- Target: &stack.AcceptTarget{},
- },
-
- // Input
- {
- Target: &stack.AcceptTarget{},
- },
-
- // Forward
- {
- Target: &stack.AcceptTarget{},
- },
-
- // Output
- {
- Target: &stack.AcceptTarget{},
- },
-
- // Postrouting
- {
- Filter: stack.IPHeaderFilter{
- Protocol: transProto,
- CheckProtocol: true,
- OutputInterface: utils.RouterNIC1Name,
- },
- Target: snatTarget,
- },
- {
- Target: &stack.AcceptTarget{},
- },
- },
- BuiltinChains: [stack.NumHooks]int{
- stack.Prerouting: 0,
- stack.Input: 2,
- stack.Forward: 3,
- stack.Output: 4,
- stack.Postrouting: 5,
- },
- }
-
- if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err)
- }
- }
- twiceNATTypes := []natType{
- {
- name: "DNAT-Masquerade",
- setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address) {
- t.Helper()
-
- setupTwiceNAT(t, s, netProto, transProto, dnatAddr, &stack.MasqueradeTarget{NetworkProtocol: netProto})
- },
- },
- {
- name: "DNAT-SNAT",
- setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address) {
- t.Helper()
-
- setupTwiceNAT(t, s, netProto, transProto, dnatAddr, &stack.SNATTarget{NetworkProtocol: netProto, Addr: snatAddr})
- },
- },
- }
-
- tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- // Setups up the stacks in such a way that:
- //
- // - Host2 is the client for all tests.
- // - When performing SNAT only:
- // + Host1 is the server.
- // + NAT will transform client-originating packets' source addresses to
- // the router's NIC1's address before reaching Host1.
- // - When performing DNAT only:
- // + Router is the server.
- // + Client will send packets directed to Host1.
- // + NAT will transform client-originating packets' destination addresses
- // to the router's NIC2's address.
- // - When performing Twice-NAT:
- // + Host1 is the server.
- // + Client will send packets directed to router's NIC2.
- // + NAT will transform client originating packets' destination addresses
- // to Host1's address.
- // + NAT will transform client-originating packets' source addresses to
- // the router's NIC1's address before reaching Host1.
- epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses
- natTypes []natType
- }{
- {
- name: "IPv4 SNAT",
- netProto: ipv4.ProtocolNumber,
- epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
- t.Helper()
-
- listenerStack := host1Stack
- serverAddr := tcpip.FullAddress{
- Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
- Port: listenPort,
- }
- serverConnectAddr := utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address
- clientConnectPort := serverAddr.Port
- ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber)
- ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
- return endpointAndAddresses{
- serverEP: ep1,
- serverAddr: serverAddr,
- serverReadableCH: ep1WECH,
- serverConnectAddr: serverConnectAddr,
-
- clientEP: ep2,
- clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address,
- clientReadableCH: ep2WECH,
- clientConnectAddr: tcpip.FullAddress{
- Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
- Port: clientConnectPort,
- },
- }
- },
- natTypes: snatTypes,
- },
- {
- name: "IPv4 DNAT",
- netProto: ipv4.ProtocolNumber,
- epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
- t.Helper()
-
- // If we are performing DNAT, then the packet will be redirected
- // to the router.
- listenerStack := routerStack
- serverAddr := tcpip.FullAddress{
- Addr: utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
- Port: listenPort,
- }
- serverConnectAddr := utils.Host2IPv4Addr.AddressWithPrefix.Address
- // DNAT will update the destination port to what the server is
- // bound to.
- clientConnectPort := serverAddr.Port + 1
- ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber)
- ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
- return endpointAndAddresses{
- serverEP: ep1,
- serverAddr: serverAddr,
- serverReadableCH: ep1WECH,
- serverConnectAddr: serverConnectAddr,
-
- clientEP: ep2,
- clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address,
- clientReadableCH: ep2WECH,
- clientConnectAddr: tcpip.FullAddress{
- Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
- Port: clientConnectPort,
- },
- }
- },
- natTypes: dnatTypes,
- },
- {
- name: "IPv4 Twice-NAT",
- netProto: ipv4.ProtocolNumber,
- epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
- t.Helper()
-
- listenerStack := host1Stack
- serverAddr := tcpip.FullAddress{
- Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
- Port: listenPort,
- }
- serverConnectAddr := utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address
- clientConnectPort := serverAddr.Port
- ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber)
- ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
- return endpointAndAddresses{
- serverEP: ep1,
- serverAddr: serverAddr,
- serverReadableCH: ep1WECH,
- serverConnectAddr: serverConnectAddr,
-
- clientEP: ep2,
- clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address,
- clientReadableCH: ep2WECH,
- clientConnectAddr: tcpip.FullAddress{
- Addr: utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
- Port: clientConnectPort,
- },
- }
- },
- natTypes: twiceNATTypes,
- },
- {
- name: "IPv6 SNAT",
- netProto: ipv6.ProtocolNumber,
- epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
- t.Helper()
-
- listenerStack := host1Stack
- serverAddr := tcpip.FullAddress{
- Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
- Port: listenPort,
- }
- serverConnectAddr := utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address
- clientConnectPort := serverAddr.Port
- ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber)
- ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
- return endpointAndAddresses{
- serverEP: ep1,
- serverAddr: serverAddr,
- serverReadableCH: ep1WECH,
- serverConnectAddr: serverConnectAddr,
-
- clientEP: ep2,
- clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address,
- clientReadableCH: ep2WECH,
- clientConnectAddr: tcpip.FullAddress{
- Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
- Port: clientConnectPort,
- },
- }
- },
- natTypes: snatTypes,
- },
- {
- name: "IPv6 DNAT",
- netProto: ipv6.ProtocolNumber,
- epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
- t.Helper()
-
- // If we are performing DNAT, then the packet will be redirected
- // to the router.
- listenerStack := routerStack
- serverAddr := tcpip.FullAddress{
- Addr: utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
- Port: listenPort,
- }
- serverConnectAddr := utils.Host2IPv6Addr.AddressWithPrefix.Address
- // DNAT will update the destination port to what the server is
- // bound to.
- clientConnectPort := serverAddr.Port + 1
- ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber)
- ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
- return endpointAndAddresses{
- serverEP: ep1,
- serverAddr: serverAddr,
- serverReadableCH: ep1WECH,
- serverConnectAddr: serverConnectAddr,
-
- clientEP: ep2,
- clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address,
- clientReadableCH: ep2WECH,
- clientConnectAddr: tcpip.FullAddress{
- Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
- Port: clientConnectPort,
- },
- }
- },
- natTypes: dnatTypes,
- },
- {
- name: "IPv6 Twice-NAT",
- netProto: ipv6.ProtocolNumber,
- epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
- t.Helper()
-
- listenerStack := host1Stack
- serverAddr := tcpip.FullAddress{
- Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
- Port: listenPort,
- }
- serverConnectAddr := utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address
- clientConnectPort := serverAddr.Port
- ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber)
- ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
- return endpointAndAddresses{
- serverEP: ep1,
- serverAddr: serverAddr,
- serverReadableCH: ep1WECH,
- serverConnectAddr: serverConnectAddr,
-
- clientEP: ep2,
- clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address,
- clientReadableCH: ep2WECH,
- clientConnectAddr: tcpip.FullAddress{
- Addr: utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
- Port: clientConnectPort,
- },
- }
- },
- natTypes: twiceNATTypes,
- },
- }
-
- subTests := []struct {
- name string
- proto tcpip.TransportProtocolNumber
- expectedConnectErr tcpip.Error
- setupServer func(t *testing.T, ep tcpip.Endpoint)
- setupServerConn func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{})
- needRemoteAddr bool
- }{
- {
- name: "UDP",
- proto: udp.ProtocolNumber,
- expectedConnectErr: nil,
- setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
- t.Helper()
-
- if err := ep.Connect(clientAddr); err != nil {
- t.Fatalf("ep.Connect(%#v): %s", clientAddr, err)
- }
- return nil, nil
- },
- needRemoteAddr: true,
- },
- {
- name: "TCP",
- proto: tcp.ProtocolNumber,
- expectedConnectErr: &tcpip.ErrConnectStarted{},
- setupServer: func(t *testing.T, ep tcpip.Endpoint) {
- t.Helper()
-
- if err := ep.Listen(1); err != nil {
- t.Fatalf("ep.Listen(1): %s", err)
- }
- },
- setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
- t.Helper()
-
- var addr tcpip.FullAddress
- for {
- newEP, wq, err := ep.Accept(&addr)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
- <-ch
- continue
- }
- if err != nil {
- t.Fatalf("ep.Accept(_): %s", err)
- }
- if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath(
- "NIC",
- )); diff != "" {
- t.Errorf("accepted address mismatch (-want +got):\n%s", diff)
- }
-
- we, newCH := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- return newEP, newCH
- }
- },
- needRemoteAddr: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- for _, subTest := range subTests {
- t.Run(subTest.name, func(t *testing.T) {
- for _, natType := range test.natTypes {
- t.Run(natType.name, func(t *testing.T) {
- stackOpts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
- }
-
- host1Stack := stack.New(stackOpts)
- routerStack := stack.New(stackOpts)
- host2Stack := stack.New(stackOpts)
- utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack)
-
- epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto)
- natType.setupNAT(t, routerStack, test.netProto, subTest.proto, epsAndAddrs.serverConnectAddr, epsAndAddrs.serverAddr.Addr)
-
- if err := epsAndAddrs.serverEP.Bind(epsAndAddrs.serverAddr); err != nil {
- t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", epsAndAddrs.serverAddr, err)
- }
- clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr}
- if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil {
- t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
- }
-
- if subTest.setupServer != nil {
- subTest.setupServer(t, epsAndAddrs.serverEP)
- }
- {
- err := epsAndAddrs.clientEP.Connect(epsAndAddrs.clientConnectAddr)
- if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" {
- t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", epsAndAddrs.clientConnectAddr, diff)
- }
- }
- serverConnectAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverConnectAddr}
- if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil {
- t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err)
- } else {
- serverConnectAddr.Port = addr.Port
- }
-
- serverEP := epsAndAddrs.serverEP
- serverCH := epsAndAddrs.serverReadableCH
- if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, serverConnectAddr); ep != nil {
- defer ep.Close()
- serverEP = ep
- serverCH = ch
- }
-
- write := func(ep tcpip.Endpoint, data []byte) {
- t.Helper()
-
- var r bytes.Reader
- r.Reset(data)
- var wOpts tcpip.WriteOptions
- n, err := ep.Write(&r, wOpts)
- if err != nil {
- t.Fatalf("ep.Write(_, %#v): %s", wOpts, err)
- }
- if want := int64(len(data)); n != want {
- t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want)
- }
- }
-
- read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) {
- t.Helper()
-
- var buf bytes.Buffer
- var res tcpip.ReadResult
- for {
- var err tcpip.Error
- opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr}
- res, err = ep.Read(&buf, opts)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
- <-ch
- continue
- }
- if err != nil {
- t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
- }
- break
- }
-
- readResult := tcpip.ReadResult{
- Count: len(data),
- Total: len(data),
- }
- if subTest.needRemoteAddr {
- readResult.RemoteAddr = expectedFrom
- }
- if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath(
- "ControlMessages",
- "RemoteAddr.NIC",
- )); diff != "" {
- t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(buf.Bytes(), data); diff != "" {
- t.Errorf("received data mismatch (-want +got):\n%s", diff)
- }
-
- if t.Failed() {
- t.FailNow()
- }
- }
-
- {
- data := []byte{1, 2, 3, 4}
- write(epsAndAddrs.clientEP, data)
- read(serverCH, serverEP, data, serverConnectAddr)
- }
-
- {
- data := []byte{5, 6, 7, 8, 9, 10, 11, 12}
- write(serverEP, data)
- read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, epsAndAddrs.clientConnectAddr)
- }
- })
- }
- })
- }
- })
- }
-}
-
-func TestNATICMPError(t *testing.T) {
- const (
- srcPort = 1234
- dstPort = 5432
- dataSize = 4
- )
-
- type icmpTypeTest struct {
- name string
- val uint8
- expectResponse bool
- }
-
- type transportTypeTest struct {
- name string
- proto tcpip.TransportProtocolNumber
- buf buffer.View
- checkNATed func(*testing.T, buffer.View)
- }
-
- ipHdr := func(v buffer.View, totalLen int, transProto tcpip.TransportProtocolNumber, srcAddr, dstAddr tcpip.Address) {
- ip := header.IPv4(v)
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(totalLen),
- Protocol: uint8(transProto),
- TTL: 64,
- SrcAddr: srcAddr,
- DstAddr: dstAddr,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
- }
-
- ip6Hdr := func(v buffer.View, payloadLen int, transProto tcpip.TransportProtocolNumber, srcAddr, dstAddr tcpip.Address) {
- ip := header.IPv6(v)
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLen),
- TransportProtocol: transProto,
- HopLimit: 64,
- SrcAddr: srcAddr,
- DstAddr: dstAddr,
- })
- }
-
- tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- host1Addr tcpip.Address
- icmpError func(*testing.T, buffer.View, uint8) buffer.View
- decrementTTL func(buffer.View)
- checkNATedError func(*testing.T, buffer.View, buffer.View, uint8)
-
- transportTypes []transportTypeTest
- icmpTypes []icmpTypeTest
- }{
- {
- name: "IPv4",
- netProto: ipv4.ProtocolNumber,
- host1Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
- icmpError: func(t *testing.T, original buffer.View, icmpType uint8) buffer.View {
- hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(original))
- if n := copy(hdr.Prepend(len(original)), original); n != len(original) {
- t.Fatalf("got copy(...) = %d, want = %d", n, len(original))
- }
- icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
- icmp.SetType(header.ICMPv4Type(icmpType))
- icmp.SetChecksum(0)
- icmp.SetChecksum(header.ICMPv4Checksum(icmp, 0))
- ipHdr(
- hdr.Prepend(header.IPv4MinimumSize),
- hdr.UsedLength(),
- header.ICMPv4ProtocolNumber,
- utils.Host1IPv4Addr.AddressWithPrefix.Address,
- utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
- )
- return hdr.View()
- },
- decrementTTL: func(v buffer.View) {
- ip := header.IPv4(v)
- ip.SetTTL(ip.TTL() - 1)
- ip.SetChecksum(0)
- ip.SetChecksum(^ip.CalculateChecksum())
- },
- checkNATedError: func(t *testing.T, v buffer.View, original buffer.View, icmpType uint8) {
- checker.IPv4(t, v,
- checker.SrcAddr(utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address),
- checker.DstAddr(utils.Host2IPv4Addr.AddressWithPrefix.Address),
- checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4Type(icmpType)),
- checker.ICMPv4Checksum(),
- checker.ICMPv4Payload(original),
- ),
- )
- },
- transportTypes: []transportTypeTest{
- {
- name: "UDP",
- proto: header.UDPProtocolNumber,
- buf: func() buffer.View {
- udpSize := header.UDPMinimumSize + dataSize
- hdr := buffer.NewPrependable(header.IPv4MinimumSize + udpSize)
- udp := header.UDP(hdr.Prepend(udpSize))
- udp.SetSourcePort(srcPort)
- udp.SetDestinationPort(dstPort)
- udp.SetChecksum(0)
- udp.SetChecksum(^udp.CalculateChecksum(header.PseudoHeaderChecksum(
- header.UDPProtocolNumber,
- utils.Host2IPv4Addr.AddressWithPrefix.Address,
- utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
- uint16(len(udp)),
- )))
- ipHdr(
- hdr.Prepend(header.IPv4MinimumSize),
- hdr.UsedLength(),
- header.UDPProtocolNumber,
- utils.Host2IPv4Addr.AddressWithPrefix.Address,
- utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
- )
- return hdr.View()
- }(),
- checkNATed: func(t *testing.T, v buffer.View) {
- checker.IPv4(t, v,
- checker.SrcAddr(utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address),
- checker.DstAddr(utils.Host1IPv4Addr.AddressWithPrefix.Address),
- checker.UDP(
- checker.SrcPort(srcPort),
- checker.DstPort(dstPort),
- ),
- )
- },
- },
- {
- name: "TCP",
- proto: header.TCPProtocolNumber,
- buf: func() buffer.View {
- tcpSize := header.TCPMinimumSize + dataSize
- hdr := buffer.NewPrependable(header.IPv4MinimumSize + tcpSize)
- tcp := header.TCP(hdr.Prepend(tcpSize))
- tcp.SetSourcePort(srcPort)
- tcp.SetDestinationPort(dstPort)
- tcp.SetDataOffset(header.TCPMinimumSize)
- tcp.SetChecksum(0)
- tcp.SetChecksum(^tcp.CalculateChecksum(header.PseudoHeaderChecksum(
- header.TCPProtocolNumber,
- utils.Host2IPv4Addr.AddressWithPrefix.Address,
- utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
- uint16(len(tcp)),
- )))
- ipHdr(
- hdr.Prepend(header.IPv4MinimumSize),
- hdr.UsedLength(),
- header.TCPProtocolNumber,
- utils.Host2IPv4Addr.AddressWithPrefix.Address,
- utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
- )
- return hdr.View()
- }(),
- checkNATed: func(t *testing.T, v buffer.View) {
- checker.IPv4(t, v,
- checker.SrcAddr(utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address),
- checker.DstAddr(utils.Host1IPv4Addr.AddressWithPrefix.Address),
- checker.TCP(
- checker.SrcPort(srcPort),
- checker.DstPort(dstPort),
- ),
- )
- },
- },
- },
- icmpTypes: []icmpTypeTest{
- {
- name: "Destination Unreachable",
- val: uint8(header.ICMPv4DstUnreachable),
- expectResponse: true,
- },
- {
- name: "Time Exceeded",
- val: uint8(header.ICMPv4TimeExceeded),
- expectResponse: true,
- },
- {
- name: "Parameter Problem",
- val: uint8(header.ICMPv4ParamProblem),
- expectResponse: true,
- },
- {
- name: "Echo Request",
- val: uint8(header.ICMPv4Echo),
- expectResponse: false,
- },
- {
- name: "Echo Reply",
- val: uint8(header.ICMPv4EchoReply),
- expectResponse: false,
- },
- },
- },
- {
- name: "IPv6",
- netProto: ipv6.ProtocolNumber,
- host1Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
- icmpError: func(t *testing.T, original buffer.View, icmpType uint8) buffer.View {
- payloadLen := header.ICMPv6MinimumSize + len(original)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen)
- icmp := header.ICMPv6(hdr.Prepend(payloadLen))
- icmp.SetType(header.ICMPv6Type(icmpType))
- if n := copy(icmp.Payload(), original); n != len(original) {
- t.Fatalf("got copy(...) = %d, want = %d", n, len(original))
- }
- icmp.SetChecksum(0)
- icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmp,
- Src: utils.Host1IPv6Addr.AddressWithPrefix.Address,
- Dst: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
- }))
- ip6Hdr(
- hdr.Prepend(header.IPv6MinimumSize),
- payloadLen,
- header.ICMPv6ProtocolNumber,
- utils.Host1IPv6Addr.AddressWithPrefix.Address,
- utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
- )
- return hdr.View()
- },
- decrementTTL: func(v buffer.View) {
- ip := header.IPv6(v)
- ip.SetHopLimit(ip.HopLimit() - 1)
- },
- checkNATedError: func(t *testing.T, v buffer.View, original buffer.View, icmpType uint8) {
- checker.IPv6(t, v,
- checker.SrcAddr(utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address),
- checker.DstAddr(utils.Host2IPv6Addr.AddressWithPrefix.Address),
- checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6Type(icmpType)),
- checker.ICMPv6Payload(original),
- ),
- )
- },
- transportTypes: []transportTypeTest{
- {
- name: "UDP",
- proto: header.UDPProtocolNumber,
- buf: func() buffer.View {
- udpSize := header.UDPMinimumSize + dataSize
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + udpSize)
- udp := header.UDP(hdr.Prepend(udpSize))
- udp.SetSourcePort(srcPort)
- udp.SetDestinationPort(dstPort)
- udp.SetChecksum(0)
- udp.SetChecksum(^udp.CalculateChecksum(header.PseudoHeaderChecksum(
- header.UDPProtocolNumber,
- utils.Host2IPv6Addr.AddressWithPrefix.Address,
- utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
- uint16(len(udp)),
- )))
- ip6Hdr(
- hdr.Prepend(header.IPv6MinimumSize),
- len(udp),
- header.UDPProtocolNumber,
- utils.Host2IPv6Addr.AddressWithPrefix.Address,
- utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
- )
- return hdr.View()
- }(),
- checkNATed: func(t *testing.T, v buffer.View) {
- checker.IPv6(t, v,
- checker.SrcAddr(utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address),
- checker.DstAddr(utils.Host1IPv6Addr.AddressWithPrefix.Address),
- checker.UDP(
- checker.SrcPort(srcPort),
- checker.DstPort(dstPort),
- ),
- )
- },
- },
- {
- name: "TCP",
- proto: header.TCPProtocolNumber,
- buf: func() buffer.View {
- tcpSize := header.TCPMinimumSize + dataSize
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + tcpSize)
- tcp := header.TCP(hdr.Prepend(tcpSize))
- tcp.SetSourcePort(srcPort)
- tcp.SetDestinationPort(dstPort)
- tcp.SetDataOffset(header.TCPMinimumSize)
- tcp.SetChecksum(0)
- tcp.SetChecksum(^tcp.CalculateChecksum(header.PseudoHeaderChecksum(
- header.TCPProtocolNumber,
- utils.Host2IPv6Addr.AddressWithPrefix.Address,
- utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
- uint16(len(tcp)),
- )))
- ip6Hdr(
- hdr.Prepend(header.IPv6MinimumSize),
- len(tcp),
- header.TCPProtocolNumber,
- utils.Host2IPv6Addr.AddressWithPrefix.Address,
- utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
- )
- return hdr.View()
- }(),
- checkNATed: func(t *testing.T, v buffer.View) {
- checker.IPv6(t, v,
- checker.SrcAddr(utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address),
- checker.DstAddr(utils.Host1IPv6Addr.AddressWithPrefix.Address),
- checker.TCP(
- checker.SrcPort(srcPort),
- checker.DstPort(dstPort),
- ),
- )
- },
- },
- },
- icmpTypes: []icmpTypeTest{
- {
- name: "Destination Unreachable",
- val: uint8(header.ICMPv6DstUnreachable),
- expectResponse: true,
- },
- {
- name: "Packet Too Big",
- val: uint8(header.ICMPv6PacketTooBig),
- expectResponse: true,
- },
- {
- name: "Time Exceeded",
- val: uint8(header.ICMPv6TimeExceeded),
- expectResponse: true,
- },
- {
- name: "Parameter Problem",
- val: uint8(header.ICMPv6ParamProblem),
- expectResponse: true,
- },
- {
- name: "Echo Request",
- val: uint8(header.ICMPv6EchoRequest),
- expectResponse: false,
- },
- {
- name: "Echo Reply",
- val: uint8(header.ICMPv6EchoReply),
- expectResponse: false,
- },
- },
- },
- }
-
- trimTests := []struct {
- name string
- trimLen int
- expectNATedICMP bool
- }{
- {
- name: "Trim nothing",
- trimLen: 0,
- expectNATedICMP: true,
- },
- {
- name: "Trim data",
- trimLen: dataSize,
- expectNATedICMP: true,
- },
- {
- name: "Trim data and transport header",
- trimLen: dataSize + 1,
- expectNATedICMP: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- for _, transportType := range test.transportTypes {
- t.Run(transportType.name, func(t *testing.T) {
- for _, icmpType := range test.icmpTypes {
- t.Run(icmpType.name, func(t *testing.T) {
- for _, trimTest := range trimTests {
- t.Run(trimTest.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
- })
-
- ep1 := channel.New(1, header.IPv6MinimumMTU, "")
- ep2 := channel.New(1, header.IPv6MinimumMTU, "")
- utils.SetupRouterStack(t, s, ep1, ep2)
-
- ipv6 := test.netProto == ipv6.ProtocolNumber
- ipt := s.IPTables()
-
- table := stack.Table{
- Rules: []stack.Rule{
- // Prerouting
- {
- Filter: stack.IPHeaderFilter{
- Protocol: transportType.proto,
- CheckProtocol: true,
- InputInterface: utils.RouterNIC2Name,
- },
- Target: &stack.DNATTarget{NetworkProtocol: test.netProto, Addr: test.host1Addr, Port: dstPort},
- },
- {
- Target: &stack.AcceptTarget{},
- },
-
- // Input
- {
- Target: &stack.AcceptTarget{},
- },
-
- // Forward
- {
- Target: &stack.AcceptTarget{},
- },
-
- // Output
- {
- Target: &stack.AcceptTarget{},
- },
-
- // Postrouting
- {
- Filter: stack.IPHeaderFilter{
- Protocol: transportType.proto,
- CheckProtocol: true,
- OutputInterface: utils.RouterNIC1Name,
- },
- Target: &stack.MasqueradeTarget{NetworkProtocol: test.netProto},
- },
- {
- Target: &stack.AcceptTarget{},
- },
- },
- BuiltinChains: [stack.NumHooks]int{
- stack.Prerouting: 0,
- stack.Input: 2,
- stack.Forward: 3,
- stack.Output: 4,
- stack.Postrouting: 5,
- },
- }
-
- if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err)
- }
-
- buf := transportType.buf
-
- ep2.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: append(buffer.View(nil), buf...).ToVectorisedView(),
- }))
-
- {
- pkt, ok := ep1.Read()
- if !ok {
- t.Fatal("expected to read a packet on ep1")
- }
- pktView := stack.PayloadSince(pkt.Pkt.NetworkHeader())
- transportType.checkNATed(t, pktView)
- if t.Failed() {
- t.FailNow()
- }
-
- pktView = pktView[:len(pktView)-trimTest.trimLen]
- buf = buf[:len(buf)-trimTest.trimLen]
-
- ep1.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: test.icmpError(t, pktView, icmpType.val).ToVectorisedView(),
- }))
- }
-
- pkt, ok := ep2.Read()
- expectResponse := icmpType.expectResponse && trimTest.expectNATedICMP
- if ok != expectResponse {
- t.Fatalf("got ep2.Read() = (%#v, %t), want = (_, %t)", pkt, ok, expectResponse)
- }
- if !expectResponse {
- return
- }
- test.decrementTTL(buf)
- test.checkNATedError(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), buf, icmpType.val)
- })
- }
- })
- }
- })
- }
- })
- }
-}
diff --git a/pkg/tcpip/tests/integration/istio_test.go b/pkg/tcpip/tests/integration/istio_test.go
deleted file mode 100644
index 95d994ef8..000000000
--- a/pkg/tcpip/tests/integration/istio_test.go
+++ /dev/null
@@ -1,365 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package istio_test
-
-import (
- "fmt"
- "io"
- "net"
- "net/http"
- "strconv"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/rand"
- "gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/link/pipe"
- "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
-)
-
-// testContext encapsulates the state required to run tests that simulate
-// an istio like environment.
-//
-// A diagram depicting the setup is shown below.
-// +-----------------------------------------------------------------------+
-// | +-------------------------------------------------+ |
-// | + ----------+ | + -----------------+ PROXY +----------+ | |
-// | | clientEP | | | serverListeningEP|--accepted-> | serverEP |-+ | |
-// | + ----------+ | + -----------------+ +----------+ | | |
-// | | -------|-------------+ +----------+ | | |
-// | | | | | proxyEP |-+ | |
-// | +-----redirect | +----------+ | |
-// | + ------------+---|------+---+ |
-// | | |
-// | Local Stack. | |
-// +-------------------------------------------------------|---------------+
-// |
-// +-----------------------------------------------------------------------+
-// | remoteStack | |
-// | +-------------SYN ---------------| |
-// | | | |
-// | +-------------------|--------------------------------|-_---+ |
-// | | + -----------------+ + ----------+ | | |
-// | | | remoteListeningEP|--accepted--->| remoteEP |<++ | |
-// | | + -----------------+ + ----------+ | |
-// | | Remote HTTP Server | |
-// | +----------------------------------------------------------+ |
-// +-----------------------------------------------------------------------+
-//
-type testContext struct {
- // localServerListener is the listening port for the server which will proxy
- // all traffic to the remote EP.
- localServerListener *gonet.TCPListener
-
- // remoteListenListener is the remote listening endpoint that will receive
- // connections from server.
- remoteServerListener *gonet.TCPListener
-
- // localStack is the stack used to create client/server endpoints and
- // also the stack on which we install NAT redirect rules.
- localStack *stack.Stack
-
- // remoteStack is the stack that represents a *remote* server.
- remoteStack *stack.Stack
-
- // defaultResponse is the response served by the HTTP server for all GET
- defaultResponse []byte
-
- // requests. wg is used to wait for HTTP server and Proxy to terminate before
- // returning from cleanup.
- wg sync.WaitGroup
-}
-
-func (ctx *testContext) cleanup() {
- ctx.localServerListener.Close()
- ctx.localStack.Close()
- ctx.remoteServerListener.Close()
- ctx.remoteStack.Close()
- ctx.wg.Wait()
-}
-
-const (
- localServerPort = 8080
- remoteServerPort = 9090
-)
-
-var (
- localIPv4Addr1 = testutil.MustParse4("10.0.0.1")
- localIPv4Addr2 = testutil.MustParse4("10.0.0.2")
- loopbackIPv4Addr = testutil.MustParse4("127.0.0.1")
- remoteIPv4Addr1 = testutil.MustParse4("10.0.0.3")
-)
-
-func newTestContext(t *testing.T) *testContext {
- t.Helper()
- localNIC, remoteNIC := pipe.New("" /* linkAddr1 */, "" /* linkAddr2 */)
-
- localStack := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
- HandleLocal: true,
- })
-
- remoteStack := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
- HandleLocal: true,
- })
-
- // Add loopback NIC. We need a loopback NIC as NAT redirect rule redirect to
- // loopback address + specified port.
- loopbackNIC := loopback.New()
- const loopbackNICID = tcpip.NICID(1)
- if err := localStack.CreateNIC(loopbackNICID, sniffer.New(loopbackNIC)); err != nil {
- t.Fatalf("localStack.CreateNIC(%d, _): %s", loopbackNICID, err)
- }
- loopbackAddr := tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: loopbackIPv4Addr.WithPrefix(),
- }
- if err := localStack.AddProtocolAddress(loopbackNICID, loopbackAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("localStack.AddProtocolAddress(%d, %+v, {}): %s", loopbackNICID, loopbackAddr, err)
- }
-
- // Create linked NICs that connects the local and remote stack.
- const localNICID = tcpip.NICID(2)
- const remoteNICID = tcpip.NICID(3)
- if err := localStack.CreateNIC(localNICID, sniffer.New(localNIC)); err != nil {
- t.Fatalf("localStack.CreateNIC(%d, _): %s", localNICID, err)
- }
- if err := remoteStack.CreateNIC(remoteNICID, sniffer.New(remoteNIC)); err != nil {
- t.Fatalf("remoteStack.CreateNIC(%d, _): %s", remoteNICID, err)
- }
-
- for _, addr := range []tcpip.Address{localIPv4Addr1, localIPv4Addr2} {
- localProtocolAddr := tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: addr.WithPrefix(),
- }
- if err := localStack.AddProtocolAddress(localNICID, localProtocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("localStack.AddProtocolAddress(%d, %+v, {}): %s", localNICID, localProtocolAddr, err)
- }
- }
-
- remoteProtocolAddr := tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: remoteIPv4Addr1.WithPrefix(),
- }
- if err := remoteStack.AddProtocolAddress(remoteNICID, remoteProtocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("remoteStack.AddProtocolAddress(%d, %+v, {}): %s", remoteNICID, remoteProtocolAddr, err)
- }
-
- // Setup route table for local and remote stacks.
- localStack.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4LoopbackSubnet,
- NIC: loopbackNICID,
- },
- {
- Destination: header.IPv4EmptySubnet,
- NIC: localNICID,
- },
- })
- remoteStack.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: remoteNICID,
- },
- })
-
- const netProto = ipv4.ProtocolNumber
- localServerAddress := tcpip.FullAddress{
- Port: localServerPort,
- }
-
- localServerListener, err := gonet.ListenTCP(localStack, localServerAddress, netProto)
- if err != nil {
- t.Fatalf("gonet.ListenTCP(_, %+v, %d) = %s", localServerAddress, netProto, err)
- }
-
- remoteServerAddress := tcpip.FullAddress{
- Port: remoteServerPort,
- }
- remoteServerListener, err := gonet.ListenTCP(remoteStack, remoteServerAddress, netProto)
- if err != nil {
- t.Fatalf("gonet.ListenTCP(_, %+v, %d) = %s", remoteServerAddress, netProto, err)
- }
-
- // Initialize a random default response served by the HTTP server.
- defaultResponse := make([]byte, 512<<10)
- if _, err := rand.Read(defaultResponse); err != nil {
- t.Fatalf("rand.Read(buf) failed: %s", err)
- }
-
- tc := &testContext{
- localServerListener: localServerListener,
- remoteServerListener: remoteServerListener,
- localStack: localStack,
- remoteStack: remoteStack,
- defaultResponse: defaultResponse,
- }
-
- tc.startServers(t)
- return tc
-}
-
-func (ctx *testContext) startServers(t *testing.T) {
- ctx.wg.Add(1)
- go func() {
- defer ctx.wg.Done()
- ctx.startHTTPServer()
- }()
- ctx.wg.Add(1)
- go func() {
- defer ctx.wg.Done()
- ctx.startTCPProxyServer(t)
- }()
-}
-
-func (ctx *testContext) startTCPProxyServer(t *testing.T) {
- t.Helper()
- for {
- conn, err := ctx.localServerListener.Accept()
- if err != nil {
- t.Logf("terminating local proxy server: %s", err)
- return
- }
- // Start a goroutine to handle this inbound connection.
- go func() {
- remoteServerAddr := tcpip.FullAddress{
- Addr: remoteIPv4Addr1,
- Port: remoteServerPort,
- }
- localServerAddr := tcpip.FullAddress{
- Addr: localIPv4Addr2,
- }
- serverConn, err := gonet.DialTCPWithBind(context.Background(), ctx.localStack, localServerAddr, remoteServerAddr, ipv4.ProtocolNumber)
- if err != nil {
- t.Logf("gonet.DialTCP(_, %+v, %d) = %s", remoteServerAddr, ipv4.ProtocolNumber, err)
- return
- }
- proxy(conn, serverConn)
- t.Logf("proxying completed")
- }()
- }
-}
-
-// proxy transparently proxies the TCP payload from conn1 to conn2
-// and vice versa.
-func proxy(conn1, conn2 net.Conn) {
- var wg sync.WaitGroup
- wg.Add(1)
- go func() {
- io.Copy(conn2, conn1)
- conn1.Close()
- conn2.Close()
- }()
- wg.Add(1)
- go func() {
- io.Copy(conn1, conn2)
- conn1.Close()
- conn2.Close()
- }()
- wg.Wait()
-}
-
-func (ctx *testContext) startHTTPServer() {
- handlerFunc := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Write([]byte(ctx.defaultResponse))
- })
- s := &http.Server{
- Handler: handlerFunc,
- }
- s.Serve(ctx.remoteServerListener)
-}
-
-func TestOutboundNATRedirect(t *testing.T) {
- ctx := newTestContext(t)
- defer ctx.cleanup()
-
- // Install an IPTable rule to redirect all TCP traffic with the sourceIP of
- // localIPv4Addr1 to the tcp proxy port.
- ipt := ctx.localStack.IPTables()
- tbl := ipt.GetTable(stack.NATID, false /* ipv6 */)
- ruleIdx := tbl.BuiltinChains[stack.Output]
- tbl.Rules[ruleIdx].Filter = stack.IPHeaderFilter{
- Protocol: tcp.ProtocolNumber,
- CheckProtocol: true,
- Src: localIPv4Addr1,
- SrcMask: tcpip.Address("\xff\xff\xff\xff"),
- }
- tbl.Rules[ruleIdx].Target = &stack.RedirectTarget{
- Port: localServerPort,
- NetworkProtocol: ipv4.ProtocolNumber,
- }
- tbl.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.NATID, tbl, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, false): %s", stack.NATID, err)
- }
-
- dialFunc := func(protocol, address string) (net.Conn, error) {
- host, port, err := net.SplitHostPort(address)
- if err != nil {
- return nil, fmt.Errorf("unable to parse address: %s, err: %s", address, err)
- }
-
- remoteServerIP := net.ParseIP(host)
- remoteServerPort, err := strconv.Atoi(port)
- if err != nil {
- return nil, fmt.Errorf("unable to parse port from string %s, err: %s", port, err)
- }
- remoteAddress := tcpip.FullAddress{
- Addr: tcpip.Address(remoteServerIP.To4()),
- Port: uint16(remoteServerPort),
- }
-
- // Dial with an explicit source address bound so that the redirect rule will
- // be able to correctly redirect these packets.
- localAddr := tcpip.FullAddress{Addr: localIPv4Addr1}
- return gonet.DialTCPWithBind(context.Background(), ctx.localStack, localAddr, remoteAddress, ipv4.ProtocolNumber)
- }
-
- httpClient := &http.Client{
- Transport: &http.Transport{
- Dial: dialFunc,
- },
- }
-
- serverURL := fmt.Sprintf("http://[%s]:%d/", net.IP(remoteIPv4Addr1), remoteServerPort)
- response, err := httpClient.Get(serverURL)
- if err != nil {
- t.Fatalf("httpClient.Get(\"/\") failed: %s", err)
- }
- if got, want := response.StatusCode, http.StatusOK; got != want {
- t.Fatalf("unexpected status code got: %d, want: %d", got, want)
- }
- body, err := io.ReadAll(response.Body)
- if err != nil {
- t.Fatalf("io.ReadAll(response.Body) failed: %s", err)
- }
- response.Body.Close()
- if diff := cmp.Diff(body, ctx.defaultResponse); diff != "" {
- t.Fatalf("unexpected response (-want +got): \n %s", diff)
- }
-}
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
deleted file mode 100644
index 95ddd8ec3..000000000
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ /dev/null
@@ -1,1640 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package link_resolution_test
-
-import (
- "bytes"
- "fmt"
- "net"
- "runtime"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "github.com/google/go-cmp/cmp/cmpopts"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/pipe"
- "gvisor.dev/gvisor/pkg/tcpip/network/arp"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/tests/utils"
- tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-func setupStack(t *testing.T, stackOpts stack.Options, host1NICID, host2NICID tcpip.NICID) (*stack.Stack, *stack.Stack) {
- host1Stack := stack.New(stackOpts)
- host2Stack := stack.New(stackOpts)
-
- host1NIC, host2NIC := pipe.New(utils.LinkAddr1, utils.LinkAddr2)
-
- if err := host1Stack.CreateNIC(host1NICID, utils.NewEthernetEndpoint(host1NIC)); err != nil {
- t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
- }
- if err := host2Stack.CreateNIC(host2NICID, utils.NewEthernetEndpoint(host2NIC)); err != nil {
- t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
- }
-
- if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", host1NICID, utils.Ipv4Addr1, err)
- }
- if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", host2NICID, utils.Ipv4Addr2, err)
- }
- if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", host1NICID, utils.Ipv6Addr1, err)
- }
- if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", host2NICID, utils.Ipv6Addr2, err)
- }
-
- host1Stack.SetRouteTable([]tcpip.Route{
- {
- Destination: utils.Ipv4Addr1.AddressWithPrefix.Subnet(),
- NIC: host1NICID,
- },
- {
- Destination: utils.Ipv6Addr1.AddressWithPrefix.Subnet(),
- NIC: host1NICID,
- },
- })
- host2Stack.SetRouteTable([]tcpip.Route{
- {
- Destination: utils.Ipv4Addr2.AddressWithPrefix.Subnet(),
- NIC: host2NICID,
- },
- {
- Destination: utils.Ipv6Addr2.AddressWithPrefix.Subnet(),
- NIC: host2NICID,
- },
- })
-
- return host1Stack, host2Stack
-}
-
-// TestPing tests that two hosts can ping eachother when link resolution is
-// enabled.
-func TestPing(t *testing.T) {
- const (
- host1NICID = 1
- host2NICID = 4
-
- // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo
- // request/reply packets.
- icmpDataOffset = 8
- )
-
- tests := []struct {
- name string
- transProto tcpip.TransportProtocolNumber
- netProto tcpip.NetworkProtocolNumber
- remoteAddr tcpip.Address
- icmpBuf func(*testing.T) []byte
- }{
- {
- name: "IPv4 Ping",
- transProto: icmp.ProtocolNumber4,
- netProto: ipv4.ProtocolNumber,
- remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
- icmpBuf: func(t *testing.T) []byte {
- data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
- hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data)))
- hdr.SetType(header.ICMPv4Echo)
- if n := copy(hdr.Payload(), data[:]); n != len(data) {
- t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
- }
- return hdr
- },
- },
- {
- name: "IPv6 Ping",
- transProto: icmp.ProtocolNumber6,
- netProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
- icmpBuf: func(t *testing.T) []byte {
- data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
- hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data)))
- hdr.SetType(header.ICMPv6EchoRequest)
- if n := copy(hdr.Payload(), data[:]); n != len(data) {
- t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
- }
- return hdr
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- stackOpts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
- }
-
- host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
-
- var wq waiter.Queue
- we, waiterCH := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- ep, err := host1Stack.NewEndpoint(test.transProto, test.netProto, &wq)
- if err != nil {
- t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err)
- }
- defer ep.Close()
-
- icmpBuf := test.icmpBuf(t)
- var r bytes.Reader
- r.Reset(icmpBuf)
- wOpts := tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: test.remoteAddr}}
- if n, err := ep.Write(&r, wOpts); err != nil {
- t.Fatalf("ep.Write(_, _): %s", err)
- } else if want := int64(len(icmpBuf)); n != want {
- t.Fatalf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want)
- }
-
- // Wait for the endpoint to be readable.
- <-waiterCH
-
- var buf bytes.Buffer
- opts := tcpip.ReadOptions{NeedRemoteAddr: true}
- res, err := ep.Read(&buf, opts)
- if err != nil {
- t.Fatalf("ep.Read(_, %d, %#v): %s", len(icmpBuf), opts, err)
- }
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: buf.Len(),
- Total: buf.Len(),
- RemoteAddr: tcpip.FullAddress{Addr: test.remoteAddr},
- }, res, checker.IgnoreCmpPath(
- "ControlMessages",
- "RemoteAddr.NIC",
- "RemoteAddr.Port",
- )); diff != "" {
- t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(buf.Bytes()[icmpDataOffset:], icmpBuf[icmpDataOffset:]); diff != "" {
- t.Errorf("received data mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
-
-type transportError struct {
- origin tcpip.SockErrOrigin
- typ uint8
- code uint8
- info uint32
- kind stack.TransportErrorKind
-}
-
-func TestTCPLinkResolutionFailure(t *testing.T) {
- const (
- host1NICID = 1
- host2NICID = 4
- )
-
- tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- remoteAddr tcpip.Address
- expectedWriteErr tcpip.Error
- sockError tcpip.SockError
- transErr transportError
- }{
- {
- name: "IPv4 with resolvable remote",
- netProto: ipv4.ProtocolNumber,
- remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
- expectedWriteErr: nil,
- },
- {
- name: "IPv6 with resolvable remote",
- netProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
- expectedWriteErr: nil,
- },
- {
- name: "IPv4 without resolvable remote",
- netProto: ipv4.ProtocolNumber,
- remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address,
- expectedWriteErr: &tcpip.ErrNoRoute{},
- sockError: tcpip.SockError{
- Err: &tcpip.ErrNoRoute{},
- Dst: tcpip.FullAddress{
- NIC: host1NICID,
- Addr: utils.Ipv4Addr3.AddressWithPrefix.Address,
- Port: 1234,
- },
- Offender: tcpip.FullAddress{
- NIC: host1NICID,
- Addr: utils.Ipv4Addr1.AddressWithPrefix.Address,
- },
- NetProto: ipv4.ProtocolNumber,
- },
- transErr: transportError{
- origin: tcpip.SockExtErrorOriginICMP,
- typ: uint8(header.ICMPv4DstUnreachable),
- code: uint8(header.ICMPv4HostUnreachable),
- kind: stack.DestinationHostUnreachableTransportError,
- },
- },
- {
- name: "IPv6 without resolvable remote",
- netProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address,
- expectedWriteErr: &tcpip.ErrNoRoute{},
- sockError: tcpip.SockError{
- Err: &tcpip.ErrNoRoute{},
- Dst: tcpip.FullAddress{
- NIC: host1NICID,
- Addr: utils.Ipv6Addr3.AddressWithPrefix.Address,
- Port: 1234,
- },
- Offender: tcpip.FullAddress{
- NIC: host1NICID,
- Addr: utils.Ipv6Addr1.AddressWithPrefix.Address,
- },
- NetProto: ipv6.ProtocolNumber,
- },
- transErr: transportError{
- origin: tcpip.SockExtErrorOriginICMP6,
- typ: uint8(header.ICMPv6DstUnreachable),
- code: uint8(header.ICMPv6AddressUnreachable),
- kind: stack.DestinationHostUnreachableTransportError,
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- stackOpts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
- Clock: clock,
- }
-
- host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID)
-
- var listenerWQ waiter.Queue
- listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, test.netProto, &listenerWQ)
- if err != nil {
- t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, test.netProto, err)
- }
- defer listenerEP.Close()
-
- listenerAddr := tcpip.FullAddress{Port: 1234}
- if err := listenerEP.Bind(listenerAddr); err != nil {
- t.Fatalf("listenerEP.Bind(%#v): %s", listenerAddr, err)
- }
-
- if err := listenerEP.Listen(1); err != nil {
- t.Fatalf("listenerEP.Listen(1): %s", err)
- }
-
- var clientWQ waiter.Queue
- we, ch := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&we, waiter.WritableEvents|waiter.EventErr)
- clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, test.netProto, &clientWQ)
- if err != nil {
- t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, test.netProto, err)
- }
- defer clientEP.Close()
-
- sockOpts := clientEP.SocketOptions()
- sockOpts.SetRecvError(true)
-
- remoteAddr := listenerAddr
- remoteAddr.Addr = test.remoteAddr
- {
- err := clientEP.Connect(remoteAddr)
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", remoteAddr, err, &tcpip.ErrConnectStarted{})
- }
- }
-
- // Wait for an error due to link resolution failing, or the endpoint to be
- // writable.
- if test.expectedWriteErr != nil {
- nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto)
- if err != nil {
- t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err)
- }
- clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer)
- } else {
- clock.RunImmediatelyScheduledJobs()
- }
- <-ch
-
- {
- var r bytes.Reader
- r.Reset([]byte{0})
- var wOpts tcpip.WriteOptions
- _, err := clientEP.Write(&r, wOpts)
- if diff := cmp.Diff(test.expectedWriteErr, err); diff != "" {
- t.Errorf("unexpected error from clientEP.Write(_, %#v), (-want, +got):\n%s", wOpts, diff)
- }
- }
-
- if test.expectedWriteErr == nil {
- return
- }
-
- sockErr := sockOpts.DequeueErr()
- if sockErr == nil {
- t.Fatalf("got sockOpts.DequeueErr() = nil, want = non-nil")
- }
-
- sockErrCmpOpts := []cmp.Option{
- cmpopts.IgnoreUnexported(tcpip.SockError{}),
- cmp.Comparer(func(a, b tcpip.Error) bool {
- // tcpip.Error holds an unexported field but the errors netstack uses
- // are pre defined so we can simply compare pointers.
- return a == b
- }),
- checker.IgnoreCmpPath(
- // Ignore the payload since we do not know the TCP seq/ack numbers.
- "Payload",
- // Ignore the cause since we will compare its properties separately
- // since the concrete type of the cause is unknown.
- "Cause",
- ),
- }
-
- if addr, err := clientEP.GetLocalAddress(); err != nil {
- t.Fatalf("clientEP.GetLocalAddress(): %s", err)
- } else {
- test.sockError.Offender.Port = addr.Port
- }
- if diff := cmp.Diff(&test.sockError, sockErr, sockErrCmpOpts...); diff != "" {
- t.Errorf("socket error mismatch (-want +got):\n%s", diff)
- }
-
- transErr, ok := sockErr.Cause.(stack.TransportError)
- if !ok {
- t.Fatalf("socket error cause is not a transport error; cause = %#v", sockErr.Cause)
- }
- if diff := cmp.Diff(
- test.transErr,
- transportError{
- origin: transErr.Origin(),
- typ: transErr.Type(),
- code: transErr.Code(),
- info: transErr.Info(),
- kind: transErr.Kind(),
- },
- cmp.AllowUnexported(transportError{}),
- ); diff != "" {
- t.Errorf("socket error mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
-
-func TestForwardingWithLinkResolutionFailure(t *testing.T) {
- const (
- incomingNICID = 1
- outgoingNICID = 2
- ttl = 2
- expectedHostUnreachableErrorCount = 1
- )
- outgoingLinkAddr := tcptestutil.MustParseLink("02:03:03:04:05:06")
-
- rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
- utils.RxICMPv4EchoRequest(e, src, dst, ttl)
- }
-
- rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
- utils.RxICMPv6EchoRequest(e, src, dst, ttl)
- }
-
- arpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) {
- if request.Proto != arp.ProtocolNumber {
- t.Errorf("got request.Proto = %d, want = %d", request.Proto, arp.ProtocolNumber)
- }
- if request.Route.RemoteLinkAddress != header.EthernetBroadcastAddress {
- t.Errorf("got request.Route.RemoteLinkAddress = %s, want = %s", request.Route.RemoteLinkAddress, header.EthernetBroadcastAddress)
- }
- rep := header.ARP(request.Pkt.NetworkHeader().View())
- if got := rep.Op(); got != header.ARPRequest {
- t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest)
- }
- if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != outgoingLinkAddr {
- t.Errorf("got HardwareAddressSender = %s, want = %s", got, outgoingLinkAddr)
- }
- if got := tcpip.Address(rep.ProtocolAddressSender()); got != src {
- t.Errorf("got ProtocolAddressSender = %s, want = %s", got, src)
- }
- if got := tcpip.Address(rep.ProtocolAddressTarget()); got != dst {
- t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, dst)
- }
- }
-
- ndpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) {
- if request.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", request.Proto, header.IPv6ProtocolNumber)
- }
-
- snmc := header.SolicitedNodeAddr(dst)
- if want := header.EthernetAddressFromMulticastIPv6Address(snmc); request.Route.RemoteLinkAddress != want {
- t.Errorf("got remote link address = %s, want = %s", request.Route.RemoteLinkAddress, want)
- }
-
- checker.IPv6(t, stack.PayloadSince(request.Pkt.NetworkHeader()),
- checker.SrcAddr(src),
- checker.DstAddr(snmc),
- checker.TTL(header.NDPHopLimit),
- checker.NDPNS(
- checker.NDPNSTargetAddress(dst),
- ))
- }
-
- icmpv4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
- checker.IPv4(t, b,
- checker.SrcAddr(src),
- checker.DstAddr(dst),
- checker.TTL(ipv4.DefaultTTL),
- checker.ICMPv4(
- checker.ICMPv4Checksum(),
- checker.ICMPv4Type(header.ICMPv4DstUnreachable),
- checker.ICMPv4Code(header.ICMPv4HostUnreachable),
- ),
- )
- }
-
- icmpv6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
- checker.IPv6(t, b,
- checker.SrcAddr(src),
- checker.DstAddr(dst),
- checker.TTL(ipv6.DefaultTTL),
- checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6DstUnreachable),
- checker.ICMPv6Code(header.ICMPv6AddressUnreachable),
- ),
- )
- }
-
- tests := []struct {
- name string
- networkProtocolFactory []stack.NetworkProtocolFactory
- networkProtocolNumber tcpip.NetworkProtocolNumber
- sourceAddr tcpip.Address
- destAddr tcpip.Address
- incomingAddr tcpip.AddressWithPrefix
- outgoingAddr tcpip.AddressWithPrefix
- transportProtocol func(*stack.Stack) stack.TransportProtocol
- rx func(*channel.Endpoint, tcpip.Address, tcpip.Address)
- linkResolutionRequestChecker func(*testing.T, channel.PacketInfo, tcpip.Address, tcpip.Address)
- icmpReplyChecker func(*testing.T, []byte, tcpip.Address, tcpip.Address)
- mtu uint32
- }{
- {
- name: "IPv4 Host unreachable",
- networkProtocolFactory: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
- networkProtocolNumber: header.IPv4ProtocolNumber,
- sourceAddr: tcptestutil.MustParse4("10.0.0.2"),
- destAddr: tcptestutil.MustParse4("11.0.0.2"),
- incomingAddr: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()),
- PrefixLen: 8,
- },
- outgoingAddr: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()),
- PrefixLen: 8,
- },
- transportProtocol: icmp.NewProtocol4,
- linkResolutionRequestChecker: arpChecker,
- icmpReplyChecker: icmpv4Checker,
- rx: rxICMPv4EchoRequest,
- mtu: ipv4.MaxTotalSize,
- },
- {
- name: "IPv6 Host unreachable",
- networkProtocolFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- networkProtocolNumber: header.IPv6ProtocolNumber,
- sourceAddr: tcptestutil.MustParse6("10::2"),
- destAddr: tcptestutil.MustParse6("11::2"),
- incomingAddr: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("10::1").To16()),
- PrefixLen: 64,
- },
- outgoingAddr: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("11::1").To16()),
- PrefixLen: 64,
- },
- transportProtocol: icmp.NewProtocol6,
- linkResolutionRequestChecker: ndpChecker,
- icmpReplyChecker: icmpv6Checker,
- rx: rxICMPv6EchoRequest,
- mtu: header.IPv6MinimumMTU,
- },
- }
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
-
- s := stack.New(stack.Options{
- NetworkProtocols: test.networkProtocolFactory,
- TransportProtocols: []stack.TransportProtocolFactory{test.transportProtocol},
- Clock: clock,
- })
-
- // Set up endpoint through which we will receive packets.
- incomingEndpoint := channel.New(1, test.mtu, "")
- if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err)
- }
- incomingProtoAddr := tcpip.ProtocolAddress{
- Protocol: test.networkProtocolNumber,
- AddressWithPrefix: test.incomingAddr,
- }
- if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingProtoAddr, err)
- }
-
- // Set up endpoint through which we will attempt to forward packets.
- outgoingEndpoint := channel.New(1, test.mtu, outgoingLinkAddr)
- outgoingEndpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err)
- }
- outgoingProtoAddr := tcpip.ProtocolAddress{
- Protocol: test.networkProtocolNumber,
- AddressWithPrefix: test.outgoingAddr,
- }
- if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingProtoAddr, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: test.incomingAddr.Subnet(),
- NIC: incomingNICID,
- },
- {
- Destination: test.outgoingAddr.Subnet(),
- NIC: outgoingNICID,
- },
- })
-
- if err := s.SetForwardingDefaultAndAllNICs(test.networkProtocolNumber, true); err != nil {
- t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", test.networkProtocolNumber, err)
- }
-
- test.rx(incomingEndpoint, test.sourceAddr, test.destAddr)
-
- nudConfigs, err := s.NUDConfigurations(outgoingNICID, test.networkProtocolNumber)
- if err != nil {
- t.Fatalf("s.NUDConfigurations(%d, %d): %s", outgoingNICID, test.networkProtocolNumber, err)
- }
- // Trigger the first packet on the endpoint.
- clock.RunImmediatelyScheduledJobs()
-
- for i := 0; i < int(nudConfigs.MaxMulticastProbes); i++ {
- request, ok := outgoingEndpoint.Read()
- if !ok {
- t.Fatal("expected ARP packet through outgoing NIC")
- }
-
- test.linkResolutionRequestChecker(t, request, test.outgoingAddr.Address, test.destAddr)
-
- // Advance the clock the span of one request timeout.
- clock.Advance(nudConfigs.RetransmitTimer)
- }
-
- // Next, we make a blocking read to retrieve the error packet. This is
- // necessary because outgoing packets are dequeued asynchronously when
- // link resolution fails, and this dequeue is what triggers the ICMP
- // error.
- reply, ok := incomingEndpoint.Read()
- if !ok {
- t.Fatal("expected ICMP packet through incoming NIC")
- }
-
- test.icmpReplyChecker(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), test.incomingAddr.Address, test.sourceAddr)
-
- // Since link resolution failed, we don't expect the packet to be
- // forwarded.
- forwardedPacket, ok := outgoingEndpoint.Read()
- if ok {
- t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", forwardedPacket)
- }
-
- if got, want := s.Stats().IP.Forwarding.HostUnreachable.Value(), expectedHostUnreachableErrorCount; int(got) != want {
- t.Errorf("got rt.Stats().IP.Forwarding.HostUnreachable.Value() = %d, want = %d", got, want)
- }
- })
- }
-}
-
-func TestGetLinkAddress(t *testing.T) {
- const (
- host1NICID = 1
- host2NICID = 4
- )
-
- tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- remoteAddr, localAddr tcpip.Address
- expectedErr tcpip.Error
- }{
- {
- name: "IPv4 resolvable",
- netProto: ipv4.ProtocolNumber,
- remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
- expectedErr: nil,
- },
- {
- name: "IPv6 resolvable",
- netProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
- expectedErr: nil,
- },
- {
- name: "IPv4 not resolvable",
- netProto: ipv4.ProtocolNumber,
- remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address,
- expectedErr: &tcpip.ErrTimeout{},
- },
- {
- name: "IPv6 not resolvable",
- netProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address,
- expectedErr: &tcpip.ErrTimeout{},
- },
- {
- name: "IPv4 bad local address",
- netProto: ipv4.ProtocolNumber,
- remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
- localAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
- expectedErr: &tcpip.ErrBadLocalAddress{},
- },
- {
- name: "IPv6 bad local address",
- netProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
- localAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
- expectedErr: &tcpip.ErrBadLocalAddress{},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- stackOpts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
- Clock: clock,
- }
-
- host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
-
- ch := make(chan stack.LinkResolutionResult, 1)
- err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, test.localAddr, test.netProto, func(r stack.LinkResolutionResult) {
- ch <- r
- })
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, &tcpip.ErrWouldBlock{})
- }
- wantRes := stack.LinkResolutionResult{Err: test.expectedErr}
- if test.expectedErr == nil {
- wantRes.LinkAddress = utils.LinkAddr2
- }
-
- nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto)
- if err != nil {
- t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err)
- }
-
- clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer)
- select {
- case got := <-ch:
- if diff := cmp.Diff(wantRes, got); diff != "" {
- t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("event didn't arrive")
- }
- })
- }
-}
-
-func TestRouteResolvedFields(t *testing.T) {
- const (
- host1NICID = 1
- host2NICID = 4
- )
-
- tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- localAddr tcpip.Address
- remoteAddr tcpip.Address
- immediatelyResolvable bool
- expectedErr tcpip.Error
- expectedLinkAddr tcpip.LinkAddress
- }{
- {
- name: "IPv4 immediately resolvable",
- netProto: ipv4.ProtocolNumber,
- localAddr: utils.Ipv4Addr1.AddressWithPrefix.Address,
- remoteAddr: header.IPv4AllSystems,
- immediatelyResolvable: true,
- expectedErr: nil,
- expectedLinkAddr: header.EthernetAddressFromMulticastIPv4Address(header.IPv4AllSystems),
- },
- {
- name: "IPv6 immediately resolvable",
- netProto: ipv6.ProtocolNumber,
- localAddr: utils.Ipv6Addr1.AddressWithPrefix.Address,
- remoteAddr: header.IPv6AllNodesMulticastAddress,
- immediatelyResolvable: true,
- expectedErr: nil,
- expectedLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress),
- },
- {
- name: "IPv4 resolvable",
- netProto: ipv4.ProtocolNumber,
- localAddr: utils.Ipv4Addr1.AddressWithPrefix.Address,
- remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
- immediatelyResolvable: false,
- expectedErr: nil,
- expectedLinkAddr: utils.LinkAddr2,
- },
- {
- name: "IPv6 resolvable",
- netProto: ipv6.ProtocolNumber,
- localAddr: utils.Ipv6Addr1.AddressWithPrefix.Address,
- remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
- immediatelyResolvable: false,
- expectedErr: nil,
- expectedLinkAddr: utils.LinkAddr2,
- },
- {
- name: "IPv4 not resolvable",
- netProto: ipv4.ProtocolNumber,
- localAddr: utils.Ipv4Addr1.AddressWithPrefix.Address,
- remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address,
- immediatelyResolvable: false,
- expectedErr: &tcpip.ErrTimeout{},
- },
- {
- name: "IPv6 not resolvable",
- netProto: ipv6.ProtocolNumber,
- localAddr: utils.Ipv6Addr1.AddressWithPrefix.Address,
- remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address,
- immediatelyResolvable: false,
- expectedErr: &tcpip.ErrTimeout{},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- stackOpts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
- Clock: clock,
- }
-
- host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
- r, err := host1Stack.FindRoute(host1NICID, test.localAddr, test.remoteAddr, test.netProto, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("host1Stack.FindRoute(%d, %s, %s, %d, false): %s", host1NICID, test.localAddr, test.remoteAddr, test.netProto, err)
- }
- defer r.Release()
-
- var wantRouteInfo stack.RouteInfo
- wantRouteInfo.LocalLinkAddress = utils.LinkAddr1
- wantRouteInfo.LocalAddress = test.localAddr
- wantRouteInfo.RemoteAddress = test.remoteAddr
- wantRouteInfo.NetProto = test.netProto
- wantRouteInfo.Loop = stack.PacketOut
- wantRouteInfo.RemoteLinkAddress = test.expectedLinkAddr
-
- ch := make(chan stack.ResolvedFieldsResult, 1)
-
- if !test.immediatelyResolvable {
- wantUnresolvedRouteInfo := wantRouteInfo
- wantUnresolvedRouteInfo.RemoteLinkAddress = ""
-
- err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) {
- ch <- r
- })
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{})
- }
-
- nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto)
- if err != nil {
- t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err)
- }
- clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer)
-
- select {
- case got := <-ch:
- if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Err: test.expectedErr}, got, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" {
- t.Errorf("route resolve result mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatalf("event didn't arrive")
- }
-
- if test.expectedErr != nil {
- return
- }
-
- // At this point the neighbor table should be populated so the route
- // should be immediately resolvable.
- }
-
- if err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) {
- ch <- r
- }); err != nil {
- t.Errorf("r.ResolvedFields(_): %s", err)
- }
- select {
- case routeResolveRes := <-ch:
- if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Err: nil}, routeResolveRes, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" {
- t.Errorf("route resolve result from resolved route mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected route to be immediately resolvable")
- }
- })
- }
-}
-
-func TestWritePacketsLinkResolution(t *testing.T) {
- const (
- host1NICID = 1
- host2NICID = 4
- )
-
- tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- remoteAddr tcpip.Address
- expectedWriteErr tcpip.Error
- }{
- {
- name: "IPv4",
- netProto: ipv4.ProtocolNumber,
- remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
- expectedWriteErr: nil,
- },
- {
- name: "IPv6",
- netProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
- expectedWriteErr: nil,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- stackOpts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- }
-
- host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID)
-
- var serverWQ waiter.Queue
- serverWE, serverCH := waiter.NewChannelEntry(nil)
- serverWQ.EventRegister(&serverWE, waiter.ReadableEvents)
- serverEP, err := host2Stack.NewEndpoint(udp.ProtocolNumber, test.netProto, &serverWQ)
- if err != nil {
- t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.netProto, err)
- }
- defer serverEP.Close()
-
- serverAddr := tcpip.FullAddress{Port: 1234}
- if err := serverEP.Bind(serverAddr); err != nil {
- t.Fatalf("serverEP.Bind(%#v): %s", serverAddr, err)
- }
-
- r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err)
- }
- defer r.Release()
-
- data := []byte{1, 2}
- var pkts stack.PacketBufferList
- for _, d := range data {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()),
- Data: buffer.View([]byte{d}).ToVectorisedView(),
- })
- pkt.TransportProtocolNumber = udp.ProtocolNumber
- length := uint16(pkt.Size())
- udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
- udpHdr.Encode(&header.UDPFields{
- SrcPort: 5555,
- DstPort: serverAddr.Port,
- Length: length,
- })
- xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length)
- xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
- udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
-
- pkts.PushBack(pkt)
- }
-
- params := stack.NetworkHeaderParams{
- Protocol: udp.ProtocolNumber,
- TTL: 64,
- TOS: stack.DefaultTOS,
- }
-
- if n, err := r.WritePackets(pkts, params); err != nil {
- t.Fatalf("r.WritePackets(_, %#v): %s", params, err)
- } else if want := pkts.Len(); want != n {
- t.Fatalf("got r.WritePackets(_, %#v) = %d, want = %d", params, n, want)
- }
-
- var writer bytes.Buffer
- count := 0
- for {
- var rOpts tcpip.ReadOptions
- res, err := serverEP.Read(&writer, rOpts)
- if err != nil {
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
- // Should not have anymore bytes to read after we read the sent
- // number of bytes.
- if count == len(data) {
- break
- }
-
- <-serverCH
- continue
- }
-
- t.Fatalf("serverEP.Read(_, %#v): %s", rOpts, err)
- }
- count += res.Count
- }
-
- if got, want := host2Stack.Stats().UDP.PacketsReceived.Value(), uint64(len(data)); got != want {
- t.Errorf("got host2Stack.Stats().UDP.PacketsReceived.Value() = %d, want = %d", got, want)
- }
- if diff := cmp.Diff(data, writer.Bytes()); diff != "" {
- t.Errorf("read bytes mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
-
-type eventType int
-
-const (
- entryAdded eventType = iota
- entryChanged
- entryRemoved
-)
-
-func (t eventType) String() string {
- switch t {
- case entryAdded:
- return "add"
- case entryChanged:
- return "change"
- case entryRemoved:
- return "remove"
- default:
- return fmt.Sprintf("unknown (%d)", t)
- }
-}
-
-type eventInfo struct {
- eventType eventType
- nicID tcpip.NICID
- entry stack.NeighborEntry
-}
-
-func (e eventInfo) String() string {
- return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry)
-}
-
-var _ stack.NUDDispatcher = (*nudDispatcher)(nil)
-
-type nudDispatcher struct {
- c chan eventInfo
-}
-
-func (d *nudDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) {
- e := eventInfo{
- eventType: entryAdded,
- nicID: nicID,
- entry: entry,
- }
- d.c <- e
-}
-
-func (d *nudDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) {
- e := eventInfo{
- eventType: entryChanged,
- nicID: nicID,
- entry: entry,
- }
- d.c <- e
-}
-
-func (d *nudDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) {
- e := eventInfo{
- eventType: entryRemoved,
- nicID: nicID,
- entry: entry,
- }
- d.c <- e
-}
-
-func (d *nudDispatcher) expectEvent(want eventInfo) error {
- select {
- case got := <-d.c:
- if diff := cmp.Diff(want, got, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAt")); diff != "" {
- return fmt.Errorf("got invalid event (-want +got):\n%s", diff)
- }
- return nil
- default:
- return fmt.Errorf("event didn't arrive")
- }
-}
-
-// TestTCPConfirmNeighborReachability tests that TCP informs layers beneath it
-// that the neighbor used for a route is reachable.
-func TestTCPConfirmNeighborReachability(t *testing.T) {
- tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- remoteAddr tcpip.Address
- neighborAddr tcpip.Address
- getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{})
- isHost1Listener bool
- }{
- {
- name: "IPv4 active connection through neighbor",
- netProto: ipv4.ProtocolNumber,
- remoteAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address,
- neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
- var listenerWQ waiter.Queue
- listenerWE, listenerCH := waiter.NewChannelEntry(nil)
- listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
- listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
- if err != nil {
- t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
- }
- t.Cleanup(listenerEP.Close)
-
- var clientWQ waiter.Queue
- clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
- clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
- if err != nil {
- t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
- }
-
- return listenerEP, listenerCH, clientEP, clientCH
- },
- },
- {
- name: "IPv6 active connection through neighbor",
- netProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address,
- neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
- var listenerWQ waiter.Queue
- listenerWE, listenerCH := waiter.NewChannelEntry(nil)
- listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
- listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
- if err != nil {
- t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
- }
- t.Cleanup(listenerEP.Close)
-
- var clientWQ waiter.Queue
- clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
- clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
- if err != nil {
- t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
- }
-
- return listenerEP, listenerCH, clientEP, clientCH
- },
- },
- {
- name: "IPv4 active connection to neighbor",
- netProto: ipv4.ProtocolNumber,
- remoteAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
- neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
- var listenerWQ waiter.Queue
- listenerWE, listenerCH := waiter.NewChannelEntry(nil)
- listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
- listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
- if err != nil {
- t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
- }
- t.Cleanup(listenerEP.Close)
-
- var clientWQ waiter.Queue
- clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
- clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
- if err != nil {
- t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
- }
-
- return listenerEP, listenerCH, clientEP, clientCH
- },
- },
- {
- name: "IPv6 active connection to neighbor",
- netProto: ipv6.ProtocolNumber,
- remoteAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
- neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
- var listenerWQ waiter.Queue
- listenerWE, listenerCH := waiter.NewChannelEntry(nil)
- listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
- listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
- if err != nil {
- t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
- }
- t.Cleanup(listenerEP.Close)
-
- var clientWQ waiter.Queue
- clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
- clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
- if err != nil {
- t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
- }
-
- return listenerEP, listenerCH, clientEP, clientCH
- },
- },
- {
- name: "IPv4 passive connection to neighbor",
- netProto: ipv4.ProtocolNumber,
- remoteAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
- neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
- var listenerWQ waiter.Queue
- listenerWE, listenerCH := waiter.NewChannelEntry(nil)
- listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
- listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
- if err != nil {
- t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
- }
- t.Cleanup(listenerEP.Close)
-
- var clientWQ waiter.Queue
- clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
- clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
- if err != nil {
- t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
- }
-
- return listenerEP, listenerCH, clientEP, clientCH
- },
- isHost1Listener: true,
- },
- {
- name: "IPv6 passive connection to neighbor",
- netProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
- neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
- var listenerWQ waiter.Queue
- listenerWE, listenerCH := waiter.NewChannelEntry(nil)
- listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
- listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
- if err != nil {
- t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
- }
- t.Cleanup(listenerEP.Close)
-
- var clientWQ waiter.Queue
- clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
- clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
- if err != nil {
- t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
- }
-
- return listenerEP, listenerCH, clientEP, clientCH
- },
- isHost1Listener: true,
- },
- {
- name: "IPv4 passive connection through neighbor",
- netProto: ipv4.ProtocolNumber,
- remoteAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
- neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
- var listenerWQ waiter.Queue
- listenerWE, listenerCH := waiter.NewChannelEntry(nil)
- listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
- listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
- if err != nil {
- t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
- }
- t.Cleanup(listenerEP.Close)
-
- var clientWQ waiter.Queue
- clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
- clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
- if err != nil {
- t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
- }
-
- return listenerEP, listenerCH, clientEP, clientCH
- },
- isHost1Listener: true,
- },
- {
- name: "IPv6 passive connection through neighbor",
- netProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
- neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
- var listenerWQ waiter.Queue
- listenerWE, listenerCH := waiter.NewChannelEntry(nil)
- listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
- listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
- if err != nil {
- t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
- }
- t.Cleanup(listenerEP.Close)
-
- var clientWQ waiter.Queue
- clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
- clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
- if err != nil {
- t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
- }
-
- return listenerEP, listenerCH, clientEP, clientCH
- },
- isHost1Listener: true,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- nudDisp := nudDispatcher{
- c: make(chan eventInfo, 3),
- }
- stackOpts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
- Clock: clock,
- }
- host1StackOpts := stackOpts
- host1StackOpts.NUDDisp = &nudDisp
-
- host1Stack := stack.New(host1StackOpts)
- routerStack := stack.New(stackOpts)
- host2Stack := stack.New(stackOpts)
- utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack)
-
- // Add a reachable dynamic entry to our neighbor table for the remote.
- {
- ch := make(chan stack.LinkResolutionResult, 1)
- err := host1Stack.GetLinkAddress(utils.Host1NICID, test.neighborAddr, "", test.netProto, func(r stack.LinkResolutionResult) {
- ch <- r
- })
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", utils.Host1NICID, test.neighborAddr, test.netProto, err, &tcpip.ErrWouldBlock{})
- }
- if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: utils.LinkAddr2, Err: nil}, <-ch); diff != "" {
- t.Fatalf("link resolution mismatch (-want +got):\n%s", diff)
- }
- }
- if err := nudDisp.expectEvent(eventInfo{
- eventType: entryAdded,
- nicID: utils.Host1NICID,
- entry: stack.NeighborEntry{State: stack.Incomplete, Addr: test.neighborAddr},
- }); err != nil {
- t.Fatalf("error waiting for initial NUD event: %s", err)
- }
- if err := nudDisp.expectEvent(eventInfo{
- eventType: entryChanged,
- nicID: utils.Host1NICID,
- entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
- }); err != nil {
- t.Fatalf("error waiting for reachable NUD event: %s", err)
- }
-
- // Wait for the remote's neighbor entry to be stale before creating a
- // TCP connection from host1 to some remote.
- nudConfigs, err := host1Stack.NUDConfigurations(utils.Host1NICID, test.netProto)
- if err != nil {
- t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", utils.Host1NICID, test.netProto, err)
- }
- // The maximum reachable time for a neighbor is some maximum random factor
- // applied to the base reachable time.
- //
- // See NUDConfigurations.BaseReachableTime for more information.
- maxReachableTime := time.Duration(float32(nudConfigs.BaseReachableTime) * nudConfigs.MaxRandomFactor)
- clock.Advance(maxReachableTime)
- if err := nudDisp.expectEvent(eventInfo{
- eventType: entryChanged,
- nicID: utils.Host1NICID,
- entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
- }); err != nil {
- t.Fatalf("error waiting for stale NUD event: %s", err)
- }
-
- listenerEP, listenerCH, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack)
- defer clientEP.Close()
- listenerAddr := tcpip.FullAddress{Addr: test.remoteAddr, Port: 1234}
- if err := listenerEP.Bind(listenerAddr); err != nil {
- t.Fatalf("listenerEP.Bind(%#v): %s", listenerAddr, err)
- }
- if err := listenerEP.Listen(1); err != nil {
- t.Fatalf("listenerEP.Listen(1): %s", err)
- }
- {
- err := clientEP.Connect(listenerAddr)
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", listenerAddr, err, &tcpip.ErrConnectStarted{})
- }
- }
-
- // Wait for the TCP handshake to complete then make sure the neighbor is
- // reachable without entering the probe state as TCP should provide NUD
- // with confirmation that the neighbor is reachable (indicated by a
- // successful 3-way handshake).
- <-clientCH
- if err := nudDisp.expectEvent(eventInfo{
- eventType: entryChanged,
- nicID: utils.Host1NICID,
- entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
- }); err != nil {
- t.Fatalf("error waiting for delay NUD event: %s", err)
- }
- <-listenerCH
- if err := nudDisp.expectEvent(eventInfo{
- eventType: entryChanged,
- nicID: utils.Host1NICID,
- entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
- }); err != nil {
- t.Fatalf("error waiting for reachable NUD event: %s", err)
- }
-
- peerEP, peerWQ, err := listenerEP.Accept(nil)
- if err != nil {
- t.Fatalf("listenerEP.Accept(): %s", err)
- }
- defer peerEP.Close()
- peerWE, peerCH := waiter.NewChannelEntry(nil)
- peerWQ.EventRegister(&peerWE, waiter.ReadableEvents)
-
- // Wait for the neighbor to be stale again then send data to the remote.
- //
- // On successful transmission, the neighbor should become reachable
- // without probing the neighbor as a TCP ACK would be received which is an
- // indication of the neighbor being reachable.
- clock.Advance(maxReachableTime)
- if err := nudDisp.expectEvent(eventInfo{
- eventType: entryChanged,
- nicID: utils.Host1NICID,
- entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
- }); err != nil {
- t.Fatalf("error waiting for stale NUD event: %s", err)
- }
- {
- var r bytes.Reader
- r.Reset([]byte{0})
- var wOpts tcpip.WriteOptions
- if _, err := clientEP.Write(&r, wOpts); err != nil {
- t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err)
- }
- }
- // Heads up, there is a race here.
- //
- // Incoming TCP segments are handled in
- // tcp.(*endpoint).handleSegmentLocked:
- //
- // - tcp.(*endpoint).rcv.handleRcvdSegment puts the segment on the
- // segment queue and notifies waiting readers (such as this channel)
- //
- // - tcp.(*endpoint).snd.handleRcvdSegment sends an ACK for the segment
- // and notifies the NUD machinery that the peer is reachable
- //
- // Thus we must permit a delay between the readable signal and the
- // expected NUD event.
- //
- // At the time of writing, this race is reliably hit with gotsan.
- <-peerCH
- for len(nudDisp.c) == 0 {
- runtime.Gosched()
- }
- if err := nudDisp.expectEvent(eventInfo{
- eventType: entryChanged,
- nicID: utils.Host1NICID,
- entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
- }); err != nil {
- t.Fatalf("error waiting for delay NUD event: %s", err)
- }
- if test.isHost1Listener {
- // If host1 is not the client, host1 does not send any data so TCP
- // has no way to know it is making forward progress. Because of this,
- // TCP should not mark the route reachable and NUD should go through the
- // probe state.
- clock.Advance(nudConfigs.DelayFirstProbeTime)
- if err := nudDisp.expectEvent(eventInfo{
- eventType: entryChanged,
- nicID: utils.Host1NICID,
- entry: stack.NeighborEntry{State: stack.Probe, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
- }); err != nil {
- t.Fatalf("error waiting for probe NUD event: %s", err)
- }
- }
- {
- var r bytes.Reader
- r.Reset([]byte{0})
- var wOpts tcpip.WriteOptions
- if _, err := peerEP.Write(&r, wOpts); err != nil {
- t.Errorf("peerEP.Write(_, %#v): %s", wOpts, err)
- }
- }
- <-clientCH
- if err := nudDisp.expectEvent(eventInfo{
- eventType: entryChanged,
- nicID: utils.Host1NICID,
- entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
- }); err != nil {
- t.Fatalf("error waiting for reachable NUD event: %s", err)
- }
- })
- }
-}
-
-func TestDAD(t *testing.T) {
- dadConfigs := stack.DADConfigurations{
- DupAddrDetectTransmits: 1,
- RetransmitTimer: time.Second,
- }
-
- tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- dadNetProto tcpip.NetworkProtocolNumber
- remoteAddr tcpip.Address
- expectedResult stack.DADResult
- }{
- {
- name: "IPv4 own address",
- netProto: ipv4.ProtocolNumber,
- dadNetProto: arp.ProtocolNumber,
- remoteAddr: utils.Ipv4Addr1.AddressWithPrefix.Address,
- expectedResult: &stack.DADSucceeded{},
- },
- {
- name: "IPv6 own address",
- netProto: ipv6.ProtocolNumber,
- dadNetProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Ipv6Addr1.AddressWithPrefix.Address,
- expectedResult: &stack.DADSucceeded{},
- },
- {
- name: "IPv4 duplicate address",
- netProto: ipv4.ProtocolNumber,
- dadNetProto: arp.ProtocolNumber,
- remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
- expectedResult: &stack.DADDupAddrDetected{HolderLinkAddress: utils.LinkAddr2},
- },
- {
- name: "IPv6 duplicate address",
- netProto: ipv6.ProtocolNumber,
- dadNetProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
- expectedResult: &stack.DADDupAddrDetected{HolderLinkAddress: utils.LinkAddr2},
- },
- {
- name: "IPv4 no duplicate address",
- netProto: ipv4.ProtocolNumber,
- dadNetProto: arp.ProtocolNumber,
- remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address,
- expectedResult: &stack.DADSucceeded{},
- },
- {
- name: "IPv6 no duplicate address",
- netProto: ipv6.ProtocolNumber,
- dadNetProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address,
- expectedResult: &stack.DADSucceeded{},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- stackOpts := stack.Options{
- Clock: clock,
- NetworkProtocols: []stack.NetworkProtocolFactory{
- arp.NewProtocol,
- ipv4.NewProtocol,
- ipv6.NewProtocol,
- },
- }
-
- host1Stack, _ := setupStack(t, stackOpts, utils.Host1NICID, utils.Host2NICID)
-
- // DAD should be disabled by default.
- if res, err := host1Stack.CheckDuplicateAddress(utils.Host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) {
- t.Errorf("unexpectedly called DAD completion handler when DAD was supposed to be disabled")
- }); err != nil {
- t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", utils.Host1NICID, test.netProto, test.remoteAddr, err)
- } else if res != stack.DADDisabled {
- t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", utils.Host1NICID, test.netProto, test.remoteAddr, res, stack.DADDisabled)
- }
-
- // Enable DAD then attempt to check if an address is duplicated.
- netEP, err := host1Stack.GetNetworkEndpoint(utils.Host1NICID, test.dadNetProto)
- if err != nil {
- t.Fatalf("host1Stack.GetNetworkEndpoint(%d, %d): %s", utils.Host1NICID, test.dadNetProto, err)
- }
- dad, ok := netEP.(stack.DuplicateAddressDetector)
- if !ok {
- t.Fatalf("expected %T to implement stack.DuplicateAddressDetector", netEP)
- }
- dad.SetDADConfigurations(dadConfigs)
- ch := make(chan stack.DADResult, 3)
- if res, err := host1Stack.CheckDuplicateAddress(utils.Host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) {
- ch <- r
- }); err != nil {
- t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", utils.Host1NICID, test.netProto, test.remoteAddr, err)
- } else if res != stack.DADStarting {
- t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", utils.Host1NICID, test.netProto, test.remoteAddr, res, stack.DADStarting)
- }
-
- expectResults := 1
- if _, ok := test.expectedResult.(*stack.DADSucceeded); ok {
- const delta = time.Nanosecond
- clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer - delta)
- select {
- case r := <-ch:
- t.Fatalf("unexpectedly got DAD result before the DAD timeout; r = %#v", r)
- default:
- }
-
- // If we expect the resolve to succeed try requesting DAD again on the
- // same address. The handler for the new request should be called once
- // the original DAD request completes.
- expectResults = 2
- if res, err := host1Stack.CheckDuplicateAddress(utils.Host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) {
- ch <- r
- }); err != nil {
- t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", utils.Host1NICID, test.netProto, test.remoteAddr, err)
- } else if res != stack.DADAlreadyRunning {
- t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", utils.Host1NICID, test.netProto, test.remoteAddr, res, stack.DADAlreadyRunning)
- }
-
- clock.Advance(delta)
- }
-
- for i := 0; i < expectResults; i++ {
- if diff := cmp.Diff(test.expectedResult, <-ch); diff != "" {
- t.Errorf("(i=%d) DAD result mismatch (-want +got):\n%s", i, diff)
- }
- }
-
- // Should have no more results.
- select {
- case r := <-ch:
- t.Errorf("unexpectedly got an extra DAD result; r = %#v", r)
- default:
- }
- })
- }
-}
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
deleted file mode 100644
index f33223e79..000000000
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ /dev/null
@@ -1,782 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package loopback_test
-
-import (
- "bytes"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/tests/utils"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil)
-
-type ndpDispatcher struct{}
-
-func (*ndpDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) {
-}
-
-func (*ndpDispatcher) OnOffLinkRouteUpdated(tcpip.NICID, tcpip.Subnet, tcpip.Address, header.NDPRoutePreference) {
-}
-
-func (*ndpDispatcher) OnOffLinkRouteInvalidated(tcpip.NICID, tcpip.Subnet, tcpip.Address) {}
-
-func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) {
-}
-
-func (*ndpDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {}
-
-func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) {
-}
-
-func (*ndpDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {}
-
-func (*ndpDispatcher) OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) {}
-
-func (*ndpDispatcher) OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) {}
-
-func (*ndpDispatcher) OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) {}
-
-func (*ndpDispatcher) OnDHCPv6Configuration(tcpip.NICID, ipv6.DHCPv6ConfigurationFromNDPRA) {}
-
-// TestInitialLoopbackAddresses tests that the loopback interface does not
-// auto-generate a link-local address when it is brought up.
-func TestInitialLoopbackAddresses(t *testing.T) {
- const nicID = 1
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPDisp: &ndpDispatcher{},
- AutoGenLinkLocal: true,
- OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: func(nicID tcpip.NICID, nicName string) string {
- t.Fatalf("should not attempt to get name for NIC with ID = %d; nicName = %s", nicID, nicName)
- return ""
- },
- },
- })},
- })
-
- if err := s.CreateNIC(nicID, loopback.New()); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
-
- nicsInfo := s.NICInfo()
- if nicInfo, ok := nicsInfo[nicID]; !ok {
- t.Fatalf("did not find NIC with ID = %d in s.NICInfo() = %#v", nicID, nicsInfo)
- } else if got := len(nicInfo.ProtocolAddresses); got != 0 {
- t.Fatalf("got len(nicInfo.ProtocolAddresses) = %d, want = 0; nicInfo.ProtocolAddresses = %#v", got, nicInfo.ProtocolAddresses)
- }
-}
-
-// TestLoopbackAcceptAllInSubnetUDP tests that a loopback interface considers
-// itself bound to all addresses in the subnet of an assigned address and UDP
-// traffic is sent/received correctly.
-func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) {
- const (
- nicID = 1
- localPort = 80
- )
-
- data := []byte{1, 2, 3, 4}
-
- ipv4ProtocolAddress := tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: utils.Ipv4Addr,
- }
- ipv4Bytes := []byte(ipv4ProtocolAddress.AddressWithPrefix.Address)
- ipv4Bytes[len(ipv4Bytes)-1]++
- otherIPv4Address := tcpip.Address(ipv4Bytes)
-
- ipv6ProtocolAddress := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: utils.Ipv6Addr,
- }
- ipv6Bytes := []byte(utils.Ipv6Addr.Address)
- ipv6Bytes[len(ipv6Bytes)-1]++
- otherIPv6Address := tcpip.Address(ipv6Bytes)
-
- tests := []struct {
- name string
- addAddress tcpip.ProtocolAddress
- bindAddr tcpip.Address
- dstAddr tcpip.Address
- expectRx bool
- }{
- {
- name: "IPv4 bind to wildcard and send to assigned address",
- addAddress: ipv4ProtocolAddress,
- dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
- expectRx: true,
- },
- {
- name: "IPv4 bind to wildcard and send to other subnet-local address",
- addAddress: ipv4ProtocolAddress,
- dstAddr: otherIPv4Address,
- expectRx: true,
- },
- {
- name: "IPv4 bind to wildcard send to other address",
- addAddress: ipv4ProtocolAddress,
- dstAddr: utils.RemoteIPv4Addr,
- expectRx: false,
- },
- {
- name: "IPv4 bind to other subnet-local address and send to assigned address",
- addAddress: ipv4ProtocolAddress,
- bindAddr: otherIPv4Address,
- dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
- expectRx: false,
- },
- {
- name: "IPv4 bind and send to other subnet-local address",
- addAddress: ipv4ProtocolAddress,
- bindAddr: otherIPv4Address,
- dstAddr: otherIPv4Address,
- expectRx: true,
- },
- {
- name: "IPv4 bind to assigned address and send to other subnet-local address",
- addAddress: ipv4ProtocolAddress,
- bindAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
- dstAddr: otherIPv4Address,
- expectRx: false,
- },
-
- {
- name: "IPv6 bind and send to assigned address",
- addAddress: ipv6ProtocolAddress,
- bindAddr: utils.Ipv6Addr.Address,
- dstAddr: utils.Ipv6Addr.Address,
- expectRx: true,
- },
- {
- name: "IPv6 bind to wildcard and send to other subnet-local address",
- addAddress: ipv6ProtocolAddress,
- dstAddr: otherIPv6Address,
- expectRx: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
- if err := s.CreateNIC(nicID, loopback.New()); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
- if err := s.AddProtocolAddress(nicID, test.addAddress, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, test.addAddress, err)
- }
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: nicID,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: nicID,
- },
- })
-
- var wq waiter.Queue
- rep, err := s.NewEndpoint(udp.ProtocolNumber, test.addAddress.Protocol, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err)
- }
- defer rep.Close()
-
- bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort}
- if err := rep.Bind(bindAddr); err != nil {
- t.Fatalf("rep.Bind(%+v): %s", bindAddr, err)
- }
-
- sep, err := s.NewEndpoint(udp.ProtocolNumber, test.addAddress.Protocol, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err)
- }
- defer sep.Close()
-
- wopts := tcpip.WriteOptions{
- To: &tcpip.FullAddress{
- Addr: test.dstAddr,
- Port: localPort,
- },
- }
- var r bytes.Reader
- r.Reset(data)
- n, err := sep.Write(&r, wopts)
- if err != nil {
- t.Fatalf("sep.Write(_, _): %s", err)
- }
- if want := int64(len(data)); n != want {
- t.Fatalf("got sep.Write(_, _) = (%d, nil), want = (%d, nil)", n, want)
- }
-
- var buf bytes.Buffer
- opts := tcpip.ReadOptions{NeedRemoteAddr: true}
- if res, err := rep.Read(&buf, opts); test.expectRx {
- if err != nil {
- t.Fatalf("rep.Read(_, %#v): %s", opts, err)
- }
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: buf.Len(),
- Total: buf.Len(),
- RemoteAddr: tcpip.FullAddress{
- Addr: test.addAddress.AddressWithPrefix.Address,
- },
- }, res,
- checker.IgnoreCmpPath("ControlMessages", "RemoteAddr.NIC", "RemoteAddr.Port"),
- ); diff != "" {
- t.Errorf("rep.Read: unexpected result (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(data, buf.Bytes()); diff != "" {
- t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
- }
- } else if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got rep.Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), &tcpip.ErrWouldBlock{})
- }
- })
- }
-}
-
-// TestLoopbackSubnetLifetimeBoundToAddr tests that the lifetime of an address
-// in a loopback interface's associated subnet is bound to the permanently bound
-// address.
-func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) {
- const nicID = 1
-
- protoAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: utils.Ipv4Addr,
- }
- addrBytes := []byte(utils.Ipv4Addr.Address)
- addrBytes[len(addrBytes)-1]++
- otherAddr := tcpip.Address(addrBytes)
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- })
- if err := s.CreateNIC(nicID, loopback.New()); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
- if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
- }
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: nicID,
- },
- })
-
- r, err := s.FindRoute(nicID, otherAddr, utils.RemoteIPv4Addr, ipv4.ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, otherAddr, utils.RemoteIPv4Addr, ipv4.ProtocolNumber, err)
- }
- defer r.Release()
-
- params := stack.NetworkHeaderParams{
- Protocol: 111,
- TTL: 64,
- TOS: stack.DefaultTOS,
- }
- data := buffer.View([]byte{1, 2, 3, 4})
- if err := r.WritePacket(params, stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: data.ToVectorisedView(),
- })); err != nil {
- t.Fatalf("r.WritePacket(%#v, _): %s", params, err)
- }
-
- // Removing the address should make the endpoint invalid.
- if err := s.RemoveAddress(nicID, protoAddr.AddressWithPrefix.Address); err != nil {
- t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, protoAddr.AddressWithPrefix.Address, err)
- }
- {
- err := r.WritePacket(params, stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: data.ToVectorisedView(),
- }))
- if _, ok := err.(*tcpip.ErrInvalidEndpointState); !ok {
- t.Fatalf("got r.WritePacket(%#v, _) = %s, want = %s", params, err, &tcpip.ErrInvalidEndpointState{})
- }
- }
-}
-
-// TestLoopbackAcceptAllInSubnetTCP tests that a loopback interface considers
-// itself bound to all addresses in the subnet of an assigned address and TCP
-// traffic is sent/received correctly.
-func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) {
- const (
- nicID = 1
- localPort = 80
- )
-
- ipv4ProtocolAddress := tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: utils.Ipv4Addr,
- }
- ipv4ProtocolAddress.AddressWithPrefix.PrefixLen = 8
- ipv4Bytes := []byte(ipv4ProtocolAddress.AddressWithPrefix.Address)
- ipv4Bytes[len(ipv4Bytes)-1]++
- otherIPv4Address := tcpip.Address(ipv4Bytes)
-
- ipv6ProtocolAddress := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: utils.Ipv6Addr,
- }
- ipv6Bytes := []byte(utils.Ipv6Addr.Address)
- ipv6Bytes[len(ipv6Bytes)-1]++
- otherIPv6Address := tcpip.Address(ipv6Bytes)
-
- tests := []struct {
- name string
- addAddress tcpip.ProtocolAddress
- bindAddr tcpip.Address
- dstAddr tcpip.Address
- expectAccept bool
- }{
- {
- name: "IPv4 bind to wildcard and send to assigned address",
- addAddress: ipv4ProtocolAddress,
- dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
- expectAccept: true,
- },
- {
- name: "IPv4 bind to wildcard and send to other subnet-local address",
- addAddress: ipv4ProtocolAddress,
- dstAddr: otherIPv4Address,
- expectAccept: true,
- },
- {
- name: "IPv4 bind to wildcard send to other address",
- addAddress: ipv4ProtocolAddress,
- dstAddr: utils.RemoteIPv4Addr,
- expectAccept: false,
- },
- {
- name: "IPv4 bind to other subnet-local address and send to assigned address",
- addAddress: ipv4ProtocolAddress,
- bindAddr: otherIPv4Address,
- dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
- expectAccept: false,
- },
- {
- name: "IPv4 bind and send to other subnet-local address",
- addAddress: ipv4ProtocolAddress,
- bindAddr: otherIPv4Address,
- dstAddr: otherIPv4Address,
- expectAccept: true,
- },
- {
- name: "IPv4 bind to assigned address and send to other subnet-local address",
- addAddress: ipv4ProtocolAddress,
- bindAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
- dstAddr: otherIPv4Address,
- expectAccept: false,
- },
-
- {
- name: "IPv6 bind and send to assigned address",
- addAddress: ipv6ProtocolAddress,
- bindAddr: utils.Ipv6Addr.Address,
- dstAddr: utils.Ipv6Addr.Address,
- expectAccept: true,
- },
- {
- name: "IPv6 bind to wildcard and send to other subnet-local address",
- addAddress: ipv6ProtocolAddress,
- dstAddr: otherIPv6Address,
- expectAccept: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
- })
- if err := s.CreateNIC(nicID, loopback.New()); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
- if err := s.AddProtocolAddress(nicID, test.addAddress, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, test.addAddress, err)
- }
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: nicID,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: nicID,
- },
- })
-
- var wq waiter.Queue
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
- listeningEndpoint, err := s.NewEndpoint(tcp.ProtocolNumber, test.addAddress.Protocol, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err)
- }
- defer listeningEndpoint.Close()
-
- bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort}
- if err := listeningEndpoint.Bind(bindAddr); err != nil {
- t.Fatalf("listeningEndpoint.Bind(%#v): %s", bindAddr, err)
- }
-
- if err := listeningEndpoint.Listen(1); err != nil {
- t.Fatalf("listeningEndpoint.Listen(1): %s", err)
- }
-
- connectingEndpoint, err := s.NewEndpoint(tcp.ProtocolNumber, test.addAddress.Protocol, &wq)
- if err != nil {
- t.Fatalf("s.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err)
- }
- defer connectingEndpoint.Close()
-
- connectAddr := tcpip.FullAddress{
- Addr: test.dstAddr,
- Port: localPort,
- }
- {
- err := connectingEndpoint.Connect(connectAddr)
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- t.Fatalf("connectingEndpoint.Connect(%#v): %s", connectAddr, err)
- }
- }
-
- if !test.expectAccept {
- _, _, err := listeningEndpoint.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got listeningEndpoint.Accept(nil) = %s, want = %s", err, &tcpip.ErrWouldBlock{})
- }
- return
- }
-
- // Wait for the listening endpoint to be "readable". That is, wait for a
- // new connection.
- <-ch
- var addr tcpip.FullAddress
- if _, _, err := listeningEndpoint.Accept(&addr); err != nil {
- t.Fatalf("listeningEndpoint.Accept(nil): %s", err)
- }
- if addr.Addr != test.addAddress.AddressWithPrefix.Address {
- t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.addAddress.AddressWithPrefix.Address)
- }
- })
- }
-}
-
-func TestExternalLoopbackTraffic(t *testing.T) {
- const (
- nicID1 = 1
- nicID2 = 2
-
- numPackets = 1
- ttl = 64
- )
- ipv4Loopback := testutil.MustParse4("127.0.0.1")
-
- loopbackSourcedICMPv4 := func(e *channel.Endpoint) {
- utils.RxICMPv4EchoRequest(e, ipv4Loopback, utils.Ipv4Addr.Address, ttl)
- }
-
- loopbackSourcedICMPv6 := func(e *channel.Endpoint) {
- utils.RxICMPv6EchoRequest(e, header.IPv6Loopback, utils.Ipv6Addr.Address, ttl)
- }
-
- loopbackDestinedICMPv4 := func(e *channel.Endpoint) {
- utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, ipv4Loopback, ttl)
- }
-
- loopbackDestinedICMPv6 := func(e *channel.Endpoint) {
- utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, header.IPv6Loopback, ttl)
- }
-
- invalidSrcAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter {
- return s.InvalidSourceAddressesReceived
- }
-
- invalidDestAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter {
- return s.InvalidDestinationAddressesReceived
- }
-
- tests := []struct {
- name string
- allowExternalLoopback bool
- forwarding bool
- rxICMP func(*channel.Endpoint)
- invalidAddressStat func(tcpip.IPStats) *tcpip.StatCounter
- shouldAccept bool
- }{
- {
- name: "IPv4 external loopback sourced traffic without forwarding and drop external loopback disabled",
- allowExternalLoopback: true,
- forwarding: false,
- rxICMP: loopbackSourcedICMPv4,
- invalidAddressStat: invalidSrcAddrStat,
- shouldAccept: true,
- },
- {
- name: "IPv4 external loopback sourced traffic without forwarding and drop external loopback enabled",
- allowExternalLoopback: false,
- forwarding: false,
- rxICMP: loopbackSourcedICMPv4,
- invalidAddressStat: invalidSrcAddrStat,
- shouldAccept: false,
- },
- {
- name: "IPv4 external loopback sourced traffic with forwarding and drop external loopback disabled",
- allowExternalLoopback: true,
- forwarding: true,
- rxICMP: loopbackSourcedICMPv4,
- invalidAddressStat: invalidSrcAddrStat,
- shouldAccept: true,
- },
- {
- name: "IPv4 external loopback sourced traffic with forwarding and drop external loopback enabled",
- allowExternalLoopback: false,
- forwarding: true,
- rxICMP: loopbackSourcedICMPv4,
- invalidAddressStat: invalidSrcAddrStat,
- shouldAccept: false,
- },
- {
- name: "IPv4 external loopback destined traffic without forwarding and drop external loopback disabled",
- allowExternalLoopback: true,
- forwarding: false,
- rxICMP: loopbackDestinedICMPv4,
- invalidAddressStat: invalidDestAddrStat,
- shouldAccept: false,
- },
- {
- name: "IPv4 external loopback destined traffic without forwarding and drop external loopback enabled",
- allowExternalLoopback: false,
- forwarding: false,
- rxICMP: loopbackDestinedICMPv4,
- invalidAddressStat: invalidDestAddrStat,
- shouldAccept: false,
- },
- {
- name: "IPv4 external loopback destined traffic with forwarding and drop external loopback disabled",
- allowExternalLoopback: true,
- forwarding: true,
- rxICMP: loopbackDestinedICMPv4,
- invalidAddressStat: invalidDestAddrStat,
- shouldAccept: true,
- },
- {
- name: "IPv4 external loopback destined traffic with forwarding and drop external loopback enabled",
- allowExternalLoopback: false,
- forwarding: true,
- rxICMP: loopbackDestinedICMPv4,
- invalidAddressStat: invalidDestAddrStat,
- shouldAccept: false,
- },
-
- {
- name: "IPv6 external loopback sourced traffic without forwarding and drop external loopback disabled",
- allowExternalLoopback: true,
- forwarding: false,
- rxICMP: loopbackSourcedICMPv6,
- invalidAddressStat: invalidSrcAddrStat,
- shouldAccept: true,
- },
- {
- name: "IPv6 external loopback sourced traffic without forwarding and drop external loopback enabled",
- allowExternalLoopback: false,
- forwarding: false,
- rxICMP: loopbackSourcedICMPv6,
- invalidAddressStat: invalidSrcAddrStat,
- shouldAccept: false,
- },
- {
- name: "IPv6 external loopback sourced traffic with forwarding and drop external loopback disabled",
- allowExternalLoopback: true,
- forwarding: true,
- rxICMP: loopbackSourcedICMPv6,
- invalidAddressStat: invalidSrcAddrStat,
- shouldAccept: true,
- },
- {
- name: "IPv6 external loopback sourced traffic with forwarding and drop external loopback enabled",
- allowExternalLoopback: false,
- forwarding: true,
- rxICMP: loopbackSourcedICMPv6,
- invalidAddressStat: invalidSrcAddrStat,
- shouldAccept: false,
- },
- {
- name: "IPv6 external loopback destined traffic without forwarding and drop external loopback disabled",
- allowExternalLoopback: true,
- forwarding: false,
- rxICMP: loopbackDestinedICMPv6,
- invalidAddressStat: invalidDestAddrStat,
- shouldAccept: false,
- },
- {
- name: "IPv6 external loopback destined traffic without forwarding and drop external loopback enabled",
- allowExternalLoopback: false,
- forwarding: false,
- rxICMP: loopbackDestinedICMPv6,
- invalidAddressStat: invalidDestAddrStat,
- shouldAccept: false,
- },
- {
- name: "IPv6 external loopback destined traffic with forwarding and drop external loopback disabled",
- allowExternalLoopback: true,
- forwarding: true,
- rxICMP: loopbackDestinedICMPv6,
- invalidAddressStat: invalidDestAddrStat,
- shouldAccept: true,
- },
- {
- name: "IPv6 external loopback destined traffic with forwarding and drop external loopback enabled",
- allowExternalLoopback: false,
- forwarding: true,
- rxICMP: loopbackDestinedICMPv6,
- invalidAddressStat: invalidDestAddrStat,
- shouldAccept: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{
- ipv4.NewProtocolWithOptions(ipv4.Options{
- AllowExternalLoopbackTraffic: test.allowExternalLoopback,
- }),
- ipv6.NewProtocolWithOptions(ipv6.Options{
- AllowExternalLoopbackTraffic: test.allowExternalLoopback,
- }),
- },
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
- })
- e := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID1, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
- }
- v4Addr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: utils.Ipv4Addr,
- }
- if err := s.AddProtocolAddress(nicID1, v4Addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, v4Addr, err)
- }
- v6Addr := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: utils.Ipv6Addr,
- }
- if err := s.AddProtocolAddress(nicID1, v6Addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, v6Addr, err)
- }
-
- if err := s.CreateNIC(nicID2, loopback.New()); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
- }
- protocolAddrV4 := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: ipv4Loopback,
- PrefixLen: 8,
- },
- }
- if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
- }
- protocolAddrV6 := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: header.IPv6Loopback.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
- }
-
- if test.forwarding {
- if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
- t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err)
- }
- if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
- t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
- }
- }
-
- s.SetRouteTable([]tcpip.Route{
- tcpip.Route{
- Destination: header.IPv4EmptySubnet,
- NIC: nicID1,
- },
- tcpip.Route{
- Destination: header.IPv6EmptySubnet,
- NIC: nicID1,
- },
- tcpip.Route{
- Destination: ipv4Loopback.WithPrefix().Subnet(),
- NIC: nicID2,
- },
- tcpip.Route{
- Destination: header.IPv6Loopback.WithPrefix().Subnet(),
- NIC: nicID2,
- },
- })
-
- stats := s.Stats().IP
- invalidAddressStat := test.invalidAddressStat(stats)
- deliveredPacketsStat := stats.PacketsDelivered
- if got := invalidAddressStat.Value(); got != 0 {
- t.Fatalf("got invalidAddressStat.Value() = %d, want = 0", got)
- }
- if got := deliveredPacketsStat.Value(); got != 0 {
- t.Fatalf("got deliveredPacketsStat.Value() = %d, want = 0", got)
- }
- test.rxICMP(e)
- var expectedInvalidPackets uint64
- if !test.shouldAccept {
- expectedInvalidPackets = numPackets
- }
- if got := invalidAddressStat.Value(); got != expectedInvalidPackets {
- t.Fatalf("got invalidAddressStat.Value() = %d, want = %d", got, expectedInvalidPackets)
- }
- if got, want := deliveredPacketsStat.Value(), numPackets-expectedInvalidPackets; got != want {
- t.Fatalf("got deliveredPacketsStat.Value() = %d, want = %d", got, want)
- }
- })
- }
-}
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
deleted file mode 100644
index 7753e7d6e..000000000
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ /dev/null
@@ -1,723 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package multicast_broadcast_test
-
-import (
- "bytes"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/tests/utils"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- defaultMTU = 1280
- ttl = 255
-)
-
-// TestPingMulticastBroadcast tests that responding to an Echo Request destined
-// to a multicast or broadcast address uses a unicast source address for the
-// reply.
-func TestPingMulticastBroadcast(t *testing.T) {
- const (
- nicID = 1
- ttl = 64
- )
-
- tests := []struct {
- name string
- protoNum tcpip.NetworkProtocolNumber
- rxICMP func(*channel.Endpoint, tcpip.Address, tcpip.Address, uint8)
- srcAddr tcpip.Address
- dstAddr tcpip.Address
- expectedSrc tcpip.Address
- }{
- {
- name: "IPv4 unicast",
- protoNum: header.IPv4ProtocolNumber,
- dstAddr: utils.Ipv4Addr.Address,
- srcAddr: utils.RemoteIPv4Addr,
- rxICMP: utils.RxICMPv4EchoRequest,
- expectedSrc: utils.Ipv4Addr.Address,
- },
- {
- name: "IPv4 directed broadcast",
- protoNum: header.IPv4ProtocolNumber,
- rxICMP: utils.RxICMPv4EchoRequest,
- srcAddr: utils.RemoteIPv4Addr,
- dstAddr: utils.Ipv4SubnetBcast,
- expectedSrc: utils.Ipv4Addr.Address,
- },
- {
- name: "IPv4 broadcast",
- protoNum: header.IPv4ProtocolNumber,
- rxICMP: utils.RxICMPv4EchoRequest,
- srcAddr: utils.RemoteIPv4Addr,
- dstAddr: header.IPv4Broadcast,
- expectedSrc: utils.Ipv4Addr.Address,
- },
- {
- name: "IPv4 all-systems multicast",
- protoNum: header.IPv4ProtocolNumber,
- rxICMP: utils.RxICMPv4EchoRequest,
- srcAddr: utils.RemoteIPv4Addr,
- dstAddr: header.IPv4AllSystems,
- expectedSrc: utils.Ipv4Addr.Address,
- },
- {
- name: "IPv6 unicast",
- protoNum: header.IPv6ProtocolNumber,
- rxICMP: utils.RxICMPv6EchoRequest,
- srcAddr: utils.RemoteIPv6Addr,
- dstAddr: utils.Ipv6Addr.Address,
- expectedSrc: utils.Ipv6Addr.Address,
- },
- {
- name: "IPv6 all-nodes multicast",
- protoNum: header.IPv6ProtocolNumber,
- rxICMP: utils.RxICMPv6EchoRequest,
- srcAddr: utils.RemoteIPv6Addr,
- dstAddr: header.IPv6AllNodesMulticastAddress,
- expectedSrc: utils.Ipv6Addr.Address,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
- })
- // We only expect a single packet in response to our ICMP Echo Request.
- e := channel.New(1, defaultMTU, "")
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
- ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: utils.Ipv4Addr}
- if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtoAddr, err)
- }
- ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: utils.Ipv6Addr}
- if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtoAddr, err)
- }
-
- // Default routes for IPv4 and IPv6 so ICMP can find a route to the remote
- // node when attempting to send the ICMP Echo Reply.
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv6EmptySubnet,
- NIC: nicID,
- },
- {
- Destination: header.IPv4EmptySubnet,
- NIC: nicID,
- },
- })
-
- test.rxICMP(e, test.srcAddr, test.dstAddr, ttl)
- pkt, ok := e.Read()
- if !ok {
- t.Fatal("expected ICMP response")
- }
-
- if pkt.Route.LocalAddress != test.expectedSrc {
- t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.expectedSrc)
- }
- // The destination of the response packet should be the source of the
- // original packet.
- if pkt.Route.RemoteAddress != test.srcAddr {
- t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.srcAddr)
- }
-
- src, dst := s.NetworkProtocolInstance(test.protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
- if src != test.expectedSrc {
- t.Errorf("got pkt source = %s, want = %s", src, test.expectedSrc)
- }
- // The destination of the response packet should be the source of the
- // original packet.
- if dst != test.srcAddr {
- t.Errorf("got pkt destination = %s, want = %s", dst, test.srcAddr)
- }
- })
- }
-
-}
-
-func rxIPv4UDP(e *channel.Endpoint, src, dst tcpip.Address, data []byte) {
- payloadLen := header.UDPMinimumSize + len(data)
- totalLen := header.IPv4MinimumSize + payloadLen
- hdr := buffer.NewPrependable(totalLen)
- u := header.UDP(hdr.Prepend(payloadLen))
- u.Encode(&header.UDPFields{
- SrcPort: utils.RemotePort,
- DstPort: utils.LocalPort,
- Length: uint16(payloadLen),
- })
- copy(u.Payload(), data)
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(payloadLen))
- sum = header.Checksum(data, sum)
- u.SetChecksum(^u.CalculateChecksum(sum))
-
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(totalLen),
- Protocol: uint8(udp.ProtocolNumber),
- TTL: ttl,
- SrcAddr: src,
- DstAddr: dst,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
-}
-
-func rxIPv6UDP(e *channel.Endpoint, src, dst tcpip.Address, data []byte) {
- payloadLen := header.UDPMinimumSize + len(data)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen)
- u := header.UDP(hdr.Prepend(payloadLen))
- u.Encode(&header.UDPFields{
- SrcPort: utils.RemotePort,
- DstPort: utils.LocalPort,
- Length: uint16(payloadLen),
- })
- copy(u.Payload(), data)
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(payloadLen))
- sum = header.Checksum(data, sum)
- u.SetChecksum(^u.CalculateChecksum(sum))
-
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLen),
- TransportProtocol: udp.ProtocolNumber,
- HopLimit: ttl,
- SrcAddr: src,
- DstAddr: dst,
- })
-
- e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
-}
-
-// TestIncomingMulticastAndBroadcast tests receiving a packet destined to some
-// multicast or broadcast address.
-func TestIncomingMulticastAndBroadcast(t *testing.T) {
- const nicID = 1
-
- data := []byte{1, 2, 3, 4}
-
- tests := []struct {
- name string
- proto tcpip.NetworkProtocolNumber
- remoteAddr tcpip.Address
- localAddr tcpip.AddressWithPrefix
- rxUDP func(*channel.Endpoint, tcpip.Address, tcpip.Address, []byte)
- bindAddr tcpip.Address
- dstAddr tcpip.Address
- expectRx bool
- }{
- {
- name: "IPv4 unicast binding to unicast",
- proto: header.IPv4ProtocolNumber,
- remoteAddr: utils.RemoteIPv4Addr,
- localAddr: utils.Ipv4Addr,
- rxUDP: rxIPv4UDP,
- bindAddr: utils.Ipv4Addr.Address,
- dstAddr: utils.Ipv4Addr.Address,
- expectRx: true,
- },
- {
- name: "IPv4 unicast binding to broadcast",
- proto: header.IPv4ProtocolNumber,
- remoteAddr: utils.RemoteIPv4Addr,
- localAddr: utils.Ipv4Addr,
- rxUDP: rxIPv4UDP,
- bindAddr: header.IPv4Broadcast,
- dstAddr: utils.Ipv4Addr.Address,
- expectRx: false,
- },
- {
- name: "IPv4 unicast binding to wildcard",
- proto: header.IPv4ProtocolNumber,
- remoteAddr: utils.RemoteIPv4Addr,
- localAddr: utils.Ipv4Addr,
- rxUDP: rxIPv4UDP,
- dstAddr: utils.Ipv4Addr.Address,
- expectRx: true,
- },
-
- {
- name: "IPv4 directed broadcast binding to subnet broadcast",
- proto: header.IPv4ProtocolNumber,
- remoteAddr: utils.RemoteIPv4Addr,
- localAddr: utils.Ipv4Addr,
- rxUDP: rxIPv4UDP,
- bindAddr: utils.Ipv4SubnetBcast,
- dstAddr: utils.Ipv4SubnetBcast,
- expectRx: true,
- },
- {
- name: "IPv4 directed broadcast binding to broadcast",
- proto: header.IPv4ProtocolNumber,
- remoteAddr: utils.RemoteIPv4Addr,
- localAddr: utils.Ipv4Addr,
- rxUDP: rxIPv4UDP,
- bindAddr: header.IPv4Broadcast,
- dstAddr: utils.Ipv4SubnetBcast,
- expectRx: false,
- },
- {
- name: "IPv4 directed broadcast binding to wildcard",
- proto: header.IPv4ProtocolNumber,
- remoteAddr: utils.RemoteIPv4Addr,
- localAddr: utils.Ipv4Addr,
- rxUDP: rxIPv4UDP,
- dstAddr: utils.Ipv4SubnetBcast,
- expectRx: true,
- },
-
- {
- name: "IPv4 broadcast binding to broadcast",
- proto: header.IPv4ProtocolNumber,
- remoteAddr: utils.RemoteIPv4Addr,
- localAddr: utils.Ipv4Addr,
- rxUDP: rxIPv4UDP,
- bindAddr: header.IPv4Broadcast,
- dstAddr: header.IPv4Broadcast,
- expectRx: true,
- },
- {
- name: "IPv4 broadcast binding to subnet broadcast",
- proto: header.IPv4ProtocolNumber,
- remoteAddr: utils.RemoteIPv4Addr,
- localAddr: utils.Ipv4Addr,
- rxUDP: rxIPv4UDP,
- bindAddr: utils.Ipv4SubnetBcast,
- dstAddr: header.IPv4Broadcast,
- expectRx: false,
- },
- {
- name: "IPv4 broadcast binding to wildcard",
- proto: header.IPv4ProtocolNumber,
- remoteAddr: utils.RemoteIPv4Addr,
- localAddr: utils.Ipv4Addr,
- rxUDP: rxIPv4UDP,
- dstAddr: utils.Ipv4SubnetBcast,
- expectRx: true,
- },
-
- {
- name: "IPv4 all-systems multicast binding to all-systems multicast",
- proto: header.IPv4ProtocolNumber,
- remoteAddr: utils.RemoteIPv4Addr,
- localAddr: utils.Ipv4Addr,
- rxUDP: rxIPv4UDP,
- bindAddr: header.IPv4AllSystems,
- dstAddr: header.IPv4AllSystems,
- expectRx: true,
- },
- {
- name: "IPv4 all-systems multicast binding to wildcard",
- proto: header.IPv4ProtocolNumber,
- remoteAddr: utils.RemoteIPv4Addr,
- localAddr: utils.Ipv4Addr,
- rxUDP: rxIPv4UDP,
- dstAddr: header.IPv4AllSystems,
- expectRx: true,
- },
- {
- name: "IPv4 all-systems multicast binding to unicast",
- proto: header.IPv4ProtocolNumber,
- remoteAddr: utils.RemoteIPv4Addr,
- localAddr: utils.Ipv4Addr,
- rxUDP: rxIPv4UDP,
- bindAddr: utils.Ipv4Addr.Address,
- dstAddr: header.IPv4AllSystems,
- expectRx: false,
- },
-
- // IPv6 has no notion of a broadcast.
- {
- name: "IPv6 unicast binding to wildcard",
- dstAddr: utils.Ipv6Addr.Address,
- proto: header.IPv6ProtocolNumber,
- remoteAddr: utils.RemoteIPv6Addr,
- localAddr: utils.Ipv6Addr,
- rxUDP: rxIPv6UDP,
- expectRx: true,
- },
- {
- name: "IPv6 broadcast-like address binding to wildcard",
- dstAddr: utils.Ipv6SubnetBcast,
- proto: header.IPv6ProtocolNumber,
- remoteAddr: utils.RemoteIPv6Addr,
- localAddr: utils.Ipv6Addr,
- rxUDP: rxIPv6UDP,
- expectRx: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
- e := channel.New(0, defaultMTU, "")
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
- protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr}
- if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
- }
-
- var wq waiter.Queue
- ep, err := s.NewEndpoint(udp.ProtocolNumber, test.proto, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err)
- }
- defer ep.Close()
-
- bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: utils.LocalPort}
- if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
- }
-
- test.rxUDP(e, test.remoteAddr, test.dstAddr, data)
- var buf bytes.Buffer
- var opts tcpip.ReadOptions
- if res, err := ep.Read(&buf, opts); test.expectRx {
- if err != nil {
- t.Fatalf("ep.Read(_, %#v): %s", opts, err)
- }
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: buf.Len(),
- Total: buf.Len(),
- }, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
- t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(data, buf.Bytes()); diff != "" {
- t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
- }
- } else if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), &tcpip.ErrWouldBlock{})
- }
- })
- }
-}
-
-// TestReuseAddrAndBroadcast makes sure broadcast packets are received by all
-// interested endpoints.
-func TestReuseAddrAndBroadcast(t *testing.T) {
- const (
- nicID = 1
- localPort = 9000
- )
- loopbackBroadcast := testutil.MustParse4("127.255.255.255")
-
- tests := []struct {
- name string
- broadcastAddr tcpip.Address
- }{
- {
- name: "Subnet directed broadcast",
- broadcastAddr: loopbackBroadcast,
- },
- {
- name: "IPv4 broadcast",
- broadcastAddr: header.IPv4Broadcast,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
- if err := s.CreateNIC(nicID, loopback.New()); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
- protoAddr := tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: "\x7f\x00\x00\x01",
- PrefixLen: 8,
- },
- }
- if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- // We use the empty subnet instead of just the loopback subnet so we
- // also have a route to the IPv4 Broadcast address.
- Destination: header.IPv4EmptySubnet,
- NIC: nicID,
- },
- })
-
- type endpointAndWaiter struct {
- ep tcpip.Endpoint
- ch chan struct{}
- }
- var eps []endpointAndWaiter
- // We create endpoints that bind to both the wildcard address and the
- // broadcast address to make sure both of these types of "broadcast
- // interested" endpoints receive broadcast packets.
- for _, bindWildcard := range []bool{false, true} {
- // Create multiple endpoints for each type of "broadcast interested"
- // endpoint so we can test that all endpoints receive the broadcast
- // packet.
- for i := 0; i < 2; i++ {
- var wq waiter.Queue
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("(eps[%d]) NewEndpoint(%d, %d, _): %s", len(eps), udp.ProtocolNumber, ipv4.ProtocolNumber, err)
- }
- defer ep.Close()
-
- ep.SocketOptions().SetReuseAddress(true)
- ep.SocketOptions().SetBroadcast(true)
-
- bindAddr := tcpip.FullAddress{Port: localPort}
- if bindWildcard {
- if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("eps[%d].Bind(%#v): %s", len(eps), bindAddr, err)
- }
- } else {
- bindAddr.Addr = test.broadcastAddr
- if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("eps[%d].Bind(%#v): %s", len(eps), bindAddr, err)
- }
- }
-
- eps = append(eps, endpointAndWaiter{ep: ep, ch: ch})
- }
- }
-
- for i, wep := range eps {
- writeOpts := tcpip.WriteOptions{
- To: &tcpip.FullAddress{
- Addr: test.broadcastAddr,
- Port: localPort,
- },
- }
- data := []byte{byte(i), 2, 3, 4}
- var r bytes.Reader
- r.Reset(data)
- if n, err := wep.ep.Write(&r, writeOpts); err != nil {
- t.Fatalf("eps[%d].Write(_, _): %s", i, err)
- } else if want := int64(len(data)); n != want {
- t.Fatalf("got eps[%d].Write(_, _) = (%d, nil), want = (%d, nil)", i, n, want)
- }
-
- for j, rep := range eps {
- // Wait for the endpoint to become readable.
- <-rep.ch
-
- var buf bytes.Buffer
- result, err := rep.ep.Read(&buf, tcpip.ReadOptions{})
- if err != nil {
- t.Errorf("(eps[%d] write) eps[%d].Read: %s", i, j, err)
- continue
- }
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: buf.Len(),
- Total: buf.Len(),
- }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
- t.Errorf("(eps[%d] write) eps[%d].Read: unexpected result (-want +got):\n%s", i, j, diff)
- }
- if diff := cmp.Diff([]byte(data), buf.Bytes()); diff != "" {
- t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff)
- }
- }
- }
- })
- }
-}
-
-func TestUDPAddRemoveMembershipSocketOption(t *testing.T) {
- const (
- nicID = 1
- )
-
- data := []byte{1, 2, 3, 4}
-
- tests := []struct {
- name string
- proto tcpip.NetworkProtocolNumber
- remoteAddr tcpip.Address
- localAddr tcpip.AddressWithPrefix
- rxUDP func(*channel.Endpoint, tcpip.Address, tcpip.Address, []byte)
- multicastAddr tcpip.Address
- }{
- {
- name: "IPv4 unicast binding to unicast",
- multicastAddr: "\xe0\x01\x02\x03",
- proto: header.IPv4ProtocolNumber,
- remoteAddr: utils.RemoteIPv4Addr,
- localAddr: utils.Ipv4Addr,
- rxUDP: rxIPv4UDP,
- },
- {
- name: "IPv6 broadcast-like address binding to wildcard",
- multicastAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04",
- proto: header.IPv6ProtocolNumber,
- remoteAddr: utils.RemoteIPv6Addr,
- localAddr: utils.Ipv6Addr,
- rxUDP: rxIPv6UDP,
- },
- }
-
- subTests := []struct {
- name string
- specifyNICID bool
- specifyNICAddr bool
- }{
- {
- name: "Specify NIC ID and NIC address",
- specifyNICID: true,
- specifyNICAddr: true,
- },
- {
- name: "Don't specify NIC ID or NIC address",
- specifyNICID: false,
- specifyNICAddr: false,
- },
- {
- name: "Specify NIC ID but don't specify NIC address",
- specifyNICID: true,
- specifyNICAddr: false,
- },
- {
- name: "Don't specify NIC ID but specify NIC address",
- specifyNICID: false,
- specifyNICAddr: true,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- for _, subTest := range subTests {
- t.Run(subTest.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
- e := channel.New(0, defaultMTU, "")
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
- protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr}
- if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
- }
-
- // Set the route table so that UDP can find a NIC that is
- // routable to the multicast address when the NIC isn't specified.
- if !subTest.specifyNICID && !subTest.specifyNICAddr {
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv6EmptySubnet,
- NIC: nicID,
- },
- {
- Destination: header.IPv4EmptySubnet,
- NIC: nicID,
- },
- })
- }
-
- var wq waiter.Queue
- ep, err := s.NewEndpoint(udp.ProtocolNumber, test.proto, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err)
- }
- defer ep.Close()
-
- bindAddr := tcpip.FullAddress{Port: utils.LocalPort}
- if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
- }
-
- memOpt := tcpip.MembershipOption{MulticastAddr: test.multicastAddr}
- if subTest.specifyNICID {
- memOpt.NIC = nicID
- }
- if subTest.specifyNICAddr {
- memOpt.InterfaceAddr = test.localAddr.Address
- }
-
- // We should receive UDP packets to the group once we join the
- // multicast group.
- addOpt := tcpip.AddMembershipOption(memOpt)
- if err := ep.SetSockOpt(&addOpt); err != nil {
- t.Fatalf("ep.SetSockOpt(&%#v): %s", addOpt, err)
- }
- test.rxUDP(e, test.remoteAddr, test.multicastAddr, data)
- var buf bytes.Buffer
- result, err := ep.Read(&buf, tcpip.ReadOptions{})
- if err != nil {
- t.Fatalf("ep.Read: %s", err)
- } else {
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: buf.Len(),
- Total: buf.Len(),
- }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
- t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(data, buf.Bytes()); diff != "" {
- t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
- }
- }
-
- // We should not receive UDP packets to the group once we leave
- // the multicast group.
- removeOpt := tcpip.RemoveMembershipOption(memOpt)
- if err := ep.SetSockOpt(&removeOpt); err != nil {
- t.Fatalf("ep.SetSockOpt(&%#v): %s", removeOpt, err)
- }
- {
- _, err := ep.Read(&buf, tcpip.ReadOptions{})
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, &tcpip.ErrWouldBlock{})
- }
- }
- })
- }
- })
- }
-}
diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go
deleted file mode 100644
index 422eb8408..000000000
--- a/pkg/tcpip/tests/integration/route_test.go
+++ /dev/null
@@ -1,441 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package route_test
-
-import (
- "bytes"
- "fmt"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/tests/utils"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-// TestLocalPing tests pinging a remote that is local the stack.
-//
-// This tests that a local route is created and packets do not leave the stack.
-func TestLocalPing(t *testing.T) {
- const (
- nicID = 1
-
- // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo
- // request/reply packets.
- icmpDataOffset = 8
- )
- ipv4Loopback := tcpip.AddressWithPrefix{
- Address: testutil.MustParse4("127.0.0.1"),
- PrefixLen: 8,
- }
-
- channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") }
- channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) {
- channelEP := e.(*channel.Endpoint)
- if n := channelEP.Drain(); n != 0 {
- t.Fatalf("got channelEP.Drain() = %d, want = 0", n)
- }
- }
-
- ipv4ICMPBuf := func(t *testing.T) buffer.View {
- data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
- hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data)))
- hdr.SetType(header.ICMPv4Echo)
- if n := copy(hdr.Payload(), data[:]); n != len(data) {
- t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
- }
- return buffer.View(hdr)
- }
-
- ipv6ICMPBuf := func(t *testing.T) buffer.View {
- data := [8]byte{1, 2, 3, 4, 5, 6, 7, 9}
- hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data)))
- hdr.SetType(header.ICMPv6EchoRequest)
- if n := copy(hdr.Payload(), data[:]); n != len(data) {
- t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
- }
- return buffer.View(hdr)
- }
-
- tests := []struct {
- name string
- transProto tcpip.TransportProtocolNumber
- netProto tcpip.NetworkProtocolNumber
- linkEndpoint func() stack.LinkEndpoint
- localAddr tcpip.AddressWithPrefix
- icmpBuf func(*testing.T) buffer.View
- expectedConnectErr tcpip.Error
- checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint)
- }{
- {
- name: "IPv4 loopback",
- transProto: icmp.ProtocolNumber4,
- netProto: ipv4.ProtocolNumber,
- linkEndpoint: loopback.New,
- localAddr: ipv4Loopback,
- icmpBuf: ipv4ICMPBuf,
- checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
- },
- {
- name: "IPv6 loopback",
- transProto: icmp.ProtocolNumber6,
- netProto: ipv6.ProtocolNumber,
- linkEndpoint: loopback.New,
- localAddr: header.IPv6Loopback.WithPrefix(),
- icmpBuf: ipv6ICMPBuf,
- checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
- },
- {
- name: "IPv4 non-loopback",
- transProto: icmp.ProtocolNumber4,
- netProto: ipv4.ProtocolNumber,
- linkEndpoint: channelEP,
- localAddr: utils.Ipv4Addr,
- icmpBuf: ipv4ICMPBuf,
- checkLinkEndpoint: channelEPCheck,
- },
- {
- name: "IPv6 non-loopback",
- transProto: icmp.ProtocolNumber6,
- netProto: ipv6.ProtocolNumber,
- linkEndpoint: channelEP,
- localAddr: utils.Ipv6Addr,
- icmpBuf: ipv6ICMPBuf,
- checkLinkEndpoint: channelEPCheck,
- },
- {
- name: "IPv4 loopback without local address",
- transProto: icmp.ProtocolNumber4,
- netProto: ipv4.ProtocolNumber,
- linkEndpoint: loopback.New,
- icmpBuf: ipv4ICMPBuf,
- expectedConnectErr: &tcpip.ErrNoRoute{},
- checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
- },
- {
- name: "IPv6 loopback without local address",
- transProto: icmp.ProtocolNumber6,
- netProto: ipv6.ProtocolNumber,
- linkEndpoint: loopback.New,
- icmpBuf: ipv6ICMPBuf,
- expectedConnectErr: &tcpip.ErrNoRoute{},
- checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
- },
- {
- name: "IPv4 non-loopback without local address",
- transProto: icmp.ProtocolNumber4,
- netProto: ipv4.ProtocolNumber,
- linkEndpoint: channelEP,
- icmpBuf: ipv4ICMPBuf,
- expectedConnectErr: &tcpip.ErrNoRoute{},
- checkLinkEndpoint: channelEPCheck,
- },
- {
- name: "IPv6 non-loopback without local address",
- transProto: icmp.ProtocolNumber6,
- netProto: ipv6.ProtocolNumber,
- linkEndpoint: channelEP,
- icmpBuf: ipv6ICMPBuf,
- expectedConnectErr: &tcpip.ErrNoRoute{},
- checkLinkEndpoint: channelEPCheck,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- for _, allowExternalLoopback := range []bool{true, false} {
- t.Run(fmt.Sprintf("AllowExternalLoopback=%t", allowExternalLoopback), func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{
- ipv4.NewProtocolWithOptions(ipv4.Options{
- AllowExternalLoopbackTraffic: allowExternalLoopback,
- }),
- ipv6.NewProtocolWithOptions(ipv6.Options{
- AllowExternalLoopbackTraffic: allowExternalLoopback,
- }),
- },
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
- HandleLocal: true,
- })
- e := test.linkEndpoint()
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
-
- if len(test.localAddr.Address) != 0 {
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: test.netProto,
- AddressWithPrefix: test.localAddr,
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- }
-
- var wq waiter.Queue
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq)
- if err != nil {
- t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err)
- }
- defer ep.Close()
-
- connAddr := tcpip.FullAddress{Addr: test.localAddr.Address}
- if err := ep.Connect(connAddr); err != test.expectedConnectErr {
- t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr)
- }
-
- if test.expectedConnectErr != nil {
- return
- }
-
- var r bytes.Reader
- payload := test.icmpBuf(t)
- r.Reset(payload)
- var wOpts tcpip.WriteOptions
- if n, err := ep.Write(&r, wOpts); err != nil {
- t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err)
- } else if n != int64(len(payload)) {
- t.Fatalf("got ep.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", payload, wOpts, n, len(payload))
- }
-
- // Wait for the endpoint to become readable.
- <-ch
-
- var w bytes.Buffer
- rr, err := ep.Read(&w, tcpip.ReadOptions{
- NeedRemoteAddr: true,
- })
- if err != nil {
- t.Fatalf("ep.Read(...): %s", err)
- }
- if diff := cmp.Diff(buffer.View(w.Bytes()[icmpDataOffset:]), payload[icmpDataOffset:]); diff != "" {
- t.Errorf("received data mismatch (-want +got):\n%s", diff)
- }
- if rr.RemoteAddr.Addr != test.localAddr.Address {
- t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr.Address)
- }
-
- test.checkLinkEndpoint(t, e)
- })
- }
- })
- }
-}
-
-// TestLocalUDP tests sending UDP packets between two endpoints that are local
-// to the stack.
-//
-// This tests that that packets never leave the stack and the addresses
-// used when sending a packet.
-func TestLocalUDP(t *testing.T) {
- const (
- nicID = 1
- )
-
- tests := []struct {
- name string
- canBePrimaryAddr tcpip.ProtocolAddress
- firstPrimaryAddr tcpip.ProtocolAddress
- }{
- {
- name: "IPv4",
- canBePrimaryAddr: utils.Ipv4Addr1,
- firstPrimaryAddr: utils.Ipv4Addr2,
- },
- {
- name: "IPv6",
- canBePrimaryAddr: utils.Ipv6Addr1,
- firstPrimaryAddr: utils.Ipv6Addr2,
- },
- }
-
- subTests := []struct {
- name string
- addAddress bool
- expectedWriteErr tcpip.Error
- }{
- {
- name: "Unassigned local address",
- addAddress: false,
- expectedWriteErr: &tcpip.ErrNoRoute{},
- },
- {
- name: "Assigned local address",
- addAddress: true,
- expectedWriteErr: nil,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- for _, subTest := range subTests {
- t.Run(subTest.name, func(t *testing.T) {
- stackOpts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- HandleLocal: true,
- }
-
- s := stack.New(stackOpts)
- ep := channel.New(1, header.IPv6MinimumMTU, "")
-
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
-
- if subTest.addAddress {
- if err := s.AddProtocolAddress(nicID, test.canBePrimaryAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, test.canBePrimaryAddr, err)
- }
- properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
- if err := s.AddProtocolAddress(nicID, test.firstPrimaryAddr, properties); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %+v, %+v): %s", nicID, test.firstPrimaryAddr, properties, err)
- }
- }
-
- var serverWQ waiter.Queue
- serverWE, serverCH := waiter.NewChannelEntry(nil)
- serverWQ.EventRegister(&serverWE, waiter.ReadableEvents)
- server, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &serverWQ)
- if err != nil {
- t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err)
- }
- defer server.Close()
-
- bindAddr := tcpip.FullAddress{Port: 80}
- if err := server.Bind(bindAddr); err != nil {
- t.Fatalf("server.Bind(%#v): %s", bindAddr, err)
- }
-
- var clientWQ waiter.Queue
- clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.ReadableEvents)
- client, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &clientWQ)
- if err != nil {
- t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err)
- }
- defer client.Close()
-
- serverAddr := tcpip.FullAddress{
- Addr: test.canBePrimaryAddr.AddressWithPrefix.Address,
- Port: 80,
- }
-
- clientPayload := []byte{1, 2, 3, 4}
- {
- var r bytes.Reader
- r.Reset(clientPayload)
- wOpts := tcpip.WriteOptions{
- To: &serverAddr,
- }
- if n, err := client.Write(&r, wOpts); err != subTest.expectedWriteErr {
- t.Fatalf("got client.Write(%#v, %#v) = (%d, %s), want = (_, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr)
- } else if subTest.expectedWriteErr != nil {
- // Nothing else to test if we expected not to be able to send the
- // UDP packet.
- return
- } else if n != int64(len(clientPayload)) {
- t.Fatalf("got client.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", clientPayload, wOpts, n, len(clientPayload))
- }
- }
-
- // Wait for the server endpoint to become readable.
- <-serverCH
-
- var clientAddr tcpip.FullAddress
- var readBuf bytes.Buffer
- if read, err := server.Read(&readBuf, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil {
- t.Fatalf("server.Read(_): %s", err)
- } else {
- clientAddr = read.RemoteAddr
-
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: readBuf.Len(),
- Total: readBuf.Len(),
- RemoteAddr: tcpip.FullAddress{
- Addr: test.canBePrimaryAddr.AddressWithPrefix.Address,
- },
- }, read, checker.IgnoreCmpPath(
- "ControlMessages",
- "RemoteAddr.NIC",
- "RemoteAddr.Port",
- )); diff != "" {
- t.Errorf("server.Read: unexpected result (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(buffer.View(clientPayload), buffer.View(readBuf.Bytes())); diff != "" {
- t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-
- serverPayload := []byte{1, 2, 3, 4}
- {
- var r bytes.Reader
- r.Reset(serverPayload)
- wOpts := tcpip.WriteOptions{
- To: &clientAddr,
- }
- if n, err := server.Write(&r, wOpts); err != nil {
- t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err)
- } else if n != int64(len(serverPayload)) {
- t.Fatalf("got server.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", serverPayload, wOpts, n, len(serverPayload))
- }
- }
-
- // Wait for the client endpoint to become readable.
- <-clientCH
-
- readBuf.Reset()
- if read, err := client.Read(&readBuf, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil {
- t.Fatalf("client.Read(_): %s", err)
- } else {
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: readBuf.Len(),
- Total: readBuf.Len(),
- RemoteAddr: tcpip.FullAddress{Addr: serverAddr.Addr},
- }, read, checker.IgnoreCmpPath(
- "ControlMessages",
- "RemoteAddr.NIC",
- "RemoteAddr.Port",
- )); diff != "" {
- t.Errorf("client.Read: unexpected result (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(buffer.View(serverPayload), buffer.View(readBuf.Bytes())); diff != "" {
- t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
- })
- }
- })
- }
-}