summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/BUILD98
-rw-r--r--pkg/tcpip/adapters/gonet/BUILD37
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_state_autogen.go3
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go803
-rw-r--r--pkg/tcpip/buffer/BUILD25
-rw-r--r--pkg/tcpip/buffer/buffer_state_autogen.go39
-rw-r--r--pkg/tcpip/buffer/buffer_unsafe_state_autogen.go3
-rw-r--r--pkg/tcpip/buffer/view_test.go629
-rw-r--r--pkg/tcpip/checker/BUILD17
-rw-r--r--pkg/tcpip/checker/checker.go1638
-rw-r--r--pkg/tcpip/faketime/BUILD21
-rw-r--r--pkg/tcpip/faketime/faketime_state_autogen.go3
-rw-r--r--pkg/tcpip/faketime/faketime_test.go95
-rw-r--r--pkg/tcpip/hash/jenkins/BUILD18
-rw-r--r--pkg/tcpip/hash/jenkins/jenkins_state_autogen.go3
-rw-r--r--pkg/tcpip/hash/jenkins/jenkins_test.go176
-rw-r--r--pkg/tcpip/header/BUILD76
-rw-r--r--pkg/tcpip/header/checksum_test.go461
-rw-r--r--pkg/tcpip/header/eth_test.go150
-rw-r--r--pkg/tcpip/header/header_state_autogen.go74
-rw-r--r--pkg/tcpip/header/igmp_test.go110
-rw-r--r--pkg/tcpip/header/ipv4_test.go254
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers_test.go1346
-rw-r--r--pkg/tcpip/header/ipv6_test.go457
-rw-r--r--pkg/tcpip/header/ipversion_test.go67
-rw-r--r--pkg/tcpip/header/mld_test.go61
-rw-r--r--pkg/tcpip/header/ndp_test.go1748
-rw-r--r--pkg/tcpip/header/parse/BUILD15
-rw-r--r--pkg/tcpip/header/parse/parse_state_autogen.go3
-rw-r--r--pkg/tcpip/header/tcp_test.go168
-rw-r--r--pkg/tcpip/internal/tcp/BUILD12
-rw-r--r--pkg/tcpip/internal/tcp/tcp_state_autogen.go36
-rw-r--r--pkg/tcpip/link/channel/BUILD15
-rw-r--r--pkg/tcpip/link/channel/channel_state_autogen.go36
-rw-r--r--pkg/tcpip/link/ethernet/BUILD29
-rw-r--r--pkg/tcpip/link/ethernet/ethernet_state_autogen.go3
-rw-r--r--pkg/tcpip/link/ethernet/ethernet_test.go71
-rw-r--r--pkg/tcpip/link/fdbased/BUILD40
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go624
-rw-r--r--pkg/tcpip/link/fdbased/fdbased_state_autogen.go9
-rw-r--r--pkg/tcpip/link/fdbased/fdbased_unsafe_state_autogen.go7
-rw-r--r--pkg/tcpip/link/loopback/BUILD15
-rw-r--r--pkg/tcpip/link/loopback/loopback_state_autogen.go3
-rw-r--r--pkg/tcpip/link/muxed/BUILD29
-rw-r--r--pkg/tcpip/link/muxed/injectable_test.go101
-rw-r--r--pkg/tcpip/link/muxed/muxed_state_autogen.go3
-rw-r--r--pkg/tcpip/link/nested/BUILD31
-rw-r--r--pkg/tcpip/link/nested/nested_state_autogen.go3
-rw-r--r--pkg/tcpip/link/nested/nested_test.go109
-rw-r--r--pkg/tcpip/link/pipe/BUILD15
-rw-r--r--pkg/tcpip/link/pipe/pipe_state_autogen.go3
-rw-r--r--pkg/tcpip/link/qdisc/fifo/BUILD19
-rw-r--r--pkg/tcpip/link/qdisc/fifo/fifo_state_autogen.go3
-rw-r--r--pkg/tcpip/link/rawfile/BUILD33
-rw-r--r--pkg/tcpip/link/rawfile/errors_test.go55
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_state_autogen.go6
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe_state_autogen.go11
-rw-r--r--pkg/tcpip/link/sharedmem/BUILD43
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/BUILD23
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/pipe_state_autogen.go3
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/pipe_test.go512
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/pipe_unsafe_state_autogen.go3
-rw-r--r--pkg/tcpip/link/sharedmem/queue/BUILD27
-rw-r--r--pkg/tcpip/link/sharedmem/queue/queue_state_autogen.go3
-rw-r--r--pkg/tcpip/link/sharedmem/queue/queue_test.go517
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_state_autogen.go6
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go815
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_unsafe_state_autogen.go3
-rw-r--r--pkg/tcpip/link/sniffer/BUILD21
-rw-r--r--pkg/tcpip/link/sniffer/sniffer_state_autogen.go3
-rw-r--r--pkg/tcpip/link/tun/BUILD42
-rw-r--r--pkg/tcpip/link/tun/tun_endpoint_refs.go140
-rw-r--r--pkg/tcpip/link/tun/tun_state_autogen.go68
-rw-r--r--pkg/tcpip/link/tun/tun_unsafe_state_autogen.go6
-rw-r--r--pkg/tcpip/link/waitable/BUILD30
-rw-r--r--pkg/tcpip/link/waitable/waitable_state_autogen.go3
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go187
-rw-r--r--pkg/tcpip/network/BUILD32
-rw-r--r--pkg/tcpip/network/arp/BUILD53
-rw-r--r--pkg/tcpip/network/arp/arp_state_autogen.go3
-rw-r--r--pkg/tcpip/network/arp/arp_test.go688
-rw-r--r--pkg/tcpip/network/arp/stats_test.go51
-rw-r--r--pkg/tcpip/network/hash/BUILD13
-rw-r--r--pkg/tcpip/network/hash/hash_state_autogen.go3
-rw-r--r--pkg/tcpip/network/internal/fragmentation/BUILD54
-rw-r--r--pkg/tcpip/network/internal/fragmentation/fragmentation_state_autogen.go68
-rw-r--r--pkg/tcpip/network/internal/fragmentation/fragmentation_test.go648
-rw-r--r--pkg/tcpip/network/internal/fragmentation/reassembler_list.go221
-rw-r--r--pkg/tcpip/network/internal/fragmentation/reassembler_test.go233
-rw-r--r--pkg/tcpip/network/internal/ip/BUILD40
-rw-r--r--pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go381
-rw-r--r--pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go808
-rw-r--r--pkg/tcpip/network/internal/ip/ip_state_autogen.go32
-rw-r--r--pkg/tcpip/network/internal/testutil/BUILD21
-rw-r--r--pkg/tcpip/network/internal/testutil/testutil.go134
-rw-r--r--pkg/tcpip/network/ip_test.go2147
-rw-r--r--pkg/tcpip/network/ipv4/BUILD67
-rw-r--r--pkg/tcpip/network/ipv4/igmp_test.go401
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_state_autogen.go113
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go3375
-rw-r--r--pkg/tcpip/network/ipv4/stats_test.go99
-rw-r--r--pkg/tcpip/network/ipv6/BUILD72
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go1758
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_state_autogen.go136
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go3523
-rw-r--r--pkg/tcpip/network/ipv6/mld_test.go620
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go1365
-rw-r--r--pkg/tcpip/network/multicast_group_test.go1285
-rw-r--r--pkg/tcpip/ports/BUILD28
-rw-r--r--pkg/tcpip/ports/ports_state_autogen.go42
-rw-r--r--pkg/tcpip/ports/ports_test.go525
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/BUILD21
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go224
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/BUILD21
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go235
-rw-r--r--pkg/tcpip/seqnum/BUILD9
-rw-r--r--pkg/tcpip/seqnum/seqnum_state_autogen.go3
-rw-r--r--pkg/tcpip/sock_err_list.go221
-rw-r--r--pkg/tcpip/stack/BUILD153
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state_test.go61
-rw-r--r--pkg/tcpip/stack/forwarding_test.go804
-rw-r--r--pkg/tcpip/stack/ndp_test.go5614
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go1584
-rw-r--r--pkg/tcpip/stack/neighbor_entry_list.go221
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go2269
-rw-r--r--pkg/tcpip/stack/nic_test.go219
-rw-r--r--pkg/tcpip/stack/nud_test.go816
-rw-r--r--pkg/tcpip/stack/packet_buffer_list.go221
-rw-r--r--pkg/tcpip/stack/packet_buffer_test.go675
-rw-r--r--pkg/tcpip/stack/stack_state_autogen.go1288
-rw-r--r--pkg/tcpip/stack/stack_test.go4671
-rw-r--r--pkg/tcpip/stack/stack_unsafe_state_autogen.go3
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go454
-rw-r--r--pkg/tcpip/stack/transport_test.go576
-rw-r--r--pkg/tcpip/stack/tuple_list.go221
-rw-r--r--pkg/tcpip/tcpip_state_autogen.go1361
-rw-r--r--pkg/tcpip/tcpip_test.go325
-rw-r--r--pkg/tcpip/tests/integration/BUILD141
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go698
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go1158
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go1640
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go782
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go723
-rw-r--r--pkg/tcpip/tests/integration/route_test.go441
-rw-r--r--pkg/tcpip/tests/utils/BUILD22
-rw-r--r--pkg/tcpip/tests/utils/utils.go390
-rw-r--r--pkg/tcpip/testutil/BUILD21
-rw-r--r--pkg/tcpip/testutil/testutil.go123
-rw-r--r--pkg/tcpip/testutil/testutil_test.go103
-rw-r--r--pkg/tcpip/testutil/testutil_unsafe.go26
-rw-r--r--pkg/tcpip/timer_test.go353
-rw-r--r--pkg/tcpip/transport/BUILD13
-rw-r--r--pkg/tcpip/transport/icmp/BUILD59
-rw-r--r--pkg/tcpip/transport/icmp/icmp_packet_list.go221
-rw-r--r--pkg/tcpip/transport/icmp/icmp_state_autogen.go170
-rw-r--r--pkg/tcpip/transport/icmp/icmp_test.go239
-rw-r--r--pkg/tcpip/transport/internal/network/BUILD45
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint_test.go318
-rw-r--r--pkg/tcpip/transport/internal/network/network_state_autogen.go110
-rw-r--r--pkg/tcpip/transport/packet/BUILD37
-rw-r--r--pkg/tcpip/transport/packet/packet_list.go221
-rw-r--r--pkg/tcpip/transport/packet/packet_state_autogen.go167
-rw-r--r--pkg/tcpip/transport/raw/BUILD41
-rw-r--r--pkg/tcpip/transport/raw/raw_packet_list.go221
-rw-r--r--pkg/tcpip/transport/raw/raw_state_autogen.go158
-rw-r--r--pkg/tcpip/transport/tcp/BUILD141
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go650
-rw-r--r--pkg/tcpip/transport/tcp/rcv_test.go74
-rw-r--r--pkg/tcpip/transport/tcp/sack_scoreboard_test.go249
-rw-r--r--pkg/tcpip/transport/tcp/segment_test.go69
-rw-r--r--pkg/tcpip/transport/tcp/tcp_endpoint_list.go221
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go559
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go1101
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go704
-rw-r--r--pkg/tcpip/transport/tcp/tcp_segment_list.go221
-rw-r--r--pkg/tcpip/transport/tcp/tcp_state_autogen.go901
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go8602
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go311
-rw-r--r--pkg/tcpip/transport/tcp/tcp_unsafe_state_autogen.go3
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/BUILD26
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go1268
-rw-r--r--pkg/tcpip/transport/tcp/timer_test.go50
-rw-r--r--pkg/tcpip/transport/tcpconntrack/BUILD23
-rw-r--r--pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go511
-rw-r--r--pkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go3
-rw-r--r--pkg/tcpip/transport/transport_state_autogen.go3
-rw-r--r--pkg/tcpip/transport/udp/BUILD68
-rw-r--r--pkg/tcpip/transport/udp/udp_packet_list.go221
-rw-r--r--pkg/tcpip/transport/udp/udp_state_autogen.go194
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go2602
190 files changed, 7681 insertions, 73418 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
deleted file mode 100644
index dbe4506cc..000000000
--- a/pkg/tcpip/BUILD
+++ /dev/null
@@ -1,98 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools:deps.bzl", "deps_test")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-package(licenses = ["notice"])
-
-go_template_instance(
- name = "sock_err_list",
- out = "sock_err_list.go",
- package = "tcpip",
- prefix = "sockError",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*SockError",
- "Linker": "*SockError",
- },
-)
-
-go_library(
- name = "tcpip",
- srcs = [
- "errors.go",
- "sock_err_list.go",
- "socketops.go",
- "stdclock.go",
- "stdclock_state.go",
- "tcpip.go",
- "timer.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/atomicbitops",
- "//pkg/sync",
- "//pkg/tcpip/buffer",
- "//pkg/waiter",
- ],
-)
-
-deps_test(
- name = "netstack_deps_test",
- allowed = [
- # gVisor deps.
- "//pkg/atomicbitops",
- "//pkg/buffer",
- "//pkg/context",
- "//pkg/gohacks",
- "//pkg/goid",
- "//pkg/ilist",
- "//pkg/linewriter",
- "//pkg/log",
- "//pkg/rand",
- "//pkg/sleep",
- "//pkg/state",
- "//pkg/state/wire",
- "//pkg/sync",
- "//pkg/waiter",
-
- # Other deps.
- "@com_github_google_btree//:go_default_library",
- "@org_golang_x_sys//unix:go_default_library",
- "@org_golang_x_time//rate:go_default_library",
- ],
- allowed_prefixes = [
- "//pkg/tcpip",
- "@org_golang_x_sys//internal/unsafeheader",
- ],
- targets = [
- "//pkg/tcpip",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/fdbased",
- "//pkg/tcpip/link/loopback",
- "//pkg/tcpip/link/qdisc/fifo",
- "//pkg/tcpip/link/sniffer",
- "//pkg/tcpip/network/arp",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/icmp",
- "//pkg/tcpip/transport/raw",
- "//pkg/tcpip/transport/tcp",
- "//pkg/tcpip/transport/udp",
- ],
-)
-
-go_test(
- name = "tcpip_test",
- size = "small",
- srcs = ["tcpip_test.go"],
- library = ":tcpip",
- deps = ["@com_github_google_go_cmp//cmp:go_default_library"],
-)
-
-go_test(
- name = "tcpip_x_test",
- size = "small",
- srcs = ["timer_test.go"],
- deps = [":tcpip"],
-)
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD
deleted file mode 100644
index a984f1712..000000000
--- a/pkg/tcpip/adapters/gonet/BUILD
+++ /dev/null
@@ -1,37 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "gonet",
- srcs = ["gonet.go"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/tcp",
- "//pkg/tcpip/transport/udp",
- "//pkg/waiter",
- ],
-)
-
-go_test(
- name = "gonet_test",
- size = "small",
- srcs = ["gonet_test.go"],
- library = ":gonet",
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/loopback",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/tcp",
- "//pkg/tcpip/transport/udp",
- "//pkg/waiter",
- "@org_golang_x_net//nettest:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/adapters/gonet/gonet_state_autogen.go b/pkg/tcpip/adapters/gonet/gonet_state_autogen.go
new file mode 100644
index 000000000..7a5c5419e
--- /dev/null
+++ b/pkg/tcpip/adapters/gonet/gonet_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package gonet
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
deleted file mode 100644
index c8460e63c..000000000
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ /dev/null
@@ -1,803 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package gonet
-
-import (
- "context"
- "fmt"
- "io"
- "net"
- "reflect"
- "strings"
- "testing"
- "time"
-
- "golang.org/x/net/nettest"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- NICID = 1
-)
-
-func TestTimeouts(t *testing.T) {
- nc := NewTCPConn(nil, nil)
- dlfs := []struct {
- name string
- f func(time.Time) error
- }{
- {"SetDeadline", nc.SetDeadline},
- {"SetReadDeadline", nc.SetReadDeadline},
- {"SetWriteDeadline", nc.SetWriteDeadline},
- }
-
- for _, dlf := range dlfs {
- if err := dlf.f(time.Time{}); err != nil {
- t.Errorf("got %s(time.Time{}) = %v, want = %v", dlf.name, err, nil)
- }
- }
-}
-
-func newLoopbackStack() (*stack.Stack, tcpip.Error) {
- // Create the stack and add a NIC.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
- })
-
- if err := s.CreateNIC(NICID, loopback.New()); err != nil {
- return nil, err
- }
-
- // Add default route.
- s.SetRouteTable([]tcpip.Route{
- // IPv4
- {
- Destination: header.IPv4EmptySubnet,
- NIC: NICID,
- },
-
- // IPv6
- {
- Destination: header.IPv6EmptySubnet,
- NIC: NICID,
- },
- })
-
- return s, nil
-}
-
-type testConnection struct {
- wq *waiter.Queue
- e *waiter.Entry
- ch chan struct{}
- ep tcpip.Endpoint
-}
-
-func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, tcpip.Error) {
- wq := &waiter.Queue{}
- ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- return nil, err
- }
-
- entry, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&entry, waiter.WritableEvents)
-
- err = ep.Connect(addr)
- if _, ok := err.(*tcpip.ErrConnectStarted); ok {
- <-ch
- err = ep.LastError()
- }
- if err != nil {
- return nil, err
- }
-
- wq.EventUnregister(&entry)
- wq.EventRegister(&entry, waiter.ReadableEvents)
-
- return &testConnection{wq, &entry, ch, ep}, nil
-}
-
-func (c *testConnection) close() {
- c.wq.EventUnregister(c.e)
- c.ep.Close()
-}
-
-// TestCloseReader tests that Conn.Close() causes Conn.Read() to unblock.
-func TestCloseReader(t *testing.T) {
- s, err := newLoopbackStack()
- if err != nil {
- t.Fatalf("newLoopbackStack() = %v", err)
- }
- defer func() {
- s.Close()
- s.Wait()
- }()
-
- addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: addr.Addr.WithPrefix(),
- }
- if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
- }
-
- l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
- if e != nil {
- t.Fatalf("NewListener() = %v", e)
- }
- done := make(chan struct{})
- go func() {
- defer close(done)
- c, err := l.Accept()
- if err != nil {
- t.Errorf("l.Accept() = %v", err)
- // Cannot call Fatalf in goroutine. Just return from the goroutine.
- return
- }
-
- // Give c.Read() a chance to block before closing the connection.
- time.AfterFunc(time.Millisecond*50, func() {
- c.Close()
- })
-
- buf := make([]byte, 256)
- n, err := c.Read(buf)
- if n != 0 || err != io.EOF {
- t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, err)
- }
- }()
- sender, err := connect(s, addr)
- if err != nil {
- t.Fatalf("connect() = %v", err)
- }
-
- select {
- case <-done:
- case <-time.After(5 * time.Second):
- t.Errorf("c.Read() didn't unblock")
- }
- sender.close()
-}
-
-// TestCloseReaderWithForwarder tests that TCPConn.Close wakes TCPConn.Read when
-// using tcp.Forwarder.
-func TestCloseReaderWithForwarder(t *testing.T) {
- s, err := newLoopbackStack()
- if err != nil {
- t.Fatalf("newLoopbackStack() = %v", err)
- }
- defer func() {
- s.Close()
- s.Wait()
- }()
-
- addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: addr.Addr.WithPrefix(),
- }
- if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
- }
-
- done := make(chan struct{})
-
- fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
- defer close(done)
-
- var wq waiter.Queue
- ep, err := r.CreateEndpoint(&wq)
- if err != nil {
- t.Fatalf("r.CreateEndpoint() = %v", err)
- }
- defer ep.Close()
- r.Complete(false)
-
- c := NewTCPConn(&wq, ep)
-
- // Give c.Read() a chance to block before closing the connection.
- time.AfterFunc(time.Millisecond*50, func() {
- c.Close()
- })
-
- buf := make([]byte, 256)
- n, e := c.Read(buf)
- if n != 0 || e != io.EOF {
- t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, e)
- }
- })
- s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
-
- sender, err := connect(s, addr)
- if err != nil {
- t.Fatalf("connect() = %v", err)
- }
-
- select {
- case <-done:
- case <-time.After(5 * time.Second):
- t.Errorf("c.Read() didn't unblock")
- }
- sender.close()
-}
-
-func TestCloseRead(t *testing.T) {
- s, terr := newLoopbackStack()
- if terr != nil {
- t.Fatalf("newLoopbackStack() = %v", terr)
- }
- defer func() {
- s.Close()
- s.Wait()
- }()
-
- addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: addr.Addr.WithPrefix(),
- }
- if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
- }
-
- fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
- var wq waiter.Queue
- _, err := r.CreateEndpoint(&wq)
- if err != nil {
- t.Fatalf("r.CreateEndpoint() = %v", err)
- }
- // Endpoint will be closed in deferred s.Close (above).
- })
-
- s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
-
- tc, terr := connect(s, addr)
- if terr != nil {
- t.Fatalf("connect() = %v", terr)
- }
- c := NewTCPConn(tc.wq, tc.ep)
-
- if err := c.CloseRead(); err != nil {
- t.Errorf("c.CloseRead() = %v", err)
- }
-
- buf := make([]byte, 256)
- if n, err := c.Read(buf); err != io.EOF {
- t.Errorf("c.Read() = (%d, %v), want (0, io.EOF)", n, err)
- }
-
- if n, err := c.Write([]byte("abc123")); n != 6 || err != nil {
- t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, err)
- }
-}
-
-func TestCloseWrite(t *testing.T) {
- s, terr := newLoopbackStack()
- if terr != nil {
- t.Fatalf("newLoopbackStack() = %v", terr)
- }
- defer func() {
- s.Close()
- s.Wait()
- }()
-
- addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: addr.Addr.WithPrefix(),
- }
- if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
- }
-
- fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
- var wq waiter.Queue
- ep, err := r.CreateEndpoint(&wq)
- if err != nil {
- t.Fatalf("r.CreateEndpoint() = %v", err)
- }
- defer ep.Close()
- r.Complete(false)
-
- c := NewTCPConn(&wq, ep)
-
- n, e := c.Read(make([]byte, 256))
- if n != 0 || e != io.EOF {
- t.Errorf("c.Read() = (%d, %v), want (0, io.EOF)", n, e)
- }
-
- if n, e = c.Write([]byte("abc123")); n != 6 || e != nil {
- t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e)
- }
- })
-
- s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
-
- tc, terr := connect(s, addr)
- if terr != nil {
- t.Fatalf("connect() = %v", terr)
- }
- c := NewTCPConn(tc.wq, tc.ep)
-
- if err := c.CloseWrite(); err != nil {
- t.Errorf("c.CloseWrite() = %v", err)
- }
-
- buf := make([]byte, 256)
- n, err := c.Read(buf)
- if err != nil || string(buf[:n]) != "abc123" {
- t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, err)
- }
-
- n, err = c.Write([]byte("abc123"))
- got, ok := err.(*net.OpError)
- want := "endpoint is closed for send"
- if n != 0 || !ok || got.Op != "write" || got.Err == nil || !strings.HasSuffix(got.Err.Error(), want) {
- t.Errorf("c.Write() = (%d, %v), want (0, OpError(Op: write, Err: %s))", n, err, want)
- }
-}
-
-func TestUDPForwarder(t *testing.T) {
- s, terr := newLoopbackStack()
- if terr != nil {
- t.Fatalf("newLoopbackStack() = %v", terr)
- }
- defer func() {
- s.Close()
- s.Wait()
- }()
-
- ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
- addr1 := tcpip.FullAddress{NICID, ip1, 11211}
- protocolAddr1 := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: ip1.WithPrefix(),
- }
- if err := s.AddProtocolAddress(NICID, protocolAddr1, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr1, err)
- }
- ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4())
- addr2 := tcpip.FullAddress{NICID, ip2, 11311}
- protocolAddr2 := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: ip2.WithPrefix(),
- }
- if err := s.AddProtocolAddress(NICID, protocolAddr2, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr2, err)
- }
-
- done := make(chan struct{})
- fwd := udp.NewForwarder(s, func(r *udp.ForwarderRequest) {
- defer close(done)
-
- var wq waiter.Queue
- ep, err := r.CreateEndpoint(&wq)
- if err != nil {
- t.Fatalf("r.CreateEndpoint() = %v", err)
- }
- defer ep.Close()
-
- c := NewTCPConn(&wq, ep)
-
- buf := make([]byte, 256)
- n, e := c.Read(buf)
- if e != nil {
- t.Errorf("c.Read() = %v", e)
- }
-
- if _, e := c.Write(buf[:n]); e != nil {
- t.Errorf("c.Write() = %v", e)
- }
- })
- s.SetTransportProtocolHandler(udp.ProtocolNumber, fwd.HandlePacket)
-
- c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber)
- if err != nil {
- t.Fatal("DialUDP(bind port 5):", err)
- }
-
- sent := "abc123"
- sendAddr := fullToUDPAddr(addr1)
- if n, err := c2.WriteTo([]byte(sent), sendAddr); err != nil || n != len(sent) {
- t.Errorf("c1.WriteTo(%q, %v) = %d, %v, want = %d, %v", sent, sendAddr, n, err, len(sent), nil)
- }
-
- buf := make([]byte, 256)
- n, recvAddr, err := c2.ReadFrom(buf)
- if err != nil || recvAddr.String() != sendAddr.String() {
- t.Errorf("c1.ReadFrom() = %d, %v, %v, want = %d, %v, %v", n, recvAddr, err, len(sent), sendAddr, nil)
- }
-}
-
-// TestDeadlineChange tests that changing the deadline affects currently blocked reads.
-func TestDeadlineChange(t *testing.T) {
- s, err := newLoopbackStack()
- if err != nil {
- t.Fatalf("newLoopbackStack() = %v", err)
- }
- defer func() {
- s.Close()
- s.Wait()
- }()
-
- addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: addr.Addr.WithPrefix(),
- }
- if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
- }
-
- l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
- if e != nil {
- t.Fatalf("NewListener() = %v", e)
- }
- done := make(chan struct{})
- go func() {
- defer close(done)
- c, err := l.Accept()
- if err != nil {
- t.Errorf("l.Accept() = %v", err)
- // Cannot call Fatalf in goroutine. Just return from the goroutine.
- return
- }
-
- c.SetDeadline(time.Now().Add(time.Minute))
- // Give c.Read() a chance to block before closing the connection.
- time.AfterFunc(time.Millisecond*50, func() {
- c.SetDeadline(time.Now().Add(time.Millisecond * 10))
- })
-
- buf := make([]byte, 256)
- n, err := c.Read(buf)
- got, ok := err.(*net.OpError)
- want := "i/o timeout"
- if n != 0 || !ok || got.Err == nil || got.Err.Error() != want {
- t.Errorf("c.Read() = (%d, %v), want (0, OpError(%s))", n, err, want)
- }
- }()
- sender, err := connect(s, addr)
- if err != nil {
- t.Fatalf("connect() = %v", err)
- }
-
- select {
- case <-done:
- case <-time.After(time.Millisecond * 500):
- t.Errorf("c.Read() didn't unblock")
- }
- sender.close()
-}
-
-func TestPacketConnTransfer(t *testing.T) {
- s, e := newLoopbackStack()
- if e != nil {
- t.Fatalf("newLoopbackStack() = %v", e)
- }
- defer func() {
- s.Close()
- s.Wait()
- }()
-
- ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
- addr1 := tcpip.FullAddress{NICID, ip1, 11211}
- protocolAddr1 := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: ip1.WithPrefix(),
- }
- if err := s.AddProtocolAddress(NICID, protocolAddr1, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr1, err)
- }
- ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4())
- addr2 := tcpip.FullAddress{NICID, ip2, 11311}
- protocolAddr2 := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: ip2.WithPrefix(),
- }
- if err := s.AddProtocolAddress(NICID, protocolAddr2, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr2, err)
- }
-
- c1, err := DialUDP(s, &addr1, nil, ipv4.ProtocolNumber)
- if err != nil {
- t.Fatal("DialUDP(bind port 4):", err)
- }
- c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber)
- if err != nil {
- t.Fatal("DialUDP(bind port 5):", err)
- }
-
- c1.SetDeadline(time.Now().Add(time.Second))
- c2.SetDeadline(time.Now().Add(time.Second))
-
- sent := "abc123"
- sendAddr := fullToUDPAddr(addr2)
- if n, err := c1.WriteTo([]byte(sent), sendAddr); err != nil || n != len(sent) {
- t.Errorf("got c1.WriteTo(%q, %v) = %d, %v, want = %d, %v", sent, sendAddr, n, err, len(sent), nil)
- }
- recv := make([]byte, len(sent))
- n, recvAddr, err := c2.ReadFrom(recv)
- if err != nil || n != len(recv) {
- t.Errorf("got c2.ReadFrom() = %d, %v, want = %d, %v", n, err, len(recv), nil)
- }
-
- if recv := string(recv); recv != sent {
- t.Errorf("got recv = %q, want = %q", recv, sent)
- }
-
- if want := fullToUDPAddr(addr1); !reflect.DeepEqual(recvAddr, want) {
- t.Errorf("got recvAddr = %v, want = %v", recvAddr, want)
- }
-
- if err := c1.Close(); err != nil {
- t.Error("c1.Close():", err)
- }
- if err := c2.Close(); err != nil {
- t.Error("c2.Close():", err)
- }
-}
-
-func TestConnectedPacketConnTransfer(t *testing.T) {
- s, e := newLoopbackStack()
- if e != nil {
- t.Fatalf("newLoopbackStack() = %v", e)
- }
- defer func() {
- s.Close()
- s.Wait()
- }()
-
- ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
- addr := tcpip.FullAddress{NICID, ip, 11211}
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: ip.WithPrefix(),
- }
- if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
- }
-
- c1, err := DialUDP(s, &addr, nil, ipv4.ProtocolNumber)
- if err != nil {
- t.Fatal("DialUDP(bind port 4):", err)
- }
- c2, err := DialUDP(s, nil, &addr, ipv4.ProtocolNumber)
- if err != nil {
- t.Fatal("DialUDP(bind port 5):", err)
- }
-
- c1.SetDeadline(time.Now().Add(time.Second))
- c2.SetDeadline(time.Now().Add(time.Second))
-
- sent := "abc123"
- if n, err := c2.Write([]byte(sent)); err != nil || n != len(sent) {
- t.Errorf("got c2.Write(%q) = %d, %v, want = %d, %v", sent, n, err, len(sent), nil)
- }
- recv := make([]byte, len(sent))
- n, err := c1.Read(recv)
- if err != nil || n != len(recv) {
- t.Errorf("got c1.Read() = %d, %v, want = %d, %v", n, err, len(recv), nil)
- }
-
- if recv := string(recv); recv != sent {
- t.Errorf("got recv = %q, want = %q", recv, sent)
- }
-
- if err := c1.Close(); err != nil {
- t.Error("c1.Close():", err)
- }
- if err := c2.Close(); err != nil {
- t.Error("c2.Close():", err)
- }
-}
-
-func makePipe() (c1, c2 net.Conn, stop func(), err error) {
- s, e := newLoopbackStack()
- if e != nil {
- return nil, nil, nil, fmt.Errorf("newLoopbackStack() = %v", e)
- }
-
- ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
- addr := tcpip.FullAddress{NICID, ip, 11211}
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: ip.WithPrefix(),
- }
- if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
- return nil, nil, nil, fmt.Errorf("AddProtocolAddress(%d, %+v, {}): %w", NICID, protocolAddr, err)
- }
-
- l, err := ListenTCP(s, addr, ipv4.ProtocolNumber)
- if err != nil {
- return nil, nil, nil, fmt.Errorf("NewListener: %w", err)
- }
-
- c1, err = DialTCP(s, addr, ipv4.ProtocolNumber)
- if err != nil {
- l.Close()
- return nil, nil, nil, fmt.Errorf("DialTCP: %w", err)
- }
-
- c2, err = l.Accept()
- if err != nil {
- l.Close()
- c1.Close()
- return nil, nil, nil, fmt.Errorf("l.Accept: %w", err)
- }
-
- stop = func() {
- c1.Close()
- c2.Close()
- s.Close()
- s.Wait()
- }
-
- if err := l.Close(); err != nil {
- stop()
- return nil, nil, nil, fmt.Errorf("l.Close(): %w", err)
- }
-
- return c1, c2, stop, nil
-}
-
-func TestTCPConnTransfer(t *testing.T) {
- c1, c2, _, err := makePipe()
- if err != nil {
- t.Fatal(err)
- }
- defer func() {
- if err := c1.Close(); err != nil {
- t.Error("c1.Close():", err)
- }
- if err := c2.Close(); err != nil {
- t.Error("c2.Close():", err)
- }
- }()
-
- c1.SetDeadline(time.Now().Add(time.Second))
- c2.SetDeadline(time.Now().Add(time.Second))
-
- const sent = "abc123"
-
- tests := []struct {
- name string
- c1 net.Conn
- c2 net.Conn
- }{
- {"connected to accepted", c1, c2},
- {"accepted to connected", c2, c1},
- }
-
- for _, test := range tests {
- if n, err := test.c1.Write([]byte(sent)); err != nil || n != len(sent) {
- t.Errorf("%s: got test.c1.Write(%q) = %d, %v, want = %d, %v", test.name, sent, n, err, len(sent), nil)
- continue
- }
-
- recv := make([]byte, len(sent))
- n, err := test.c2.Read(recv)
- if err != nil || n != len(recv) {
- t.Errorf("%s: got test.c2.Read() = %d, %v, want = %d, %v", test.name, n, err, len(recv), nil)
- continue
- }
-
- if recv := string(recv); recv != sent {
- t.Errorf("%s: got recv = %q, want = %q", test.name, recv, sent)
- }
- }
-}
-
-func TestTCPDialError(t *testing.T) {
- s, e := newLoopbackStack()
- if e != nil {
- t.Fatalf("newLoopbackStack() = %v", e)
- }
- defer func() {
- s.Close()
- s.Wait()
- }()
-
- ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
- addr := tcpip.FullAddress{NICID, ip, 11211}
-
- switch _, err := DialTCP(s, addr, ipv4.ProtocolNumber); err := err.(type) {
- case *net.OpError:
- if err.Err.Error() != (&tcpip.ErrNoRoute{}).String() {
- t.Errorf("got DialTCP() = %s, want = %s", err, &tcpip.ErrNoRoute{})
- }
- default:
- t.Errorf("got DialTCP(...) = %v, want %s", err, &tcpip.ErrNoRoute{})
- }
-}
-
-func TestDialContextTCPCanceled(t *testing.T) {
- s, err := newLoopbackStack()
- if err != nil {
- t.Fatalf("newLoopbackStack() = %v", err)
- }
- defer func() {
- s.Close()
- s.Wait()
- }()
-
- addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: addr.Addr.WithPrefix(),
- }
- if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
- }
-
- ctx := context.Background()
- ctx, cancel := context.WithCancel(ctx)
- cancel()
-
- if _, err := DialContextTCP(ctx, s, addr, ipv4.ProtocolNumber); err != context.Canceled {
- t.Errorf("got DialContextTCP(...) = %v, want = %v", err, context.Canceled)
- }
-}
-
-func TestDialContextTCPTimeout(t *testing.T) {
- s, err := newLoopbackStack()
- if err != nil {
- t.Fatalf("newLoopbackStack() = %v", err)
- }
- defer func() {
- s.Close()
- s.Wait()
- }()
-
- addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: addr.Addr.WithPrefix(),
- }
- if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
- }
-
- fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
- time.Sleep(time.Second)
- r.Complete(true)
- })
- s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
-
- ctx := context.Background()
- ctx, cancel := context.WithDeadline(ctx, time.Now().Add(100*time.Millisecond))
- defer cancel()
-
- if _, err := DialContextTCP(ctx, s, addr, ipv4.ProtocolNumber); err != context.DeadlineExceeded {
- t.Errorf("got DialContextTCP(...) = %v, want = %v", err, context.DeadlineExceeded)
- }
-}
-
-func TestNetTest(t *testing.T) {
- nettest.TestConn(t, makePipe)
-}
diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD
deleted file mode 100644
index 23aa0ad05..000000000
--- a/pkg/tcpip/buffer/BUILD
+++ /dev/null
@@ -1,25 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "buffer",
- srcs = [
- "prependable.go",
- "view.go",
- "view_unsafe.go",
- ],
- visibility = ["//visibility:public"],
-)
-
-go_test(
- name = "buffer_x_test",
- size = "small",
- srcs = [
- "view_test.go",
- ],
- deps = [
- ":buffer",
- "//pkg/tcpip",
- ],
-)
diff --git a/pkg/tcpip/buffer/buffer_state_autogen.go b/pkg/tcpip/buffer/buffer_state_autogen.go
new file mode 100644
index 000000000..51bfbff8a
--- /dev/null
+++ b/pkg/tcpip/buffer/buffer_state_autogen.go
@@ -0,0 +1,39 @@
+// automatically generated by stateify.
+
+package buffer
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (vv *VectorisedView) StateTypeName() string {
+ return "pkg/tcpip/buffer.VectorisedView"
+}
+
+func (vv *VectorisedView) StateFields() []string {
+ return []string{
+ "views",
+ "size",
+ }
+}
+
+func (vv *VectorisedView) beforeSave() {}
+
+// +checklocksignore
+func (vv *VectorisedView) StateSave(stateSinkObject state.Sink) {
+ vv.beforeSave()
+ stateSinkObject.Save(0, &vv.views)
+ stateSinkObject.Save(1, &vv.size)
+}
+
+func (vv *VectorisedView) afterLoad() {}
+
+// +checklocksignore
+func (vv *VectorisedView) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &vv.views)
+ stateSourceObject.Load(1, &vv.size)
+}
+
+func init() {
+ state.Register((*VectorisedView)(nil))
+}
diff --git a/pkg/tcpip/buffer/buffer_unsafe_state_autogen.go b/pkg/tcpip/buffer/buffer_unsafe_state_autogen.go
new file mode 100644
index 000000000..5a5c40722
--- /dev/null
+++ b/pkg/tcpip/buffer/buffer_unsafe_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package buffer
diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go
deleted file mode 100644
index d296d9c2b..000000000
--- a/pkg/tcpip/buffer/view_test.go
+++ /dev/null
@@ -1,629 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package buffer_test contains tests for the buffer.VectorisedView type.
-package buffer_test
-
-import (
- "bytes"
- "io"
- "reflect"
- "testing"
- "unsafe"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
-)
-
-// copy returns a deep-copy of the vectorised view.
-func copyVV(vv buffer.VectorisedView) buffer.VectorisedView {
- views := make([]buffer.View, 0, len(vv.Views()))
- for _, v := range vv.Views() {
- views = append(views, append(buffer.View(nil), v...))
- }
- return buffer.NewVectorisedView(vv.Size(), views)
-}
-
-// vv is an helper to build buffer.VectorisedView from different strings.
-func vv(size int, pieces ...string) buffer.VectorisedView {
- views := make([]buffer.View, len(pieces))
- for i, p := range pieces {
- views[i] = []byte(p)
- }
-
- return buffer.NewVectorisedView(size, views)
-}
-
-// v returns a buffer.View containing piece.
-func v(piece string) buffer.View {
- return buffer.View(piece)
-}
-
-var capLengthTestCases = []struct {
- comment string
- in buffer.VectorisedView
- length int
- want buffer.VectorisedView
-}{
- {
- comment: "Simple case",
- in: vv(2, "12"),
- length: 1,
- want: vv(1, "1"),
- },
- {
- comment: "Case spanning across two Views",
- in: vv(4, "123", "4"),
- length: 2,
- want: vv(2, "12"),
- },
- {
- comment: "Corner case with negative length",
- in: vv(1, "1"),
- length: -1,
- want: vv(0),
- },
- {
- comment: "Corner case with length = 0",
- in: vv(3, "12", "3"),
- length: 0,
- want: vv(0),
- },
- {
- comment: "Corner case with length = size",
- in: vv(1, "1"),
- length: 1,
- want: vv(1, "1"),
- },
- {
- comment: "Corner case with length > size",
- in: vv(1, "1"),
- length: 2,
- want: vv(1, "1"),
- },
-}
-
-func TestCapLength(t *testing.T) {
- for _, c := range capLengthTestCases {
- orig := copyVV(c.in)
- c.in.CapLength(c.length)
- if !reflect.DeepEqual(c.in, c.want) {
- t.Errorf("Test \"%s\" failed when calling CapLength(%d) on %v. Got %v. Want %v",
- c.comment, c.length, orig, c.in, c.want)
- }
- }
-}
-
-var trimFrontTestCases = []struct {
- comment string
- in buffer.VectorisedView
- count int
- want buffer.VectorisedView
-}{
- {
- comment: "Simple case",
- in: vv(2, "12"),
- count: 1,
- want: vv(1, "2"),
- },
- {
- comment: "Case where we trim an entire View",
- in: vv(2, "1", "2"),
- count: 1,
- want: vv(1, "2"),
- },
- {
- comment: "Case spanning across two Views",
- in: vv(3, "1", "23"),
- count: 2,
- want: vv(1, "3"),
- },
- {
- comment: "Case with one empty Views",
- in: vv(3, "1", "", "23"),
- count: 2,
- want: vv(1, "3"),
- },
- {
- comment: "Corner case with negative count",
- in: vv(1, "1"),
- count: -1,
- want: vv(1, "1"),
- },
- {
- comment: " Corner case with count = 0",
- in: vv(1, "1"),
- count: 0,
- want: vv(1, "1"),
- },
- {
- comment: "Corner case with count = size",
- in: vv(1, "1"),
- count: 1,
- want: vv(0),
- },
- {
- comment: "Corner case with count > size",
- in: vv(1, "1"),
- count: 2,
- want: vv(0),
- },
-}
-
-func TestTrimFront(t *testing.T) {
- for _, c := range trimFrontTestCases {
- orig := copyVV(c.in)
- c.in.TrimFront(c.count)
- if !reflect.DeepEqual(c.in, c.want) {
- t.Errorf("Test \"%s\" failed when calling TrimFront(%d) on %v. Got %v. Want %v",
- c.comment, c.count, orig, c.in, c.want)
- }
- }
-}
-
-var toViewCases = []struct {
- comment string
- in buffer.VectorisedView
- want buffer.View
-}{
- {
- comment: "Simple case",
- in: vv(2, "12"),
- want: []byte("12"),
- },
- {
- comment: "Case with multiple views",
- in: vv(2, "1", "2"),
- want: []byte("12"),
- },
- {
- comment: "Empty case",
- in: vv(0),
- want: []byte(""),
- },
-}
-
-func TestToView(t *testing.T) {
- for _, c := range toViewCases {
- got := c.in.ToView()
- if !reflect.DeepEqual(got, c.want) {
- t.Errorf("Test \"%s\" failed when calling ToView() on %v. Got %v. Want %v",
- c.comment, c.in, got, c.want)
- }
- }
-}
-
-var toCloneCases = []struct {
- comment string
- inView buffer.VectorisedView
- inBuffer []buffer.View
-}{
- {
- comment: "Simple case",
- inView: vv(1, "1"),
- inBuffer: make([]buffer.View, 1),
- },
- {
- comment: "Case with multiple views",
- inView: vv(2, "1", "2"),
- inBuffer: make([]buffer.View, 2),
- },
- {
- comment: "Case with buffer too small",
- inView: vv(2, "1", "2"),
- inBuffer: make([]buffer.View, 1),
- },
- {
- comment: "Case with buffer larger than needed",
- inView: vv(1, "1"),
- inBuffer: make([]buffer.View, 2),
- },
- {
- comment: "Case with nil buffer",
- inView: vv(1, "1"),
- inBuffer: nil,
- },
-}
-
-func TestToClone(t *testing.T) {
- for _, c := range toCloneCases {
- t.Run(c.comment, func(t *testing.T) {
- got := c.inView.Clone(c.inBuffer)
- if !reflect.DeepEqual(got, c.inView) {
- t.Fatalf("got (%+v).Clone(%+v) = %+v, want = %+v",
- c.inView, c.inBuffer, got, c.inView)
- }
- })
- }
-}
-
-type readToTestCases struct {
- comment string
- vv buffer.VectorisedView
- bytesToRead int
- wantBytes string
- leftVV buffer.VectorisedView
-}
-
-func createReadToTestCases() []readToTestCases {
- return []readToTestCases{
- {
- comment: "large VV, short read",
- vv: vv(30, "012345678901234567890123456789"),
- bytesToRead: 10,
- wantBytes: "0123456789",
- leftVV: vv(20, "01234567890123456789"),
- },
- {
- comment: "largeVV, multiple views, short read",
- vv: vv(13, "123", "345", "567", "8910"),
- bytesToRead: 6,
- wantBytes: "123345",
- leftVV: vv(7, "567", "8910"),
- },
- {
- comment: "smallVV (multiple views), large read",
- vv: vv(3, "1", "2", "3"),
- bytesToRead: 10,
- wantBytes: "123",
- leftVV: vv(0, ""),
- },
- {
- comment: "smallVV (single view), large read",
- vv: vv(1, "1"),
- bytesToRead: 10,
- wantBytes: "1",
- leftVV: vv(0, ""),
- },
- {
- comment: "emptyVV, large read",
- vv: vv(0, ""),
- bytesToRead: 10,
- wantBytes: "",
- leftVV: vv(0, ""),
- },
- }
-}
-
-func TestVVReadToVV(t *testing.T) {
- for _, tc := range createReadToTestCases() {
- t.Run(tc.comment, func(t *testing.T) {
- var readTo buffer.VectorisedView
- inSize := tc.vv.Size()
- copied := tc.vv.ReadToVV(&readTo, tc.bytesToRead)
- if got, want := copied, len(tc.wantBytes); got != want {
- t.Errorf("incorrect number of bytes copied returned in ReadToVV got: %d, want: %d, tc: %+v", got, want, tc)
- }
- if got, want := string(readTo.ToView()), tc.wantBytes; got != want {
- t.Errorf("unexpected content in readTo got: %s, want: %s", got, want)
- }
- if got, want := tc.vv.Size(), inSize-copied; got != want {
- t.Errorf("test VV has incorrect size after reading got: %d, want: %d, tc.vv: %+v", got, want, tc.vv)
- }
- if got, want := string(tc.vv.ToView()), string(tc.leftVV.ToView()); got != want {
- t.Errorf("unexpected data left in vv after read got: %+v, want: %+v", got, want)
- }
- })
- }
-}
-
-func TestVVReadTo(t *testing.T) {
- for _, tc := range createReadToTestCases() {
- t.Run(tc.comment, func(t *testing.T) {
- b := make([]byte, tc.bytesToRead)
- dst := tcpip.SliceWriter(b)
- origSize := tc.vv.Size()
- copied, err := tc.vv.ReadTo(&dst, false /* peek */)
- if err != nil && err != io.ErrShortWrite {
- t.Errorf("got ReadTo(&dst, false) = (_, %s); want nil or io.ErrShortWrite", err)
- }
- if got, want := copied, len(tc.wantBytes); got != want {
- t.Errorf("got ReadTo(&dst, false) = (%d, _); want %d", got, want)
- }
- if got, want := string(b[:copied]), tc.wantBytes; got != want {
- t.Errorf("got dst = %q, want %q", got, want)
- }
- if got, want := tc.vv.Size(), origSize-copied; got != want {
- t.Errorf("got after-read tc.vv.Size() = %d, want %d", got, want)
- }
- if got, want := string(tc.vv.ToView()), string(tc.leftVV.ToView()); got != want {
- t.Errorf("got after-read data in tc.vv = %q, want %q", got, want)
- }
- })
- }
-}
-
-func TestVVReadToPeek(t *testing.T) {
- for _, tc := range createReadToTestCases() {
- t.Run(tc.comment, func(t *testing.T) {
- b := make([]byte, tc.bytesToRead)
- dst := tcpip.SliceWriter(b)
- origSize := tc.vv.Size()
- origData := string(tc.vv.ToView())
- copied, err := tc.vv.ReadTo(&dst, true /* peek */)
- if err != nil && err != io.ErrShortWrite {
- t.Errorf("got ReadTo(&dst, true) = (_, %s); want nil or io.ErrShortWrite", err)
- }
- if got, want := copied, len(tc.wantBytes); got != want {
- t.Errorf("got ReadTo(&dst, true) = (%d, _); want %d", got, want)
- }
- if got, want := string(b[:copied]), tc.wantBytes; got != want {
- t.Errorf("got dst = %q, want %q", got, want)
- }
- // Expect tc.vv is unchanged.
- if got, want := tc.vv.Size(), origSize; got != want {
- t.Errorf("got after-read tc.vv.Size() = %d, want %d", got, want)
- }
- if got, want := string(tc.vv.ToView()), origData; got != want {
- t.Errorf("got after-read data in tc.vv = %q, want %q", got, want)
- }
- })
- }
-}
-
-func TestVVRead(t *testing.T) {
- testCases := []struct {
- comment string
- vv buffer.VectorisedView
- bytesToRead int
- readBytes string
- leftBytes string
- wantError bool
- }{
- {
- comment: "large VV, short read",
- vv: vv(30, "012345678901234567890123456789"),
- bytesToRead: 10,
- readBytes: "0123456789",
- leftBytes: "01234567890123456789",
- },
- {
- comment: "largeVV, multiple buffers, short read",
- vv: vv(13, "123", "345", "567", "8910"),
- bytesToRead: 6,
- readBytes: "123345",
- leftBytes: "5678910",
- },
- {
- comment: "smallVV, large read",
- vv: vv(3, "1", "2", "3"),
- bytesToRead: 10,
- readBytes: "123",
- leftBytes: "",
- },
- {
- comment: "smallVV, large read",
- vv: vv(1, "1"),
- bytesToRead: 10,
- readBytes: "1",
- leftBytes: "",
- },
- {
- comment: "emptyVV, large read",
- vv: vv(0, ""),
- bytesToRead: 10,
- readBytes: "",
- wantError: true,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.comment, func(t *testing.T) {
- readTo := buffer.NewView(tc.bytesToRead)
- inSize := tc.vv.Size()
- copied, err := tc.vv.Read(readTo)
- if !tc.wantError && err != nil {
- t.Fatalf("unexpected error in tc.vv.Read(..) = %s", err)
- }
- readTo = readTo[:copied]
- if got, want := copied, len(tc.readBytes); got != want {
- t.Errorf("incorrect number of bytes copied returned in ReadToVV got: %d, want: %d, tc.vv: %+v", got, want, tc.vv)
- }
- if got, want := string(readTo), tc.readBytes; got != want {
- t.Errorf("unexpected data in readTo got: %s, want: %s", got, want)
- }
- if got, want := tc.vv.Size(), inSize-copied; got != want {
- t.Errorf("test VV has incorrect size after reading got: %d, want: %d, tc.vv: %+v", got, want, tc.vv)
- }
- if got, want := string(tc.vv.ToView()), tc.leftBytes; got != want {
- t.Errorf("vv has incorrect data after Read got: %s, want: %s", got, want)
- }
- })
- }
-}
-
-var pullUpTestCases = []struct {
- comment string
- in buffer.VectorisedView
- count int
- want []byte
- result buffer.VectorisedView
- ok bool
-}{
- {
- comment: "simple case",
- in: vv(2, "12"),
- count: 1,
- want: []byte("1"),
- result: vv(2, "12"),
- ok: true,
- },
- {
- comment: "entire View",
- in: vv(2, "1", "2"),
- count: 1,
- want: []byte("1"),
- result: vv(2, "1", "2"),
- ok: true,
- },
- {
- comment: "spanning across two Views",
- in: vv(3, "1", "23"),
- count: 2,
- want: []byte("12"),
- result: vv(3, "12", "3"),
- ok: true,
- },
- {
- comment: "spanning across all Views",
- in: vv(5, "1", "23", "45"),
- count: 5,
- want: []byte("12345"),
- result: vv(5, "12345"),
- ok: true,
- },
- {
- comment: "count = 0",
- in: vv(1, "1"),
- count: 0,
- want: []byte{},
- result: vv(1, "1"),
- ok: true,
- },
- {
- comment: "count = size",
- in: vv(1, "1"),
- count: 1,
- want: []byte("1"),
- result: vv(1, "1"),
- ok: true,
- },
- {
- comment: "count too large",
- in: vv(3, "1", "23"),
- count: 4,
- want: nil,
- result: vv(3, "1", "23"),
- ok: false,
- },
- {
- comment: "empty vv",
- in: vv(0, ""),
- count: 1,
- want: nil,
- result: vv(0, ""),
- ok: false,
- },
- {
- comment: "empty vv, count = 0",
- in: vv(0, ""),
- count: 0,
- want: nil,
- result: vv(0, ""),
- ok: true,
- },
- {
- comment: "empty views",
- in: vv(3, "", "1", "", "23"),
- count: 2,
- want: []byte("12"),
- result: vv(3, "12", "3"),
- ok: true,
- },
-}
-
-func TestPullUp(t *testing.T) {
- for _, c := range pullUpTestCases {
- got, ok := c.in.PullUp(c.count)
-
- // Is the return value right?
- if ok != c.ok {
- t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t",
- c.comment, c.count, c.in, ok, c.ok)
- }
- if bytes.Compare(got, buffer.View(c.want)) != 0 {
- t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v",
- c.comment, c.count, c.in, got, c.want)
- }
-
- // Is the underlying structure right?
- if !reflect.DeepEqual(c.in, c.result) {
- t.Errorf("Test %q failed when calling PullUp(%d). Got vv with structure %v. Wanted %v",
- c.comment, c.count, c.in, c.result)
- }
- }
-}
-
-func TestToVectorisedView(t *testing.T) {
- testCases := []struct {
- in buffer.View
- want buffer.VectorisedView
- }{
- {nil, buffer.VectorisedView{}},
- {buffer.View{}, buffer.VectorisedView{}},
- {buffer.View{'a'}, buffer.NewVectorisedView(1, []buffer.View{{'a'}})},
- }
- for _, tc := range testCases {
- if got, want := tc.in.ToVectorisedView(), tc.want; !reflect.DeepEqual(got, want) {
- t.Errorf("(%v).ToVectorisedView failed got: %+v, want: %+v", tc.in, got, want)
- }
- }
-}
-
-func TestAppendView(t *testing.T) {
- testCases := []struct {
- vv buffer.VectorisedView
- in buffer.View
- want buffer.VectorisedView
- }{
- {vv(0), nil, vv(0)},
- {vv(0), v(""), vv(0)},
- {vv(4, "abcd"), nil, vv(4, "abcd")},
- {vv(4, "abcd"), v(""), vv(4, "abcd")},
- {vv(4, "abcd"), v("e"), vv(5, "abcd", "e")},
- }
- for _, tc := range testCases {
- tc.vv.AppendView(tc.in)
- if got, want := tc.vv, tc.want; !reflect.DeepEqual(got, want) {
- t.Errorf("(%v).ToVectorisedView failed got: %+v, want: %+v", tc.in, got, want)
- }
- }
-}
-
-func TestAppendViews(t *testing.T) {
- testCases := []struct {
- vv buffer.VectorisedView
- in []buffer.View
- want buffer.VectorisedView
- }{
- {vv(0), nil, vv(0)},
- {vv(0), []buffer.View{}, vv(0)},
- {vv(0), []buffer.View{v("")}, vv(0, "")},
- {vv(4, "abcd"), nil, vv(4, "abcd")},
- {vv(4, "abcd"), []buffer.View{}, vv(4, "abcd")},
- {vv(4, "abcd"), []buffer.View{v("")}, vv(4, "abcd", "")},
- {vv(4, "abcd"), []buffer.View{v("")}, vv(4, "abcd", "")},
- {vv(4, "abcd"), []buffer.View{v("e")}, vv(5, "abcd", "e")},
- {vv(4, "abcd"), []buffer.View{v("e"), v("fg")}, vv(7, "abcd", "e", "fg")},
- {vv(4, "abcd"), []buffer.View{v(""), v("fg")}, vv(6, "abcd", "", "fg")},
- }
- for _, tc := range testCases {
- tc.vv.AppendViews(tc.in)
- if got, want := tc.vv, tc.want; !reflect.DeepEqual(got, want) {
- t.Errorf("(%v).ToVectorisedView failed got: %+v, want: %+v", tc.in, got, want)
- }
- }
-}
-
-func TestMemSize(t *testing.T) {
- const perViewCap = 128
- views := make([]buffer.View, 2, 32)
- views[0] = make(buffer.View, 10, perViewCap)
- views[1] = make(buffer.View, 20, perViewCap)
- vv := buffer.NewVectorisedView(30, views)
- want := int(unsafe.Sizeof(vv)) + cap(views)*int(unsafe.Sizeof(views)) + 2*perViewCap
- if got := vv.MemSize(); got != want {
- t.Errorf("vv.MemSize() = %d, want %d", got, want)
- }
-}
diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD
deleted file mode 100644
index c984470e6..000000000
--- a/pkg/tcpip/checker/BUILD
+++ /dev/null
@@ -1,17 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "checker",
- testonly = 1,
- srcs = ["checker.go"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/seqnum",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
deleted file mode 100644
index 24c2c3e6b..000000000
--- a/pkg/tcpip/checker/checker.go
+++ /dev/null
@@ -1,1638 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package checker provides helper functions to check networking packets for
-// validity.
-package checker
-
-import (
- "encoding/binary"
- "reflect"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
-)
-
-// NetworkChecker is a function to check a property of a network packet.
-type NetworkChecker func(*testing.T, []header.Network)
-
-// TransportChecker is a function to check a property of a transport packet.
-type TransportChecker func(*testing.T, header.Transport)
-
-// ControlMessagesChecker is a function to check a property of ancillary data.
-type ControlMessagesChecker func(*testing.T, tcpip.ControlMessages)
-
-// IPv4 checks the validity and properties of the given IPv4 packet. It is
-// expected to be used in conjunction with other network checkers for specific
-// properties. For example, to check the source and destination address, one
-// would call:
-//
-// checker.IPv4(t, b, checker.SrcAddr(x), checker.DstAddr(y))
-func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) {
- t.Helper()
-
- ipv4 := header.IPv4(b)
-
- if !ipv4.IsValid(len(b)) {
- t.Fatalf("Not a valid IPv4 packet: %x", ipv4)
- }
-
- if !ipv4.IsChecksumValid() {
- t.Errorf("Bad checksum, got = %d", ipv4.Checksum())
- }
-
- for _, f := range checkers {
- f(t, []header.Network{ipv4})
- }
- if t.Failed() {
- t.FailNow()
- }
-}
-
-// IPv6 checks the validity and properties of the given IPv6 packet. The usage
-// is similar to IPv4.
-func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) {
- t.Helper()
-
- ipv6 := header.IPv6(b)
- if !ipv6.IsValid(len(b)) {
- t.Fatalf("Not a valid IPv6 packet: %x", ipv6)
- }
-
- for _, f := range checkers {
- f(t, []header.Network{ipv6})
- }
- if t.Failed() {
- t.FailNow()
- }
-}
-
-// SrcAddr creates a checker that checks the source address.
-func SrcAddr(addr tcpip.Address) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- if a := h[0].SourceAddress(); a != addr {
- t.Errorf("Bad source address, got %v, want %v", a, addr)
- }
- }
-}
-
-// DstAddr creates a checker that checks the destination address.
-func DstAddr(addr tcpip.Address) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- if a := h[0].DestinationAddress(); a != addr {
- t.Errorf("Bad destination address, got %v, want %v", a, addr)
- }
- }
-}
-
-// TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6).
-func TTL(ttl uint8) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- var v uint8
- switch ip := h[0].(type) {
- case header.IPv4:
- v = ip.TTL()
- case header.IPv6:
- v = ip.HopLimit()
- case *ipv6HeaderWithExtHdr:
- v = ip.HopLimit()
- default:
- t.Fatalf("unrecognized header type %T for TTL evaluation", ip)
- }
- if v != ttl {
- t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl)
- }
- }
-}
-
-// IPFullLength creates a checker for the full IP packet length. The
-// expected size is checked against both the Total Length in the
-// header and the number of bytes received.
-func IPFullLength(packetLength uint16) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- var v uint16
- var l uint16
- switch ip := h[0].(type) {
- case header.IPv4:
- v = ip.TotalLength()
- l = uint16(len(ip))
- case header.IPv6:
- v = ip.PayloadLength() + header.IPv6FixedHeaderSize
- l = uint16(len(ip))
- default:
- t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4 or header.IPv6", ip)
- }
- if l != packetLength {
- t.Errorf("bad packet length, got = %d, want = %d", l, packetLength)
- }
- if v != packetLength {
- t.Errorf("unexpected packet length in header, got = %d, want = %d", v, packetLength)
- }
- }
-}
-
-// IPv4HeaderLength creates a checker that checks the IPv4 Header length.
-func IPv4HeaderLength(headerLength int) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- switch ip := h[0].(type) {
- case header.IPv4:
- if hl := ip.HeaderLength(); hl != uint8(headerLength) {
- t.Errorf("Bad header length, got = %d, want = %d", hl, headerLength)
- }
- default:
- t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", ip)
- }
- }
-}
-
-// PayloadLen creates a checker that checks the payload length.
-func PayloadLen(payloadLength int) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- if l := len(h[0].Payload()); l != payloadLength {
- t.Errorf("Bad payload length, got = %d, want = %d", l, payloadLength)
- }
- }
-}
-
-// IPPayload creates a checker that checks the payload.
-func IPPayload(payload []byte) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- got := h[0].Payload()
-
- // cmp.Diff does not consider nil slices equal to empty slices, but we do.
- if len(got) == 0 && len(payload) == 0 {
- return
- }
-
- if diff := cmp.Diff(payload, got); diff != "" {
- t.Errorf("payload mismatch (-want +got):\n%s", diff)
- }
- }
-}
-
-// IPv4Options returns a checker that checks the options in an IPv4 packet.
-func IPv4Options(want header.IPv4Options) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- ip, ok := h[0].(header.IPv4)
- if !ok {
- t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0])
- }
- options := ip.Options()
- // cmp.Diff does not consider nil slices equal to empty slices, but we do.
- if len(want) == 0 && len(options) == 0 {
- return
- }
- if diff := cmp.Diff(want, options); diff != "" {
- t.Errorf("options mismatch (-want +got):\n%s", diff)
- }
- }
-}
-
-// IPv4RouterAlert returns a checker that checks that the RouterAlert option is
-// set in an IPv4 packet.
-func IPv4RouterAlert() NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
- ip, ok := h[0].(header.IPv4)
- if !ok {
- t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0])
- }
- iterator := ip.Options().MakeIterator()
- for {
- opt, done, err := iterator.Next()
- if err != nil {
- t.Fatalf("error acquiring next IPv4 option at offset %d", err.Pointer)
- }
- if done {
- break
- }
- if opt.Type() != header.IPv4OptionRouterAlertType {
- continue
- }
- want := [header.IPv4OptionRouterAlertLength]byte{
- byte(header.IPv4OptionRouterAlertType),
- header.IPv4OptionRouterAlertLength,
- header.IPv4OptionRouterAlertValue,
- header.IPv4OptionRouterAlertValue,
- }
- if diff := cmp.Diff(want[:], opt.Contents()); diff != "" {
- t.Errorf("router alert option mismatch (-want +got):\n%s", diff)
- }
- return
- }
- t.Errorf("failed to find router alert option in %v", ip.Options())
- }
-}
-
-// FragmentOffset creates a checker that checks the FragmentOffset field.
-func FragmentOffset(offset uint16) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- // We only do this for IPv4 for now.
- switch ip := h[0].(type) {
- case header.IPv4:
- if v := ip.FragmentOffset(); v != offset {
- t.Errorf("Bad fragment offset, got = %d, want = %d", v, offset)
- }
- }
- }
-}
-
-// FragmentFlags creates a checker that checks the fragment flags field.
-func FragmentFlags(flags uint8) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- // We only do this for IPv4 for now.
- switch ip := h[0].(type) {
- case header.IPv4:
- if v := ip.Flags(); v != flags {
- t.Errorf("Bad fragment offset, got = %d, want = %d", v, flags)
- }
- }
- }
-}
-
-// ReceiveTClass creates a checker that checks the TCLASS field in
-// ControlMessages.
-func ReceiveTClass(want uint32) ControlMessagesChecker {
- return func(t *testing.T, cm tcpip.ControlMessages) {
- t.Helper()
- if !cm.HasTClass {
- t.Errorf("got cm.HasTClass = %t, want = true", cm.HasTClass)
- } else if got := cm.TClass; got != want {
- t.Errorf("got cm.TClass = %d, want %d", got, want)
- }
- }
-}
-
-// ReceiveTOS creates a checker that checks the TOS field in ControlMessages.
-func ReceiveTOS(want uint8) ControlMessagesChecker {
- return func(t *testing.T, cm tcpip.ControlMessages) {
- t.Helper()
- if !cm.HasTOS {
- t.Errorf("got cm.HasTOS = %t, want = true", cm.HasTOS)
- } else if got := cm.TOS; got != want {
- t.Errorf("got cm.TOS = %d, want %d", got, want)
- }
- }
-}
-
-// ReceiveIPPacketInfo creates a checker that checks the PacketInfo field in
-// ControlMessages.
-func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker {
- return func(t *testing.T, cm tcpip.ControlMessages) {
- t.Helper()
- if !cm.HasIPPacketInfo {
- t.Errorf("got cm.HasIPPacketInfo = %t, want = true", cm.HasIPPacketInfo)
- } else if diff := cmp.Diff(want, cm.PacketInfo); diff != "" {
- t.Errorf("IPPacketInfo mismatch (-want +got):\n%s", diff)
- }
- }
-}
-
-// ReceiveIPv6PacketInfo creates a checker that checks the IPv6PacketInfo field
-// in ControlMessages.
-func ReceiveIPv6PacketInfo(want tcpip.IPv6PacketInfo) ControlMessagesChecker {
- return func(t *testing.T, cm tcpip.ControlMessages) {
- t.Helper()
- if !cm.HasIPv6PacketInfo {
- t.Errorf("got cm.HasIPv6PacketInfo = %t, want = true", cm.HasIPv6PacketInfo)
- } else if diff := cmp.Diff(want, cm.IPv6PacketInfo); diff != "" {
- t.Errorf("IPv6PacketInfo mismatch (-want +got):\n%s", diff)
- }
- }
-}
-
-// ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress
-// field in ControlMessages.
-func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker {
- return func(t *testing.T, cm tcpip.ControlMessages) {
- t.Helper()
- if !cm.HasOriginalDstAddress {
- t.Errorf("got cm.HasOriginalDstAddress = %t, want = true", cm.HasOriginalDstAddress)
- } else if diff := cmp.Diff(want, cm.OriginalDstAddress); diff != "" {
- t.Errorf("OriginalDstAddress mismatch (-want +got):\n%s", diff)
- }
- }
-}
-
-// TOS creates a checker that checks the TOS field.
-func TOS(tos uint8, label uint32) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- if v, l := h[0].TOS(); v != tos || l != label {
- t.Errorf("Bad TOS, got = (%d, %d), want = (%d,%d)", v, l, tos, label)
- }
- }
-}
-
-// Raw creates a checker that checks the bytes of payload.
-// The checker always checks the payload of the last network header.
-// For instance, in case of IPv6 fragments, the payload that will be checked
-// is the one containing the actual data that the packet is carrying, without
-// the bytes added by the IPv6 fragmentation.
-func Raw(want []byte) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- if got := h[len(h)-1].Payload(); !reflect.DeepEqual(got, want) {
- t.Errorf("Wrong payload, got %v, want %v", got, want)
- }
- }
-}
-
-// IPv6Fragment creates a checker that validates an IPv6 fragment.
-func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader {
- t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber)
- }
-
- ipv6Frag := header.IPv6Fragment(h[0].Payload())
- if !ipv6Frag.IsValid() {
- t.Error("Not a valid IPv6 fragment")
- }
-
- for _, f := range checkers {
- f(t, []header.Network{h[0], ipv6Frag})
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// TCP creates a checker that checks that the transport protocol is TCP and
-// potentially additional transport header fields.
-func TCP(checkers ...TransportChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- first := h[0]
- last := h[len(h)-1]
-
- if p := last.TransportProtocol(); p != header.TCPProtocolNumber {
- t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber)
- }
-
- tcp := header.TCP(last.Payload())
- payload := tcp.Payload()
- payloadChecksum := header.Checksum(payload, 0)
- if !tcp.IsChecksumValid(first.SourceAddress(), first.DestinationAddress(), payloadChecksum, uint16(len(payload))) {
- t.Errorf("Bad checksum, got = %d", tcp.Checksum())
- }
-
- // Run the transport checkers.
- for _, f := range checkers {
- f(t, tcp)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// UDP creates a checker that checks that the transport protocol is UDP and
-// potentially additional transport header fields.
-func UDP(checkers ...TransportChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- last := h[len(h)-1]
-
- if p := last.TransportProtocol(); p != header.UDPProtocolNumber {
- t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber)
- }
-
- udp := header.UDP(last.Payload())
- for _, f := range checkers {
- f(t, udp)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// SrcPort creates a checker that checks the source port.
-func SrcPort(port uint16) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- if p := h.SourcePort(); p != port {
- t.Errorf("Bad source port, got = %d, want = %d", p, port)
- }
- }
-}
-
-// DstPort creates a checker that checks the destination port.
-func DstPort(port uint16) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- if p := h.DestinationPort(); p != port {
- t.Errorf("Bad destination port, got = %d, want = %d", p, port)
- }
- }
-}
-
-// NoChecksum creates a checker that checks if the checksum is zero.
-func NoChecksum(noChecksum bool) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- udp, ok := h.(header.UDP)
- if !ok {
- t.Fatalf("UDP header not found in h: %T", h)
- }
-
- if b := udp.Checksum() == 0; b != noChecksum {
- t.Errorf("bad checksum state, got %t, want %t", b, noChecksum)
- }
- }
-}
-
-// TCPSeqNum creates a checker that checks the sequence number.
-func TCPSeqNum(seq uint32) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- tcp, ok := h.(header.TCP)
- if !ok {
- t.Fatalf("TCP header not found in h: %T", h)
- }
-
- if s := tcp.SequenceNumber(); s != seq {
- t.Errorf("Bad sequence number, got = %d, want = %d", s, seq)
- }
- }
-}
-
-// TCPAckNum creates a checker that checks the ack number.
-func TCPAckNum(seq uint32) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- tcp, ok := h.(header.TCP)
- if !ok {
- t.Fatalf("TCP header not found in h: %T", h)
- }
-
- if s := tcp.AckNumber(); s != seq {
- t.Errorf("Bad ack number, got = %d, want = %d", s, seq)
- }
- }
-}
-
-// TCPWindow creates a checker that checks the tcp window.
-func TCPWindow(window uint16) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- tcp, ok := h.(header.TCP)
- if !ok {
- t.Fatalf("TCP header not found in hdr : %T", h)
- }
-
- if w := tcp.WindowSize(); w != window {
- t.Errorf("Bad window, got %d, want %d", w, window)
- }
- }
-}
-
-// TCPWindowGreaterThanEq creates a checker that checks that the TCP window
-// is greater than or equal to the provided value.
-func TCPWindowGreaterThanEq(window uint16) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- tcp, ok := h.(header.TCP)
- if !ok {
- t.Fatalf("TCP header not found in h: %T", h)
- }
-
- if w := tcp.WindowSize(); w < window {
- t.Errorf("Bad window, got %d, want > %d", w, window)
- }
- }
-}
-
-// TCPWindowLessThanEq creates a checker that checks that the tcp window
-// is less than or equal to the provided value.
-func TCPWindowLessThanEq(window uint16) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- tcp, ok := h.(header.TCP)
- if !ok {
- t.Fatalf("TCP header not found in h: %T", h)
- }
-
- if w := tcp.WindowSize(); w > window {
- t.Errorf("Bad window, got %d, want < %d", w, window)
- }
- }
-}
-
-// TCPFlags creates a checker that checks the tcp flags.
-func TCPFlags(flags header.TCPFlags) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- tcp, ok := h.(header.TCP)
- if !ok {
- t.Fatalf("TCP header not found in h: %T", h)
- }
-
- if got := tcp.Flags(); got != flags {
- t.Errorf("got tcp.Flags() = %s, want %s", got, flags)
- }
- }
-}
-
-// TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the
-// given mask, match the supplied flags.
-func TCPFlagsMatch(flags, mask header.TCPFlags) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- tcp, ok := h.(header.TCP)
- if !ok {
- t.Fatalf("TCP header not found in h: %T", h)
- }
-
- if got := tcp.Flags(); (got & mask) != (flags & mask) {
- t.Errorf("got tcp.Flags() = %s, want %s, mask %s", got, flags, mask)
- }
- }
-}
-
-// TCPSynOptions creates a checker that checks the presence of TCP options in
-// SYN segments.
-//
-// If wndscale is negative, the window scale option must not be present.
-func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- tcp, ok := h.(header.TCP)
- if !ok {
- return
- }
- opts := tcp.Options()
- limit := len(opts)
- foundMSS := false
- foundWS := false
- foundTS := false
- foundSACKPermitted := false
- tsVal := uint32(0)
- tsEcr := uint32(0)
- for i := 0; i < limit; {
- switch opts[i] {
- case header.TCPOptionEOL:
- i = limit
- case header.TCPOptionNOP:
- i++
- case header.TCPOptionMSS:
- v := uint16(opts[i+2])<<8 | uint16(opts[i+3])
- if wantOpts.MSS != v {
- t.Errorf("Bad MSS, got = %d, want = %d", v, wantOpts.MSS)
- }
- foundMSS = true
- i += 4
- case header.TCPOptionWS:
- if wantOpts.WS < 0 {
- t.Error("WS present when it shouldn't be")
- }
- v := int(opts[i+2])
- if v != wantOpts.WS {
- t.Errorf("Bad WS, got = %d, want = %d", v, wantOpts.WS)
- }
- foundWS = true
- i += 3
- case header.TCPOptionTS:
- if i+9 >= limit {
- t.Errorf("TS Option truncated , option is only: %d bytes, want 10", limit-i)
- }
- if opts[i+1] != 10 {
- t.Errorf("Bad length %d for TS option, limit: %d", opts[i+1], limit)
- }
- tsVal = binary.BigEndian.Uint32(opts[i+2:])
- tsEcr = uint32(0)
- if tcp.Flags()&header.TCPFlagAck != 0 {
- // If the syn is an SYN-ACK then read
- // the tsEcr value as well.
- tsEcr = binary.BigEndian.Uint32(opts[i+6:])
- }
- foundTS = true
- i += 10
- case header.TCPOptionSACKPermitted:
- if i+1 >= limit {
- t.Errorf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i)
- }
- if opts[i+1] != 2 {
- t.Errorf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit)
- }
- foundSACKPermitted = true
- i += 2
-
- default:
- i += int(opts[i+1])
- }
- }
-
- if !foundMSS {
- t.Errorf("MSS option not found. Options: %x", opts)
- }
-
- if !foundWS && wantOpts.WS >= 0 {
- t.Errorf("WS option not found. Options: %x", opts)
- }
- if wantOpts.TS && !foundTS {
- t.Errorf("TS option not found. Options: %x", opts)
- }
- if foundTS && tsVal == 0 {
- t.Error("TS option specified but the timestamp value is zero")
- }
- if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 {
- t.Errorf("TS option specified but TSEcr is incorrect, got = %d, want = %d", tsEcr, wantOpts.TSEcr)
- }
- if wantOpts.SACKPermitted && !foundSACKPermitted {
- t.Errorf("SACKPermitted option not found. Options: %x", opts)
- }
- }
-}
-
-// TCPTimestampChecker creates a checker that validates that a TCP segment has a
-// TCP Timestamp option if wantTS is true, it also compares the wantTSVal and
-// wantTSEcr values with those in the TCP segment (if present).
-//
-// If wantTSVal or wantTSEcr is zero then the corresponding comparison is
-// skipped.
-func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- tcp, ok := h.(header.TCP)
- if !ok {
- return
- }
- opts := tcp.Options()
- limit := len(opts)
- foundTS := false
- tsVal := uint32(0)
- tsEcr := uint32(0)
- for i := 0; i < limit; {
- switch opts[i] {
- case header.TCPOptionEOL:
- i = limit
- case header.TCPOptionNOP:
- i++
- case header.TCPOptionTS:
- if i+9 >= limit {
- t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i)
- }
- if opts[i+1] != 10 {
- t.Errorf("TS option found, but bad length specified: got = %d, want = 10", opts[i+1])
- }
- tsVal = binary.BigEndian.Uint32(opts[i+2:])
- tsEcr = binary.BigEndian.Uint32(opts[i+6:])
- foundTS = true
- i += 10
- default:
- // We don't recognize this option, just skip over it.
- if i+2 > limit {
- return
- }
- l := int(opts[i+1])
- if l < 2 || i+l > limit {
- return
- }
- i += l
- }
- }
-
- if wantTS != foundTS {
- t.Errorf("TS Option mismatch, got TS= %t, want TS= %t", foundTS, wantTS)
- }
- if wantTS && wantTSVal != 0 && wantTSVal != tsVal {
- t.Errorf("Timestamp value is incorrect, got = %d, want = %d", tsVal, wantTSVal)
- }
- if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr {
- t.Errorf("Timestamp Echo Reply is incorrect, got = %d, want = %d", tsEcr, wantTSEcr)
- }
- }
-}
-
-// TCPSACKBlockChecker creates a checker that verifies that the segment does
-// contain the specified SACK blocks in the TCP options.
-func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
- tcp, ok := h.(header.TCP)
- if !ok {
- return
- }
- var gotSACKBlocks []header.SACKBlock
-
- opts := tcp.Options()
- limit := len(opts)
- for i := 0; i < limit; {
- switch opts[i] {
- case header.TCPOptionEOL:
- i = limit
- case header.TCPOptionNOP:
- i++
- case header.TCPOptionSACK:
- if i+2 > limit {
- // Malformed SACK block.
- t.Errorf("malformed SACK option in options: %v", opts)
- }
- sackOptionLen := int(opts[i+1])
- if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 {
- // Malformed SACK block.
- t.Errorf("malformed SACK option length in options: %v", opts)
- }
- numBlocks := sackOptionLen / 8
- for j := 0; j < numBlocks; j++ {
- start := binary.BigEndian.Uint32(opts[i+2+j*8:])
- end := binary.BigEndian.Uint32(opts[i+2+j*8+4:])
- gotSACKBlocks = append(gotSACKBlocks, header.SACKBlock{
- Start: seqnum.Value(start),
- End: seqnum.Value(end),
- })
- }
- i += sackOptionLen
- default:
- // We don't recognize this option, just skip over it.
- if i+2 > limit {
- break
- }
- l := int(opts[i+1])
- if l < 2 || i+l > limit {
- break
- }
- i += l
- }
- }
-
- if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) {
- t.Errorf("SACKBlocks are not equal, got = %v, want = %v", gotSACKBlocks, sackBlocks)
- }
- }
-}
-
-// Payload creates a checker that checks the payload.
-func Payload(want []byte) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- if got := h.Payload(); !reflect.DeepEqual(got, want) {
- t.Errorf("Wrong payload, got %v, want %v", got, want)
- }
- }
-}
-
-// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4
-// and potentially additional ICMPv4 header fields.
-func ICMPv4(checkers ...TransportChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- last := h[len(h)-1]
-
- if p := last.TransportProtocol(); p != header.ICMPv4ProtocolNumber {
- t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv4ProtocolNumber)
- }
-
- icmp := header.ICMPv4(last.Payload())
- for _, f := range checkers {
- f(t, icmp)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// ICMPv4Type creates a checker that checks the ICMPv4 Type field.
-func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmpv4, ok := h.(header.ICMPv4)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
- }
- if got := icmpv4.Type(); got != want {
- t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want)
- }
- }
-}
-
-// ICMPv4Code creates a checker that checks the ICMPv4 Code field.
-func ICMPv4Code(want header.ICMPv4Code) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmpv4, ok := h.(header.ICMPv4)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
- }
- if got := icmpv4.Code(); got != want {
- t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want)
- }
- }
-}
-
-// ICMPv4Ident creates a checker that checks the ICMPv4 echo Ident.
-func ICMPv4Ident(want uint16) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmpv4, ok := h.(header.ICMPv4)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
- }
- if got := icmpv4.Ident(); got != want {
- t.Fatalf("unexpected ICMP ident, got = %d, want = %d", got, want)
- }
- }
-}
-
-// ICMPv4Seq creates a checker that checks the ICMPv4 echo Sequence.
-func ICMPv4Seq(want uint16) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmpv4, ok := h.(header.ICMPv4)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
- }
- if got := icmpv4.Sequence(); got != want {
- t.Fatalf("unexpected ICMP sequence, got = %d, want = %d", got, want)
- }
- }
-}
-
-// ICMPv4Pointer creates a checker that checks the ICMPv4 Param Problem pointer.
-func ICMPv4Pointer(want uint8) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmpv4, ok := h.(header.ICMPv4)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
- }
- if got := icmpv4.Pointer(); got != want {
- t.Fatalf("unexpected ICMP Param Problem pointer, got = %d, want = %d", got, want)
- }
- }
-}
-
-// ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum.
-// This assumes that the payload exactly makes up the rest of the slice.
-func ICMPv4Checksum() TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmpv4, ok := h.(header.ICMPv4)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
- }
- heldChecksum := icmpv4.Checksum()
- icmpv4.SetChecksum(0)
- newChecksum := ^header.Checksum(icmpv4, 0)
- icmpv4.SetChecksum(heldChecksum)
- if heldChecksum != newChecksum {
- t.Errorf("unexpected ICMP checksum, got = %d, want = %d", heldChecksum, newChecksum)
- }
- }
-}
-
-// ICMPv4Payload creates a checker that checks the payload in an ICMPv4 packet.
-func ICMPv4Payload(want []byte) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmpv4, ok := h.(header.ICMPv4)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
- }
- payload := icmpv4.Payload()
-
- // cmp.Diff does not consider nil slices equal to empty slices, but we do.
- if len(want) == 0 && len(payload) == 0 {
- return
- }
-
- if diff := cmp.Diff(want, payload); diff != "" {
- t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff)
- }
- }
-}
-
-// ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and
-// potentially additional ICMPv6 header fields.
-//
-// ICMPv6 will validate the checksum field before calling checkers.
-func ICMPv6(checkers ...TransportChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- last := h[len(h)-1]
-
- if p := last.TransportProtocol(); p != header.ICMPv6ProtocolNumber {
- t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv6ProtocolNumber)
- }
-
- icmp := header.ICMPv6(last.Payload())
- if got, want := icmp.Checksum(), header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmp,
- Src: last.SourceAddress(),
- Dst: last.DestinationAddress(),
- }); got != want {
- t.Fatalf("Bad ICMPv6 checksum; got %d, want %d", got, want)
- }
-
- for _, f := range checkers {
- f(t, icmp)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// ICMPv6Type creates a checker that checks the ICMPv6 Type field.
-func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmpv6, ok := h.(header.ICMPv6)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
- }
- if got := icmpv6.Type(); got != want {
- t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want)
- }
- }
-}
-
-// ICMPv6Code creates a checker that checks the ICMPv6 Code field.
-func ICMPv6Code(want header.ICMPv6Code) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmpv6, ok := h.(header.ICMPv6)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
- }
- if got := icmpv6.Code(); got != want {
- t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want)
- }
- }
-}
-
-// ICMPv6TypeSpecific creates a checker that checks the ICMPv6 TypeSpecific
-// field.
-func ICMPv6TypeSpecific(want uint32) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmpv6, ok := h.(header.ICMPv6)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
- }
- if got := icmpv6.TypeSpecific(); got != want {
- t.Fatalf("unexpected ICMP TypeSpecific, got = %d, want = %d", got, want)
- }
- }
-}
-
-// ICMPv6Payload creates a checker that checks the payload in an ICMPv6 packet.
-func ICMPv6Payload(want []byte) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmpv6, ok := h.(header.ICMPv6)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
- }
- payload := icmpv6.Payload()
-
- // cmp.Diff does not consider nil slices equal to empty slices, but we do.
- if len(want) == 0 && len(payload) == 0 {
- return
- }
-
- if diff := cmp.Diff(want, payload); diff != "" {
- t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff)
- }
- }
-}
-
-// MLD creates a checker that checks that the packet contains a valid MLD
-// message for type of mldType, with potentially additional checks specified by
-// checkers.
-//
-// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
-// MLD message as far as the size of the message (minSize) is concerned. The
-// values within the message are up to checkers to validate.
-func MLD(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- // Check normal ICMPv6 first.
- ICMPv6(
- ICMPv6Type(msgType),
- ICMPv6Code(0))(t, h)
-
- last := h[len(h)-1]
-
- icmp := header.ICMPv6(last.Payload())
- if got := len(icmp.MessageBody()); got < minSize {
- t.Fatalf("ICMPv6 MLD (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize)
- }
-
- for _, f := range checkers {
- f(t, icmp)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// MLDMaxRespDelay creates a checker that checks the Maximum Response Delay
-// field of a MLD message.
-//
-// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
-// containing a valid MLD message as far as the size is concerned.
-func MLDMaxRespDelay(want time.Duration) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmp := h.(header.ICMPv6)
- ns := header.MLD(icmp.MessageBody())
-
- if got := ns.MaximumResponseDelay(); got != want {
- t.Errorf("got %T.MaximumResponseDelay() = %s, want = %s", ns, got, want)
- }
- }
-}
-
-// MLDMulticastAddress creates a checker that checks the Multicast Address
-// field of a MLD message.
-//
-// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
-// containing a valid MLD message as far as the size is concerned.
-func MLDMulticastAddress(want tcpip.Address) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmp := h.(header.ICMPv6)
- ns := header.MLD(icmp.MessageBody())
-
- if got := ns.MulticastAddress(); got != want {
- t.Errorf("got %T.MulticastAddress() = %s, want = %s", ns, got, want)
- }
- }
-}
-
-// NDP creates a checker that checks that the packet contains a valid NDP
-// message for type of ty, with potentially additional checks specified by
-// checkers.
-//
-// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
-// NDP message as far as the size of the message (minSize) is concerned. The
-// values within the message are up to checkers to validate.
-func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- // Check normal ICMPv6 first.
- ICMPv6(
- ICMPv6Type(msgType),
- ICMPv6Code(0))(t, h)
-
- last := h[len(h)-1]
-
- icmp := header.ICMPv6(last.Payload())
- if got := len(icmp.MessageBody()); got < minSize {
- t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize)
- }
-
- for _, f := range checkers {
- f(t, icmp)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// NDPNS creates a checker that checks that the packet contains a valid NDP
-// Neighbor Solicitation message (as per the raw wire format), with potentially
-// additional checks specified by checkers.
-//
-// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
-// NDPNS message as far as the size of the message is concerned. The values
-// within the message are up to checkers to validate.
-func NDPNS(checkers ...TransportChecker) NetworkChecker {
- return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...)
-}
-
-// NDPNSTargetAddress creates a checker that checks the Target Address field of
-// a header.NDPNeighborSolicit.
-//
-// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
-// containing a valid NDPNS message as far as the size is concerned.
-func NDPNSTargetAddress(want tcpip.Address) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmp := h.(header.ICMPv6)
- ns := header.NDPNeighborSolicit(icmp.MessageBody())
-
- if got := ns.TargetAddress(); got != want {
- t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want)
- }
- }
-}
-
-// NDPNA creates a checker that checks that the packet contains a valid NDP
-// Neighbor Advertisement message (as per the raw wire format), with potentially
-// additional checks specified by checkers.
-//
-// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
-// NDPNA message as far as the size of the message is concerned. The values
-// within the message are up to checkers to validate.
-func NDPNA(checkers ...TransportChecker) NetworkChecker {
- return NDP(header.ICMPv6NeighborAdvert, header.NDPNAMinimumSize, checkers...)
-}
-
-// NDPNATargetAddress creates a checker that checks the Target Address field of
-// a header.NDPNeighborAdvert.
-//
-// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
-// containing a valid NDPNA message as far as the size is concerned.
-func NDPNATargetAddress(want tcpip.Address) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmp := h.(header.ICMPv6)
- na := header.NDPNeighborAdvert(icmp.MessageBody())
-
- if got := na.TargetAddress(); got != want {
- t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want)
- }
- }
-}
-
-// NDPNASolicitedFlag creates a checker that checks the Solicited field of
-// a header.NDPNeighborAdvert.
-//
-// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
-// containing a valid NDPNA message as far as the size is concerned.
-func NDPNASolicitedFlag(want bool) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmp := h.(header.ICMPv6)
- na := header.NDPNeighborAdvert(icmp.MessageBody())
-
- if got := na.SolicitedFlag(); got != want {
- t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want)
- }
- }
-}
-
-// ndpOptions checks that optsBuf only contains opts.
-func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption) {
- t.Helper()
-
- it, err := optsBuf.Iter(true)
- if err != nil {
- t.Errorf("optsBuf.Iter(true): %s", err)
- return
- }
-
- i := 0
- for {
- opt, done, err := it.Next()
- if err != nil {
- // This should never happen as Iter(true) above did not return an error.
- t.Fatalf("unexpected error when iterating over NDP options: %s", err)
- }
- if done {
- break
- }
-
- if i >= len(opts) {
- t.Errorf("got unexpected option: %s", opt)
- continue
- }
-
- switch wantOpt := opts[i].(type) {
- case header.NDPSourceLinkLayerAddressOption:
- gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption)
- if !ok {
- t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
- } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
- t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
- }
- case header.NDPTargetLinkLayerAddressOption:
- gotOpt, ok := opt.(header.NDPTargetLinkLayerAddressOption)
- if !ok {
- t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
- } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
- t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
- }
- case header.NDPNonceOption:
- gotOpt, ok := opt.(header.NDPNonceOption)
- if !ok {
- t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
- } else if diff := cmp.Diff(wantOpt.Nonce(), gotOpt.Nonce()); diff != "" {
- t.Errorf("nonce mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt)
- }
-
- i++
- }
-
- if missing := opts[i:]; len(missing) > 0 {
- t.Errorf("missing options: %s", missing)
- }
-}
-
-// NDPNAOptions creates a checker that checks that the packet contains the
-// provided NDP options within an NDP Neighbor Solicitation message.
-//
-// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
-// containing a valid NDPNA message as far as the size is concerned.
-func NDPNAOptions(opts []header.NDPOption) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmp := h.(header.ICMPv6)
- na := header.NDPNeighborAdvert(icmp.MessageBody())
- ndpOptions(t, na.Options(), opts)
- }
-}
-
-// NDPNSOptions creates a checker that checks that the packet contains the
-// provided NDP options within an NDP Neighbor Solicitation message.
-//
-// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
-// containing a valid NDPNS message as far as the size is concerned.
-func NDPNSOptions(opts []header.NDPOption) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmp := h.(header.ICMPv6)
- ns := header.NDPNeighborSolicit(icmp.MessageBody())
- ndpOptions(t, ns.Options(), opts)
- }
-}
-
-// NDPRS creates a checker that checks that the packet contains a valid NDP
-// Router Solicitation message (as per the raw wire format).
-//
-// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
-// NDPRS as far as the size of the message is concerned. The values within the
-// message are up to checkers to validate.
-func NDPRS(checkers ...TransportChecker) NetworkChecker {
- return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize, checkers...)
-}
-
-// NDPRSOptions creates a checker that checks that the packet contains the
-// provided NDP options within an NDP Router Solicitation message.
-//
-// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
-// containing a valid NDPRS message as far as the size is concerned.
-func NDPRSOptions(opts []header.NDPOption) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmp := h.(header.ICMPv6)
- rs := header.NDPRouterSolicit(icmp.MessageBody())
- ndpOptions(t, rs.Options(), opts)
- }
-}
-
-// IGMP checks the validity and properties of the given IGMP packet. It is
-// expected to be used in conjunction with other IGMP transport checkers for
-// specific properties.
-func IGMP(checkers ...TransportChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- last := h[len(h)-1]
-
- if p := last.TransportProtocol(); p != header.IGMPProtocolNumber {
- t.Fatalf("Bad protocol, got %d, want %d", p, header.IGMPProtocolNumber)
- }
-
- igmp := header.IGMP(last.Payload())
- for _, f := range checkers {
- f(t, igmp)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// IGMPType creates a checker that checks the IGMP Type field.
-func IGMPType(want header.IGMPType) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- igmp, ok := h.(header.IGMP)
- if !ok {
- t.Fatalf("got transport header = %T, want = header.IGMP", h)
- }
- if got := igmp.Type(); got != want {
- t.Errorf("got igmp.Type() = %d, want = %d", got, want)
- }
- }
-}
-
-// IGMPMaxRespTime creates a checker that checks the IGMP Max Resp Time field.
-func IGMPMaxRespTime(want time.Duration) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- igmp, ok := h.(header.IGMP)
- if !ok {
- t.Fatalf("got transport header = %T, want = header.IGMP", h)
- }
- if got := igmp.MaxRespTime(); got != want {
- t.Errorf("got igmp.MaxRespTime() = %s, want = %s", got, want)
- }
- }
-}
-
-// IGMPGroupAddress creates a checker that checks the IGMP Group Address field.
-func IGMPGroupAddress(want tcpip.Address) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- igmp, ok := h.(header.IGMP)
- if !ok {
- t.Fatalf("got transport header = %T, want = header.IGMP", h)
- }
- if got := igmp.GroupAddress(); got != want {
- t.Errorf("got igmp.GroupAddress() = %s, want = %s", got, want)
- }
- }
-}
-
-// IPv6ExtHdrChecker is a function to check an extension header.
-type IPv6ExtHdrChecker func(*testing.T, header.IPv6PayloadHeader)
-
-// IPv6WithExtHdr is like IPv6 but allows IPv6 packets with extension headers.
-func IPv6WithExtHdr(t *testing.T, b []byte, checkers ...NetworkChecker) {
- t.Helper()
-
- ipv6 := header.IPv6(b)
- if !ipv6.IsValid(len(b)) {
- t.Error("not a valid IPv6 packet")
- return
- }
-
- payloadIterator := header.MakeIPv6PayloadIterator(
- header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()),
- buffer.View(ipv6.Payload()).ToVectorisedView(),
- )
-
- var rawPayloadHeader header.IPv6RawPayloadHeader
- for {
- h, done, err := payloadIterator.Next()
- if err != nil {
- t.Errorf("payloadIterator.Next(): %s", err)
- return
- }
- if done {
- t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, true, _)", h, done)
- return
- }
- r, ok := h.(header.IPv6RawPayloadHeader)
- if ok {
- rawPayloadHeader = r
- break
- }
- }
-
- networkHeader := ipv6HeaderWithExtHdr{
- IPv6: ipv6,
- transport: tcpip.TransportProtocolNumber(rawPayloadHeader.Identifier),
- payload: rawPayloadHeader.Buf.ToView(),
- }
-
- for _, checker := range checkers {
- checker(t, []header.Network{&networkHeader})
- }
-}
-
-// IPv6ExtHdr checks for the presence of extension headers.
-//
-// All the extension headers in headers will be checked exhaustively in the
-// order provided.
-func IPv6ExtHdr(headers ...IPv6ExtHdrChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- extHdrs, ok := h[0].(*ipv6HeaderWithExtHdr)
- if !ok {
- t.Errorf("got network header = %T, want = *ipv6HeaderWithExtHdr", h[0])
- return
- }
-
- payloadIterator := header.MakeIPv6PayloadIterator(
- header.IPv6ExtensionHeaderIdentifier(extHdrs.IPv6.NextHeader()),
- buffer.View(extHdrs.IPv6.Payload()).ToVectorisedView(),
- )
-
- for _, check := range headers {
- h, done, err := payloadIterator.Next()
- if err != nil {
- t.Errorf("payloadIterator.Next(): %s", err)
- return
- }
- if done {
- t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, false, _)", h, done)
- return
- }
- check(t, h)
- }
- // Validate we consumed all headers.
- //
- // The next one over should be a raw payload and then iterator should
- // terminate.
- wantDone := false
- for {
- h, done, err := payloadIterator.Next()
- if err != nil {
- t.Errorf("payloadIterator.Next(): %s", err)
- return
- }
- if done != wantDone {
- t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, %t, _)", h, done, wantDone)
- return
- }
- if done {
- break
- }
- if _, ok := h.(header.IPv6RawPayloadHeader); !ok {
- t.Errorf("got payloadIterator.Next() = (%T, _, _), want = (header.IPv6RawPayloadHeader, _, _)", h)
- continue
- }
- wantDone = true
- }
- }
-}
-
-var _ header.Network = (*ipv6HeaderWithExtHdr)(nil)
-
-// ipv6HeaderWithExtHdr provides a header.Network implementation that takes
-// extension headers into consideration, which is not the case with vanilla
-// header.IPv6.
-type ipv6HeaderWithExtHdr struct {
- header.IPv6
- transport tcpip.TransportProtocolNumber
- payload []byte
-}
-
-// TransportProtocol implements header.Network.
-func (h *ipv6HeaderWithExtHdr) TransportProtocol() tcpip.TransportProtocolNumber {
- return h.transport
-}
-
-// Payload implements header.Network.
-func (h *ipv6HeaderWithExtHdr) Payload() []byte {
- return h.payload
-}
-
-// IPv6ExtHdrOptionChecker is a function to check an extension header option.
-type IPv6ExtHdrOptionChecker func(*testing.T, header.IPv6ExtHdrOption)
-
-// IPv6HopByHopExtensionHeader checks the extension header is a Hop by Hop
-// extension header and validates the containing options with checkers.
-//
-// checkers must exhaustively contain all the expected options.
-func IPv6HopByHopExtensionHeader(checkers ...IPv6ExtHdrOptionChecker) IPv6ExtHdrChecker {
- return func(t *testing.T, payloadHeader header.IPv6PayloadHeader) {
- t.Helper()
-
- hbh, ok := payloadHeader.(header.IPv6HopByHopOptionsExtHdr)
- if !ok {
- t.Errorf("unexpected IPv6 payload header, got = %T, want = header.IPv6HopByHopOptionsExtHdr", payloadHeader)
- return
- }
- optionsIterator := hbh.Iter()
- for _, f := range checkers {
- opt, done, err := optionsIterator.Next()
- if err != nil {
- t.Errorf("optionsIterator.Next(): %s", err)
- return
- }
- if done {
- t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, false, _)", opt, done)
- }
- f(t, opt)
- }
- // Validate all options were consumed.
- for {
- opt, done, err := optionsIterator.Next()
- if err != nil {
- t.Errorf("optionsIterator.Next(): %s", err)
- return
- }
- if !done {
- t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, true, _)", opt, done)
- }
- if done {
- break
- }
- }
- }
-}
-
-// IPv6RouterAlert validates that an extension header option is the RouterAlert
-// option and matches on its value.
-func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker {
- return func(t *testing.T, opt header.IPv6ExtHdrOption) {
- routerAlert, ok := opt.(*header.IPv6RouterAlertOption)
- if !ok {
- t.Errorf("unexpected extension header option, got = %T, want = header.IPv6RouterAlertOption", opt)
- return
- }
- if routerAlert.Value != want {
- t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, want)
- }
- }
-}
-
-// IPv6UnknownOption validates that an extension header option is the
-// unknown header option.
-func IPv6UnknownOption() IPv6ExtHdrOptionChecker {
- return func(t *testing.T, opt header.IPv6ExtHdrOption) {
- _, ok := opt.(*header.IPv6UnknownExtHdrOption)
- if !ok {
- t.Errorf("got = %T, want = header.IPv6UnknownExtHdrOption", opt)
- }
- }
-}
-
-// IgnoreCmpPath returns a cmp.Option that ignores listed field paths.
-func IgnoreCmpPath(paths ...string) cmp.Option {
- ignores := map[string]struct{}{}
- for _, path := range paths {
- ignores[path] = struct{}{}
- }
- return cmp.FilterPath(func(path cmp.Path) bool {
- _, ok := ignores[path.String()]
- return ok
- }, cmp.Ignore())
-}
diff --git a/pkg/tcpip/faketime/BUILD b/pkg/tcpip/faketime/BUILD
deleted file mode 100644
index bb9d44aff..000000000
--- a/pkg/tcpip/faketime/BUILD
+++ /dev/null
@@ -1,21 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "faketime",
- srcs = ["faketime.go"],
- visibility = ["//visibility:public"],
- deps = ["//pkg/tcpip"],
-)
-
-go_test(
- name = "faketime_test",
- size = "small",
- srcs = [
- "faketime_test.go",
- ],
- deps = [
- "//pkg/tcpip/faketime",
- ],
-)
diff --git a/pkg/tcpip/faketime/faketime_state_autogen.go b/pkg/tcpip/faketime/faketime_state_autogen.go
new file mode 100644
index 000000000..3de72f27d
--- /dev/null
+++ b/pkg/tcpip/faketime/faketime_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package faketime
diff --git a/pkg/tcpip/faketime/faketime_test.go b/pkg/tcpip/faketime/faketime_test.go
deleted file mode 100644
index fd2bb470a..000000000
--- a/pkg/tcpip/faketime/faketime_test.go
+++ /dev/null
@@ -1,95 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package faketime_test
-
-import (
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
-)
-
-func TestManualClockAdvance(t *testing.T) {
- const timeout = time.Millisecond
- clock := faketime.NewManualClock()
- start := clock.NowMonotonic()
- clock.Advance(timeout)
- if got, want := clock.NowMonotonic().Sub(start), timeout; got != want {
- t.Errorf("got = %d, want = %d", got, want)
- }
-}
-
-func TestManualClockAfterFunc(t *testing.T) {
- const (
- timeout1 = time.Millisecond // timeout for counter1
- timeout2 = 2 * time.Millisecond // timeout for counter2
- )
- tests := []struct {
- name string
- advance time.Duration
- wantCounter1 int
- wantCounter2 int
- }{
- {
- name: "before timeout1",
- advance: timeout1 - 1,
- wantCounter1: 0,
- wantCounter2: 0,
- },
- {
- name: "timeout1",
- advance: timeout1,
- wantCounter1: 1,
- wantCounter2: 0,
- },
- {
- name: "timeout2",
- advance: timeout2,
- wantCounter1: 1,
- wantCounter2: 1,
- },
- {
- name: "after timeout2",
- advance: timeout2 + 1,
- wantCounter1: 1,
- wantCounter2: 1,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- counter1 := 0
- counter2 := 0
- clock.AfterFunc(timeout1, func() {
- counter1++
- })
- clock.AfterFunc(timeout2, func() {
- counter2++
- })
- start := clock.NowMonotonic()
- clock.Advance(test.advance)
- if got, want := counter1, test.wantCounter1; got != want {
- t.Errorf("got counter1 = %d, want = %d", got, want)
- }
- if got, want := counter2, test.wantCounter2; got != want {
- t.Errorf("got counter2 = %d, want = %d", got, want)
- }
- if got, want := clock.NowMonotonic().Sub(start), test.advance; got != want {
- t.Errorf("got elapsed = %d, want = %d", got, want)
- }
- })
- }
-}
diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD
deleted file mode 100644
index ff2719291..000000000
--- a/pkg/tcpip/hash/jenkins/BUILD
+++ /dev/null
@@ -1,18 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "jenkins",
- srcs = ["jenkins.go"],
- visibility = ["//visibility:public"],
-)
-
-go_test(
- name = "jenkins_test",
- size = "small",
- srcs = [
- "jenkins_test.go",
- ],
- library = ":jenkins",
-)
diff --git a/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go b/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go
new file mode 100644
index 000000000..216cc5a2e
--- /dev/null
+++ b/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package jenkins
diff --git a/pkg/tcpip/hash/jenkins/jenkins_test.go b/pkg/tcpip/hash/jenkins/jenkins_test.go
deleted file mode 100644
index 4c78b5808..000000000
--- a/pkg/tcpip/hash/jenkins/jenkins_test.go
+++ /dev/null
@@ -1,176 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-package jenkins
-
-import (
- "bytes"
- "encoding/binary"
- "hash"
- "hash/fnv"
- "math"
- "testing"
-)
-
-func TestGolden32(t *testing.T) {
- var golden32 = []struct {
- out []byte
- in string
- }{
- {[]byte{0x00, 0x00, 0x00, 0x00}, ""},
- {[]byte{0xca, 0x2e, 0x94, 0x42}, "a"},
- {[]byte{0x45, 0xe6, 0x1e, 0x58}, "ab"},
- {[]byte{0xed, 0x13, 0x1f, 0x5b}, "abc"},
- }
-
- hash := New32()
-
- for _, g := range golden32 {
- hash.Reset()
- done, error := hash.Write([]byte(g.in))
- if error != nil {
- t.Fatalf("write error: %s", error)
- }
- if done != len(g.in) {
- t.Fatalf("wrote only %d out of %d bytes", done, len(g.in))
- }
- if actual := hash.Sum(nil); !bytes.Equal(g.out, actual) {
- t.Errorf("hash(%q) = 0x%x want 0x%x", g.in, actual, g.out)
- }
- }
-}
-
-func TestIntegrity32(t *testing.T) {
- data := []byte{'1', '2', 3, 4, 5}
-
- h := New32()
- h.Write(data)
- sum := h.Sum(nil)
-
- if size := h.Size(); size != len(sum) {
- t.Fatalf("Size()=%d but len(Sum())=%d", size, len(sum))
- }
-
- if a := h.Sum(nil); !bytes.Equal(sum, a) {
- t.Fatalf("first Sum()=0x%x, second Sum()=0x%x", sum, a)
- }
-
- h.Reset()
- h.Write(data)
- if a := h.Sum(nil); !bytes.Equal(sum, a) {
- t.Fatalf("Sum()=0x%x, but after Reset() Sum()=0x%x", sum, a)
- }
-
- h.Reset()
- h.Write(data[:2])
- h.Write(data[2:])
- if a := h.Sum(nil); !bytes.Equal(sum, a) {
- t.Fatalf("Sum()=0x%x, but with partial writes, Sum()=0x%x", sum, a)
- }
-
- sum32 := h.(hash.Hash32).Sum32()
- if sum32 != binary.BigEndian.Uint32(sum) {
- t.Fatalf("Sum()=0x%x, but Sum32()=0x%x", sum, sum32)
- }
-}
-
-func BenchmarkJenkins32KB(b *testing.B) {
- h := New32()
-
- b.SetBytes(1024)
- data := make([]byte, 1024)
- for i := range data {
- data[i] = byte(i)
- }
- in := make([]byte, 0, h.Size())
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- h.Reset()
- h.Write(data)
- h.Sum(in)
- }
-}
-
-func BenchmarkFnv32(b *testing.B) {
- arr := make([]int64, 1000)
- for i := 0; i < b.N; i++ {
- var payload [8]byte
- binary.BigEndian.PutUint32(payload[:4], uint32(i))
- binary.BigEndian.PutUint32(payload[4:], uint32(i))
-
- h := fnv.New32()
- h.Write(payload[:])
- idx := int(h.Sum32()) % len(arr)
- arr[idx]++
- }
- b.StopTimer()
- c := 0
- if b.N > 1000000 {
- for i := 0; i < len(arr)-1; i++ {
- if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) {
- if c == 0 {
- b.Logf("i %d val[i] %d val[i+1] %d b.N %b\n", i, arr[i], arr[i+1], b.N)
- }
- c++
- }
- }
- if c > 0 {
- b.Logf("Unbalanced buckets: %d", c)
- }
- }
-}
-
-func BenchmarkSum32(b *testing.B) {
- arr := make([]int64, 1000)
- for i := 0; i < b.N; i++ {
- var payload [8]byte
- binary.BigEndian.PutUint32(payload[:4], uint32(i))
- binary.BigEndian.PutUint32(payload[4:], uint32(i))
- h := Sum32(0)
- h.Write(payload[:])
- idx := int(h.Sum32()) % len(arr)
- arr[idx]++
- }
- b.StopTimer()
- if b.N > 1000000 {
- for i := 0; i < len(arr)-1; i++ {
- if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) {
- b.Logf("val[%3d]=%8d\tval[%3d]=%8d\tb.N=%b\n", i, arr[i], i+1, arr[i+1], b.N)
- break
- }
- }
- }
-}
-
-func BenchmarkNew32(b *testing.B) {
- arr := make([]int64, 1000)
- for i := 0; i < b.N; i++ {
- var payload [8]byte
- binary.BigEndian.PutUint32(payload[:4], uint32(i))
- binary.BigEndian.PutUint32(payload[4:], uint32(i))
- h := New32()
- h.Write(payload[:])
- idx := int(h.Sum32()) % len(arr)
- arr[idx]++
- }
- b.StopTimer()
- if b.N > 1000000 {
- for i := 0; i < len(arr)-1; i++ {
- if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) {
- b.Logf("val[%3d]=%8d\tval[%3d]=%8d\tb.N=%b\n", i, arr[i], i+1, arr[i+1], b.N)
- break
- }
- }
- }
-}
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
deleted file mode 100644
index 01240f5d0..000000000
--- a/pkg/tcpip/header/BUILD
+++ /dev/null
@@ -1,76 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "header",
- srcs = [
- "arp.go",
- "checksum.go",
- "eth.go",
- "gue.go",
- "icmpv4.go",
- "icmpv6.go",
- "igmp.go",
- "interfaces.go",
- "ipv4.go",
- "ipv6.go",
- "ipv6_extension_headers.go",
- "ipv6_fragment.go",
- "mld.go",
- "ndp_neighbor_advert.go",
- "ndp_neighbor_solicit.go",
- "ndp_options.go",
- "ndp_router_advert.go",
- "ndp_router_solicit.go",
- "ndpoptionidentifier_string.go",
- "tcp.go",
- "udp.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/seqnum",
- "@com_github_google_btree//:go_default_library",
- ],
-)
-
-go_test(
- name = "header_x_test",
- size = "small",
- srcs = [
- "checksum_test.go",
- "igmp_test.go",
- "ipv4_test.go",
- "ipv6_test.go",
- "ipversion_test.go",
- "tcp_test.go",
- ],
- deps = [
- ":header",
- "//pkg/rand",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/testutil",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
-
-go_test(
- name = "header_test",
- size = "small",
- srcs = [
- "eth_test.go",
- "ipv6_extension_headers_test.go",
- "mld_test.go",
- "ndp_test.go",
- ],
- library = ":header",
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/testutil",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go
deleted file mode 100644
index 3445511f4..000000000
--- a/pkg/tcpip/header/checksum_test.go
+++ /dev/null
@@ -1,461 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package header provides the implementation of the encoding and decoding of
-// network protocol headers.
-package header_test
-
-import (
- "bytes"
- "fmt"
- "math/rand"
- "sync"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
-)
-
-func TestChecksumer(t *testing.T) {
- testCases := []struct {
- name string
- data [][]byte
- want uint16
- }{
- {
- name: "empty",
- want: 0,
- },
- {
- name: "OneOddView",
- data: [][]byte{
- []byte{1, 9, 0, 5, 4},
- },
- want: 1294,
- },
- {
- name: "TwoOddViews",
- data: [][]byte{
- []byte{1, 9, 0, 5, 4},
- []byte{4, 3, 7, 1, 2, 123},
- },
- want: 33819,
- },
- {
- name: "OneEvenView",
- data: [][]byte{
- []byte{1, 9, 0, 5},
- },
- want: 270,
- },
- {
- name: "TwoEvenViews",
- data: [][]byte{
- buffer.NewViewFromBytes([]byte{98, 1, 9, 0}),
- buffer.NewViewFromBytes([]byte{9, 0, 5, 4}),
- },
- want: 30981,
- },
- {
- name: "ThreeViews",
- data: [][]byte{
- []byte{77, 11, 33, 0, 55, 44},
- []byte{98, 1, 9, 0, 5, 4},
- []byte{4, 3, 7, 1, 2, 123, 99},
- },
- want: 34236,
- },
- }
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- var all bytes.Buffer
- var c header.Checksumer
- for _, b := range tc.data {
- c.Add(b)
- // Append to the buffer. We will check the checksum as a whole later.
- if _, err := all.Write(b); err != nil {
- t.Fatalf("all.Write(b) = _, %s; want _, nil", err)
- }
- }
- if got, want := c.Checksum(), tc.want; got != want {
- t.Errorf("c.Checksum() = %d, want %d", got, want)
- }
- if got, want := header.Checksum(all.Bytes(), 0 /* initial */), tc.want; got != want {
- t.Errorf("Checksum(flatten tc.data) = %d, want %d", got, want)
- }
- })
- }
-}
-
-func TestChecksum(t *testing.T) {
- var bufSizes = []int{0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256, 257, 1023, 1024}
- type testCase struct {
- buf []byte
- initial uint16
- csumOrig uint16
- csumNew uint16
- }
- testCases := make([]testCase, 100000)
- // Ensure same buffer generation for test consistency.
- rnd := rand.New(rand.NewSource(42))
- for i := range testCases {
- testCases[i].buf = make([]byte, bufSizes[i%len(bufSizes)])
- testCases[i].initial = uint16(rnd.Intn(65536))
- rnd.Read(testCases[i].buf)
- }
-
- for i := range testCases {
- testCases[i].csumOrig = header.ChecksumOld(testCases[i].buf, testCases[i].initial)
- testCases[i].csumNew = header.Checksum(testCases[i].buf, testCases[i].initial)
- if got, want := testCases[i].csumNew, testCases[i].csumOrig; got != want {
- t.Fatalf("new checksum for (buf = %x, initial = %d) does not match old got: %d, want: %d", testCases[i].buf, testCases[i].initial, got, want)
- }
- }
-}
-
-func BenchmarkChecksum(b *testing.B) {
- var bufSizes = []int{64, 128, 256, 512, 1024, 1500, 2048, 4096, 8192, 16384, 32767, 32768, 65535, 65536}
-
- checkSumImpls := []struct {
- fn func([]byte, uint16) uint16
- name string
- }{
- {header.ChecksumOld, fmt.Sprintf("checksum_old")},
- {header.Checksum, fmt.Sprintf("checksum")},
- }
-
- for _, csumImpl := range checkSumImpls {
- // Ensure same buffer generation for test consistency.
- rnd := rand.New(rand.NewSource(42))
- for _, bufSz := range bufSizes {
- b.Run(fmt.Sprintf("%s_%d", csumImpl.name, bufSz), func(b *testing.B) {
- tc := struct {
- buf []byte
- initial uint16
- csum uint16
- }{
- buf: make([]byte, bufSz),
- initial: uint16(rnd.Intn(65536)),
- }
- rnd.Read(tc.buf)
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- tc.csum = csumImpl.fn(tc.buf, tc.initial)
- }
- })
- }
- }
-}
-
-func testICMPChecksum(t *testing.T, headerChecksum func() uint16, icmpChecksum func() uint16, want uint16, pktStr string) {
- // icmpChecksum should not do any modifications of the header to
- // calculate its checksum. Let's call it from a few go-routines and the
- // race detector will trigger a warning if there are any concurrent
- // read/write accesses.
-
- const concurrency = 5
- start := make(chan int)
- ready := make(chan bool, concurrency)
- var wg sync.WaitGroup
- wg.Add(concurrency)
- defer wg.Wait()
-
- for i := 0; i < concurrency; i++ {
- go func() {
- defer wg.Done()
-
- ready <- true
- <-start
-
- if got := headerChecksum(); want != got {
- t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want)
- }
- if got := icmpChecksum(); want != got {
- t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want)
- }
- }()
- }
- for i := 0; i < concurrency; i++ {
- <-ready
- }
- close(start)
-}
-
-func TestICMPv4Checksum(t *testing.T) {
- rnd := rand.New(rand.NewSource(42))
-
- h := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize))
- if _, err := rnd.Read(h); err != nil {
- t.Fatalf("rnd.Read failed: %v", err)
- }
- h.SetChecksum(0)
-
- buf := make([]byte, 13)
- if _, err := rnd.Read(buf); err != nil {
- t.Fatalf("rnd.Read failed: %v", err)
- }
- vv := buffer.NewVectorisedView(len(buf), []buffer.View{
- buffer.NewViewFromBytes(buf[:5]),
- buffer.NewViewFromBytes(buf[5:]),
- })
-
- want := header.Checksum(vv.ToView(), 0)
- want = ^header.Checksum(h, want)
- h.SetChecksum(want)
-
- testICMPChecksum(t, h.Checksum, func() uint16 {
- return header.ICMPv4Checksum(h, header.ChecksumVV(vv, 0))
- }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView()))
-}
-
-func TestICMPv6Checksum(t *testing.T) {
- rnd := rand.New(rand.NewSource(42))
-
- h := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize))
- if _, err := rnd.Read(h); err != nil {
- t.Fatalf("rnd.Read failed: %v", err)
- }
- h.SetChecksum(0)
-
- buf := make([]byte, 13)
- if _, err := rnd.Read(buf); err != nil {
- t.Fatalf("rnd.Read failed: %v", err)
- }
- vv := buffer.NewVectorisedView(len(buf), []buffer.View{
- buffer.NewViewFromBytes(buf[:7]),
- buffer.NewViewFromBytes(buf[7:10]),
- buffer.NewViewFromBytes(buf[10:]),
- })
-
- dst := header.IPv6Loopback
- src := header.IPv6Loopback
-
- want := header.PseudoHeaderChecksum(header.ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size()))
- want = header.Checksum(vv.ToView(), want)
- want = ^header.Checksum(h, want)
- h.SetChecksum(want)
-
- testICMPChecksum(t, h.Checksum, func() uint16 {
- return header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: h,
- Src: src,
- Dst: dst,
- PayloadCsum: header.ChecksumVV(vv, 0),
- PayloadLen: vv.Size(),
- })
- }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView()))
-}
-
-func randomAddress(size int) tcpip.Address {
- s := make([]byte, size)
- for i := 0; i < size; i++ {
- s[i] = byte(rand.Uint32())
- }
- return tcpip.Address(s)
-}
-
-func TestChecksummableNetworkUpdateAddress(t *testing.T) {
- tests := []struct {
- name string
- update func(header.IPv4, tcpip.Address)
- }{
- {
- name: "SetSourceAddressWithChecksumUpdate",
- update: header.IPv4.SetSourceAddressWithChecksumUpdate,
- },
- {
- name: "SetDestinationAddressWithChecksumUpdate",
- update: header.IPv4.SetDestinationAddressWithChecksumUpdate,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- for i := 0; i < 1000; i++ {
- var origBytes [header.IPv4MinimumSize]byte
- header.IPv4(origBytes[:]).Encode(&header.IPv4Fields{
- TOS: 1,
- TotalLength: header.IPv4MinimumSize,
- ID: 2,
- Flags: 3,
- FragmentOffset: 4,
- TTL: 5,
- Protocol: 6,
- Checksum: 0,
- SrcAddr: randomAddress(header.IPv4AddressSize),
- DstAddr: randomAddress(header.IPv4AddressSize),
- })
-
- addr := randomAddress(header.IPv4AddressSize)
-
- bytesCopy := origBytes
- h := header.IPv4(bytesCopy[:])
- origXSum := h.CalculateChecksum()
- h.SetChecksum(^origXSum)
-
- test.update(h, addr)
- got := ^h.Checksum()
- h.SetChecksum(0)
- want := h.CalculateChecksum()
- if got != want {
- t.Errorf("got h.Checksum() = 0x%x, want = 0x%x; originalBytes = 0x%x, new addr = %s", got, want, origBytes, addr)
- }
- }
- })
- }
-}
-
-func TestChecksummableTransportUpdatePort(t *testing.T) {
- // The fields in the pseudo header is not tested here so we just use 0.
- const pseudoHeaderXSum = 0
-
- tests := []struct {
- name string
- transportHdr func(_, _ uint16) (header.ChecksummableTransport, func(uint16) uint16)
- proto tcpip.TransportProtocolNumber
- }{
- {
- name: "TCP",
- transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) {
- h := header.TCP(make([]byte, header.TCPMinimumSize))
- h.Encode(&header.TCPFields{
- SrcPort: src,
- DstPort: dst,
- SeqNum: 1,
- AckNum: 2,
- DataOffset: header.TCPMinimumSize,
- Flags: 3,
- WindowSize: 4,
- Checksum: 0,
- UrgentPointer: 5,
- })
- h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum))
- return h, h.CalculateChecksum
- },
- proto: header.TCPProtocolNumber,
- },
- {
- name: "UDP",
- transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) {
- h := header.UDP(make([]byte, header.UDPMinimumSize))
- h.Encode(&header.UDPFields{
- SrcPort: src,
- DstPort: dst,
- Length: 0,
- Checksum: 0,
- })
- h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum))
- return h, h.CalculateChecksum
- },
- proto: header.UDPProtocolNumber,
- },
- }
-
- for i := 0; i < 1000; i++ {
- origSrcPort := uint16(rand.Uint32())
- origDstPort := uint16(rand.Uint32())
- newPort := uint16(rand.Uint32())
-
- t.Run(fmt.Sprintf("OrigSrcPort=%d,OrigDstPort=%d,NewPort=%d", origSrcPort, origDstPort, newPort), func(*testing.T) {
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- for _, subTest := range []struct {
- name string
- update func(header.ChecksummableTransport)
- }{
- {
- name: "Source port",
- update: func(h header.ChecksummableTransport) { h.SetSourcePortWithChecksumUpdate(newPort) },
- },
- {
- name: "Destination port",
- update: func(h header.ChecksummableTransport) { h.SetDestinationPortWithChecksumUpdate(newPort) },
- },
- } {
- t.Run(subTest.name, func(t *testing.T) {
- h, calcXSum := test.transportHdr(origSrcPort, origDstPort)
- subTest.update(h)
- // TCP and UDP hold the 1s complement of the fully calculated
- // checksum.
- got := ^h.Checksum()
- h.SetChecksum(0)
-
- if want := calcXSum(pseudoHeaderXSum); got != want {
- h, _ := test.transportHdr(origSrcPort, origDstPort)
- t.Errorf("got Checksum() = 0x%x, want = 0x%x; originalBytes = %#v, new port = %d", got, want, h, newPort)
- }
- })
- }
- })
- }
- })
- }
-}
-
-func TestChecksummableTransportUpdatePseudoHeaderAddress(t *testing.T) {
- const addressSize = 6
-
- tests := []struct {
- name string
- transportHdr func() header.ChecksummableTransport
- proto tcpip.TransportProtocolNumber
- }{
- {
- name: "TCP",
- transportHdr: func() header.ChecksummableTransport { return header.TCP(make([]byte, header.TCPMinimumSize)) },
- proto: header.TCPProtocolNumber,
- },
- {
- name: "UDP",
- transportHdr: func() header.ChecksummableTransport { return header.UDP(make([]byte, header.UDPMinimumSize)) },
- proto: header.UDPProtocolNumber,
- },
- }
-
- for i := 0; i < 1000; i++ {
- permanent := randomAddress(addressSize)
- old := randomAddress(addressSize)
- new := randomAddress(addressSize)
-
- t.Run(fmt.Sprintf("Permanent=%q,Old=%q,New=%q", permanent, old, new), func(t *testing.T) {
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- for _, fullChecksum := range []bool{true, false} {
- t.Run(fmt.Sprintf("FullChecksum=%t", fullChecksum), func(t *testing.T) {
- initialXSum := header.PseudoHeaderChecksum(test.proto, permanent, old, 0)
- if fullChecksum {
- // TCP and UDP hold the 1s complement of the fully calculated
- // checksum.
- initialXSum = ^initialXSum
- }
-
- h := test.transportHdr()
- h.SetChecksum(initialXSum)
- h.UpdateChecksumPseudoHeaderAddress(old, new, fullChecksum)
-
- got := h.Checksum()
- if fullChecksum {
- got = ^got
- }
- if want := header.PseudoHeaderChecksum(test.proto, permanent, new, 0); got != want {
- t.Errorf("got Checksum() = 0x%x, want = 0x%x; h = %#v", got, want, h)
- }
- })
- }
- })
- }
- })
- }
-}
diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go
deleted file mode 100644
index adc04e855..000000000
--- a/pkg/tcpip/header/eth_test.go
+++ /dev/null
@@ -1,150 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package header
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
-)
-
-func TestIsValidUnicastEthernetAddress(t *testing.T) {
- tests := []struct {
- name string
- addr tcpip.LinkAddress
- expected bool
- }{
- {
- "Nil",
- tcpip.LinkAddress([]byte(nil)),
- false,
- },
- {
- "Empty",
- tcpip.LinkAddress(""),
- false,
- },
- {
- "InvalidLength",
- tcpip.LinkAddress("\x01\x02\x03"),
- false,
- },
- {
- "Unspecified",
- UnspecifiedEthernetAddress,
- false,
- },
- {
- "Multicast",
- tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
- false,
- },
- {
- "Valid",
- tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06"),
- true,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- if got := IsValidUnicastEthernetAddress(test.addr); got != test.expected {
- t.Fatalf("got IsValidUnicastEthernetAddress = %t, want = %t", got, test.expected)
- }
- })
- }
-}
-
-func TestIsMulticastEthernetAddress(t *testing.T) {
- tests := []struct {
- name string
- addr tcpip.LinkAddress
- expected bool
- }{
- {
- "Nil",
- tcpip.LinkAddress([]byte(nil)),
- false,
- },
- {
- "Empty",
- tcpip.LinkAddress(""),
- false,
- },
- {
- "InvalidLength",
- tcpip.LinkAddress("\x01\x02\x03"),
- false,
- },
- {
- "Unspecified",
- UnspecifiedEthernetAddress,
- false,
- },
- {
- "Multicast",
- tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
- true,
- },
- {
- "Unicast",
- tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06"),
- false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- if got := IsMulticastEthernetAddress(test.addr); got != test.expected {
- t.Fatalf("got IsMulticastEthernetAddress = %t, want = %t", got, test.expected)
- }
- })
- }
-}
-
-func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) {
- tests := []struct {
- name string
- addr tcpip.Address
- expectedLinkAddr tcpip.LinkAddress
- }{
- {
- name: "IPv4 Multicast without 24th bit set",
- addr: "\xe0\x7e\xdc\xba",
- expectedLinkAddr: "\x01\x00\x5e\x7e\xdc\xba",
- },
- {
- name: "IPv4 Multicast with 24th bit set",
- addr: "\xe0\xfe\xdc\xba",
- expectedLinkAddr: "\x01\x00\x5e\x7e\xdc\xba",
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- if got := EthernetAddressFromMulticastIPv4Address(test.addr); got != test.expectedLinkAddr {
- t.Fatalf("got EthernetAddressFromMulticastIPv4Address(%s) = %s, want = %s", test.addr, got, test.expectedLinkAddr)
- }
- })
- }
-}
-
-func TestEthernetAddressFromMulticastIPv6Address(t *testing.T) {
- addr := testutil.MustParse6("ff02:304:506:708:90a:b0c:d0e:f1a")
- if got, want := EthernetAddressFromMulticastIPv6Address(addr), tcpip.LinkAddress("\x33\x33\x0d\x0e\x0f\x1a"); got != want {
- t.Fatalf("got EthernetAddressFromMulticastIPv6Address(%s) = %s, want = %s", addr, got, want)
- }
-}
diff --git a/pkg/tcpip/header/header_state_autogen.go b/pkg/tcpip/header/header_state_autogen.go
new file mode 100644
index 000000000..d6dd58874
--- /dev/null
+++ b/pkg/tcpip/header/header_state_autogen.go
@@ -0,0 +1,74 @@
+// automatically generated by stateify.
+
+package header
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (r *SACKBlock) StateTypeName() string {
+ return "pkg/tcpip/header.SACKBlock"
+}
+
+func (r *SACKBlock) StateFields() []string {
+ return []string{
+ "Start",
+ "End",
+ }
+}
+
+func (r *SACKBlock) beforeSave() {}
+
+// +checklocksignore
+func (r *SACKBlock) StateSave(stateSinkObject state.Sink) {
+ r.beforeSave()
+ stateSinkObject.Save(0, &r.Start)
+ stateSinkObject.Save(1, &r.End)
+}
+
+func (r *SACKBlock) afterLoad() {}
+
+// +checklocksignore
+func (r *SACKBlock) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &r.Start)
+ stateSourceObject.Load(1, &r.End)
+}
+
+func (t *TCPOptions) StateTypeName() string {
+ return "pkg/tcpip/header.TCPOptions"
+}
+
+func (t *TCPOptions) StateFields() []string {
+ return []string{
+ "TS",
+ "TSVal",
+ "TSEcr",
+ "SACKBlocks",
+ }
+}
+
+func (t *TCPOptions) beforeSave() {}
+
+// +checklocksignore
+func (t *TCPOptions) StateSave(stateSinkObject state.Sink) {
+ t.beforeSave()
+ stateSinkObject.Save(0, &t.TS)
+ stateSinkObject.Save(1, &t.TSVal)
+ stateSinkObject.Save(2, &t.TSEcr)
+ stateSinkObject.Save(3, &t.SACKBlocks)
+}
+
+func (t *TCPOptions) afterLoad() {}
+
+// +checklocksignore
+func (t *TCPOptions) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &t.TS)
+ stateSourceObject.Load(1, &t.TSVal)
+ stateSourceObject.Load(2, &t.TSEcr)
+ stateSourceObject.Load(3, &t.SACKBlocks)
+}
+
+func init() {
+ state.Register((*SACKBlock)(nil))
+ state.Register((*TCPOptions)(nil))
+}
diff --git a/pkg/tcpip/header/igmp_test.go b/pkg/tcpip/header/igmp_test.go
deleted file mode 100644
index 575604928..000000000
--- a/pkg/tcpip/header/igmp_test.go
+++ /dev/null
@@ -1,110 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package header_test
-
-import (
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
-)
-
-// TestIGMPHeader tests the functions within header.igmp
-func TestIGMPHeader(t *testing.T) {
- const maxRespTimeTenthSec = 0xF0
- b := []byte{
- 0x11, // IGMP Type, Membership Query
- maxRespTimeTenthSec, // Maximum Response Time
- 0xC0, 0xC0, // Checksum
- 0x01, 0x02, 0x03, 0x04, // Group Address
- }
-
- igmpHeader := header.IGMP(b)
-
- if got, want := igmpHeader.Type(), header.IGMPMembershipQuery; got != want {
- t.Errorf("got igmpHeader.Type() = %x, want = %x", got, want)
- }
-
- if got, want := igmpHeader.MaxRespTime(), header.DecisecondToDuration(maxRespTimeTenthSec); got != want {
- t.Errorf("got igmpHeader.MaxRespTime() = %s, want = %s", got, want)
- }
-
- if got, want := igmpHeader.Checksum(), uint16(0xC0C0); got != want {
- t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, want)
- }
-
- if got, want := igmpHeader.GroupAddress(), testutil.MustParse4("1.2.3.4"); got != want {
- t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, want)
- }
-
- igmpType := header.IGMPv2MembershipReport
- igmpHeader.SetType(igmpType)
- if got := igmpHeader.Type(); got != igmpType {
- t.Errorf("got igmpHeader.Type() = %x, want = %x", got, igmpType)
- }
- if got := header.IGMPType(b[0]); got != igmpType {
- t.Errorf("got IGMPtype in backing buffer = %x, want %x", got, igmpType)
- }
-
- respTime := byte(0x02)
- igmpHeader.SetMaxRespTime(respTime)
- if got, want := igmpHeader.MaxRespTime(), header.DecisecondToDuration(respTime); got != want {
- t.Errorf("got igmpHeader.MaxRespTime() = %s, want = %s", got, want)
- }
-
- checksum := uint16(0x0102)
- igmpHeader.SetChecksum(checksum)
- if got := igmpHeader.Checksum(); got != checksum {
- t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, checksum)
- }
-
- groupAddress := testutil.MustParse4("4.3.2.1")
- igmpHeader.SetGroupAddress(groupAddress)
- if got := igmpHeader.GroupAddress(); got != groupAddress {
- t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, groupAddress)
- }
-}
-
-// TestIGMPChecksum ensures that the checksum calculator produces the expected
-// checksum.
-func TestIGMPChecksum(t *testing.T) {
- b := []byte{
- 0x11, // IGMP Type, Membership Query
- 0xF0, // Maximum Response Time
- 0xC0, 0xC0, // Checksum
- 0x01, 0x02, 0x03, 0x04, // Group Address
- }
-
- igmpHeader := header.IGMP(b)
-
- // Calculate the initial checksum after setting the checksum temporarily to 0
- // to avoid checksumming the checksum.
- initialChecksum := igmpHeader.Checksum()
- igmpHeader.SetChecksum(0)
- checksum := ^header.Checksum(b, 0)
- igmpHeader.SetChecksum(initialChecksum)
-
- if got := header.IGMPCalculateChecksum(igmpHeader); got != checksum {
- t.Errorf("got IGMPCalculateChecksum = %x, want %x", got, checksum)
- }
-}
-
-func TestDecisecondToDuration(t *testing.T) {
- const valueInDeciseconds = 5
- if got, want := header.DecisecondToDuration(valueInDeciseconds), valueInDeciseconds*time.Second/10; got != want {
- t.Fatalf("got header.DecisecondToDuration(%d) = %s, want = %s", valueInDeciseconds, got, want)
- }
-}
diff --git a/pkg/tcpip/header/ipv4_test.go b/pkg/tcpip/header/ipv4_test.go
deleted file mode 100644
index c02fe898b..000000000
--- a/pkg/tcpip/header/ipv4_test.go
+++ /dev/null
@@ -1,254 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package header_test
-
-import (
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
-)
-
-func TestIPv4OptionsSerializer(t *testing.T) {
- optCases := []struct {
- name string
- option []header.IPv4SerializableOption
- expect []byte
- }{
- {
- name: "NOP",
- option: []header.IPv4SerializableOption{
- &header.IPv4SerializableNOPOption{},
- },
- expect: []byte{1, 0, 0, 0},
- },
- {
- name: "ListEnd",
- option: []header.IPv4SerializableOption{
- &header.IPv4SerializableListEndOption{},
- },
- expect: []byte{0, 0, 0, 0},
- },
- {
- name: "RouterAlert",
- option: []header.IPv4SerializableOption{
- &header.IPv4SerializableRouterAlertOption{},
- },
- expect: []byte{148, 4, 0, 0},
- }, {
- name: "NOP and RouterAlert",
- option: []header.IPv4SerializableOption{
- &header.IPv4SerializableNOPOption{},
- &header.IPv4SerializableRouterAlertOption{},
- },
- expect: []byte{1, 148, 4, 0, 0, 0, 0, 0},
- },
- }
-
- for _, opt := range optCases {
- t.Run(opt.name, func(t *testing.T) {
- s := header.IPv4OptionsSerializer(opt.option)
- l := s.Length()
- if got := len(opt.expect); got != int(l) {
- t.Fatalf("s.Length() = %d, want = %d", got, l)
- }
- b := make([]byte, l)
- for i := range b {
- // Fill the buffer with full bytes to ensure padding is being set
- // correctly.
- b[i] = 0xFF
- }
- if serializedLength := s.Serialize(b); serializedLength != l {
- t.Fatalf("s.Serialize(_) = %d, want %d", serializedLength, l)
- }
- if diff := cmp.Diff(opt.expect, b); diff != "" {
- t.Errorf("mismatched serialized option (-want +got):\n%s", diff)
- }
- })
- }
-}
-
-// TestIPv4Encode checks that ipv4.Encode correctly fills out the requested
-// fields when options are supplied.
-func TestIPv4EncodeOptions(t *testing.T) {
- tests := []struct {
- name string
- numberOfNops int
- encodedOptions header.IPv4Options // reply should look like this
- wantIHL int
- }{
- {
- name: "valid no options",
- wantIHL: header.IPv4MinimumSize,
- },
- {
- name: "one byte options",
- numberOfNops: 1,
- encodedOptions: header.IPv4Options{1, 0, 0, 0},
- wantIHL: header.IPv4MinimumSize + 4,
- },
- {
- name: "two byte options",
- numberOfNops: 2,
- encodedOptions: header.IPv4Options{1, 1, 0, 0},
- wantIHL: header.IPv4MinimumSize + 4,
- },
- {
- name: "three byte options",
- numberOfNops: 3,
- encodedOptions: header.IPv4Options{1, 1, 1, 0},
- wantIHL: header.IPv4MinimumSize + 4,
- },
- {
- name: "four byte options",
- numberOfNops: 4,
- encodedOptions: header.IPv4Options{1, 1, 1, 1},
- wantIHL: header.IPv4MinimumSize + 4,
- },
- {
- name: "five byte options",
- numberOfNops: 5,
- encodedOptions: header.IPv4Options{1, 1, 1, 1, 1, 0, 0, 0},
- wantIHL: header.IPv4MinimumSize + 8,
- },
- {
- name: "thirty nine byte options",
- numberOfNops: 39,
- encodedOptions: header.IPv4Options{
- 1, 1, 1, 1, 1, 1, 1, 1,
- 1, 1, 1, 1, 1, 1, 1, 1,
- 1, 1, 1, 1, 1, 1, 1, 1,
- 1, 1, 1, 1, 1, 1, 1, 1,
- 1, 1, 1, 1, 1, 1, 1, 0,
- },
- wantIHL: header.IPv4MinimumSize + 40,
- },
- }
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- serializeOpts := header.IPv4OptionsSerializer(make([]header.IPv4SerializableOption, test.numberOfNops))
- for i := range serializeOpts {
- serializeOpts[i] = &header.IPv4SerializableNOPOption{}
- }
- paddedOptionLength := serializeOpts.Length()
- ipHeaderLength := int(header.IPv4MinimumSize + paddedOptionLength)
- if ipHeaderLength > header.IPv4MaximumHeaderSize {
- t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
- }
- totalLen := uint16(ipHeaderLength)
- hdr := buffer.NewPrependable(int(totalLen))
- ip := header.IPv4(hdr.Prepend(ipHeaderLength))
- // To check the padding works, poison the last byte of the options space.
- if paddedOptionLength != serializeOpts.Length() {
- ip.SetHeaderLength(uint8(ipHeaderLength))
- ip.Options()[paddedOptionLength-1] = 0xff
- ip.SetHeaderLength(0)
- }
- ip.Encode(&header.IPv4Fields{
- Options: serializeOpts,
- })
- options := ip.Options()
- wantOptions := test.encodedOptions
- if got, want := int(ip.HeaderLength()), test.wantIHL; got != want {
- t.Errorf("got IHL of %d, want %d", got, want)
- }
-
- // cmp.Diff does not consider nil slices equal to empty slices, but we do.
- if len(wantOptions) == 0 && len(options) == 0 {
- return
- }
-
- if diff := cmp.Diff(wantOptions, options); diff != "" {
- t.Errorf("options mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
-
-func TestIsV4LinkLocalUnicastAddress(t *testing.T) {
- tests := []struct {
- name string
- addr tcpip.Address
- expected bool
- }{
- {
- name: "Valid (lowest)",
- addr: "\xa9\xfe\x00\x00",
- expected: true,
- },
- {
- name: "Valid (highest)",
- addr: "\xa9\xfe\xff\xff",
- expected: true,
- },
- {
- name: "Invalid (before subnet)",
- addr: "\xa9\xfd\xff\xff",
- expected: false,
- },
- {
- name: "Invalid (after subnet)",
- addr: "\xa9\xff\x00\x00",
- expected: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- if got := header.IsV4LinkLocalUnicastAddress(test.addr); got != test.expected {
- t.Errorf("got header.IsV4LinkLocalUnicastAddress(%s) = %t, want = %t", test.addr, got, test.expected)
- }
- })
- }
-}
-
-func TestIsV4LinkLocalMulticastAddress(t *testing.T) {
- tests := []struct {
- name string
- addr tcpip.Address
- expected bool
- }{
- {
- name: "Valid (lowest)",
- addr: "\xe0\x00\x00\x00",
- expected: true,
- },
- {
- name: "Valid (highest)",
- addr: "\xe0\x00\x00\xff",
- expected: true,
- },
- {
- name: "Invalid (before subnet)",
- addr: "\xdf\xff\xff\xff",
- expected: false,
- },
- {
- name: "Invalid (after subnet)",
- addr: "\xe0\x00\x01\x00",
- expected: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- if got := header.IsV4LinkLocalMulticastAddress(test.addr); got != test.expected {
- t.Errorf("got header.IsV4LinkLocalMulticastAddress(%s) = %t, want = %t", test.addr, got, test.expected)
- }
- })
- }
-}
diff --git a/pkg/tcpip/header/ipv6_extension_headers_test.go b/pkg/tcpip/header/ipv6_extension_headers_test.go
deleted file mode 100644
index 65adc6250..000000000
--- a/pkg/tcpip/header/ipv6_extension_headers_test.go
+++ /dev/null
@@ -1,1346 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package header
-
-import (
- "bytes"
- "errors"
- "io"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
-)
-
-// Equal returns true of a and b are equivalent.
-//
-// Note, Equal will return true if a and b hold the same Identifier value and
-// contain the same bytes in Buf, even if the bytes are split across views
-// differently.
-//
-// Needed to use cmp.Equal on IPv6RawPayloadHeader as it contains unexported
-// fields.
-func (a IPv6RawPayloadHeader) Equal(b IPv6RawPayloadHeader) bool {
- return a.Identifier == b.Identifier && bytes.Equal(a.Buf.ToView(), b.Buf.ToView())
-}
-
-// Equal returns true of a and b are equivalent.
-//
-// Note, Equal will return true if a and b hold equivalent ipv6OptionsExtHdrs.
-//
-// Needed to use cmp.Equal on IPv6RawPayloadHeader as it contains unexported
-// fields.
-func (a IPv6HopByHopOptionsExtHdr) Equal(b IPv6HopByHopOptionsExtHdr) bool {
- return bytes.Equal(a.ipv6OptionsExtHdr, b.ipv6OptionsExtHdr)
-}
-
-// Equal returns true of a and b are equivalent.
-//
-// Note, Equal will return true if a and b hold equivalent ipv6OptionsExtHdrs.
-//
-// Needed to use cmp.Equal on IPv6RawPayloadHeader as it contains unexported
-// fields.
-func (a IPv6DestinationOptionsExtHdr) Equal(b IPv6DestinationOptionsExtHdr) bool {
- return bytes.Equal(a.ipv6OptionsExtHdr, b.ipv6OptionsExtHdr)
-}
-
-func TestIPv6UnknownExtHdrOption(t *testing.T) {
- tests := []struct {
- name string
- identifier IPv6ExtHdrOptionIdentifier
- expectedUnknownAction IPv6OptionUnknownAction
- }{
- {
- name: "Skip with zero LSBs",
- identifier: 0,
- expectedUnknownAction: IPv6OptionUnknownActionSkip,
- },
- {
- name: "Discard with zero LSBs",
- identifier: 64,
- expectedUnknownAction: IPv6OptionUnknownActionDiscard,
- },
- {
- name: "Discard and ICMP with zero LSBs",
- identifier: 128,
- expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMP,
- },
- {
- name: "Discard and ICMP for non multicast destination with zero LSBs",
- identifier: 192,
- expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest,
- },
- {
- name: "Skip with non-zero LSBs",
- identifier: 63,
- expectedUnknownAction: IPv6OptionUnknownActionSkip,
- },
- {
- name: "Discard with non-zero LSBs",
- identifier: 127,
- expectedUnknownAction: IPv6OptionUnknownActionDiscard,
- },
- {
- name: "Discard and ICMP with non-zero LSBs",
- identifier: 191,
- expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMP,
- },
- {
- name: "Discard and ICMP for non multicast destination with non-zero LSBs",
- identifier: 255,
- expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- opt := &IPv6UnknownExtHdrOption{Identifier: test.identifier, Data: []byte{1, 2, 3, 4}}
- if a := opt.UnknownAction(); a != test.expectedUnknownAction {
- t.Fatalf("got UnknownAction() = %d, want = %d", a, test.expectedUnknownAction)
- }
- })
- }
-
-}
-
-func TestIPv6OptionsExtHdrIterErr(t *testing.T) {
- tests := []struct {
- name string
- bytes []byte
- err error
- }{
- {
- name: "Single unknown with zero length",
- bytes: []byte{255, 0},
- },
- {
- name: "Single unknown with non-zero length",
- bytes: []byte{255, 3, 1, 2, 3},
- },
- {
- name: "Two options",
- bytes: []byte{
- 255, 0,
- 254, 1, 1,
- },
- },
- {
- name: "Three options",
- bytes: []byte{
- 255, 0,
- 254, 1, 1,
- 253, 4, 2, 3, 4, 5,
- },
- },
- {
- name: "Single unknown only identifier",
- bytes: []byte{255},
- err: io.ErrUnexpectedEOF,
- },
- {
- name: "Single unknown too small with length = 1",
- bytes: []byte{255, 1},
- err: io.ErrUnexpectedEOF,
- },
- {
- name: "Single unknown too small with length = 2",
- bytes: []byte{255, 2, 1},
- err: io.ErrUnexpectedEOF,
- },
- {
- name: "Valid first with second unknown only identifier",
- bytes: []byte{
- 255, 0,
- 254,
- },
- err: io.ErrUnexpectedEOF,
- },
- {
- name: "Valid first with second unknown missing data",
- bytes: []byte{
- 255, 0,
- 254, 1,
- },
- err: io.ErrUnexpectedEOF,
- },
- {
- name: "Valid first with second unknown too small",
- bytes: []byte{
- 255, 0,
- 254, 2, 1,
- },
- err: io.ErrUnexpectedEOF,
- },
- {
- name: "One Pad1",
- bytes: []byte{0},
- },
- {
- name: "Multiple Pad1",
- bytes: []byte{0, 0, 0},
- },
- {
- name: "Multiple PadN",
- bytes: []byte{
- // Pad3
- 1, 1, 1,
-
- // Pad5
- 1, 3, 1, 2, 3,
- },
- },
- {
- name: "Pad5 too small middle of data buffer",
- bytes: []byte{1, 3, 1, 2},
- err: io.ErrUnexpectedEOF,
- },
- {
- name: "Pad5 no data",
- bytes: []byte{1, 3},
- err: io.ErrUnexpectedEOF,
- },
- {
- name: "Router alert without data",
- bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 0},
- err: ErrMalformedIPv6ExtHdrOption,
- },
- {
- name: "Router alert with partial data",
- bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1, 1},
- err: ErrMalformedIPv6ExtHdrOption,
- },
- {
- name: "Router alert with partial data and Pad1",
- bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1, 1, 0},
- err: ErrMalformedIPv6ExtHdrOption,
- },
- {
- name: "Router alert with extra data",
- bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 3, 1, 2, 3},
- err: ErrMalformedIPv6ExtHdrOption,
- },
- {
- name: "Router alert with missing data",
- bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1},
- err: io.ErrUnexpectedEOF,
- },
- }
-
- check := func(t *testing.T, it IPv6OptionsExtHdrOptionsIterator, expectedErr error) {
- for i := 0; ; i++ {
- _, done, err := it.Next()
- if err != nil {
- // If we encountered a non-nil error while iterating, make sure it is
- // is the same error as expectedErr.
- if !errors.Is(err, expectedErr) {
- t.Fatalf("got %d-th Next() = %v, want = %v", i, err, expectedErr)
- }
-
- return
- }
- if done {
- // If we are done (without an error), make sure that we did not expect
- // an error.
- if expectedErr != nil {
- t.Fatalf("expected error when iterating; want = %s", expectedErr)
- }
-
- return
- }
- }
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- t.Run("Hop By Hop", func(t *testing.T) {
- extHdr := IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: test.bytes}
- check(t, extHdr.Iter(), test.err)
- })
-
- t.Run("Destination", func(t *testing.T) {
- extHdr := IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: test.bytes}
- check(t, extHdr.Iter(), test.err)
- })
- })
- }
-}
-
-func TestIPv6OptionsExtHdrIter(t *testing.T) {
- tests := []struct {
- name string
- bytes []byte
- expected []IPv6ExtHdrOption
- }{
- {
- name: "Single unknown with zero length",
- bytes: []byte{255, 0},
- expected: []IPv6ExtHdrOption{
- &IPv6UnknownExtHdrOption{Identifier: 255, Data: []byte{}},
- },
- },
- {
- name: "Single unknown with non-zero length",
- bytes: []byte{255, 3, 1, 2, 3},
- expected: []IPv6ExtHdrOption{
- &IPv6UnknownExtHdrOption{Identifier: 255, Data: []byte{1, 2, 3}},
- },
- },
- {
- name: "Single Pad1",
- bytes: []byte{0},
- },
- {
- name: "Two Pad1",
- bytes: []byte{0, 0},
- },
- {
- name: "Single Pad3",
- bytes: []byte{1, 1, 1},
- },
- {
- name: "Single Pad5",
- bytes: []byte{1, 3, 1, 2, 3},
- },
- {
- name: "Multiple Pad",
- bytes: []byte{
- // Pad1
- 0,
-
- // Pad2
- 1, 0,
-
- // Pad3
- 1, 1, 1,
-
- // Pad4
- 1, 2, 1, 2,
-
- // Pad5
- 1, 3, 1, 2, 3,
- },
- },
- {
- name: "Multiple options",
- bytes: []byte{
- // Pad1
- 0,
-
- // Unknown
- 255, 0,
-
- // Pad2
- 1, 0,
-
- // Unknown
- 254, 1, 1,
-
- // Pad3
- 1, 1, 1,
-
- // Unknown
- 253, 4, 2, 3, 4, 5,
-
- // Pad4
- 1, 2, 1, 2,
- },
- expected: []IPv6ExtHdrOption{
- &IPv6UnknownExtHdrOption{Identifier: 255, Data: []byte{}},
- &IPv6UnknownExtHdrOption{Identifier: 254, Data: []byte{1}},
- &IPv6UnknownExtHdrOption{Identifier: 253, Data: []byte{2, 3, 4, 5}},
- },
- },
- }
-
- checkIter := func(t *testing.T, it IPv6OptionsExtHdrOptionsIterator, expected []IPv6ExtHdrOption) {
- for i, e := range expected {
- opt, done, err := it.Next()
- if err != nil {
- t.Errorf("(i=%d) Next(): %s", i, err)
- }
- if done {
- t.Errorf("(i=%d) unexpectedly done iterating", i)
- }
- if diff := cmp.Diff(e, opt); diff != "" {
- t.Errorf("(i=%d) got option mismatch (-want +got):\n%s", i, diff)
- }
-
- if t.Failed() {
- t.FailNow()
- }
- }
-
- opt, done, err := it.Next()
- if err != nil {
- t.Errorf("(last) Next(): %s", err)
- }
- if !done {
- t.Errorf("(last) iterator unexpectedly not done")
- }
- if opt != nil {
- t.Errorf("(last) got Next() = %T, want = nil", opt)
- }
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- t.Run("Hop By Hop", func(t *testing.T) {
- extHdr := IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: test.bytes}
- checkIter(t, extHdr.Iter(), test.expected)
- })
-
- t.Run("Destination", func(t *testing.T) {
- extHdr := IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: test.bytes}
- checkIter(t, extHdr.Iter(), test.expected)
- })
- })
- }
-}
-
-func TestIPv6RoutingExtHdr(t *testing.T) {
- tests := []struct {
- name string
- bytes []byte
- segmentsLeft uint8
- }{
- {
- name: "Zeroes",
- bytes: []byte{0, 0, 0, 0, 0, 0},
- segmentsLeft: 0,
- },
- {
- name: "Ones",
- bytes: []byte{1, 1, 1, 1, 1, 1},
- segmentsLeft: 1,
- },
- {
- name: "Mixed",
- bytes: []byte{1, 2, 3, 4, 5, 6},
- segmentsLeft: 2,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- extHdr := IPv6RoutingExtHdr(test.bytes)
- if got := extHdr.SegmentsLeft(); got != test.segmentsLeft {
- t.Errorf("got SegmentsLeft() = %d, want = %d", got, test.segmentsLeft)
- }
- })
- }
-}
-
-func TestIPv6FragmentExtHdr(t *testing.T) {
- tests := []struct {
- name string
- bytes [6]byte
- fragmentOffset uint16
- more bool
- id uint32
- }{
- {
- name: "Zeroes",
- bytes: [6]byte{0, 0, 0, 0, 0, 0},
- fragmentOffset: 0,
- more: false,
- id: 0,
- },
- {
- name: "Ones",
- bytes: [6]byte{0, 9, 0, 0, 0, 1},
- fragmentOffset: 1,
- more: true,
- id: 1,
- },
- {
- name: "Mixed",
- bytes: [6]byte{68, 9, 128, 4, 2, 1},
- fragmentOffset: 2177,
- more: true,
- id: 2147746305,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- extHdr := IPv6FragmentExtHdr(test.bytes)
- if got := extHdr.FragmentOffset(); got != test.fragmentOffset {
- t.Errorf("got FragmentOffset() = %d, want = %d", got, test.fragmentOffset)
- }
- if got := extHdr.More(); got != test.more {
- t.Errorf("got More() = %t, want = %t", got, test.more)
- }
- if got := extHdr.ID(); got != test.id {
- t.Errorf("got ID() = %d, want = %d", got, test.id)
- }
- })
- }
-}
-
-func makeVectorisedViewFromByteBuffers(bs ...[]byte) buffer.VectorisedView {
- size := 0
- var vs []buffer.View
-
- for _, b := range bs {
- vs = append(vs, buffer.View(b))
- size += len(b)
- }
-
- return buffer.NewVectorisedView(size, vs)
-}
-
-func TestIPv6ExtHdrIterErr(t *testing.T) {
- tests := []struct {
- name string
- firstNextHdr IPv6ExtensionHeaderIdentifier
- payload buffer.VectorisedView
- err error
- }{
- {
- name: "Upper layer only without data",
- firstNextHdr: 255,
- },
- {
- name: "Upper layer only with data",
- firstNextHdr: 255,
- payload: makeVectorisedViewFromByteBuffers([]byte{1, 2, 3, 4}),
- },
- {
- name: "No next header",
- firstNextHdr: IPv6NoNextHeaderIdentifier,
- },
- {
- name: "No next header with data",
- firstNextHdr: IPv6NoNextHeaderIdentifier,
- payload: makeVectorisedViewFromByteBuffers([]byte{1, 2, 3, 4}),
- },
- {
- name: "Valid single hop by hop",
- firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
- payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3, 4}),
- },
- {
- name: "Hop by hop too small",
- firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
- payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3}),
- err: io.ErrUnexpectedEOF,
- },
- {
- name: "Valid single fragment",
- firstNextHdr: IPv6FragmentExtHdrIdentifier,
- payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 68, 9, 128, 4, 2, 1}),
- },
- {
- name: "Fragment too small",
- firstNextHdr: IPv6FragmentExtHdrIdentifier,
- payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 68, 9, 128, 4, 2}),
- err: io.ErrUnexpectedEOF,
- },
- {
- name: "Valid single destination",
- firstNextHdr: IPv6DestinationOptionsExtHdrIdentifier,
- payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3, 4}),
- },
- {
- name: "Destination too small",
- firstNextHdr: IPv6DestinationOptionsExtHdrIdentifier,
- payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3}),
- err: io.ErrUnexpectedEOF,
- },
- {
- name: "Valid single routing",
- firstNextHdr: IPv6RoutingExtHdrIdentifier,
- payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2, 3, 4, 5, 6}),
- },
- {
- name: "Valid single routing across views",
- firstNextHdr: IPv6RoutingExtHdrIdentifier,
- payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2}, []byte{3, 4, 5, 6}),
- },
- {
- name: "Routing too small with zero length field",
- firstNextHdr: IPv6RoutingExtHdrIdentifier,
- payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2, 3, 4, 5}),
- err: io.ErrUnexpectedEOF,
- },
- {
- name: "Valid routing with non-zero length field",
- firstNextHdr: IPv6RoutingExtHdrIdentifier,
- payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8}),
- },
- {
- name: "Valid routing with non-zero length field across views",
- firstNextHdr: IPv6RoutingExtHdrIdentifier,
- payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6}, []byte{1, 2, 3, 4, 5, 6, 7, 8}),
- },
- {
- name: "Routing too small with non-zero length field",
- firstNextHdr: IPv6RoutingExtHdrIdentifier,
- payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7}),
- err: io.ErrUnexpectedEOF,
- },
- {
- name: "Routing too small with non-zero length field across views",
- firstNextHdr: IPv6RoutingExtHdrIdentifier,
- payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6}, []byte{1, 2, 3, 4, 5, 6, 7}),
- err: io.ErrUnexpectedEOF,
- },
- {
- name: "Mixed",
- firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
- payload: makeVectorisedViewFromByteBuffers([]byte{
- // Hop By Hop Options extension header.
- uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
-
- // (Atomic) Fragment extension header.
- //
- // Reserved bits are 1 which should not affect anything.
- uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1,
-
- // Routing extension header.
- uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
-
- // Destination Options extension header.
- 255, 0, 255, 4, 1, 2, 3, 4,
-
- // Upper layer data.
- 1, 2, 3, 4,
- }),
- },
- {
- name: "Mixed without upper layer data",
- firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
- payload: makeVectorisedViewFromByteBuffers([]byte{
- // Hop By Hop Options extension header.
- uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
-
- // (Atomic) Fragment extension header.
- //
- // Reserved bits are 1 which should not affect anything.
- uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1,
-
- // Routing extension header.
- uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
-
- // Destination Options extension header.
- 255, 0, 255, 4, 1, 2, 3, 4,
- }),
- },
- {
- name: "Mixed without upper layer data but last ext hdr too small",
- firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
- payload: makeVectorisedViewFromByteBuffers([]byte{
- // Hop By Hop Options extension header.
- uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
-
- // (Atomic) Fragment extension header.
- //
- // Reserved bits are 1 which should not affect anything.
- uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1,
-
- // Routing extension header.
- uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
-
- // Destination Options extension header.
- 255, 0, 255, 4, 1, 2, 3,
- }),
- err: io.ErrUnexpectedEOF,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- it := MakeIPv6PayloadIterator(test.firstNextHdr, test.payload)
-
- for i := 0; ; i++ {
- _, done, err := it.Next()
- if err != nil {
- // If we encountered a non-nil error while iterating, make sure it is
- // is the same error as test.err.
- if !errors.Is(err, test.err) {
- t.Fatalf("got %d-th Next() = %v, want = %v", i, err, test.err)
- }
-
- return
- }
- if done {
- // If we are done (without an error), make sure that we did not expect
- // an error.
- if test.err != nil {
- t.Fatalf("expected error when iterating; want = %s", test.err)
- }
-
- return
- }
- }
- })
- }
-}
-
-func TestIPv6ExtHdrIter(t *testing.T) {
- routingExtHdrWithUpperLayerData := buffer.View([]byte{255, 0, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4})
- upperLayerData := buffer.View([]byte{1, 2, 3, 4})
- tests := []struct {
- name string
- firstNextHdr IPv6ExtensionHeaderIdentifier
- payload buffer.VectorisedView
- expected []IPv6PayloadHeader
- }{
- // With a non-atomic fragment that is not the first fragment, the payload
- // after the fragment will not be parsed because the payload is expected to
- // only hold upper layer data.
- {
- name: "hopbyhop - fragment (not first) - routing - upper",
- firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
- payload: makeVectorisedViewFromByteBuffers([]byte{
- // Hop By Hop extension header.
- uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
-
- // Fragment extension header.
- //
- // More = 1, Fragment Offset = 2117, ID = 2147746305
- uint8(IPv6RoutingExtHdrIdentifier), 0, 68, 9, 128, 4, 2, 1,
-
- // Routing extension header.
- //
- // Even though we have a routing ext header here, it should be
- // be interpretted as raw bytes as only the first fragment is expected
- // to hold headers.
- 255, 0, 1, 2, 3, 4, 5, 6,
-
- // Upper layer data.
- 1, 2, 3, 4,
- }),
- expected: []IPv6PayloadHeader{
- IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}},
- IPv6FragmentExtHdr([6]byte{68, 9, 128, 4, 2, 1}),
- IPv6RawPayloadHeader{
- Identifier: IPv6RoutingExtHdrIdentifier,
- Buf: routingExtHdrWithUpperLayerData.ToVectorisedView(),
- },
- },
- },
- {
- name: "hopbyhop - fragment (first) - routing - upper",
- firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
- payload: makeVectorisedViewFromByteBuffers([]byte{
- // Hop By Hop extension header.
- uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
-
- // Fragment extension header.
- //
- // More = 1, Fragment Offset = 0, ID = 2147746305
- uint8(IPv6RoutingExtHdrIdentifier), 0, 0, 1, 128, 4, 2, 1,
-
- // Routing extension header.
- 255, 0, 1, 2, 3, 4, 5, 6,
-
- // Upper layer data.
- 1, 2, 3, 4,
- }),
- expected: []IPv6PayloadHeader{
- IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}},
- IPv6FragmentExtHdr([6]byte{0, 1, 128, 4, 2, 1}),
- IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
- IPv6RawPayloadHeader{
- Identifier: 255,
- Buf: upperLayerData.ToVectorisedView(),
- },
- },
- },
- {
- name: "fragment - routing - upper (across views)",
- firstNextHdr: IPv6FragmentExtHdrIdentifier,
- payload: makeVectorisedViewFromByteBuffers([]byte{
- // Fragment extension header.
- uint8(IPv6RoutingExtHdrIdentifier), 0, 68, 9, 128, 4, 2, 1,
-
- // Routing extension header.
- 255, 0, 1, 2}, []byte{3, 4, 5, 6,
-
- // Upper layer data.
- 1, 2, 3, 4,
- }),
- expected: []IPv6PayloadHeader{
- IPv6FragmentExtHdr([6]byte{68, 9, 128, 4, 2, 1}),
- IPv6RawPayloadHeader{
- Identifier: IPv6RoutingExtHdrIdentifier,
- Buf: routingExtHdrWithUpperLayerData.ToVectorisedView(),
- },
- },
- },
-
- // If we have an atomic fragment, the payload following the fragment
- // extension header should be parsed normally.
- {
- name: "atomic fragment - routing - destination - upper",
- firstNextHdr: IPv6FragmentExtHdrIdentifier,
- payload: makeVectorisedViewFromByteBuffers([]byte{
- // Fragment extension header.
- //
- // Reserved bits are 1 which should not affect anything.
- uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1,
-
- // Routing extension header.
- uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
-
- // Destination Options extension header.
- 255, 0, 1, 4, 1, 2, 3, 4,
-
- // Upper layer data.
- 1, 2, 3, 4,
- }),
- expected: []IPv6PayloadHeader{
- IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}),
- IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
- IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}},
- IPv6RawPayloadHeader{
- Identifier: 255,
- Buf: upperLayerData.ToVectorisedView(),
- },
- },
- },
- {
- name: "atomic fragment - routing - upper (across views)",
- firstNextHdr: IPv6FragmentExtHdrIdentifier,
- payload: makeVectorisedViewFromByteBuffers([]byte{
- // Fragment extension header.
- //
- // Reserved bits are 1 which should not affect anything.
- uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6}, []byte{128, 4, 2, 1,
-
- // Routing extension header.
- 255, 0, 1, 2}, []byte{3, 4, 5, 6,
-
- // Upper layer data.
- 1, 2}, []byte{3, 4}),
- expected: []IPv6PayloadHeader{
- IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}),
- IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
- IPv6RawPayloadHeader{
- Identifier: 255,
- Buf: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]),
- },
- },
- },
- {
- name: "atomic fragment - destination - no next header",
- firstNextHdr: IPv6FragmentExtHdrIdentifier,
- payload: makeVectorisedViewFromByteBuffers([]byte{
- // Fragment extension header.
- //
- // Res (Reserved) bits are 1 which should not affect anything.
- uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 0, 6, 128, 4, 2, 1,
-
- // Destination Options extension header.
- uint8(IPv6NoNextHeaderIdentifier), 0, 1, 4, 1, 2, 3, 4,
-
- // Random data.
- 1, 2, 3, 4,
- }),
- expected: []IPv6PayloadHeader{
- IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}),
- IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}},
- },
- },
- {
- name: "routing - atomic fragment - no next header",
- firstNextHdr: IPv6RoutingExtHdrIdentifier,
- payload: makeVectorisedViewFromByteBuffers([]byte{
- // Routing extension header.
- uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
-
- // Fragment extension header.
- //
- // Reserved bits are 1 which should not affect anything.
- uint8(IPv6NoNextHeaderIdentifier), 0, 0, 6, 128, 4, 2, 1,
-
- // Random data.
- 1, 2, 3, 4,
- }),
- expected: []IPv6PayloadHeader{
- IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
- IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}),
- },
- },
- {
- name: "routing - atomic fragment - no next header (across views)",
- firstNextHdr: IPv6RoutingExtHdrIdentifier,
- payload: makeVectorisedViewFromByteBuffers([]byte{
- // Routing extension header.
- uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
-
- // Fragment extension header.
- //
- // Reserved bits are 1 which should not affect anything.
- uint8(IPv6NoNextHeaderIdentifier), 255, 0, 6}, []byte{128, 4, 2, 1,
-
- // Random data.
- 1, 2, 3, 4,
- }),
- expected: []IPv6PayloadHeader{
- IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
- IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}),
- },
- },
- {
- name: "hopbyhop - routing - fragment - no next header",
- firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
- payload: makeVectorisedViewFromByteBuffers([]byte{
- // Hop By Hop Options extension header.
- uint8(IPv6RoutingExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
-
- // Routing extension header.
- uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
-
- // Fragment extension header.
- //
- // Fragment Offset = 32; Res = 6.
- uint8(IPv6NoNextHeaderIdentifier), 0, 1, 6, 128, 4, 2, 1,
-
- // Random data.
- 1, 2, 3, 4,
- }),
- expected: []IPv6PayloadHeader{
- IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}},
- IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
- IPv6FragmentExtHdr([6]byte{1, 6, 128, 4, 2, 1}),
- IPv6RawPayloadHeader{
- Identifier: IPv6NoNextHeaderIdentifier,
- Buf: upperLayerData.ToVectorisedView(),
- },
- },
- },
-
- // Test the raw payload for common transport layer protocol numbers.
- {
- name: "TCP raw payload",
- firstNextHdr: IPv6ExtensionHeaderIdentifier(TCPProtocolNumber),
- payload: makeVectorisedViewFromByteBuffers(upperLayerData),
- expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
- Identifier: IPv6ExtensionHeaderIdentifier(TCPProtocolNumber),
- Buf: upperLayerData.ToVectorisedView(),
- }},
- },
- {
- name: "UDP raw payload",
- firstNextHdr: IPv6ExtensionHeaderIdentifier(UDPProtocolNumber),
- payload: makeVectorisedViewFromByteBuffers(upperLayerData),
- expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
- Identifier: IPv6ExtensionHeaderIdentifier(UDPProtocolNumber),
- Buf: upperLayerData.ToVectorisedView(),
- }},
- },
- {
- name: "ICMPv4 raw payload",
- firstNextHdr: IPv6ExtensionHeaderIdentifier(ICMPv4ProtocolNumber),
- payload: makeVectorisedViewFromByteBuffers(upperLayerData),
- expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
- Identifier: IPv6ExtensionHeaderIdentifier(ICMPv4ProtocolNumber),
- Buf: upperLayerData.ToVectorisedView(),
- }},
- },
- {
- name: "ICMPv6 raw payload",
- firstNextHdr: IPv6ExtensionHeaderIdentifier(ICMPv6ProtocolNumber),
- payload: makeVectorisedViewFromByteBuffers(upperLayerData),
- expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
- Identifier: IPv6ExtensionHeaderIdentifier(ICMPv6ProtocolNumber),
- Buf: upperLayerData.ToVectorisedView(),
- }},
- },
- {
- name: "Unknwon next header raw payload",
- firstNextHdr: 255,
- payload: makeVectorisedViewFromByteBuffers(upperLayerData),
- expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
- Identifier: 255,
- Buf: upperLayerData.ToVectorisedView(),
- }},
- },
- {
- name: "Unknwon next header raw payload (across views)",
- firstNextHdr: 255,
- payload: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]),
- expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
- Identifier: 255,
- Buf: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]),
- }},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- it := MakeIPv6PayloadIterator(test.firstNextHdr, test.payload)
-
- for i, e := range test.expected {
- extHdr, done, err := it.Next()
- if err != nil {
- t.Errorf("(i=%d) Next(): %s", i, err)
- }
- if done {
- t.Errorf("(i=%d) unexpectedly done iterating", i)
- }
- if diff := cmp.Diff(e, extHdr); diff != "" {
- t.Errorf("(i=%d) got ext hdr mismatch (-want +got):\n%s", i, diff)
- }
-
- if t.Failed() {
- t.FailNow()
- }
- }
-
- extHdr, done, err := it.Next()
- if err != nil {
- t.Errorf("(last) Next(): %s", err)
- }
- if !done {
- t.Errorf("(last) iterator unexpectedly not done")
- }
- if extHdr != nil {
- t.Errorf("(last) got Next() = %T, want = nil", extHdr)
- }
- })
- }
-}
-
-var _ IPv6SerializableHopByHopOption = (*dummyHbHOptionSerializer)(nil)
-
-// dummyHbHOptionSerializer provides a generic implementation of
-// IPv6SerializableHopByHopOption for use in tests.
-type dummyHbHOptionSerializer struct {
- id IPv6ExtHdrOptionIdentifier
- payload []byte
- align int
- alignOffset int
-}
-
-// identifier implements IPv6SerializableHopByHopOption.
-func (s *dummyHbHOptionSerializer) identifier() IPv6ExtHdrOptionIdentifier {
- return s.id
-}
-
-// length implements IPv6SerializableHopByHopOption.
-func (s *dummyHbHOptionSerializer) length() uint8 {
- return uint8(len(s.payload))
-}
-
-// alignment implements IPv6SerializableHopByHopOption.
-func (s *dummyHbHOptionSerializer) alignment() (int, int) {
- align := 1
- if s.align != 0 {
- align = s.align
- }
- return align, s.alignOffset
-}
-
-// serializeInto implements IPv6SerializableHopByHopOption.
-func (s *dummyHbHOptionSerializer) serializeInto(b []byte) uint8 {
- return uint8(copy(b, s.payload))
-}
-
-func TestIPv6HopByHopSerializer(t *testing.T) {
- validateDummies := func(t *testing.T, serializable IPv6SerializableHopByHopOption, deserialized IPv6ExtHdrOption) {
- t.Helper()
- dummy, ok := serializable.(*dummyHbHOptionSerializer)
- if !ok {
- t.Fatalf("got serializable = %T, want = *dummyHbHOptionSerializer", serializable)
- }
- unknown, ok := deserialized.(*IPv6UnknownExtHdrOption)
- if !ok {
- t.Fatalf("got deserialized = %T, want = %T", deserialized, &IPv6UnknownExtHdrOption{})
- }
- if dummy.id != unknown.Identifier {
- t.Errorf("got deserialized identifier = %d, want = %d", unknown.Identifier, dummy.id)
- }
- if diff := cmp.Diff(dummy.payload, unknown.Data); diff != "" {
- t.Errorf("option payload deserialization mismatch (-want +got):\n%s", diff)
- }
- }
- tests := []struct {
- name string
- nextHeader uint8
- options []IPv6SerializableHopByHopOption
- expect []byte
- validate func(*testing.T, IPv6SerializableHopByHopOption, IPv6ExtHdrOption)
- }{
- {
- name: "single option",
- nextHeader: 13,
- options: []IPv6SerializableHopByHopOption{
- &dummyHbHOptionSerializer{
- id: 15,
- payload: []byte{9, 8, 7, 6},
- },
- },
- expect: []byte{13, 0, 15, 4, 9, 8, 7, 6},
- validate: validateDummies,
- },
- {
- name: "short option padN zero",
- nextHeader: 88,
- options: []IPv6SerializableHopByHopOption{
- &dummyHbHOptionSerializer{
- id: 22,
- payload: []byte{4, 5},
- },
- },
- expect: []byte{88, 0, 22, 2, 4, 5, 1, 0},
- validate: validateDummies,
- },
- {
- name: "short option pad1",
- nextHeader: 11,
- options: []IPv6SerializableHopByHopOption{
- &dummyHbHOptionSerializer{
- id: 33,
- payload: []byte{1, 2, 3},
- },
- },
- expect: []byte{11, 0, 33, 3, 1, 2, 3, 0},
- validate: validateDummies,
- },
- {
- name: "long option padN",
- nextHeader: 55,
- options: []IPv6SerializableHopByHopOption{
- &dummyHbHOptionSerializer{
- id: 77,
- payload: []byte{1, 2, 3, 4, 5, 6, 7, 8},
- },
- },
- expect: []byte{55, 1, 77, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 0, 0},
- validate: validateDummies,
- },
- {
- name: "two options",
- nextHeader: 33,
- options: []IPv6SerializableHopByHopOption{
- &dummyHbHOptionSerializer{
- id: 11,
- payload: []byte{1, 2, 3},
- },
- &dummyHbHOptionSerializer{
- id: 22,
- payload: []byte{4, 5, 6},
- },
- },
- expect: []byte{33, 1, 11, 3, 1, 2, 3, 22, 3, 4, 5, 6, 1, 2, 0, 0},
- validate: validateDummies,
- },
- {
- name: "two options align 2n",
- nextHeader: 33,
- options: []IPv6SerializableHopByHopOption{
- &dummyHbHOptionSerializer{
- id: 11,
- payload: []byte{1, 2, 3},
- },
- &dummyHbHOptionSerializer{
- id: 22,
- payload: []byte{4, 5, 6},
- align: 2,
- },
- },
- expect: []byte{33, 1, 11, 3, 1, 2, 3, 0, 22, 3, 4, 5, 6, 1, 1, 0},
- validate: validateDummies,
- },
- {
- name: "two options align 8n+1",
- nextHeader: 33,
- options: []IPv6SerializableHopByHopOption{
- &dummyHbHOptionSerializer{
- id: 11,
- payload: []byte{1, 2},
- },
- &dummyHbHOptionSerializer{
- id: 22,
- payload: []byte{4, 5, 6},
- align: 8,
- alignOffset: 1,
- },
- },
- expect: []byte{33, 1, 11, 2, 1, 2, 1, 1, 0, 22, 3, 4, 5, 6, 1, 0},
- validate: validateDummies,
- },
- {
- name: "no options",
- nextHeader: 33,
- options: []IPv6SerializableHopByHopOption{},
- expect: []byte{33, 0, 1, 4, 0, 0, 0, 0},
- },
- {
- name: "Router Alert",
- nextHeader: 33,
- options: []IPv6SerializableHopByHopOption{&IPv6RouterAlertOption{Value: IPv6RouterAlertMLD}},
- expect: []byte{33, 0, 5, 2, 0, 0, 1, 0},
- validate: func(t *testing.T, _ IPv6SerializableHopByHopOption, deserialized IPv6ExtHdrOption) {
- t.Helper()
- routerAlert, ok := deserialized.(*IPv6RouterAlertOption)
- if !ok {
- t.Fatalf("got deserialized = %T, want = *IPv6RouterAlertOption", deserialized)
- }
- if routerAlert.Value != IPv6RouterAlertMLD {
- t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, IPv6RouterAlertMLD)
- }
- },
- },
- }
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := IPv6SerializableHopByHopExtHdr(test.options)
- length := s.length()
- if length != len(test.expect) {
- t.Fatalf("got s.length() = %d, want = %d", length, len(test.expect))
- }
- b := make([]byte, length)
- for i := range b {
- // Fill the buffer with ones to ensure all padding is correctly set.
- b[i] = 0xFF
- }
- if got := s.serializeInto(test.nextHeader, b); got != length {
- t.Fatalf("got s.serializeInto(..) = %d, want = %d", got, length)
- }
- if diff := cmp.Diff(test.expect, b); diff != "" {
- t.Fatalf("serialization mismatch (-want +got):\n%s", diff)
- }
-
- // Deserialize the options and verify them.
- optLen := (b[ipv6HopByHopExtHdrLengthOffset] + ipv6HopByHopExtHdrUnaccountedLenWords) * ipv6ExtHdrLenBytesPerUnit
- iter := ipv6OptionsExtHdr(b[ipv6HopByHopExtHdrOptionsOffset:optLen]).Iter()
- for _, testOpt := range test.options {
- opt, done, err := iter.Next()
- if err != nil {
- t.Fatalf("iter.Next(): %s", err)
- }
- if done {
- t.Fatalf("got iter.Next() = (%T, %t, _), want = (_, false, _)", opt, done)
- }
- test.validate(t, testOpt, opt)
- }
- opt, done, err := iter.Next()
- if err != nil {
- t.Fatalf("iter.Next(): %s", err)
- }
- if !done {
- t.Fatalf("got iter.Next() = (%T, %t, _), want = (_, true, _)", opt, done)
- }
- })
- }
-}
-
-var _ IPv6SerializableExtHdr = (*dummyIPv6ExtHdrSerializer)(nil)
-
-// dummyIPv6ExtHdrSerializer provides a generic implementation of
-// IPv6SerializableExtHdr for use in tests.
-//
-// The dummy header always carries the nextHeader value in the first byte.
-type dummyIPv6ExtHdrSerializer struct {
- id IPv6ExtensionHeaderIdentifier
- headerContents []byte
-}
-
-// identifier implements IPv6SerializableExtHdr.
-func (s *dummyIPv6ExtHdrSerializer) identifier() IPv6ExtensionHeaderIdentifier {
- return s.id
-}
-
-// length implements IPv6SerializableExtHdr.
-func (s *dummyIPv6ExtHdrSerializer) length() int {
- return len(s.headerContents) + 1
-}
-
-// serializeInto implements IPv6SerializableExtHdr.
-func (s *dummyIPv6ExtHdrSerializer) serializeInto(nextHeader uint8, b []byte) int {
- b[0] = nextHeader
- return copy(b[1:], s.headerContents) + 1
-}
-
-func TestIPv6ExtHdrSerializer(t *testing.T) {
- tests := []struct {
- name string
- headers []IPv6SerializableExtHdr
- nextHeader tcpip.TransportProtocolNumber
- expectSerialized []byte
- expectNextHeader uint8
- }{
- {
- name: "one header",
- headers: []IPv6SerializableExtHdr{
- &dummyIPv6ExtHdrSerializer{
- id: 15,
- headerContents: []byte{1, 2, 3, 4},
- },
- },
- nextHeader: TCPProtocolNumber,
- expectSerialized: []byte{byte(TCPProtocolNumber), 1, 2, 3, 4},
- expectNextHeader: 15,
- },
- {
- name: "two headers",
- headers: []IPv6SerializableExtHdr{
- &dummyIPv6ExtHdrSerializer{
- id: 22,
- headerContents: []byte{1, 2, 3},
- },
- &dummyIPv6ExtHdrSerializer{
- id: 23,
- headerContents: []byte{4, 5, 6},
- },
- },
- nextHeader: ICMPv6ProtocolNumber,
- expectSerialized: []byte{
- 23, 1, 2, 3,
- byte(ICMPv6ProtocolNumber), 4, 5, 6,
- },
- expectNextHeader: 22,
- },
- {
- name: "no headers",
- headers: []IPv6SerializableExtHdr{},
- nextHeader: UDPProtocolNumber,
- expectSerialized: []byte{},
- expectNextHeader: byte(UDPProtocolNumber),
- },
- }
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := IPv6ExtHdrSerializer(test.headers)
- l := s.Length()
- if got, want := l, len(test.expectSerialized); got != want {
- t.Fatalf("got serialized length = %d, want = %d", got, want)
- }
- b := make([]byte, l)
- for i := range b {
- // Fill the buffer with garbage to make sure we're writing to all bytes.
- b[i] = 0xFF
- }
- nextHeader, serializedLen := s.Serialize(test.nextHeader, b)
- if serializedLen != len(test.expectSerialized) || nextHeader != test.expectNextHeader {
- t.Errorf(
- "got s.Serialize(..) = (%d, %d), want = (%d, %d)",
- nextHeader,
- serializedLen,
- test.expectNextHeader,
- len(test.expectSerialized),
- )
- }
- if diff := cmp.Diff(test.expectSerialized, b); diff != "" {
- t.Errorf("serialization mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go
deleted file mode 100644
index 89be84068..000000000
--- a/pkg/tcpip/header/ipv6_test.go
+++ /dev/null
@@ -1,457 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package header_test
-
-import (
- "bytes"
- "crypto/sha256"
- "fmt"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/rand"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
-)
-
-const linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
-
-var (
- linkLocalAddr = testutil.MustParse6("fe80::1")
- linkLocalMulticastAddr = testutil.MustParse6("ff02::1")
- uniqueLocalAddr1 = testutil.MustParse6("fc00::1")
- uniqueLocalAddr2 = testutil.MustParse6("fd00::2")
- globalAddr = testutil.MustParse6("a000::1")
-)
-
-func TestEthernetAdddressToModifiedEUI64(t *testing.T) {
- expectedIID := [header.IIDSize]byte{0, 2, 3, 255, 254, 4, 5, 6}
-
- if diff := cmp.Diff(expectedIID, header.EthernetAddressToModifiedEUI64(linkAddr)); diff != "" {
- t.Errorf("EthernetAddressToModifiedEUI64(%s) mismatch (-want +got):\n%s", linkAddr, diff)
- }
-
- var buf [header.IIDSize]byte
- header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:])
- if diff := cmp.Diff(expectedIID, buf); diff != "" {
- t.Errorf("EthernetAddressToModifiedEUI64IntoBuf(%s, _) mismatch (-want +got):\n%s", linkAddr, diff)
- }
-}
-
-func TestLinkLocalAddr(t *testing.T) {
- if got, want := header.LinkLocalAddr(linkAddr), testutil.MustParse6("fe80::2:3ff:fe04:506"); got != want {
- t.Errorf("got LinkLocalAddr(%s) = %s, want = %s", linkAddr, got, want)
- }
-}
-
-func TestAppendOpaqueInterfaceIdentifier(t *testing.T) {
- var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte
- if n, err := rand.Read(secretKeyBuf[:]); err != nil {
- t.Fatalf("rand.Read(_): %s", err)
- } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want {
- t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n)
- }
-
- tests := []struct {
- name string
- prefix tcpip.Subnet
- nicName string
- dadCounter uint8
- secretKey []byte
- }{
- {
- name: "SecretKey of minimum size",
- prefix: header.IPv6LinkLocalPrefix.Subnet(),
- nicName: "eth0",
- dadCounter: 0,
- secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes],
- },
- {
- name: "SecretKey of less than minimum size",
- prefix: func() tcpip.Subnet {
- addrWithPrefix := tcpip.AddressWithPrefix{
- Address: "\x01\x02\x03\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
- PrefixLen: header.IIDOffsetInIPv6Address * 8,
- }
- return addrWithPrefix.Subnet()
- }(),
- nicName: "eth10",
- dadCounter: 1,
- secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2],
- },
- {
- name: "SecretKey of more than minimum size",
- prefix: func() tcpip.Subnet {
- addrWithPrefix := tcpip.AddressWithPrefix{
- Address: "\x01\x02\x03\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
- PrefixLen: header.IIDOffsetInIPv6Address * 8,
- }
- return addrWithPrefix.Subnet()
- }(),
- nicName: "eth11",
- dadCounter: 2,
- secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2],
- },
- {
- name: "Nil SecretKey and empty nicName",
- prefix: func() tcpip.Subnet {
- addrWithPrefix := tcpip.AddressWithPrefix{
- Address: "\x01\x02\x03\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
- PrefixLen: header.IIDOffsetInIPv6Address * 8,
- }
- return addrWithPrefix.Subnet()
- }(),
- nicName: "",
- dadCounter: 3,
- secretKey: nil,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- h := sha256.New()
- h.Write([]byte(test.prefix.ID()[:header.IIDOffsetInIPv6Address]))
- h.Write([]byte(test.nicName))
- h.Write([]byte{test.dadCounter})
- if k := test.secretKey; k != nil {
- h.Write(k)
- }
- var hashSum [sha256.Size]byte
- h.Sum(hashSum[:0])
- want := hashSum[:header.IIDSize]
-
- // Passing a nil buffer should result in a new buffer returned with the
- // IID.
- if got := header.AppendOpaqueInterfaceIdentifier(nil, test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) {
- t.Errorf("got AppendOpaqueInterfaceIdentifier(nil, %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want)
- }
-
- // Passing a buffer with sufficient capacity for the IID should populate
- // the buffer provided.
- var iidBuf [header.IIDSize]byte
- if got := header.AppendOpaqueInterfaceIdentifier(iidBuf[:0], test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) {
- t.Errorf("got AppendOpaqueInterfaceIdentifier(iidBuf[:0], %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want)
- }
- if got := iidBuf[:]; !bytes.Equal(got, want) {
- t.Errorf("got iidBuf = %x, want = %x", got, want)
- }
- })
- }
-}
-
-func TestLinkLocalAddrWithOpaqueIID(t *testing.T) {
- var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte
- if n, err := rand.Read(secretKeyBuf[:]); err != nil {
- t.Fatalf("rand.Read(_): %s", err)
- } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want {
- t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n)
- }
-
- prefix := header.IPv6LinkLocalPrefix.Subnet()
-
- tests := []struct {
- name string
- prefix tcpip.Subnet
- nicName string
- dadCounter uint8
- secretKey []byte
- }{
- {
- name: "SecretKey of minimum size",
- nicName: "eth0",
- dadCounter: 0,
- secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes],
- },
- {
- name: "SecretKey of less than minimum size",
- nicName: "eth10",
- dadCounter: 1,
- secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2],
- },
- {
- name: "SecretKey of more than minimum size",
- nicName: "eth11",
- dadCounter: 2,
- secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2],
- },
- {
- name: "Nil SecretKey and empty nicName",
- nicName: "",
- dadCounter: 3,
- secretKey: nil,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- addrBytes := [header.IPv6AddressSize]byte{
- 0: 0xFE,
- 1: 0x80,
- }
-
- want := tcpip.Address(header.AppendOpaqueInterfaceIdentifier(
- addrBytes[:header.IIDOffsetInIPv6Address],
- prefix,
- test.nicName,
- test.dadCounter,
- test.secretKey,
- ))
-
- if got := header.LinkLocalAddrWithOpaqueIID(test.nicName, test.dadCounter, test.secretKey); got != want {
- t.Errorf("got LinkLocalAddrWithOpaqueIID(%s, %d, %x) = %s, want = %s", test.nicName, test.dadCounter, test.secretKey, got, want)
- }
- })
- }
-}
-
-func TestIsV6LinkLocalMulticastAddress(t *testing.T) {
- tests := []struct {
- name string
- addr tcpip.Address
- expected bool
- }{
- {
- name: "Valid Link Local Multicast",
- addr: linkLocalMulticastAddr,
- expected: true,
- },
- {
- name: "Valid Link Local Multicast with flags",
- addr: "\xff\xf2\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- expected: true,
- },
- {
- name: "Link Local Unicast",
- addr: linkLocalAddr,
- expected: false,
- },
- {
- name: "IPv4 Multicast",
- addr: "\xe0\x00\x00\x01",
- expected: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- if got := header.IsV6LinkLocalMulticastAddress(test.addr); got != test.expected {
- t.Errorf("got header.IsV6LinkLocalMulticastAddress(%s) = %t, want = %t", test.addr, got, test.expected)
- }
- })
- }
-}
-
-func TestIsV6LinkLocalUnicastAddress(t *testing.T) {
- tests := []struct {
- name string
- addr tcpip.Address
- expected bool
- }{
- {
- name: "Valid Link Local Unicast",
- addr: linkLocalAddr,
- expected: true,
- },
- {
- name: "Link Local Multicast",
- addr: linkLocalMulticastAddr,
- expected: false,
- },
- {
- name: "Unique Local",
- addr: uniqueLocalAddr1,
- expected: false,
- },
- {
- name: "Global",
- addr: globalAddr,
- expected: false,
- },
- {
- name: "IPv4 Link Local",
- addr: "\xa9\xfe\x00\x01",
- expected: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- if got := header.IsV6LinkLocalUnicastAddress(test.addr); got != test.expected {
- t.Errorf("got header.IsV6LinkLocalUnicastAddress(%s) = %t, want = %t", test.addr, got, test.expected)
- }
- })
- }
-}
-
-func TestScopeForIPv6Address(t *testing.T) {
- tests := []struct {
- name string
- addr tcpip.Address
- scope header.IPv6AddressScope
- err tcpip.Error
- }{
- {
- name: "Unique Local",
- addr: uniqueLocalAddr1,
- scope: header.GlobalScope,
- err: nil,
- },
- {
- name: "Link Local Unicast",
- addr: linkLocalAddr,
- scope: header.LinkLocalScope,
- err: nil,
- },
- {
- name: "Link Local Multicast",
- addr: linkLocalMulticastAddr,
- scope: header.LinkLocalScope,
- err: nil,
- },
- {
- name: "Global",
- addr: globalAddr,
- scope: header.GlobalScope,
- err: nil,
- },
- {
- name: "IPv4",
- addr: "\x01\x02\x03\x04",
- scope: header.GlobalScope,
- err: &tcpip.ErrBadAddress{},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- got, err := header.ScopeForIPv6Address(test.addr)
- if diff := cmp.Diff(test.err, err); diff != "" {
- t.Errorf("unexpected error from header.IsV6UniqueLocalAddress(%s), (-want, +got):\n%s", test.addr, diff)
- }
- if got != test.scope {
- t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (%d, _), want = (%d, _)", test.addr, got, test.scope)
- }
- })
- }
-}
-
-func TestSolicitedNodeAddr(t *testing.T) {
- tests := []struct {
- addr tcpip.Address
- want tcpip.Address
- }{
- {
- addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\xa0",
- want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x0e\x0f\xa0",
- },
- {
- addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\xdd\x0e\x0f\xa0",
- want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x0e\x0f\xa0",
- },
- {
- addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\xdd\x01\x02\x03",
- want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x01\x02\x03",
- },
- }
-
- for _, test := range tests {
- t.Run(fmt.Sprintf("%s", test.addr), func(t *testing.T) {
- if got := header.SolicitedNodeAddr(test.addr); got != test.want {
- t.Fatalf("got header.SolicitedNodeAddr(%s) = %s, want = %s", test.addr, got, test.want)
- }
- })
- }
-}
-
-func TestV6MulticastScope(t *testing.T) {
- tests := []struct {
- addr tcpip.Address
- want header.IPv6MulticastScope
- }{
- {
- addr: "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- want: header.IPv6Reserved0MulticastScope,
- },
- {
- addr: "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- want: header.IPv6InterfaceLocalMulticastScope,
- },
- {
- addr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- want: header.IPv6LinkLocalMulticastScope,
- },
- {
- addr: "\xff\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- want: header.IPv6RealmLocalMulticastScope,
- },
- {
- addr: "\xff\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- want: header.IPv6AdminLocalMulticastScope,
- },
- {
- addr: "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- want: header.IPv6SiteLocalMulticastScope,
- },
- {
- addr: "\xff\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- want: header.IPv6MulticastScope(6),
- },
- {
- addr: "\xff\x07\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- want: header.IPv6MulticastScope(7),
- },
- {
- addr: "\xff\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- want: header.IPv6OrganizationLocalMulticastScope,
- },
- {
- addr: "\xff\x09\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- want: header.IPv6MulticastScope(9),
- },
- {
- addr: "\xff\x0a\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- want: header.IPv6MulticastScope(10),
- },
- {
- addr: "\xff\x0b\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- want: header.IPv6MulticastScope(11),
- },
- {
- addr: "\xff\x0c\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- want: header.IPv6MulticastScope(12),
- },
- {
- addr: "\xff\x0d\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- want: header.IPv6MulticastScope(13),
- },
- {
- addr: "\xff\x0e\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- want: header.IPv6GlobalMulticastScope,
- },
- {
- addr: "\xff\x0f\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- want: header.IPv6ReservedFMulticastScope,
- },
- }
-
- for _, test := range tests {
- t.Run(fmt.Sprintf("%s", test.addr), func(t *testing.T) {
- if got := header.V6MulticastScope(test.addr); got != test.want {
- t.Fatalf("got header.V6MulticastScope(%s) = %d, want = %d", test.addr, got, test.want)
- }
- })
- }
-}
diff --git a/pkg/tcpip/header/ipversion_test.go b/pkg/tcpip/header/ipversion_test.go
deleted file mode 100644
index b5540bf66..000000000
--- a/pkg/tcpip/header/ipversion_test.go
+++ /dev/null
@@ -1,67 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package header_test
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip/header"
-)
-
-func TestIPv4(t *testing.T) {
- b := header.IPv4(make([]byte, header.IPv4MinimumSize))
- b.Encode(&header.IPv4Fields{})
-
- const want = header.IPv4Version
- if v := header.IPVersion(b); v != want {
- t.Fatalf("Bad version, want %v, got %v", want, v)
- }
-}
-
-func TestIPv6(t *testing.T) {
- b := header.IPv6(make([]byte, header.IPv6MinimumSize))
- b.Encode(&header.IPv6Fields{})
-
- const want = header.IPv6Version
- if v := header.IPVersion(b); v != want {
- t.Fatalf("Bad version, want %v, got %v", want, v)
- }
-}
-
-func TestOtherVersion(t *testing.T) {
- const want = header.IPv4Version + header.IPv6Version
- b := make([]byte, 1)
- b[0] = want << 4
-
- if v := header.IPVersion(b); v != want {
- t.Fatalf("Bad version, want %v, got %v", want, v)
- }
-}
-
-func TestTooShort(t *testing.T) {
- b := make([]byte, 1)
- b[0] = (header.IPv4Version + header.IPv6Version) << 4
-
- // Get the version of a zero-length slice.
- const want = -1
- if v := header.IPVersion(b[:0]); v != want {
- t.Fatalf("Bad version, want %v, got %v", want, v)
- }
-
- // Get the version of a nil slice.
- if v := header.IPVersion(nil); v != want {
- t.Fatalf("Bad version, want %v, got %v", want, v)
- }
-}
diff --git a/pkg/tcpip/header/mld_test.go b/pkg/tcpip/header/mld_test.go
deleted file mode 100644
index 0cecf10d4..000000000
--- a/pkg/tcpip/header/mld_test.go
+++ /dev/null
@@ -1,61 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package header
-
-import (
- "encoding/binary"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
-)
-
-func TestMLD(t *testing.T) {
- b := []byte{
- // Maximum Response Delay
- 0, 0,
-
- // Reserved
- 0, 0,
-
- // MulticastAddress
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6,
- }
-
- const maxRespDelay = 513
- binary.BigEndian.PutUint16(b, maxRespDelay)
-
- mld := MLD(b)
-
- if got, want := mld.MaximumResponseDelay(), maxRespDelay*time.Millisecond; got != want {
- t.Errorf("got mld.MaximumResponseDelay() = %s, want = %s", got, want)
- }
-
- const newMaxRespDelay = 1234
- mld.SetMaximumResponseDelay(newMaxRespDelay)
- if got, want := mld.MaximumResponseDelay(), newMaxRespDelay*time.Millisecond; got != want {
- t.Errorf("got mld.MaximumResponseDelay() = %s, want = %s", got, want)
- }
-
- if got, want := mld.MulticastAddress(), tcpip.Address([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}); got != want {
- t.Errorf("got mld.MulticastAddress() = %s, want = %s", got, want)
- }
-
- multicastAddress := tcpip.Address([]byte{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0})
- mld.SetMulticastAddress(multicastAddress)
- if got := mld.MulticastAddress(); got != multicastAddress {
- t.Errorf("got mld.MulticastAddress() = %s, want = %s", got, multicastAddress)
- }
-}
diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go
deleted file mode 100644
index 2a897e938..000000000
--- a/pkg/tcpip/header/ndp_test.go
+++ /dev/null
@@ -1,1748 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package header
-
-import (
- "bytes"
- "encoding/binary"
- "errors"
- "fmt"
- "io"
- "regexp"
- "strings"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
-)
-
-// TestNDPNeighborSolicit tests the functions of NDPNeighborSolicit.
-func TestNDPNeighborSolicit(t *testing.T) {
- b := []byte{
- 0, 0, 0, 0,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- }
-
- // Test getting the Target Address.
- ns := NDPNeighborSolicit(b)
- addr := testutil.MustParse6("102:304:506:708:90a:b0c:d0e:f10")
- if got := ns.TargetAddress(); got != addr {
- t.Errorf("got ns.TargetAddress = %s, want %s", got, addr)
- }
-
- // Test updating the Target Address.
- addr2 := testutil.MustParse6("1112:1314:1516:1718:191a:1b1c:1d1e:1f11")
- ns.SetTargetAddress(addr2)
- if got := ns.TargetAddress(); got != addr2 {
- t.Errorf("got ns.TargetAddress = %s, want %s", got, addr2)
- }
- // Make sure the address got updated in the backing buffer.
- if got := tcpip.Address(b[ndpNSTargetAddessOffset:][:IPv6AddressSize]); got != addr2 {
- t.Errorf("got targetaddress buffer = %s, want %s", got, addr2)
- }
-}
-
-func TestNDPRouteInformationOption(t *testing.T) {
- tests := []struct {
- name string
-
- length uint8
- prefixLength uint8
- prf NDPRoutePreference
- lifetimeS uint32
- prefixBytes []byte
- expectedPrefix tcpip.Subnet
-
- expectedErr error
- }{
- {
- name: "Length=1 with Prefix Length = 0",
- length: 1,
- prefixLength: 0,
- prf: MediumRoutePreference,
- lifetimeS: 1,
- prefixBytes: nil,
- expectedPrefix: IPv6EmptySubnet,
- },
- {
- name: "Length=1 but Prefix Length > 0",
- length: 1,
- prefixLength: 1,
- prf: MediumRoutePreference,
- lifetimeS: 1,
- prefixBytes: nil,
- expectedErr: ErrNDPOptMalformedBody,
- },
- {
- name: "Length=2 with Prefix Length = 0",
- length: 2,
- prefixLength: 0,
- prf: MediumRoutePreference,
- lifetimeS: 1,
- prefixBytes: nil,
- expectedPrefix: IPv6EmptySubnet,
- },
- {
- name: "Length=2 with Prefix Length in [1, 64] (1)",
- length: 2,
- prefixLength: 1,
- prf: LowRoutePreference,
- lifetimeS: 1,
- prefixBytes: nil,
- expectedPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
- PrefixLen: 1,
- }.Subnet(),
- },
- {
- name: "Length=2 with Prefix Length in [1, 64] (64)",
- length: 2,
- prefixLength: 64,
- prf: HighRoutePreference,
- lifetimeS: 1,
- prefixBytes: nil,
- expectedPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
- PrefixLen: 64,
- }.Subnet(),
- },
- {
- name: "Length=2 with Prefix Length > 64",
- length: 2,
- prefixLength: 65,
- prf: HighRoutePreference,
- lifetimeS: 1,
- prefixBytes: nil,
- expectedErr: ErrNDPOptMalformedBody,
- },
- {
- name: "Length=3 with Prefix Length = 0",
- length: 3,
- prefixLength: 0,
- prf: MediumRoutePreference,
- lifetimeS: 1,
- prefixBytes: nil,
- expectedPrefix: IPv6EmptySubnet,
- },
- {
- name: "Length=3 with Prefix Length in [1, 64] (1)",
- length: 3,
- prefixLength: 1,
- prf: LowRoutePreference,
- lifetimeS: 1,
- prefixBytes: nil,
- expectedPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
- PrefixLen: 1,
- }.Subnet(),
- },
- {
- name: "Length=3 with Prefix Length in [1, 64] (64)",
- length: 3,
- prefixLength: 64,
- prf: HighRoutePreference,
- lifetimeS: 1,
- prefixBytes: nil,
- expectedPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
- PrefixLen: 64,
- }.Subnet(),
- },
- {
- name: "Length=3 with Prefix Length in [65, 128] (65)",
- length: 3,
- prefixLength: 65,
- prf: HighRoutePreference,
- lifetimeS: 1,
- prefixBytes: nil,
- expectedPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
- PrefixLen: 65,
- }.Subnet(),
- },
- {
- name: "Length=3 with Prefix Length in [65, 128] (128)",
- length: 3,
- prefixLength: 128,
- prf: HighRoutePreference,
- lifetimeS: 1,
- prefixBytes: nil,
- expectedPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
- PrefixLen: 128,
- }.Subnet(),
- },
- {
- name: "Length=3 with (invalid) Prefix Length > 128",
- length: 3,
- prefixLength: 129,
- prf: HighRoutePreference,
- lifetimeS: 1,
- prefixBytes: nil,
- expectedErr: ErrNDPOptMalformedBody,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- expectedRouteInformationBytes := [...]byte{
- // Type, Length
- 24, test.length,
-
- // Prefix Length, Prf
- uint8(test.prefixLength), uint8(test.prf) << 3,
-
- // Route Lifetime
- 0, 0, 0, 0,
-
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- }
- binary.BigEndian.PutUint32(expectedRouteInformationBytes[4:], test.lifetimeS)
- _ = copy(expectedRouteInformationBytes[8:], test.prefixBytes)
-
- opts := NDPOptions(expectedRouteInformationBytes[:test.length*lengthByteUnits])
- it, err := opts.Iter(false)
- if err != nil {
- t.Fatalf("got Iter(false) = (_, %s), want = (_, nil)", err)
- }
- opt, done, err := it.Next()
- if !errors.Is(err, test.expectedErr) {
- t.Fatalf("got Next() = (_, _, %s), want = (_, _, %s)", err, test.expectedErr)
- }
- if want := test.expectedErr != nil; done != want {
- t.Fatalf("got Next() = (_, %t, _), want = (_, %t, _)", done, want)
- }
- if test.expectedErr != nil {
- return
- }
-
- if got := opt.kind(); got != ndpRouteInformationType {
- t.Errorf("got kind() = %d, want = %d", got, ndpRouteInformationType)
- }
-
- ri, ok := opt.(NDPRouteInformation)
- if !ok {
- t.Fatalf("got opt = %T, want = NDPRouteInformation", opt)
- }
-
- if got := ri.PrefixLength(); got != test.prefixLength {
- t.Errorf("got PrefixLength() = %d, want = %d", got, test.prefixLength)
- }
- if got := ri.RoutePreference(); got != test.prf {
- t.Errorf("got RoutePreference() = %d, want = %d", got, test.prf)
- }
- if got, want := ri.RouteLifetime(), time.Duration(test.lifetimeS)*time.Second; got != want {
- t.Errorf("got RouteLifetime() = %s, want = %s", got, want)
- }
- if got, err := ri.Prefix(); err != nil {
- t.Errorf("Prefix(): %s", err)
- } else if got != test.expectedPrefix {
- t.Errorf("got Prefix() = %s, want = %s", got, test.expectedPrefix)
- }
-
- // Iterator should not return anything else.
- {
- next, done, err := it.Next()
- if err != nil {
- t.Errorf("got Next() = (_, _, %s), want = (_, _, nil)", err)
- }
- if !done {
- t.Error("got Next() = (_, false, _), want = (_, true, _)")
- }
- if next != nil {
- t.Errorf("got Next() = (%x, _, _), want = (nil, _, _)", next)
- }
- }
- })
- }
-}
-
-// TestNDPNeighborAdvert tests the functions of NDPNeighborAdvert.
-func TestNDPNeighborAdvert(t *testing.T) {
- b := []byte{
- 160, 0, 0, 0,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- }
-
- // Test getting the Target Address.
- na := NDPNeighborAdvert(b)
- addr := testutil.MustParse6("102:304:506:708:90a:b0c:d0e:f10")
- if got := na.TargetAddress(); got != addr {
- t.Errorf("got TargetAddress = %s, want %s", got, addr)
- }
-
- // Test getting the Router Flag.
- if got := na.RouterFlag(); !got {
- t.Errorf("got RouterFlag = false, want = true")
- }
-
- // Test getting the Solicited Flag.
- if got := na.SolicitedFlag(); got {
- t.Errorf("got SolicitedFlag = true, want = false")
- }
-
- // Test getting the Override Flag.
- if got := na.OverrideFlag(); !got {
- t.Errorf("got OverrideFlag = false, want = true")
- }
-
- // Test updating the Target Address.
- addr2 := testutil.MustParse6("1112:1314:1516:1718:191a:1b1c:1d1e:1f11")
- na.SetTargetAddress(addr2)
- if got := na.TargetAddress(); got != addr2 {
- t.Errorf("got TargetAddress = %s, want %s", got, addr2)
- }
- // Make sure the address got updated in the backing buffer.
- if got := tcpip.Address(b[ndpNATargetAddressOffset:][:IPv6AddressSize]); got != addr2 {
- t.Errorf("got targetaddress buffer = %s, want %s", got, addr2)
- }
-
- // Test updating the Router Flag.
- na.SetRouterFlag(false)
- if got := na.RouterFlag(); got {
- t.Errorf("got RouterFlag = true, want = false")
- }
-
- // Test updating the Solicited Flag.
- na.SetSolicitedFlag(true)
- if got := na.SolicitedFlag(); !got {
- t.Errorf("got SolicitedFlag = false, want = true")
- }
-
- // Test updating the Override Flag.
- na.SetOverrideFlag(false)
- if got := na.OverrideFlag(); got {
- t.Errorf("got OverrideFlag = true, want = false")
- }
-
- // Make sure flags got updated in the backing buffer.
- if got := b[ndpNAFlagsOffset]; got != 64 {
- t.Errorf("got flags byte = %d, want = 64", got)
- }
-}
-
-func TestNDPRouterAdvert(t *testing.T) {
- tests := []struct {
- hopLimit uint8
- managedFlag, otherConfFlag bool
- prf NDPRoutePreference
- routerLifetimeS uint16
- reachableTimeMS, retransTimerMS uint32
- }{
- {
- hopLimit: 1,
- managedFlag: false,
- otherConfFlag: true,
- prf: HighRoutePreference,
- routerLifetimeS: 2,
- reachableTimeMS: 3,
- retransTimerMS: 4,
- },
- {
- hopLimit: 64,
- managedFlag: true,
- otherConfFlag: false,
- prf: LowRoutePreference,
- routerLifetimeS: 258,
- reachableTimeMS: 78492,
- retransTimerMS: 13213,
- },
- }
-
- for i, test := range tests {
- t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
- flags := uint8(0)
- if test.managedFlag {
- flags |= 1 << 7
- }
- if test.otherConfFlag {
- flags |= 1 << 6
- }
- flags |= uint8(test.prf) << 3
-
- b := []byte{
- test.hopLimit, flags, 1, 2,
- 3, 4, 5, 6,
- 7, 8, 9, 10,
- }
- binary.BigEndian.PutUint16(b[2:], test.routerLifetimeS)
- binary.BigEndian.PutUint32(b[4:], test.reachableTimeMS)
- binary.BigEndian.PutUint32(b[8:], test.retransTimerMS)
-
- ra := NDPRouterAdvert(b)
-
- if got := ra.CurrHopLimit(); got != test.hopLimit {
- t.Errorf("got ra.CurrHopLimit() = %d, want = %d", got, test.hopLimit)
- }
-
- if got := ra.ManagedAddrConfFlag(); got != test.managedFlag {
- t.Errorf("got ManagedAddrConfFlag() = %t, want = %t", got, test.managedFlag)
- }
-
- if got := ra.OtherConfFlag(); got != test.otherConfFlag {
- t.Errorf("got OtherConfFlag() = %t, want = %t", got, test.otherConfFlag)
- }
-
- if got := ra.DefaultRouterPreference(); got != test.prf {
- t.Errorf("got DefaultRouterPreference() = %d, want = %d", got, test.prf)
- }
-
- if got, want := ra.RouterLifetime(), time.Second*time.Duration(test.routerLifetimeS); got != want {
- t.Errorf("got ra.RouterLifetime() = %d, want = %d", got, want)
- }
-
- if got, want := ra.ReachableTime(), time.Millisecond*time.Duration(test.reachableTimeMS); got != want {
- t.Errorf("got ra.ReachableTime() = %d, want = %d", got, want)
- }
-
- if got, want := ra.RetransTimer(), time.Millisecond*time.Duration(test.retransTimerMS); got != want {
- t.Errorf("got ra.RetransTimer() = %d, want = %d", got, want)
- }
- })
- }
-}
-
-// TestNDPSourceLinkLayerAddressOptionEthernetAddress tests getting the
-// Ethernet address from an NDPSourceLinkLayerAddressOption.
-func TestNDPSourceLinkLayerAddressOptionEthernetAddress(t *testing.T) {
- tests := []struct {
- name string
- buf []byte
- expected tcpip.LinkAddress
- }{
- {
- "ValidMAC",
- []byte{1, 2, 3, 4, 5, 6},
- tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
- },
- {
- "SLLBodyTooShort",
- []byte{1, 2, 3, 4, 5},
- tcpip.LinkAddress([]byte(nil)),
- },
- {
- "SLLBodyLargerThanNeeded",
- []byte{1, 2, 3, 4, 5, 6, 7, 8},
- tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- sll := NDPSourceLinkLayerAddressOption(test.buf)
- if got := sll.EthernetAddress(); got != test.expected {
- t.Errorf("got sll.EthernetAddress = %s, want = %s", got, test.expected)
- }
- })
- }
-}
-
-// TestNDPTargetLinkLayerAddressOptionEthernetAddress tests getting the
-// Ethernet address from an NDPTargetLinkLayerAddressOption.
-func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) {
- tests := []struct {
- name string
- buf []byte
- expected tcpip.LinkAddress
- }{
- {
- "ValidMAC",
- []byte{1, 2, 3, 4, 5, 6},
- tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
- },
- {
- "TLLBodyTooShort",
- []byte{1, 2, 3, 4, 5},
- tcpip.LinkAddress([]byte(nil)),
- },
- {
- "TLLBodyLargerThanNeeded",
- []byte{1, 2, 3, 4, 5, 6, 7, 8},
- tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- tll := NDPTargetLinkLayerAddressOption(test.buf)
- if got := tll.EthernetAddress(); got != test.expected {
- t.Errorf("got tll.EthernetAddress = %s, want = %s", got, test.expected)
- }
- })
- }
-}
-
-func TestOpts(t *testing.T) {
- const optionHeaderLen = 2
-
- checkNonce := func(expectedNonce []byte) func(*testing.T, NDPOption) {
- return func(t *testing.T, opt NDPOption) {
- if got := opt.kind(); got != ndpNonceOptionType {
- t.Errorf("got kind() = %d, want = %d", got, ndpNonceOptionType)
- }
- nonce, ok := opt.(NDPNonceOption)
- if !ok {
- t.Fatalf("got nonce = %T, want = NDPNonceOption", opt)
- }
- if diff := cmp.Diff(expectedNonce, nonce.Nonce()); diff != "" {
- t.Errorf("nonce mismatch (-want +got):\n%s", diff)
- }
- }
- }
-
- checkTLL := func(expectedAddr tcpip.LinkAddress) func(*testing.T, NDPOption) {
- return func(t *testing.T, opt NDPOption) {
- if got := opt.kind(); got != ndpTargetLinkLayerAddressOptionType {
- t.Errorf("got kind() = %d, want = %d", got, ndpTargetLinkLayerAddressOptionType)
- }
- tll, ok := opt.(NDPTargetLinkLayerAddressOption)
- if !ok {
- t.Fatalf("got tll = %T, want = NDPTargetLinkLayerAddressOption", opt)
- }
- if got, want := tll.EthernetAddress(), expectedAddr; got != want {
- t.Errorf("got tll.EthernetAddress = %s, want = %s", got, want)
- }
- }
- }
-
- checkSLL := func(expectedAddr tcpip.LinkAddress) func(*testing.T, NDPOption) {
- return func(t *testing.T, opt NDPOption) {
- if got := opt.kind(); got != ndpSourceLinkLayerAddressOptionType {
- t.Errorf("got kind() = %d, want = %d", got, ndpSourceLinkLayerAddressOptionType)
- }
- sll, ok := opt.(NDPSourceLinkLayerAddressOption)
- if !ok {
- t.Fatalf("got sll = %T, want = NDPSourceLinkLayerAddressOption", opt)
- }
- if got, want := sll.EthernetAddress(), expectedAddr; got != want {
- t.Errorf("got sll.EthernetAddress = %s, want = %s", got, want)
- }
- }
- }
-
- const validLifetimeSeconds = 16909060
- address := testutil.MustParse6("90a:b0c:d0e:f10:1112:1314:1516:1718")
-
- expectedRDNSSBytes := [...]byte{
- // Type, Length
- 25, 3,
-
- // Reserved
- 0, 0,
-
- // Lifetime
- 1, 2, 4, 8,
-
- // Address
- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
- }
- binary.BigEndian.PutUint32(expectedRDNSSBytes[4:], validLifetimeSeconds)
- if n := copy(expectedRDNSSBytes[8:], address); n != IPv6AddressSize {
- t.Fatalf("got copy(...) = %d, want = %d", n, IPv6AddressSize)
- }
- // Update reserved fields to non zero values to make sure serializing sets
- // them to zero.
- rdnssBytes := expectedRDNSSBytes
- rdnssBytes[1] = 1
- rdnssBytes[2] = 2
-
- const searchListPaddingBytes = 3
- const domainName = "abc.abcd.e"
- expectedSearchListBytes := [...]byte{
- // Type, Length
- 31, 3,
-
- // Reserved
- 0, 0,
-
- // Lifetime
- 1, 0, 0, 0,
-
- // Domain names
- 3, 'a', 'b', 'c',
- 4, 'a', 'b', 'c', 'd',
- 1, 'e',
- 0,
- 0, 0, 0, 0,
- }
- binary.BigEndian.PutUint32(expectedSearchListBytes[4:], validLifetimeSeconds)
- // Update reserved fields to non zero values to make sure serializing sets
- // them to zero.
- searchListBytes := expectedSearchListBytes
- searchListBytes[2] = 1
- searchListBytes[3] = 2
-
- const prefixLength = 43
- const onLinkFlag = false
- const slaacFlag = true
- const preferredLifetimeSeconds = 84281096
- const onLinkFlagBit = 7
- const slaacFlagBit = 6
- boolToByte := func(v bool) byte {
- if v {
- return 1
- }
- return 0
- }
- flags := boolToByte(onLinkFlag)<<onLinkFlagBit | boolToByte(slaacFlag)<<slaacFlagBit
- expectedPrefixInformationBytes := [...]byte{
- // Type, Length
- 3, 4,
-
- prefixLength, flags,
-
- // Valid Lifetime
- 1, 2, 3, 4,
-
- // Preferred Lifetime
- 5, 6, 7, 8,
-
- // Reserved2
- 0, 0, 0, 0,
-
- // Address
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 17, 18, 19, 20,
- 21, 22, 23, 24,
- }
- binary.BigEndian.PutUint32(expectedPrefixInformationBytes[4:], validLifetimeSeconds)
- binary.BigEndian.PutUint32(expectedPrefixInformationBytes[8:], preferredLifetimeSeconds)
- if n := copy(expectedPrefixInformationBytes[16:], address); n != IPv6AddressSize {
- t.Fatalf("got copy(...) = %d, want = %d", n, IPv6AddressSize)
- }
- // Update reserved fields to non zero values to make sure serializing sets
- // them to zero.
- prefixInformationBytes := expectedPrefixInformationBytes
- prefixInformationBytes[3] |= (1 << slaacFlagBit) - 1
- binary.BigEndian.PutUint32(prefixInformationBytes[12:], validLifetimeSeconds+1)
- tests := []struct {
- name string
- buf []byte
- opt NDPOption
- expectedBuf []byte
- check func(*testing.T, NDPOption)
- }{
- {
- name: "Nonce",
- buf: make([]byte, 8),
- opt: NDPNonceOption([]byte{1, 2, 3, 4, 5, 6}),
- expectedBuf: []byte{14, 1, 1, 2, 3, 4, 5, 6},
- check: checkNonce([]byte{1, 2, 3, 4, 5, 6}),
- },
- {
- name: "Nonce with padding",
- buf: []byte{1, 1, 1, 1, 1, 1, 1, 1},
- opt: NDPNonceOption([]byte{1, 2, 3, 4, 5}),
- expectedBuf: []byte{14, 1, 1, 2, 3, 4, 5, 0},
- check: checkNonce([]byte{1, 2, 3, 4, 5, 0}),
- },
-
- {
- name: "TLL Ethernet",
- buf: make([]byte, 8),
- opt: NDPTargetLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06"),
- expectedBuf: []byte{2, 1, 1, 2, 3, 4, 5, 6},
- check: checkTLL("\x01\x02\x03\x04\x05\x06"),
- },
- {
- name: "TLL Padding",
- buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
- opt: NDPTargetLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06\x07\x08"),
- expectedBuf: []byte{2, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0},
- check: checkTLL("\x01\x02\x03\x04\x05\x06"),
- },
- {
- name: "TLL Empty",
- buf: nil,
- opt: NDPTargetLinkLayerAddressOption(""),
- expectedBuf: nil,
- },
-
- {
- name: "SLL Ethernet",
- buf: make([]byte, 8),
- opt: NDPSourceLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06"),
- expectedBuf: []byte{1, 1, 1, 2, 3, 4, 5, 6},
- check: checkSLL("\x01\x02\x03\x04\x05\x06"),
- },
- {
- name: "SLL Padding",
- buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
- opt: NDPSourceLinkLayerAddressOption("\x01\x02\x03\x04\x05\x06\x07\x08"),
- expectedBuf: []byte{1, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0},
- check: checkSLL("\x01\x02\x03\x04\x05\x06"),
- },
- {
- name: "SLL Empty",
- buf: nil,
- opt: NDPSourceLinkLayerAddressOption(""),
- expectedBuf: nil,
- },
-
- {
- name: "RDNSS",
- buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
- // NDPRecursiveDNSServer holds the option after the header bytes.
- opt: NDPRecursiveDNSServer(rdnssBytes[optionHeaderLen:]),
- expectedBuf: expectedRDNSSBytes[:],
- check: func(t *testing.T, opt NDPOption) {
- if got := opt.kind(); got != ndpRecursiveDNSServerOptionType {
- t.Errorf("got kind() = %d, want = %d", got, ndpRecursiveDNSServerOptionType)
- }
- rdnss, ok := opt.(NDPRecursiveDNSServer)
- if !ok {
- t.Fatalf("got opt = %T, want = NDPRecursiveDNSServer", opt)
- }
- if got, want := rdnss.length(), len(expectedRDNSSBytes[optionHeaderLen:]); got != want {
- t.Errorf("got length() = %d, want = %d", got, want)
- }
- if got, want := rdnss.Lifetime(), validLifetimeSeconds*time.Second; got != want {
- t.Errorf("got Lifetime() = %s, want = %s", got, want)
- }
- if addrs, err := rdnss.Addresses(); err != nil {
- t.Errorf("Addresses(): %s", err)
- } else if diff := cmp.Diff([]tcpip.Address{address}, addrs); diff != "" {
- t.Errorf("mismatched addresses (-want +got):\n%s", diff)
- }
- },
- },
-
- {
- name: "Search list",
- buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
- opt: NDPDNSSearchList(searchListBytes[optionHeaderLen:]),
- expectedBuf: expectedSearchListBytes[:],
- check: func(t *testing.T, opt NDPOption) {
- if got := opt.kind(); got != ndpDNSSearchListOptionType {
- t.Errorf("got kind() = %d, want = %d", got, ndpDNSSearchListOptionType)
- }
-
- dnssl, ok := opt.(NDPDNSSearchList)
- if !ok {
- t.Fatalf("got opt = %T, want = NDPDNSSearchList", opt)
- }
- if got, want := dnssl.length(), len(expectedRDNSSBytes[optionHeaderLen:]); got != want {
- t.Errorf("got length() = %d, want = %d", got, want)
- }
- if got, want := dnssl.Lifetime(), validLifetimeSeconds*time.Second; got != want {
- t.Errorf("got Lifetime() = %s, want = %s", got, want)
- }
-
- if domainNames, err := dnssl.DomainNames(); err != nil {
- t.Errorf("DomainNames(): %s", err)
- } else if diff := cmp.Diff([]string{domainName}, domainNames); diff != "" {
- t.Errorf("domain names mismatch (-want +got):\n%s", diff)
- }
- },
- },
-
- {
- name: "Prefix Information",
- buf: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
- // NDPPrefixInformation holds the option after the header bytes.
- opt: NDPPrefixInformation(prefixInformationBytes[optionHeaderLen:]),
- expectedBuf: expectedPrefixInformationBytes[:],
- check: func(t *testing.T, opt NDPOption) {
- if got := opt.kind(); got != ndpPrefixInformationType {
- t.Errorf("got kind() = %d, want = %d", got, ndpPrefixInformationType)
- }
-
- pi, ok := opt.(NDPPrefixInformation)
- if !ok {
- t.Fatalf("got opt = %T, want = NDPPrefixInformation", opt)
- }
-
- if got, want := pi.length(), len(expectedPrefixInformationBytes[optionHeaderLen:]); got != want {
- t.Errorf("got length() = %d, want = %d", got, want)
- }
- if got := pi.PrefixLength(); got != prefixLength {
- t.Errorf("got PrefixLength() = %d, want = %d", got, prefixLength)
- }
- if got := pi.OnLinkFlag(); got != onLinkFlag {
- t.Errorf("got OnLinkFlag() = %t, want = %t", got, onLinkFlag)
- }
- if got := pi.AutonomousAddressConfigurationFlag(); got != slaacFlag {
- t.Errorf("got AutonomousAddressConfigurationFlag() = %t, want = %t", got, slaacFlag)
- }
- if got, want := pi.ValidLifetime(), validLifetimeSeconds*time.Second; got != want {
- t.Errorf("got ValidLifetime() = %s, want = %s", got, want)
- }
- if got, want := pi.PreferredLifetime(), preferredLifetimeSeconds*time.Second; got != want {
- t.Errorf("got PreferredLifetime() = %s, want = %s", got, want)
- }
- if got := pi.Prefix(); got != address {
- t.Errorf("got Prefix() = %s, want = %s", got, address)
- }
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- opts := NDPOptions(test.buf)
- serializer := NDPOptionsSerializer{
- test.opt,
- }
- if got, want := int(serializer.Length()), len(test.expectedBuf); got != want {
- t.Fatalf("got Length() = %d, want = %d", got, want)
- }
- opts.Serialize(serializer)
- if diff := cmp.Diff(test.expectedBuf, test.buf); diff != "" {
- t.Fatalf("serialized buffer mismatch (-want +got):\n%s", diff)
- }
-
- it, err := opts.Iter(true)
- if err != nil {
- t.Fatalf("got Iter(true) = (_, %s), want = (_, nil)", err)
- }
-
- if len(test.expectedBuf) > 0 {
- next, done, err := it.Next()
- if err != nil {
- t.Fatalf("got Next() = (_, _, %s), want = (_, _, nil)", err)
- }
- if done {
- t.Fatal("got Next() = (_, true, _), want = (_, false, _)")
- }
- test.check(t, next)
- }
-
- // Iterator should not return anything else.
- next, done, err := it.Next()
- if err != nil {
- t.Errorf("got Next() = (_, _, %s), want = (_, _, nil)", err)
- }
- if !done {
- t.Error("got Next() = (_, false, _), want = (_, true, _)")
- }
- if next != nil {
- t.Errorf("got Next() = (%x, _, _), want = (nil, _, _)", next)
- }
- })
- }
-}
-
-func TestNDPRecursiveDNSServerOption(t *testing.T) {
- tests := []struct {
- name string
- buf []byte
- lifetime time.Duration
- addrs []tcpip.Address
- }{
- {
- "Valid1Addr",
- []byte{
- 25, 3, 0, 0,
- 0, 0, 0, 0,
- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
- },
- 0,
- []tcpip.Address{
- "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
- },
- },
- {
- "Valid2Addr",
- []byte{
- 25, 5, 0, 0,
- 0, 0, 0, 0,
- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
- 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16,
- },
- 0,
- []tcpip.Address{
- "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
- "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10",
- },
- },
- {
- "Valid3Addr",
- []byte{
- 25, 7, 0, 0,
- 0, 0, 0, 0,
- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
- 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16,
- 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 17,
- },
- 0,
- []tcpip.Address{
- "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
- "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10",
- "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x11",
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- opts := NDPOptions(test.buf)
- it, err := opts.Iter(true)
- if err != nil {
- t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
- }
-
- // Iterator should get our option.
- next, done, err := it.Next()
- if err != nil {
- t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if done {
- t.Fatal("got Next = (_, true, _), want = (_, false, _)")
- }
- if got := next.kind(); got != ndpRecursiveDNSServerOptionType {
- t.Fatalf("got Type = %d, want = %d", got, ndpRecursiveDNSServerOptionType)
- }
-
- opt, ok := next.(NDPRecursiveDNSServer)
- if !ok {
- t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next)
- }
- if got := opt.Lifetime(); got != test.lifetime {
- t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime)
- }
- addrs, err := opt.Addresses()
- if err != nil {
- t.Errorf("opt.Addresses() = %s", err)
- }
- if diff := cmp.Diff(addrs, test.addrs); diff != "" {
- t.Errorf("mismatched addresses (-want +got):\n%s", diff)
- }
-
- // Iterator should not return anything else.
- next, done, err = it.Next()
- if err != nil {
- t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if !done {
- t.Error("got Next = (_, false, _), want = (_, true, _)")
- }
- if next != nil {
- t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
- }
- })
- }
-}
-
-// TestNDPDNSSearchListOption tests the getters of NDPDNSSearchList.
-func TestNDPDNSSearchListOption(t *testing.T) {
- tests := []struct {
- name string
- buf []byte
- lifetime time.Duration
- domainNames []string
- err error
- }{
- {
- name: "Valid1Label",
- buf: []byte{
- 0, 0,
- 0, 0, 0, 1,
- 3, 'a', 'b', 'c',
- 0,
- 0, 0, 0,
- },
- lifetime: time.Second,
- domainNames: []string{
- "abc",
- },
- err: nil,
- },
- {
- name: "Valid2Label",
- buf: []byte{
- 0, 0,
- 0, 0, 0, 5,
- 3, 'a', 'b', 'c',
- 4, 'a', 'b', 'c', 'd',
- 0,
- 0, 0, 0, 0, 0, 0,
- },
- lifetime: 5 * time.Second,
- domainNames: []string{
- "abc.abcd",
- },
- err: nil,
- },
- {
- name: "Valid3Label",
- buf: []byte{
- 0, 0,
- 1, 0, 0, 0,
- 3, 'a', 'b', 'c',
- 4, 'a', 'b', 'c', 'd',
- 1, 'e',
- 0,
- 0, 0, 0, 0,
- },
- lifetime: 16777216 * time.Second,
- domainNames: []string{
- "abc.abcd.e",
- },
- err: nil,
- },
- {
- name: "Valid2Domains",
- buf: []byte{
- 0, 0,
- 1, 2, 3, 4,
- 3, 'a', 'b', 'c',
- 0,
- 2, 'd', 'e',
- 3, 'x', 'y', 'z',
- 0,
- 0, 0, 0,
- },
- lifetime: 16909060 * time.Second,
- domainNames: []string{
- "abc",
- "de.xyz",
- },
- err: nil,
- },
- {
- name: "Valid3DomainsMixedCase",
- buf: []byte{
- 0, 0,
- 0, 0, 0, 0,
- 3, 'a', 'B', 'c',
- 0,
- 2, 'd', 'E',
- 3, 'X', 'y', 'z',
- 0,
- 1, 'J',
- 0,
- },
- lifetime: 0,
- domainNames: []string{
- "abc",
- "de.xyz",
- "j",
- },
- err: nil,
- },
- {
- name: "ValidDomainAfterNULL",
- buf: []byte{
- 0, 0,
- 0, 0, 0, 0,
- 3, 'a', 'B', 'c',
- 0, 0, 0, 0,
- 2, 'd', 'E',
- 3, 'X', 'y', 'z',
- 0,
- },
- lifetime: 0,
- domainNames: []string{
- "abc",
- "de.xyz",
- },
- err: nil,
- },
- {
- name: "Valid0Domains",
- buf: []byte{
- 0, 0,
- 0, 0, 0, 0,
- 0,
- 0, 0, 0, 0, 0, 0, 0,
- },
- lifetime: 0,
- domainNames: nil,
- err: nil,
- },
- {
- name: "NoTrailingNull",
- buf: []byte{
- 0, 0,
- 0, 0, 0, 0,
- 7, 'a', 'b', 'c', 'd', 'e', 'f', 'g',
- },
- lifetime: 0,
- domainNames: nil,
- err: io.ErrUnexpectedEOF,
- },
- {
- name: "IncorrectLength",
- buf: []byte{
- 0, 0,
- 0, 0, 0, 0,
- 8, 'a', 'b', 'c', 'd', 'e', 'f', 'g',
- },
- lifetime: 0,
- domainNames: nil,
- err: io.ErrUnexpectedEOF,
- },
- {
- name: "IncorrectLengthWithNULL",
- buf: []byte{
- 0, 0,
- 0, 0, 0, 0,
- 7, 'a', 'b', 'c', 'd', 'e', 'f',
- 0,
- },
- lifetime: 0,
- domainNames: nil,
- err: ErrNDPOptMalformedBody,
- },
- {
- name: "LabelOfLength63",
- buf: []byte{
- 0, 0,
- 0, 0, 0, 0,
- 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k',
- 0,
- },
- lifetime: 0,
- domainNames: []string{
- "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk",
- },
- err: nil,
- },
- {
- name: "LabelOfLength64",
- buf: []byte{
- 0, 0,
- 0, 0, 0, 0,
- 64, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l',
- 0,
- },
- lifetime: 0,
- domainNames: nil,
- err: ErrNDPOptMalformedBody,
- },
- {
- name: "DomainNameOfLength255",
- buf: []byte{
- 0, 0,
- 0, 0, 0, 0,
- 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k',
- 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k',
- 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k',
- 62, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j',
- 0,
- },
- lifetime: 0,
- domainNames: []string{
- "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghij",
- },
- err: nil,
- },
- {
- name: "DomainNameOfLength256",
- buf: []byte{
- 0, 0,
- 0, 0, 0, 0,
- 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k',
- 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k',
- 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k',
- 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k',
- 0,
- },
- lifetime: 0,
- domainNames: nil,
- err: ErrNDPOptMalformedBody,
- },
- {
- name: "StartingDigitForLabel",
- buf: []byte{
- 0, 0,
- 0, 0, 0, 1,
- 3, '9', 'b', 'c',
- 0,
- 0, 0, 0,
- },
- lifetime: time.Second,
- domainNames: nil,
- err: ErrNDPOptMalformedBody,
- },
- {
- name: "StartingHyphenForLabel",
- buf: []byte{
- 0, 0,
- 0, 0, 0, 1,
- 3, '-', 'b', 'c',
- 0,
- 0, 0, 0,
- },
- lifetime: time.Second,
- domainNames: nil,
- err: ErrNDPOptMalformedBody,
- },
- {
- name: "EndingHyphenForLabel",
- buf: []byte{
- 0, 0,
- 0, 0, 0, 1,
- 3, 'a', 'b', '-',
- 0,
- 0, 0, 0,
- },
- lifetime: time.Second,
- domainNames: nil,
- err: ErrNDPOptMalformedBody,
- },
- {
- name: "EndingDigitForLabel",
- buf: []byte{
- 0, 0,
- 0, 0, 0, 1,
- 3, 'a', 'b', '9',
- 0,
- 0, 0, 0,
- },
- lifetime: time.Second,
- domainNames: []string{
- "ab9",
- },
- err: nil,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- opt := NDPDNSSearchList(test.buf)
-
- if got := opt.Lifetime(); got != test.lifetime {
- t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime)
- }
- domainNames, err := opt.DomainNames()
- if !errors.Is(err, test.err) {
- t.Errorf("opt.DomainNames() = %s", err)
- }
- if diff := cmp.Diff(domainNames, test.domainNames); diff != "" {
- t.Errorf("mismatched domain names (-want +got):\n%s", diff)
- }
- })
- }
-}
-
-func TestNDPSearchListOptionDomainNameLabelInvalidSymbols(t *testing.T) {
- for r := rune(0); r <= 255; r++ {
- t.Run(fmt.Sprintf("RuneVal=%d", r), func(t *testing.T) {
- buf := []byte{
- 0, 0,
- 0, 0, 0, 0,
- 3, 'a', 0 /* will be replaced */, 'c',
- 0,
- 0, 0, 0,
- }
- buf[8] = uint8(r)
- opt := NDPDNSSearchList(buf)
-
- // As per RFC 1035 section 2.3.1, the label must only include ASCII
- // letters, digits and hyphens (a-z, A-Z, 0-9, -).
- var expectedErr error
- re := regexp.MustCompile(`[a-zA-Z0-9-]`)
- if !re.Match([]byte{byte(r)}) {
- expectedErr = ErrNDPOptMalformedBody
- }
-
- if domainNames, err := opt.DomainNames(); !errors.Is(err, expectedErr) {
- t.Errorf("got opt.DomainNames() = (%s, %v), want = (_, %v)", domainNames, err, ErrNDPOptMalformedBody)
- }
- })
- }
-}
-
-// TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions
-// the iterator was returned for is malformed.
-func TestNDPOptionsIterCheck(t *testing.T) {
- tests := []struct {
- name string
- buf []byte
- expectedErr error
- }{
- {
- name: "ZeroLengthField",
- buf: []byte{0, 0, 0, 0, 0, 0, 0, 0},
- expectedErr: ErrNDPOptMalformedHeader,
- },
- {
- name: "ValidSourceLinkLayerAddressOption",
- buf: []byte{1, 1, 1, 2, 3, 4, 5, 6},
- expectedErr: nil,
- },
- {
- name: "TooSmallSourceLinkLayerAddressOption",
- buf: []byte{1, 1, 1, 2, 3, 4, 5},
- expectedErr: io.ErrUnexpectedEOF,
- },
- {
- name: "ValidTargetLinkLayerAddressOption",
- buf: []byte{2, 1, 1, 2, 3, 4, 5, 6},
- expectedErr: nil,
- },
- {
- name: "TooSmallTargetLinkLayerAddressOption",
- buf: []byte{2, 1, 1, 2, 3, 4, 5},
- expectedErr: io.ErrUnexpectedEOF,
- },
- {
- name: "ValidPrefixInformation",
- buf: []byte{
- 3, 4, 43, 64,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 0, 0, 0, 0,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 17, 18, 19, 20,
- 21, 22, 23, 24,
- },
- expectedErr: nil,
- },
- {
- name: "TooSmallPrefixInformation",
- buf: []byte{
- 3, 4, 43, 64,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 0, 0, 0, 0,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 17, 18, 19, 20,
- 21, 22, 23,
- },
- expectedErr: io.ErrUnexpectedEOF,
- },
- {
- name: "InvalidPrefixInformationLength",
- buf: []byte{
- 3, 3, 43, 64,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 0, 0, 0, 0,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- },
- expectedErr: ErrNDPOptMalformedBody,
- },
- {
- name: "ValidSourceAndTargetLinkLayerAddressWithPrefixInformation",
- buf: []byte{
- // Source Link-Layer Address.
- 1, 1, 1, 2, 3, 4, 5, 6,
-
- // Target Link-Layer Address.
- 2, 1, 7, 8, 9, 10, 11, 12,
-
- // Prefix information.
- 3, 4, 43, 64,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 0, 0, 0, 0,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 17, 18, 19, 20,
- 21, 22, 23, 24,
- },
- expectedErr: nil,
- },
- {
- name: "ValidSourceAndTargetLinkLayerAddressWithPrefixInformationWithUnrecognized",
- buf: []byte{
- // Source Link-Layer Address.
- 1, 1, 1, 2, 3, 4, 5, 6,
-
- // Target Link-Layer Address.
- 2, 1, 7, 8, 9, 10, 11, 12,
-
- // 255 is an unrecognized type. If 255 ends up
- // being the type for some recognized type,
- // update 255 to some other unrecognized value.
- 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8,
-
- // Prefix information.
- 3, 4, 43, 64,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 0, 0, 0, 0,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 17, 18, 19, 20,
- 21, 22, 23, 24,
- },
- expectedErr: nil,
- },
- {
- name: "InvalidRecursiveDNSServerCutsOffAddress",
- buf: []byte{
- 25, 4, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
- 0, 1, 2, 3, 4, 5, 6, 7,
- },
- expectedErr: ErrNDPOptMalformedBody,
- },
- {
- name: "InvalidRecursiveDNSServerInvalidLengthField",
- buf: []byte{
- 25, 2, 0, 0,
- 0, 0, 0, 0,
- 0, 1, 2, 3, 4, 5, 6, 7, 8,
- },
- expectedErr: io.ErrUnexpectedEOF,
- },
- {
- name: "RecursiveDNSServerTooSmall",
- buf: []byte{
- 25, 1, 0, 0,
- 0, 0, 0,
- },
- expectedErr: io.ErrUnexpectedEOF,
- },
- {
- name: "RecursiveDNSServerMulticast",
- buf: []byte{
- 25, 3, 0, 0,
- 0, 0, 0, 0,
- 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
- },
- expectedErr: ErrNDPOptMalformedBody,
- },
- {
- name: "RecursiveDNSServerUnspecified",
- buf: []byte{
- 25, 3, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- },
- expectedErr: ErrNDPOptMalformedBody,
- },
- {
- name: "DNSSearchListLargeCompliantRFC1035",
- buf: []byte{
- 31, 33, 0, 0,
- 0, 0, 0, 0,
- 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k',
- 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k',
- 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k',
- 62, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j',
- 0,
- },
- expectedErr: nil,
- },
- {
- name: "DNSSearchListNonCompliantRFC1035",
- buf: []byte{
- 31, 33, 0, 0,
- 0, 0, 0, 0,
- 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k',
- 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k',
- 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k',
- 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
- 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
- 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
- 'i', 'j', 'k',
- 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- },
- expectedErr: ErrNDPOptMalformedBody,
- },
- {
- name: "DNSSearchListValidSmall",
- buf: []byte{
- 31, 2, 0, 0,
- 0, 0, 0, 0,
- 6, 'a', 'b', 'c', 'd', 'e', 'f',
- 0,
- },
- expectedErr: nil,
- },
- {
- name: "DNSSearchListTooSmall",
- buf: []byte{
- 31, 1, 0, 0,
- 0, 0, 0,
- },
- expectedErr: io.ErrUnexpectedEOF,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- opts := NDPOptions(test.buf)
-
- if _, err := opts.Iter(true); !errors.Is(err, test.expectedErr) {
- t.Fatalf("got Iter(true) = (_, %v), want = (_, %v)", err, test.expectedErr)
- }
-
- // test.buf may be malformed but we chose not to check
- // the iterator so it must return true.
- if _, err := opts.Iter(false); err != nil {
- t.Fatalf("got Iter(false) = (_, %s), want = (_, nil)", err)
- }
- })
- }
-}
-
-// TestNDPOptionsIter tests that we can iterator over a valid NDPOptions. Note,
-// this test does not actually check any of the option's getters, it simply
-// checks the option Type and Body. We have other tests that tests the option
-// field gettings given an option body and don't need to duplicate those tests
-// here.
-func TestNDPOptionsIter(t *testing.T) {
- buf := []byte{
- // Source Link-Layer Address.
- 1, 1, 1, 2, 3, 4, 5, 6,
-
- // Target Link-Layer Address.
- 2, 1, 7, 8, 9, 10, 11, 12,
-
- // 255 is an unrecognized type. If 255 ends up being the type
- // for some recognized type, update 255 to some other
- // unrecognized value. Note, this option should be skipped when
- // iterating.
- 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8,
-
- // Prefix information.
- 3, 4, 43, 64,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 0, 0, 0, 0,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 17, 18, 19, 20,
- 21, 22, 23, 24,
- }
-
- opts := NDPOptions(buf)
- it, err := opts.Iter(true)
- if err != nil {
- t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
- }
-
- // Test the first (Source Link-Layer) option.
- next, done, err := it.Next()
- if err != nil {
- t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if done {
- t.Fatal("got Next = (_, true, _), want = (_, false, _)")
- }
- if got, want := []byte(next.(NDPSourceLinkLayerAddressOption)), buf[2:][:6]; !bytes.Equal(got, want) {
- t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
- }
- if got := next.kind(); got != ndpSourceLinkLayerAddressOptionType {
- t.Errorf("got Type = %d, want = %d", got, ndpSourceLinkLayerAddressOptionType)
- }
-
- // Test the next (Target Link-Layer) option.
- next, done, err = it.Next()
- if err != nil {
- t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if done {
- t.Fatal("got Next = (_, true, _), want = (_, false, _)")
- }
- if got, want := []byte(next.(NDPTargetLinkLayerAddressOption)), buf[10:][:6]; !bytes.Equal(got, want) {
- t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
- }
- if got := next.kind(); got != ndpTargetLinkLayerAddressOptionType {
- t.Errorf("got Type = %d, want = %d", got, ndpTargetLinkLayerAddressOptionType)
- }
-
- // Test the next (Prefix Information) option.
- // Note, the unrecognized option should be skipped.
- next, done, err = it.Next()
- if err != nil {
- t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if done {
- t.Fatal("got Next = (_, true, _), want = (_, false, _)")
- }
- if got, want := next.(NDPPrefixInformation), buf[34:][:30]; !bytes.Equal(got, want) {
- t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
- }
- if got := next.kind(); got != ndpPrefixInformationType {
- t.Errorf("got Type = %d, want = %d", got, ndpPrefixInformationType)
- }
-
- // Iterator should not return anything else.
- next, done, err = it.Next()
- if err != nil {
- t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if !done {
- t.Error("got Next = (_, false, _), want = (_, true, _)")
- }
- if next != nil {
- t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
- }
-}
-
-func TestNDPRoutePreferenceStringer(t *testing.T) {
- p := NDPRoutePreference(0)
- for {
- var wantStr string
- switch p {
- case 0b01:
- wantStr = "HighRoutePreference"
- case 0b00:
- wantStr = "MediumRoutePreference"
- case 0b11:
- wantStr = "LowRoutePreference"
- case 0b10:
- wantStr = "ReservedRoutePreference"
- default:
- wantStr = fmt.Sprintf("NDPRoutePreference(%d)", p)
- }
-
- if gotStr := p.String(); gotStr != wantStr {
- t.Errorf("got NDPRoutePreference(%d).String() = %s, want = %s", p, gotStr, wantStr)
- }
-
- p++
- if p == 0 {
- // Overflowed, we hit all values.
- break
- }
- }
-}
diff --git a/pkg/tcpip/header/parse/BUILD b/pkg/tcpip/header/parse/BUILD
deleted file mode 100644
index 2adee9288..000000000
--- a/pkg/tcpip/header/parse/BUILD
+++ /dev/null
@@ -1,15 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "parse",
- srcs = ["parse.go"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/stack",
- ],
-)
diff --git a/pkg/tcpip/header/parse/parse_state_autogen.go b/pkg/tcpip/header/parse/parse_state_autogen.go
new file mode 100644
index 000000000..ad047be32
--- /dev/null
+++ b/pkg/tcpip/header/parse/parse_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package parse
diff --git a/pkg/tcpip/header/tcp_test.go b/pkg/tcpip/header/tcp_test.go
deleted file mode 100644
index 96db8460f..000000000
--- a/pkg/tcpip/header/tcp_test.go
+++ /dev/null
@@ -1,168 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package header_test
-
-import (
- "reflect"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip/header"
-)
-
-func TestEncodeSACKBlocks(t *testing.T) {
- testCases := []struct {
- sackBlocks []header.SACKBlock
- want []header.SACKBlock
- bufSize int
- }{
- {
- []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
- []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}},
- 40,
- },
- {
- []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
- []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}},
- 30,
- },
- {
- []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
- []header.SACKBlock{{10, 20}, {22, 30}},
- 20,
- },
- {
- []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
- []header.SACKBlock{{10, 20}},
- 10,
- },
- {
- []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
- nil,
- 8,
- },
- {
- []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
- []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}},
- 60,
- },
- }
- for _, tc := range testCases {
- b := make([]byte, tc.bufSize)
- t.Logf("testing: %v", tc)
- header.EncodeSACKBlocks(tc.sackBlocks, b)
- opts := header.ParseTCPOptions(b)
- if got, want := opts.SACKBlocks, tc.want; !reflect.DeepEqual(got, want) {
- t.Errorf("header.EncodeSACKBlocks(%v, %v), encoded blocks got: %v, want: %v", tc.sackBlocks, b, got, want)
- }
- }
-}
-
-func TestTCPParseOptions(t *testing.T) {
- type tsOption struct {
- tsVal uint32
- tsEcr uint32
- }
-
- generateOptions := func(tsOpt *tsOption, sackBlocks []header.SACKBlock) []byte {
- l := 0
- if tsOpt != nil {
- l += 10
- }
- if len(sackBlocks) != 0 {
- l += len(sackBlocks)*8 + 2
- }
- b := make([]byte, l)
- offset := 0
- if tsOpt != nil {
- offset = header.EncodeTSOption(tsOpt.tsVal, tsOpt.tsEcr, b)
- }
- header.EncodeSACKBlocks(sackBlocks, b[offset:])
- return b
- }
-
- testCases := []struct {
- b []byte
- want header.TCPOptions
- }{
- // Trivial cases.
- {nil, header.TCPOptions{false, 0, 0, nil}},
- {[]byte{header.TCPOptionNOP}, header.TCPOptions{false, 0, 0, nil}},
- {[]byte{header.TCPOptionNOP, header.TCPOptionNOP}, header.TCPOptions{false, 0, 0, nil}},
- {[]byte{header.TCPOptionEOL}, header.TCPOptions{false, 0, 0, nil}},
- {[]byte{header.TCPOptionNOP, header.TCPOptionEOL, header.TCPOptionTS, 10, 1, 1}, header.TCPOptions{false, 0, 0, nil}},
-
- // Test timestamp parsing.
- {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}},
- {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}},
-
- // Test malformed timestamp option.
- {[]byte{header.TCPOptionTS, 8, 1, 1}, header.TCPOptions{false, 0, 0, nil}},
- {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 8, 1, 1}, header.TCPOptions{false, 0, 0, nil}},
- {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 8, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}},
-
- // Test SACKBlock parsing.
- {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, 1, 0, 0, 0, 10}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{1, 10}}}},
- {[]byte{header.TCPOptionSACK, 18, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{1, 10}, {11, 12}}}},
-
- // Test malformed SACK option.
- {[]byte{header.TCPOptionSACK, 0}, header.TCPOptions{false, 0, 0, nil}},
- {[]byte{header.TCPOptionSACK, 8, 0, 0, 0, 1, 0, 0, 0, 10}, header.TCPOptions{false, 0, 0, nil}},
- {[]byte{header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, nil}},
- {[]byte{header.TCPOptionSACK, 17, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, nil}},
- {[]byte{header.TCPOptionSACK}, header.TCPOptions{false, 0, 0, nil}},
- {[]byte{header.TCPOptionSACK, 10}, header.TCPOptions{false, 0, 0, nil}},
- {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, 1, 0, 0, 0}, header.TCPOptions{false, 0, 0, nil}},
-
- // Test Timestamp + SACK block parsing.
- {generateOptions(&tsOption{1, 1}, []header.SACKBlock{{1, 10}, {11, 12}}), header.TCPOptions{true, 1, 1, []header.SACKBlock{{1, 10}, {11, 12}}}},
- {generateOptions(&tsOption{1, 2}, []header.SACKBlock{{1, 10}, {11, 12}}), header.TCPOptions{true, 1, 2, []header.SACKBlock{{1, 10}, {11, 12}}}},
- {generateOptions(&tsOption{1, 3}, []header.SACKBlock{{1, 10}, {11, 12}, {13, 14}, {14, 15}, {15, 16}}), header.TCPOptions{true, 1, 3, []header.SACKBlock{{1, 10}, {11, 12}, {13, 14}, {14, 15}}}},
-
- // Test valid timestamp + malformed SACK block parsing.
- {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK}, header.TCPOptions{true, 1, 1, nil}},
- {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 10}, header.TCPOptions{true, 1, 1, nil}},
- {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 10, 0, 0, 0}, header.TCPOptions{true, 1, 1, nil}},
- {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}},
- {[]byte{header.TCPOptionSACK, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}},
- {[]byte{header.TCPOptionSACK, 10, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{134873088, 65536}}}},
- {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{8, 167772160}}}},
- {[]byte{header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}},
- }
- for _, tc := range testCases {
- if got, want := header.ParseTCPOptions(tc.b), tc.want; !reflect.DeepEqual(got, want) {
- t.Errorf("ParseTCPOptions(%v) = %v, want: %v", tc.b, got, tc.want)
- }
- }
-}
-
-func TestTCPFlags(t *testing.T) {
- for _, tt := range []struct {
- flags header.TCPFlags
- want string
- }{
- {header.TCPFlagFin, "F "},
- {header.TCPFlagSyn, " S "},
- {header.TCPFlagRst, " R "},
- {header.TCPFlagPsh, " P "},
- {header.TCPFlagAck, " A "},
- {header.TCPFlagUrg, " U"},
- {header.TCPFlagSyn | header.TCPFlagAck, " S A "},
- {header.TCPFlagFin | header.TCPFlagAck, "F A "},
- } {
- if got := tt.flags.String(); got != tt.want {
- t.Errorf("got TCPFlags(%#b).String() = %s, want = %s", tt.flags, got, tt.want)
- }
- }
-}
diff --git a/pkg/tcpip/internal/tcp/BUILD b/pkg/tcpip/internal/tcp/BUILD
deleted file mode 100644
index 9ae258a0b..000000000
--- a/pkg/tcpip/internal/tcp/BUILD
+++ /dev/null
@@ -1,12 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "tcp",
- srcs = ["tcp.go"],
- visibility = ["//pkg/tcpip:__subpackages__"],
- deps = [
- "//pkg/tcpip",
- ],
-)
diff --git a/pkg/tcpip/internal/tcp/tcp_state_autogen.go b/pkg/tcpip/internal/tcp/tcp_state_autogen.go
new file mode 100644
index 000000000..e973a7bbd
--- /dev/null
+++ b/pkg/tcpip/internal/tcp/tcp_state_autogen.go
@@ -0,0 +1,36 @@
+// automatically generated by stateify.
+
+package tcp
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (offset *TSOffset) StateTypeName() string {
+ return "pkg/tcpip/internal/tcp.TSOffset"
+}
+
+func (offset *TSOffset) StateFields() []string {
+ return []string{
+ "milliseconds",
+ }
+}
+
+func (offset *TSOffset) beforeSave() {}
+
+// +checklocksignore
+func (offset *TSOffset) StateSave(stateSinkObject state.Sink) {
+ offset.beforeSave()
+ stateSinkObject.Save(0, &offset.milliseconds)
+}
+
+func (offset *TSOffset) afterLoad() {}
+
+// +checklocksignore
+func (offset *TSOffset) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &offset.milliseconds)
+}
+
+func init() {
+ state.Register((*TSOffset)(nil))
+}
diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD
deleted file mode 100644
index 973f06cbc..000000000
--- a/pkg/tcpip/link/channel/BUILD
+++ /dev/null
@@ -1,15 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "channel",
- srcs = ["channel.go"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/header",
- "//pkg/tcpip/stack",
- ],
-)
diff --git a/pkg/tcpip/link/channel/channel_state_autogen.go b/pkg/tcpip/link/channel/channel_state_autogen.go
new file mode 100644
index 000000000..7730b59b8
--- /dev/null
+++ b/pkg/tcpip/link/channel/channel_state_autogen.go
@@ -0,0 +1,36 @@
+// automatically generated by stateify.
+
+package channel
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (n *NotificationHandle) StateTypeName() string {
+ return "pkg/tcpip/link/channel.NotificationHandle"
+}
+
+func (n *NotificationHandle) StateFields() []string {
+ return []string{
+ "n",
+ }
+}
+
+func (n *NotificationHandle) beforeSave() {}
+
+// +checklocksignore
+func (n *NotificationHandle) StateSave(stateSinkObject state.Sink) {
+ n.beforeSave()
+ stateSinkObject.Save(0, &n.n)
+}
+
+func (n *NotificationHandle) afterLoad() {}
+
+// +checklocksignore
+func (n *NotificationHandle) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &n.n)
+}
+
+func init() {
+ state.Register((*NotificationHandle)(nil))
+}
diff --git a/pkg/tcpip/link/ethernet/BUILD b/pkg/tcpip/link/ethernet/BUILD
deleted file mode 100644
index 0ae0d201a..000000000
--- a/pkg/tcpip/link/ethernet/BUILD
+++ /dev/null
@@ -1,29 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "ethernet",
- srcs = ["ethernet.go"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/nested",
- "//pkg/tcpip/stack",
- ],
-)
-
-go_test(
- name = "ethernet_test",
- size = "small",
- srcs = ["ethernet_test.go"],
- deps = [
- ":ethernet",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/stack",
- ],
-)
diff --git a/pkg/tcpip/link/ethernet/ethernet_state_autogen.go b/pkg/tcpip/link/ethernet/ethernet_state_autogen.go
new file mode 100644
index 000000000..71d255c20
--- /dev/null
+++ b/pkg/tcpip/link/ethernet/ethernet_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package ethernet
diff --git a/pkg/tcpip/link/ethernet/ethernet_test.go b/pkg/tcpip/link/ethernet/ethernet_test.go
deleted file mode 100644
index 08a7f1ce1..000000000
--- a/pkg/tcpip/link/ethernet/ethernet_test.go
+++ /dev/null
@@ -1,71 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ethernet_test
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/ethernet"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-var _ stack.NetworkDispatcher = (*testNetworkDispatcher)(nil)
-
-type testNetworkDispatcher struct {
- networkPackets int
-}
-
-func (t *testNetworkDispatcher) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
- t.networkPackets++
-}
-
-func (*testNetworkDispatcher) DeliverOutboundPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
-}
-
-func TestDeliverNetworkPacket(t *testing.T) {
- const (
- linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- otherLinkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07")
- otherLinkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
- )
-
- e := ethernet.New(channel.New(0, 0, linkAddr))
- var networkDispatcher testNetworkDispatcher
- e.Attach(&networkDispatcher)
-
- if networkDispatcher.networkPackets != 0 {
- t.Fatalf("got networkDispatcher.networkPackets = %d, want = 0", networkDispatcher.networkPackets)
- }
-
- // An ethernet frame with a destination link address that is not assigned to
- // our ethernet link endpoint should still be delivered to the network
- // dispatcher since the ethernet endpoint is not expected to filter frames.
- eth := buffer.NewView(header.EthernetMinimumSize)
- header.Ethernet(eth).Encode(&header.EthernetFields{
- SrcAddr: otherLinkAddr1,
- DstAddr: otherLinkAddr2,
- Type: header.IPv4ProtocolNumber,
- })
- e.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: eth.ToVectorisedView(),
- }))
- if networkDispatcher.networkPackets != 1 {
- t.Fatalf("got networkDispatcher.networkPackets = %d, want = 1", networkDispatcher.networkPackets)
- }
-}
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
deleted file mode 100644
index 1d0163823..000000000
--- a/pkg/tcpip/link/fdbased/BUILD
+++ /dev/null
@@ -1,40 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "fdbased",
- srcs = [
- "endpoint.go",
- "endpoint_unsafe.go",
- "mmap.go",
- "mmap_stub.go",
- "mmap_unsafe.go",
- "packet_dispatchers.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/rawfile",
- "//pkg/tcpip/stack",
- "@org_golang_x_sys//unix:go_default_library",
- ],
-)
-
-go_test(
- name = "fdbased_test",
- size = "small",
- srcs = ["endpoint_test.go"],
- library = ":fdbased",
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/stack",
- "@com_github_google_go_cmp//cmp:go_default_library",
- "@org_golang_x_sys//unix:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
deleted file mode 100644
index eccd21579..000000000
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ /dev/null
@@ -1,624 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//go:build linux
-// +build linux
-
-package fdbased
-
-import (
- "bytes"
- "fmt"
- "math/rand"
- "reflect"
- "testing"
- "time"
- "unsafe"
-
- "github.com/google/go-cmp/cmp"
- "golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-const (
- mtu = 1500
- laddr = tcpip.LinkAddress("\x11\x22\x33\x44\x55\x66")
- raddr = tcpip.LinkAddress("\x77\x88\x99\xaa\xbb\xcc")
- proto = 10
- csumOffset = 48
- gsoMSS = 500
-)
-
-type packetInfo struct {
- Raddr tcpip.LinkAddress
- Proto tcpip.NetworkProtocolNumber
- Contents *stack.PacketBuffer
-}
-
-type packetContents struct {
- LinkHeader buffer.View
- NetworkHeader buffer.View
- TransportHeader buffer.View
- Data buffer.View
-}
-
-func checkPacketInfoEqual(t *testing.T, got, want packetInfo) {
- t.Helper()
- if diff := cmp.Diff(
- want, got,
- cmp.Transformer("ExtractPacketBuffer", func(pk *stack.PacketBuffer) *packetContents {
- if pk == nil {
- return nil
- }
- return &packetContents{
- LinkHeader: pk.LinkHeader().View(),
- NetworkHeader: pk.NetworkHeader().View(),
- TransportHeader: pk.TransportHeader().View(),
- Data: pk.Data().AsRange().ToOwnedView(),
- }
- }),
- ); diff != "" {
- t.Errorf("unexpected packetInfo (-want +got):\n%s", diff)
- }
-}
-
-type context struct {
- t *testing.T
- readFDs []int
- writeFDs []int
- ep stack.LinkEndpoint
- ch chan packetInfo
- done chan struct{}
-}
-
-func newContext(t *testing.T, opt *Options) *context {
- firstFDPair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_SEQPACKET, 0)
- if err != nil {
- t.Fatalf("Socketpair failed: %v", err)
- }
- secondFDPair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_SEQPACKET, 0)
- if err != nil {
- t.Fatalf("Socketpair failed: %v", err)
- }
-
- done := make(chan struct{}, 2)
- opt.ClosedFunc = func(tcpip.Error) {
- done <- struct{}{}
- }
-
- opt.FDs = []int{firstFDPair[1], secondFDPair[1]}
- ep, err := New(opt)
- if err != nil {
- t.Fatalf("Failed to create FD endpoint: %v", err)
- }
-
- c := &context{
- t: t,
- readFDs: []int{firstFDPair[0], secondFDPair[0]},
- writeFDs: opt.FDs,
- ep: ep,
- ch: make(chan packetInfo, 100),
- done: done,
- }
-
- ep.Attach(c)
-
- return c
-}
-
-func (c *context) cleanup() {
- for _, fd := range c.readFDs {
- unix.Close(fd)
- }
- <-c.done
- <-c.done
- for _, fd := range c.writeFDs {
- unix.Close(fd)
- }
-}
-
-func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- c.ch <- packetInfo{remote, protocol, pkt}
-}
-
-func (c *context) DeliverOutboundPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- panic("unimplemented")
-}
-
-func TestNoEthernetProperties(t *testing.T) {
- c := newContext(t, &Options{MTU: mtu})
- defer c.cleanup()
-
- if want, v := uint16(0), c.ep.MaxHeaderLength(); want != v {
- t.Fatalf("MaxHeaderLength() = %v, want %v", v, want)
- }
-
- if want, v := uint32(mtu), c.ep.MTU(); want != v {
- t.Fatalf("MTU() = %v, want %v", v, want)
- }
-}
-
-func TestEthernetProperties(t *testing.T) {
- c := newContext(t, &Options{EthernetHeader: true, MTU: mtu})
- defer c.cleanup()
-
- if want, v := uint16(header.EthernetMinimumSize), c.ep.MaxHeaderLength(); want != v {
- t.Fatalf("MaxHeaderLength() = %v, want %v", v, want)
- }
-
- if want, v := uint32(mtu), c.ep.MTU(); want != v {
- t.Fatalf("MTU() = %v, want %v", v, want)
- }
-}
-
-func TestAddress(t *testing.T) {
- addrs := []tcpip.LinkAddress{"", "abc", "def"}
- for _, a := range addrs {
- t.Run(fmt.Sprintf("Address: %q", a), func(t *testing.T) {
- c := newContext(t, &Options{Address: a, MTU: mtu})
- defer c.cleanup()
-
- if want, v := a, c.ep.LinkAddress(); want != v {
- t.Fatalf("LinkAddress() = %v, want %v", v, want)
- }
- })
- }
-}
-
-func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash uint32) {
- c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth, GSOMaxSize: gsoMaxSize})
- defer c.cleanup()
-
- var r stack.RouteInfo
- r.RemoteLinkAddress = raddr
-
- // Build payload.
- payload := buffer.NewView(plen)
- if _, err := rand.Read(payload); err != nil {
- t.Fatalf("rand.Read(payload): %s", err)
- }
-
- // Build packet buffer.
- const netHdrLen = 100
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(c.ep.MaxHeaderLength()) + netHdrLen,
- Data: payload.ToVectorisedView(),
- })
- pkt.Hash = hash
-
- // Build header.
- b := pkt.NetworkHeader().Push(netHdrLen)
- if _, err := rand.Read(b); err != nil {
- t.Fatalf("rand.Read(b): %s", err)
- }
-
- // Write.
- want := append(append(buffer.View(nil), b...), payload...)
- const l3HdrLen = header.IPv6MinimumSize
- if gsoMaxSize != 0 {
- pkt.GSOOptions = stack.GSO{
- Type: stack.GSOTCPv6,
- NeedsCsum: true,
- CsumOffset: csumOffset,
- MSS: gsoMSS,
- L3HdrLen: l3HdrLen,
- }
- }
- if err := c.ep.WritePacket(r, proto, pkt); err != nil {
- t.Fatalf("WritePacket failed: %v", err)
- }
-
- // Read from the corresponding FD, then compare with what we wrote.
- b = make([]byte, mtu)
- fd := c.readFDs[hash%uint32(len(c.readFDs))]
- n, err := unix.Read(fd, b)
- if err != nil {
- t.Fatalf("Read failed: %v", err)
- }
- b = b[:n]
- if gsoMaxSize != 0 {
- vnetHdr := *(*virtioNetHdr)(unsafe.Pointer(&b[0]))
- if vnetHdr.flags&_VIRTIO_NET_HDR_F_NEEDS_CSUM == 0 {
- t.Fatalf("virtioNetHdr.flags %v doesn't contain %v", vnetHdr.flags, _VIRTIO_NET_HDR_F_NEEDS_CSUM)
- }
- const csumStart = header.EthernetMinimumSize + l3HdrLen
- if vnetHdr.csumStart != csumStart {
- t.Fatalf("vnetHdr.csumStart = %v, want %v", vnetHdr.csumStart, csumStart)
- }
- if vnetHdr.csumOffset != csumOffset {
- t.Fatalf("vnetHdr.csumOffset = %v, want %v", vnetHdr.csumOffset, csumOffset)
- }
- gsoType := uint8(0)
- if plen > gsoMSS {
- gsoType = _VIRTIO_NET_HDR_GSO_TCPV6
- }
- if vnetHdr.gsoType != gsoType {
- t.Fatalf("vnetHdr.gsoType = %v, want %v", vnetHdr.gsoType, gsoType)
- }
- b = b[virtioNetHdrSize:]
- }
- if eth {
- h := header.Ethernet(b)
- b = b[header.EthernetMinimumSize:]
-
- if a := h.SourceAddress(); a != laddr {
- t.Fatalf("SourceAddress() = %v, want %v", a, laddr)
- }
-
- if a := h.DestinationAddress(); a != raddr {
- t.Fatalf("DestinationAddress() = %v, want %v", a, raddr)
- }
-
- if et := h.Type(); et != proto {
- t.Fatalf("Type() = %v, want %v", et, proto)
- }
- }
- if len(b) != len(want) {
- t.Fatalf("Read returned %v bytes, want %v", len(b), len(want))
- }
- if !bytes.Equal(b, want) {
- t.Fatalf("Read returned %x, want %x", b, want)
- }
-}
-
-func TestWritePacket(t *testing.T) {
- lengths := []int{0, 100, 1000}
- eths := []bool{true, false}
- gsos := []uint32{0, 32768}
-
- for _, eth := range eths {
- for _, plen := range lengths {
- for _, gso := range gsos {
- t.Run(
- fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v", eth, plen, gso),
- func(t *testing.T) {
- testWritePacket(t, plen, eth, gso, 0)
- },
- )
- }
- }
- }
-}
-
-func TestHashedWritePacket(t *testing.T) {
- lengths := []int{0, 100, 1000}
- eths := []bool{true, false}
- gsos := []uint32{0, 32768}
- hashes := []uint32{0, 1}
- for _, eth := range eths {
- for _, plen := range lengths {
- for _, gso := range gsos {
- for _, hash := range hashes {
- t.Run(
- fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v,Hash=%d", eth, plen, gso, hash),
- func(t *testing.T) {
- testWritePacket(t, plen, eth, gso, hash)
- },
- )
- }
- }
- }
- }
-}
-
-func TestPreserveSrcAddress(t *testing.T) {
- baddr := tcpip.LinkAddress("\xcc\xbb\xaa\x77\x88\x99")
-
- c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: true})
- defer c.cleanup()
-
- // Set LocalLinkAddress in route to the value of the bridged address.
- var r stack.RouteInfo
- r.LocalLinkAddress = baddr
- r.RemoteLinkAddress = raddr
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- // WritePacket panics given a prependable with anything less than
- // the minimum size of the ethernet header.
- // TODO(b/153685824): Figure out if this should use c.ep.MaxHeaderLength().
- ReserveHeaderBytes: header.EthernetMinimumSize,
- Data: buffer.VectorisedView{},
- })
- if err := c.ep.WritePacket(r, proto, pkt); err != nil {
- t.Fatalf("WritePacket failed: %v", err)
- }
-
- // Read from the FD, then compare with what we wrote.
- b := make([]byte, mtu)
- n, err := unix.Read(c.readFDs[0], b)
- if err != nil {
- t.Fatalf("Read failed: %v", err)
- }
- b = b[:n]
- h := header.Ethernet(b)
-
- if a := h.SourceAddress(); a != baddr {
- t.Fatalf("SourceAddress() = %v, want %v", a, baddr)
- }
-}
-
-func TestDeliverPacket(t *testing.T) {
- lengths := []int{100, 1000}
- eths := []bool{true, false}
-
- for _, eth := range eths {
- for _, plen := range lengths {
- t.Run(fmt.Sprintf("Eth=%v,PayloadLen=%v", eth, plen), func(t *testing.T) {
- c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth})
- defer c.cleanup()
-
- // Build packet.
- all := make([]byte, plen)
- if _, err := rand.Read(all); err != nil {
- t.Fatalf("rand.Read(all): %s", err)
- }
- // Make it look like an IPv4 packet.
- all[0] = 0x40
-
- wantPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.EthernetMinimumSize,
- Data: buffer.NewViewFromBytes(all).ToVectorisedView(),
- })
- if eth {
- hdr := header.Ethernet(wantPkt.LinkHeader().Push(header.EthernetMinimumSize))
- hdr.Encode(&header.EthernetFields{
- SrcAddr: raddr,
- DstAddr: laddr,
- Type: proto,
- })
- all = append(hdr, all...)
- }
-
- // Write packet via the file descriptor.
- if _, err := unix.Write(c.readFDs[0], all); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- // Receive packet through the endpoint.
- select {
- case pi := <-c.ch:
- want := packetInfo{
- Raddr: raddr,
- Proto: proto,
- Contents: wantPkt,
- }
- if !eth {
- want.Proto = header.IPv4ProtocolNumber
- want.Raddr = ""
- }
- checkPacketInfoEqual(t, pi, want)
- case <-time.After(10 * time.Second):
- t.Fatalf("Timed out waiting for packet")
- }
- })
- }
- }
-}
-
-func TestBufConfigMaxLength(t *testing.T) {
- got := 0
- for _, i := range BufConfig {
- got += i
- }
- want := header.MaxIPPacketSize // maximum TCP packet size
- if got < want {
- t.Errorf("total buffer size is invalid: got %d, want >= %d", got, want)
- }
-}
-
-func TestBufConfigFirst(t *testing.T) {
- // The stack assumes that the TCP/IP header is enterily contained in the first view.
- // Therefore, the first view needs to be large enough to contain the maximum TCP/IP
- // header, which is 120 bytes (60 bytes for IP + 60 bytes for TCP).
- want := 120
- got := BufConfig[0]
- if got < want {
- t.Errorf("first view has an invalid size: got %d, want >= %d", got, want)
- }
-}
-
-var capLengthTestCases = []struct {
- comment string
- config []int
- n int
- wantUsed int
- wantLengths []int
-}{
- {
- comment: "Single slice",
- config: []int{2},
- n: 1,
- wantUsed: 1,
- wantLengths: []int{1},
- },
- {
- comment: "Multiple slices",
- config: []int{1, 2},
- n: 2,
- wantUsed: 2,
- wantLengths: []int{1, 1},
- },
- {
- comment: "Entire buffer",
- config: []int{1, 2},
- n: 3,
- wantUsed: 2,
- wantLengths: []int{1, 2},
- },
- {
- comment: "Entire buffer but not on the last slice",
- config: []int{1, 2, 3},
- n: 3,
- wantUsed: 2,
- wantLengths: []int{1, 2},
- },
-}
-
-func TestIovecBuffer(t *testing.T) {
- for _, c := range capLengthTestCases {
- t.Run(c.comment, func(t *testing.T) {
- b := newIovecBuffer(c.config, false /* skipsVnetHdr */)
-
- // Test initial allocation.
- iovecs := b.nextIovecs()
- if got, want := len(iovecs), len(c.config); got != want {
- t.Fatalf("len(iovecs) = %d, want %d", got, want)
- }
-
- // Make a copy as iovecs points to internal slice. We will need this state
- // later.
- oldIovecs := append([]unix.Iovec(nil), iovecs...)
-
- // Test the views that get pulled.
- vv := b.pullViews(c.n)
- var lengths []int
- for _, v := range vv.Views() {
- lengths = append(lengths, len(v))
- }
- if !reflect.DeepEqual(lengths, c.wantLengths) {
- t.Errorf("Pulled view lengths = %v, want %v", lengths, c.wantLengths)
- }
-
- // Test that new views get reallocated.
- for i, newIov := range b.nextIovecs() {
- if i < c.wantUsed {
- if newIov.Base == oldIovecs[i].Base {
- t.Errorf("b.views[%d] should have been reallocated", i)
- }
- } else {
- if newIov.Base != oldIovecs[i].Base {
- t.Errorf("b.views[%d] should not have been reallocated", i)
- }
- }
- }
- })
- }
-}
-
-func TestIovecBufferSkipVnetHdr(t *testing.T) {
- for _, test := range []struct {
- desc string
- readN int
- wantLen int
- }{
- {
- desc: "nothing read",
- readN: 0,
- wantLen: 0,
- },
- {
- desc: "smaller than vnet header",
- readN: virtioNetHdrSize - 1,
- wantLen: 0,
- },
- {
- desc: "header skipped",
- readN: virtioNetHdrSize + 100,
- wantLen: 100,
- },
- } {
- t.Run(test.desc, func(t *testing.T) {
- b := newIovecBuffer([]int{10, 20, 50, 50}, true)
- // Pretend a read happend.
- b.nextIovecs()
- vv := b.pullViews(test.readN)
- if got, want := vv.Size(), test.wantLen; got != want {
- t.Errorf("b.pullView(%d).Size() = %d; want %d", test.readN, got, want)
- }
- if got, want := len(vv.ToOwnedView()), test.wantLen; got != want {
- t.Errorf("b.pullView(%d).ToOwnedView() has length %d; want %d", test.readN, got, want)
- }
- })
- }
-}
-
-// fakeNetworkDispatcher delivers packets to pkts.
-type fakeNetworkDispatcher struct {
- pkts []*stack.PacketBuffer
-}
-
-func (d *fakeNetworkDispatcher) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- d.pkts = append(d.pkts, pkt)
-}
-
-func (d *fakeNetworkDispatcher) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- panic("unimplemented")
-}
-
-func TestDispatchPacketFormat(t *testing.T) {
- for _, test := range []struct {
- name string
- newDispatcher func(fd int, e *endpoint) (linkDispatcher, error)
- }{
- {
- name: "readVDispatcher",
- newDispatcher: newReadVDispatcher,
- },
- {
- name: "recvMMsgDispatcher",
- newDispatcher: newRecvMMsgDispatcher,
- },
- } {
- t.Run(test.name, func(t *testing.T) {
- // Create a socket pair to send/recv.
- fds, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_DGRAM, 0)
- if err != nil {
- t.Fatal(err)
- }
- defer unix.Close(fds[0])
- defer unix.Close(fds[1])
-
- data := []byte{
- // Ethernet header.
- 1, 2, 3, 4, 5, 60,
- 1, 2, 3, 4, 5, 61,
- 8, 0,
- // Mock network header.
- 40, 41, 42, 43,
- }
- err = unix.Sendmsg(fds[1], data, nil, nil, 0)
- if err != nil {
- t.Fatal(err)
- }
-
- // Create and run dispatcher once.
- sink := &fakeNetworkDispatcher{}
- d, err := test.newDispatcher(fds[0], &endpoint{
- hdrSize: header.EthernetMinimumSize,
- dispatcher: sink,
- })
- if err != nil {
- t.Fatal(err)
- }
- if ok, err := d.dispatch(); !ok || err != nil {
- t.Fatalf("d.dispatch() = %v, %v", ok, err)
- }
-
- // Verify packet.
- if got, want := len(sink.pkts), 1; got != want {
- t.Fatalf("len(sink.pkts) = %d, want %d", got, want)
- }
- pkt := sink.pkts[0]
- if got, want := pkt.LinkHeader().View().Size(), header.EthernetMinimumSize; got != want {
- t.Errorf("pkt.LinkHeader().View().Size() = %d, want %d", got, want)
- }
- if got, want := pkt.Data().Size(), 4; got != want {
- t.Errorf("pkt.Data().Size() = %d, want %d", got, want)
- }
- })
- }
-}
diff --git a/pkg/tcpip/link/fdbased/fdbased_state_autogen.go b/pkg/tcpip/link/fdbased/fdbased_state_autogen.go
new file mode 100644
index 000000000..586f166a4
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/fdbased_state_autogen.go
@@ -0,0 +1,9 @@
+// automatically generated by stateify.
+
+//go:build linux && ((linux && amd64) || (linux && arm64)) && (!linux || (!amd64 && !arm64)) && linux
+// +build linux
+// +build linux,amd64 linux,arm64
+// +build !linux !amd64,!arm64
+// +build linux
+
+package fdbased
diff --git a/pkg/tcpip/link/fdbased/fdbased_unsafe_state_autogen.go b/pkg/tcpip/link/fdbased/fdbased_unsafe_state_autogen.go
new file mode 100644
index 000000000..6a5ed4a3c
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/fdbased_unsafe_state_autogen.go
@@ -0,0 +1,7 @@
+// automatically generated by stateify.
+
+//go:build linux && ((linux && amd64) || (linux && arm64))
+// +build linux
+// +build linux,amd64 linux,arm64
+
+package fdbased
diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD
deleted file mode 100644
index 6bf3805b7..000000000
--- a/pkg/tcpip/link/loopback/BUILD
+++ /dev/null
@@ -1,15 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "loopback",
- srcs = ["loopback.go"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/stack",
- ],
-)
diff --git a/pkg/tcpip/link/loopback/loopback_state_autogen.go b/pkg/tcpip/link/loopback/loopback_state_autogen.go
new file mode 100644
index 000000000..c00fd9f19
--- /dev/null
+++ b/pkg/tcpip/link/loopback/loopback_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package loopback
diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD
deleted file mode 100644
index 193524525..000000000
--- a/pkg/tcpip/link/muxed/BUILD
+++ /dev/null
@@ -1,29 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "muxed",
- srcs = ["injectable.go"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/header",
- "//pkg/tcpip/stack",
- ],
-)
-
-go_test(
- name = "muxed_test",
- size = "small",
- srcs = ["injectable_test.go"],
- library = ":muxed",
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/link/fdbased",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/stack",
- "@org_golang_x_sys//unix:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go
deleted file mode 100644
index 040e3a35b..000000000
--- a/pkg/tcpip/link/muxed/injectable_test.go
+++ /dev/null
@@ -1,101 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package muxed
-
-import (
- "bytes"
- "net"
- "os"
- "testing"
-
- "golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-func TestInjectableEndpointRawDispatch(t *testing.T) {
- endpoint, sock, dstIP := makeTestInjectableEndpoint(t)
-
- endpoint.InjectOutbound(dstIP, []byte{0xFA})
-
- buf := make([]byte, ipv4.MaxTotalSize)
- bytesRead, err := sock.Read(buf)
- if err != nil {
- t.Fatalf("Unable to read from socketpair: %v", err)
- }
- if got, want := buf[:bytesRead], []byte{0xFA}; !bytes.Equal(got, want) {
- t.Fatalf("Read %v from the socketpair, wanted %v", got, want)
- }
-}
-
-func TestInjectableEndpointDispatch(t *testing.T) {
- endpoint, sock, dstIP := makeTestInjectableEndpoint(t)
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: 1,
- Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(),
- })
- pkt.TransportHeader().Push(1)[0] = 0xFA
- var packetRoute stack.RouteInfo
- packetRoute.RemoteAddress = dstIP
-
- endpoint.WritePacket(packetRoute, ipv4.ProtocolNumber, pkt)
-
- buf := make([]byte, 6500)
- bytesRead, err := sock.Read(buf)
- if err != nil {
- t.Fatalf("Unable to read from socketpair: %v", err)
- }
- if got, want := buf[:bytesRead], []byte{0xFA, 0xFB}; !bytes.Equal(got, want) {
- t.Fatalf("Read %v from the socketpair, wanted %v", got, want)
- }
-}
-
-func TestInjectableEndpointDispatchHdrOnly(t *testing.T) {
- endpoint, sock, dstIP := makeTestInjectableEndpoint(t)
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: 1,
- Data: buffer.NewView(0).ToVectorisedView(),
- })
- pkt.TransportHeader().Push(1)[0] = 0xFA
- var packetRoute stack.RouteInfo
- packetRoute.RemoteAddress = dstIP
- endpoint.WritePacket(packetRoute, ipv4.ProtocolNumber, pkt)
- buf := make([]byte, 6500)
- bytesRead, err := sock.Read(buf)
- if err != nil {
- t.Fatalf("Unable to read from socketpair: %v", err)
- }
- if got, want := buf[:bytesRead], []byte{0xFA}; !bytes.Equal(got, want) {
- t.Fatalf("Read %v from the socketpair, wanted %v", got, want)
- }
-}
-
-func makeTestInjectableEndpoint(t *testing.T) (*InjectableEndpoint, *os.File, tcpip.Address) {
- dstIP := tcpip.Address(net.ParseIP("1.2.3.4").To4())
- pair, err := unix.Socketpair(unix.AF_UNIX,
- unix.SOCK_SEQPACKET|unix.SOCK_CLOEXEC|unix.SOCK_NONBLOCK, 0)
- if err != nil {
- t.Fatal("Failed to create socket pair:", err)
- }
- underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone)
- routes := map[tcpip.Address]stack.InjectableLinkEndpoint{dstIP: underlyingEndpoint}
- endpoint := NewInjectableEndpoint(routes)
- return endpoint, os.NewFile(uintptr(pair[0]), "test route end"), dstIP
-}
diff --git a/pkg/tcpip/link/muxed/muxed_state_autogen.go b/pkg/tcpip/link/muxed/muxed_state_autogen.go
new file mode 100644
index 000000000..56330e2a5
--- /dev/null
+++ b/pkg/tcpip/link/muxed/muxed_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package muxed
diff --git a/pkg/tcpip/link/nested/BUILD b/pkg/tcpip/link/nested/BUILD
deleted file mode 100644
index 00b42b924..000000000
--- a/pkg/tcpip/link/nested/BUILD
+++ /dev/null
@@ -1,31 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "nested",
- srcs = [
- "nested.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/header",
- "//pkg/tcpip/stack",
- ],
-)
-
-go_test(
- name = "nested_test",
- size = "small",
- srcs = [
- "nested_test.go",
- ],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/nested",
- "//pkg/tcpip/stack",
- ],
-)
diff --git a/pkg/tcpip/link/nested/nested_state_autogen.go b/pkg/tcpip/link/nested/nested_state_autogen.go
new file mode 100644
index 000000000..9e1b5ca4e
--- /dev/null
+++ b/pkg/tcpip/link/nested/nested_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package nested
diff --git a/pkg/tcpip/link/nested/nested_test.go b/pkg/tcpip/link/nested/nested_test.go
deleted file mode 100644
index c1f9d308c..000000000
--- a/pkg/tcpip/link/nested/nested_test.go
+++ /dev/null
@@ -1,109 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package nested_test
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/nested"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-type parentEndpoint struct {
- nested.Endpoint
-}
-
-var _ stack.LinkEndpoint = (*parentEndpoint)(nil)
-var _ stack.NetworkDispatcher = (*parentEndpoint)(nil)
-
-type childEndpoint struct {
- stack.LinkEndpoint
- dispatcher stack.NetworkDispatcher
-}
-
-var _ stack.LinkEndpoint = (*childEndpoint)(nil)
-
-func (c *childEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
- c.dispatcher = dispatcher
-}
-
-func (c *childEndpoint) IsAttached() bool {
- return c.dispatcher != nil
-}
-
-type counterDispatcher struct {
- count int
-}
-
-var _ stack.NetworkDispatcher = (*counterDispatcher)(nil)
-
-func (d *counterDispatcher) DeliverNetworkPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
- d.count++
-}
-
-func (d *counterDispatcher) DeliverOutboundPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
- panic("unimplemented")
-}
-
-func TestNestedLinkEndpoint(t *testing.T) {
- const emptyAddress = tcpip.LinkAddress("")
-
- var (
- childEP childEndpoint
- nestedEP parentEndpoint
- disp counterDispatcher
- )
- nestedEP.Endpoint.Init(&childEP, &nestedEP)
-
- if childEP.IsAttached() {
- t.Error("On init, childEP.IsAttached() = true, want = false")
- }
- if nestedEP.IsAttached() {
- t.Error("On init, nestedEP.IsAttached() = true, want = false")
- }
-
- nestedEP.Attach(&disp)
- if disp.count != 0 {
- t.Fatalf("After attach, got disp.count = %d, want = 0", disp.count)
- }
- if !childEP.IsAttached() {
- t.Error("After attach, childEP.IsAttached() = false, want = true")
- }
- if !nestedEP.IsAttached() {
- t.Error("After attach, nestedEP.IsAttached() = false, want = true")
- }
-
- nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
- if disp.count != 1 {
- t.Errorf("After first packet with dispatcher attached, got disp.count = %d, want = 1", disp.count)
- }
-
- nestedEP.Attach(nil)
- if childEP.IsAttached() {
- t.Error("After detach, childEP.IsAttached() = true, want = false")
- }
- if nestedEP.IsAttached() {
- t.Error("After detach, nestedEP.IsAttached() = true, want = false")
- }
-
- disp.count = 0
- nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
- if disp.count != 0 {
- t.Errorf("After second packet with dispatcher detached, got disp.count = %d, want = 0", disp.count)
- }
-
-}
diff --git a/pkg/tcpip/link/pipe/BUILD b/pkg/tcpip/link/pipe/BUILD
deleted file mode 100644
index 9f31c1ffc..000000000
--- a/pkg/tcpip/link/pipe/BUILD
+++ /dev/null
@@ -1,15 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "pipe",
- srcs = ["pipe.go"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/stack",
- ],
-)
diff --git a/pkg/tcpip/link/pipe/pipe_state_autogen.go b/pkg/tcpip/link/pipe/pipe_state_autogen.go
new file mode 100644
index 000000000..d3b40feb4
--- /dev/null
+++ b/pkg/tcpip/link/pipe/pipe_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package pipe
diff --git a/pkg/tcpip/link/qdisc/fifo/BUILD b/pkg/tcpip/link/qdisc/fifo/BUILD
deleted file mode 100644
index 5bea598eb..000000000
--- a/pkg/tcpip/link/qdisc/fifo/BUILD
+++ /dev/null
@@ -1,19 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "fifo",
- srcs = [
- "endpoint.go",
- "packet_buffer_queue.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/sleep",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/header",
- "//pkg/tcpip/stack",
- ],
-)
diff --git a/pkg/tcpip/link/qdisc/fifo/fifo_state_autogen.go b/pkg/tcpip/link/qdisc/fifo/fifo_state_autogen.go
new file mode 100644
index 000000000..9eb52b1cb
--- /dev/null
+++ b/pkg/tcpip/link/qdisc/fifo/fifo_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package fifo
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
deleted file mode 100644
index 4efd7c45e..000000000
--- a/pkg/tcpip/link/rawfile/BUILD
+++ /dev/null
@@ -1,33 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "rawfile",
- srcs = [
- "blockingpoll_amd64.s",
- "blockingpoll_arm64.s",
- "blockingpoll_noyield_unsafe.go",
- "blockingpoll_yield_unsafe.go",
- "errors.go",
- "rawfile_unsafe.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip",
- "@org_golang_x_sys//unix:go_default_library",
- ],
-)
-
-go_test(
- name = "rawfile_test",
- srcs = [
- "errors_test.go",
- ],
- library = "rawfile",
- deps = [
- "//pkg/tcpip",
- "@com_github_google_go_cmp//cmp:go_default_library",
- "@org_golang_x_sys//unix:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/link/rawfile/errors_test.go b/pkg/tcpip/link/rawfile/errors_test.go
deleted file mode 100644
index 1b88c309b..000000000
--- a/pkg/tcpip/link/rawfile/errors_test.go
+++ /dev/null
@@ -1,55 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//go:build linux
-// +build linux
-
-package rawfile
-
-import (
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/tcpip"
-)
-
-func TestTranslateErrno(t *testing.T) {
- for _, test := range []struct {
- errno unix.Errno
- translated tcpip.Error
- }{
- {
- errno: unix.Errno(0),
- translated: &tcpip.ErrInvalidEndpointState{},
- },
- {
- errno: unix.Errno(maxErrno),
- translated: &tcpip.ErrInvalidEndpointState{},
- },
- {
- errno: unix.Errno(514),
- translated: &tcpip.ErrInvalidEndpointState{},
- },
- {
- errno: unix.EEXIST,
- translated: &tcpip.ErrDuplicateAddress{},
- },
- } {
- got := TranslateErrno(test.errno)
- if diff := cmp.Diff(test.translated, got); diff != "" {
- t.Errorf("unexpected result from TranslateErrno(%q), (-want, +got):\n%s", test.errno, diff)
- }
- }
-}
diff --git a/pkg/tcpip/link/rawfile/rawfile_state_autogen.go b/pkg/tcpip/link/rawfile/rawfile_state_autogen.go
new file mode 100644
index 000000000..00708246f
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/rawfile_state_autogen.go
@@ -0,0 +1,6 @@
+// automatically generated by stateify.
+
+//go:build linux
+// +build linux
+
+package rawfile
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe_state_autogen.go b/pkg/tcpip/link/rawfile/rawfile_unsafe_state_autogen.go
new file mode 100644
index 000000000..c42f3a3b6
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe_state_autogen.go
@@ -0,0 +1,11 @@
+// automatically generated by stateify.
+
+//go:build linux && !amd64 && !arm64 && ((linux && amd64) || (linux && arm64)) && go1.12 && linux
+// +build linux
+// +build !amd64
+// +build !arm64
+// +build linux,amd64 linux,arm64
+// +build go1.12
+// +build linux
+
+package rawfile
diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD
deleted file mode 100644
index 4215ee852..000000000
--- a/pkg/tcpip/link/sharedmem/BUILD
+++ /dev/null
@@ -1,43 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "sharedmem",
- srcs = [
- "rx.go",
- "sharedmem.go",
- "sharedmem_unsafe.go",
- "tx.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/log",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/rawfile",
- "//pkg/tcpip/link/sharedmem/queue",
- "//pkg/tcpip/stack",
- "@org_golang_x_sys//unix:go_default_library",
- ],
-)
-
-go_test(
- name = "sharedmem_test",
- srcs = [
- "sharedmem_test.go",
- ],
- library = ":sharedmem",
- deps = [
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/sharedmem/pipe",
- "//pkg/tcpip/link/sharedmem/queue",
- "//pkg/tcpip/stack",
- "@org_golang_x_sys//unix:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD
deleted file mode 100644
index 87020ec08..000000000
--- a/pkg/tcpip/link/sharedmem/pipe/BUILD
+++ /dev/null
@@ -1,23 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "pipe",
- srcs = [
- "pipe.go",
- "pipe_unsafe.go",
- "rx.go",
- "tx.go",
- ],
- visibility = ["//visibility:public"],
-)
-
-go_test(
- name = "pipe_test",
- srcs = [
- "pipe_test.go",
- ],
- library = ":pipe",
- deps = ["//pkg/sync"],
-)
diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_state_autogen.go b/pkg/tcpip/link/sharedmem/pipe/pipe_state_autogen.go
new file mode 100644
index 000000000..d3b40feb4
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/pipe/pipe_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package pipe
diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
deleted file mode 100644
index 2777f1411..000000000
--- a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
+++ /dev/null
@@ -1,512 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package pipe
-
-import (
- "math/rand"
- "reflect"
- "runtime"
- "testing"
-
- "gvisor.dev/gvisor/pkg/sync"
-)
-
-func TestSimpleReadWrite(t *testing.T) {
- // Check that a simple write can be properly read from the rx side.
- tr := rand.New(rand.NewSource(99))
- rr := rand.New(rand.NewSource(99))
-
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- wb := tx.Push(10)
- if wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
- for i := range wb {
- wb[i] = byte(tr.Intn(256))
- }
- tx.Flush()
-
- var rx Rx
- rx.Init(b)
- rb := rx.Pull()
- if len(rb) != 10 {
- t.Fatalf("Bad buffer size returned: got %v, want %v", len(rb), 10)
- }
-
- for i := range rb {
- if v := byte(rr.Intn(256)); v != rb[i] {
- t.Fatalf("Bad read buffer at index %v: got %v, want %v", i, rb[i], v)
- }
- }
- rx.Flush()
-}
-
-func TestEmptyRead(t *testing.T) {
- // Check that pulling from an empty pipe fails.
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- var rx Rx
- rx.Init(b)
- if rb := rx.Pull(); rb != nil {
- t.Fatalf("Pull succeeded on empty pipe")
- }
-}
-
-func TestTooLargeWrite(t *testing.T) {
- // Check that writes that are too large are properly rejected.
- b := make([]byte, 96)
- var tx Tx
- tx.Init(b)
-
- if wb := tx.Push(96); wb != nil {
- t.Fatalf("Write of 96 bytes succeeded on 96-byte pipe")
- }
-
- if wb := tx.Push(88); wb != nil {
- t.Fatalf("Write of 88 bytes succeeded on 96-byte pipe")
- }
-
- if wb := tx.Push(80); wb == nil {
- t.Fatalf("Write of 80 bytes failed on 96-byte pipe")
- }
-}
-
-func TestFullWrite(t *testing.T) {
- // Check that writes fail when the pipe is full.
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- if wb := tx.Push(80); wb == nil {
- t.Fatalf("Write of 80 bytes failed on 96-byte pipe")
- }
-
- if wb := tx.Push(1); wb != nil {
- t.Fatalf("Write succeeded on full pipe")
- }
-}
-
-func TestFullAndFlushedWrite(t *testing.T) {
- // Check that writes fail when the pipe is full and has already been
- // flushed.
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- if wb := tx.Push(80); wb == nil {
- t.Fatalf("Write of 80 bytes failed on 96-byte pipe")
- }
-
- tx.Flush()
-
- if wb := tx.Push(1); wb != nil {
- t.Fatalf("Write succeeded on full pipe")
- }
-}
-
-func TestTxFlushTwice(t *testing.T) {
- // Checks that a second consecutive tx flush is a no-op.
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- if wb := tx.Push(50); wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
- tx.Flush()
-
- // Make copy of original tx queue, flush it, then check that it didn't
- // change.
- orig := tx
- tx.Flush()
-
- if !reflect.DeepEqual(orig, tx) {
- t.Fatalf("Flush mutated tx pipe: got %v, want %v", tx, orig)
- }
-}
-
-func TestRxFlushTwice(t *testing.T) {
- // Checks that a second consecutive rx flush is a no-op.
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- if wb := tx.Push(50); wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
- tx.Flush()
-
- var rx Rx
- rx.Init(b)
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull failed on non-empty pipe")
- }
- rx.Flush()
-
- // Make copy of original rx queue, flush it, then check that it didn't
- // change.
- orig := rx
- rx.Flush()
-
- if !reflect.DeepEqual(orig, rx) {
- t.Fatalf("Flush mutated rx pipe: got %v, want %v", rx, orig)
- }
-}
-
-func TestWrapInMiddleOfTransaction(t *testing.T) {
- // Check that writes are not flushed when we need to wrap the buffer
- // around.
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- if wb := tx.Push(50); wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
- tx.Flush()
-
- var rx Rx
- rx.Init(b)
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull failed on non-empty pipe")
- }
- rx.Flush()
-
- // At this point the ring buffer is empty, but the write is at offset
- // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment).
- if wb := tx.Push(10); wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
-
- if wb := tx.Push(50); wb == nil {
- t.Fatalf("Push failed on non-full pipe")
- }
-
- // We haven't flushed yet, so pull must return nil.
- if rb := rx.Pull(); rb != nil {
- t.Fatalf("Pull succeeded on non-flushed pipe")
- }
-
- tx.Flush()
-
- // The two buffers must be available now.
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull failed on non-empty pipe")
- }
-
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull failed on non-empty pipe")
- }
-}
-
-func TestWriteAbort(t *testing.T) {
- // Check that a read fails on a pipe that has had data pushed to it but
- // has aborted the push.
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- if wb := tx.Push(10); wb == nil {
- t.Fatalf("Write failed on empty pipe")
- }
-
- var rx Rx
- rx.Init(b)
- if rb := rx.Pull(); rb != nil {
- t.Fatalf("Pull succeeded on empty pipe")
- }
-
- tx.Abort()
- if rb := rx.Pull(); rb != nil {
- t.Fatalf("Pull succeeded on empty pipe")
- }
-}
-
-func TestWrappedWriteAbort(t *testing.T) {
- // Check that writes are properly aborted even if the writes wrap
- // around.
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- if wb := tx.Push(50); wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
- tx.Flush()
-
- var rx Rx
- rx.Init(b)
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull failed on non-empty pipe")
- }
- rx.Flush()
-
- // At this point the ring buffer is empty, but the write is at offset
- // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment).
- if wb := tx.Push(10); wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
-
- if wb := tx.Push(50); wb == nil {
- t.Fatalf("Push failed on non-full pipe")
- }
-
- // We haven't flushed yet, so pull must return nil.
- if rb := rx.Pull(); rb != nil {
- t.Fatalf("Pull succeeded on non-flushed pipe")
- }
-
- tx.Abort()
-
- // The pushes were aborted, so no data should be readable.
- if rb := rx.Pull(); rb != nil {
- t.Fatalf("Pull succeeded on non-flushed pipe")
- }
-
- // Try the same transactions again, but flush this time.
- if wb := tx.Push(10); wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
-
- if wb := tx.Push(50); wb == nil {
- t.Fatalf("Push failed on non-full pipe")
- }
-
- tx.Flush()
-
- // The two buffers must be available now.
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull failed on non-empty pipe")
- }
-
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull failed on non-empty pipe")
- }
-}
-
-func TestEmptyReadOnNonFlushedWrite(t *testing.T) {
- // Check that a read fails on a pipe that has had data pushed to it
- // but not yet flushed.
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- if wb := tx.Push(10); wb == nil {
- t.Fatalf("Write failed on empty pipe")
- }
-
- var rx Rx
- rx.Init(b)
- if rb := rx.Pull(); rb != nil {
- t.Fatalf("Pull succeeded on empty pipe")
- }
-
- tx.Flush()
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull on failed on non-empty pipe")
- }
-}
-
-func TestPullAfterPullingEntirePipe(t *testing.T) {
- // Check that Pull fails when the pipe is full, but all of it has
- // already been pulled but not yet flushed.
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- if wb := tx.Push(50); wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
- tx.Flush()
-
- var rx Rx
- rx.Init(b)
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull failed on non-empty pipe")
- }
- rx.Flush()
-
- // At this point the ring buffer is empty, but the write is at offset
- // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). Write 3
- // buffers that will fill the pipe.
- if wb := tx.Push(10); wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
-
- if wb := tx.Push(20); wb == nil {
- t.Fatalf("Push failed on non-full pipe")
- }
-
- if wb := tx.Push(24); wb == nil {
- t.Fatalf("Push failed on non-full pipe")
- }
-
- tx.Flush()
-
- // The three buffers must be available now.
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull failed on non-empty pipe")
- }
-
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull failed on non-empty pipe")
- }
-
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull failed on non-empty pipe")
- }
-
- // Fourth pull must fail.
- if rb := rx.Pull(); rb != nil {
- t.Fatalf("Pull succeeded on empty pipe")
- }
-}
-
-func TestNoRoomToWrapOnPush(t *testing.T) {
- // Check that Push fails when it tries to allocate room to add a wrap
- // message.
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- if wb := tx.Push(50); wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
- tx.Flush()
-
- var rx Rx
- rx.Init(b)
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull failed on non-empty pipe")
- }
- rx.Flush()
-
- // At this point the ring buffer is empty, but the write is at offset
- // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). Write 20,
- // which won't fit (64+20+8+padding = 96, which wouldn't leave room for
- // the padding), so it wraps around.
- if wb := tx.Push(20); wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
-
- tx.Flush()
-
- // Buffer offset is at 28. Try to write 70, which would require a wrap
- // slot which cannot be created now.
- if wb := tx.Push(70); wb != nil {
- t.Fatalf("Push succeeded on pipe with no room for wrap message")
- }
-}
-
-func TestRxImplicitFlushOfWrapMessage(t *testing.T) {
- // Check if the first read is that of a wrapping message, that it gets
- // immediately flushed.
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- if wb := tx.Push(50); wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
- tx.Flush()
-
- // This will cause a wrapping message to written.
- if wb := tx.Push(60); wb != nil {
- t.Fatalf("Push succeeded when there is no room in pipe")
- }
-
- var rx Rx
- rx.Init(b)
-
- // Read the first message.
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull failed on non-empty pipe")
- }
- rx.Flush()
-
- // This should fail because of the wrapping message is taking up space.
- if wb := tx.Push(60); wb != nil {
- t.Fatalf("Push succeeded when there is no room in pipe")
- }
-
- // Try to read the next one. This should consume the wrapping message.
- rx.Pull()
-
- // This must now succeed.
- if wb := tx.Push(60); wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
-}
-
-func TestConcurrentReaderWriter(t *testing.T) {
- // Push a million buffers of random sizes and random contents. Check
- // that buffers read match what was written.
- tr := rand.New(rand.NewSource(99))
- rr := rand.New(rand.NewSource(99))
-
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- var rx Rx
- rx.Init(b)
-
- const count = 1000000
- var wg sync.WaitGroup
- defer wg.Wait()
- wg.Add(1)
- go func() {
- defer wg.Done()
- runtime.Gosched()
- for i := 0; i < count; i++ {
- n := 1 + tr.Intn(80)
- wb := tx.Push(uint64(n))
- for wb == nil {
- wb = tx.Push(uint64(n))
- }
-
- for j := range wb {
- wb[j] = byte(tr.Intn(256))
- }
-
- tx.Flush()
- }
- }()
-
- for i := 0; i < count; i++ {
- n := 1 + rr.Intn(80)
- rb := rx.Pull()
- for rb == nil {
- rb = rx.Pull()
- }
-
- if n != len(rb) {
- t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n)
- }
-
- for j := range rb {
- if v := byte(rr.Intn(256)); v != rb[j] {
- t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v)
- }
- }
-
- rx.Flush()
- }
-}
diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe_state_autogen.go b/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe_state_autogen.go
new file mode 100644
index 000000000..d3b40feb4
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package pipe
diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD
deleted file mode 100644
index 3ba06af73..000000000
--- a/pkg/tcpip/link/sharedmem/queue/BUILD
+++ /dev/null
@@ -1,27 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "queue",
- srcs = [
- "rx.go",
- "tx.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/log",
- "//pkg/tcpip/link/sharedmem/pipe",
- ],
-)
-
-go_test(
- name = "queue_test",
- srcs = [
- "queue_test.go",
- ],
- library = ":queue",
- deps = [
- "//pkg/tcpip/link/sharedmem/pipe",
- ],
-)
diff --git a/pkg/tcpip/link/sharedmem/queue/queue_state_autogen.go b/pkg/tcpip/link/sharedmem/queue/queue_state_autogen.go
new file mode 100644
index 000000000..563d4fbb4
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/queue/queue_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package queue
diff --git a/pkg/tcpip/link/sharedmem/queue/queue_test.go b/pkg/tcpip/link/sharedmem/queue/queue_test.go
deleted file mode 100644
index 9a0aad5d7..000000000
--- a/pkg/tcpip/link/sharedmem/queue/queue_test.go
+++ /dev/null
@@ -1,517 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package queue
-
-import (
- "encoding/binary"
- "reflect"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe"
-)
-
-func TestBasicTxQueue(t *testing.T) {
- // Tests that a basic transmit on a queue works, and that completion
- // gets properly reported as well.
- pb1 := make([]byte, 100)
- pb2 := make([]byte, 100)
-
- var rxp pipe.Rx
- rxp.Init(pb1)
-
- var txp pipe.Tx
- txp.Init(pb2)
-
- var q Tx
- q.Init(pb1, pb2)
-
- // Enqueue two buffers.
- b := []TxBuffer{
- {nil, 100, 60},
- {nil, 200, 40},
- }
-
- b[0].Next = &b[1]
-
- const usedID = 1002
- const usedTotalSize = 100
- if !q.Enqueue(usedID, usedTotalSize, 2, &b[0]) {
- t.Fatalf("Enqueue failed on empty queue")
- }
-
- // Check the contents of the pipe.
- d := rxp.Pull()
- if d == nil {
- t.Fatalf("Tx pipe is empty after Enqueue")
- }
-
- want := []byte{
- 234, 3, 0, 0, 0, 0, 0, 0, // id
- 100, 0, 0, 0, // total size
- 0, 0, 0, 0, // reserved
- 100, 0, 0, 0, 0, 0, 0, 0, // offset 1
- 60, 0, 0, 0, // size 1
- 200, 0, 0, 0, 0, 0, 0, 0, // offset 2
- 40, 0, 0, 0, // size 2
- }
-
- if !reflect.DeepEqual(want, d) {
- t.Fatalf("Bad posted packet: got %v, want %v", d, want)
- }
-
- rxp.Flush()
-
- // Check that there are no completions yet.
- if _, ok := q.CompletedPacket(); ok {
- t.Fatalf("Packet reported as completed too soon")
- }
-
- // Post a completion.
- d = txp.Push(8)
- if d == nil {
- t.Fatalf("Unable to push to rx pipe")
- }
- binary.LittleEndian.PutUint64(d, usedID)
- txp.Flush()
-
- // Check that completion is properly reported.
- id, ok := q.CompletedPacket()
- if !ok {
- t.Fatalf("Completion not reported")
- }
-
- if id != usedID {
- t.Fatalf("Bad completion id: got %v, want %v", id, usedID)
- }
-}
-
-func TestBasicRxQueue(t *testing.T) {
- // Tests that a basic receive on a queue works.
- pb1 := make([]byte, 100)
- pb2 := make([]byte, 100)
-
- var rxp pipe.Rx
- rxp.Init(pb1)
-
- var txp pipe.Tx
- txp.Init(pb2)
-
- var q Rx
- q.Init(pb1, pb2, nil)
-
- // Post two buffers.
- b := []RxBuffer{
- {100, 60, 1077, 0},
- {200, 40, 2123, 0},
- }
-
- if !q.PostBuffers(b) {
- t.Fatalf("PostBuffers failed on empty queue")
- }
-
- // Check the contents of the pipe.
- want := [][]byte{
- {
- 100, 0, 0, 0, 0, 0, 0, 0, // Offset1
- 60, 0, 0, 0, // Size1
- 0, 0, 0, 0, // Remaining in group 1
- 0, 0, 0, 0, 0, 0, 0, 0, // User data 1
- 53, 4, 0, 0, 0, 0, 0, 0, // ID 1
- },
- {
- 200, 0, 0, 0, 0, 0, 0, 0, // Offset2
- 40, 0, 0, 0, // Size2
- 0, 0, 0, 0, // Remaining in group 2
- 0, 0, 0, 0, 0, 0, 0, 0, // User data 2
- 75, 8, 0, 0, 0, 0, 0, 0, // ID 2
- },
- }
-
- for i := range b {
- d := rxp.Pull()
- if d == nil {
- t.Fatalf("Tx pipe is empty after PostBuffers")
- }
-
- if !reflect.DeepEqual(want[i], d) {
- t.Fatalf("Bad posted packet: got %v, want %v", d, want[i])
- }
-
- rxp.Flush()
- }
-
- // Check that there are no completions.
- if _, n := q.Dequeue(nil); n != 0 {
- t.Fatalf("Packet reported as received too soon")
- }
-
- // Post a completion.
- d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer)
- if d == nil {
- t.Fatalf("Unable to push to rx pipe")
- }
-
- copy(d, []byte{
- 100, 0, 0, 0, // packet size
- 0, 0, 0, 0, // reserved
-
- 100, 0, 0, 0, 0, 0, 0, 0, // offset 1
- 60, 0, 0, 0, // size 1
- 0, 0, 0, 0, 0, 0, 0, 0, // user data 1
- 53, 4, 0, 0, 0, 0, 0, 0, // ID 1
-
- 200, 0, 0, 0, 0, 0, 0, 0, // offset 2
- 40, 0, 0, 0, // size 2
- 0, 0, 0, 0, 0, 0, 0, 0, // user data 2
- 75, 8, 0, 0, 0, 0, 0, 0, // ID 2
- })
-
- txp.Flush()
-
- // Check that completion is properly reported.
- bufs, n := q.Dequeue(nil)
- if n != 100 {
- t.Fatalf("Bad packet size: got %v, want %v", n, 100)
- }
-
- if !reflect.DeepEqual(bufs, b) {
- t.Fatalf("Bad returned buffers: got %v, want %v", bufs, b)
- }
-}
-
-func TestBadTxCompletion(t *testing.T) {
- // Check that tx completions with bad sizes are properly ignored.
- pb1 := make([]byte, 100)
- pb2 := make([]byte, 100)
-
- var rxp pipe.Rx
- rxp.Init(pb1)
-
- var txp pipe.Tx
- txp.Init(pb2)
-
- var q Tx
- q.Init(pb1, pb2)
-
- // Post a completion that is too short, and check that it is ignored.
- if d := txp.Push(7); d == nil {
- t.Fatalf("Unable to push to rx pipe")
- }
- txp.Flush()
-
- if _, ok := q.CompletedPacket(); ok {
- t.Fatalf("Bad completion not ignored")
- }
-
- // Post a completion that is too long, and check that it is ignored.
- if d := txp.Push(10); d == nil {
- t.Fatalf("Unable to push to rx pipe")
- }
- txp.Flush()
-
- if _, ok := q.CompletedPacket(); ok {
- t.Fatalf("Bad completion not ignored")
- }
-}
-
-func TestBadRxCompletion(t *testing.T) {
- // Check that bad rx completions are properly ignored.
- pb1 := make([]byte, 100)
- pb2 := make([]byte, 100)
-
- var rxp pipe.Rx
- rxp.Init(pb1)
-
- var txp pipe.Tx
- txp.Init(pb2)
-
- var q Rx
- q.Init(pb1, pb2, nil)
-
- // Post a completion that is too short, and check that it is ignored.
- if d := txp.Push(7); d == nil {
- t.Fatalf("Unable to push to rx pipe")
- }
- txp.Flush()
-
- if b, _ := q.Dequeue(nil); b != nil {
- t.Fatalf("Bad completion not ignored")
- }
-
- // Post a completion whose buffer sizes add up to less than the total
- // size.
- d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer)
- if d == nil {
- t.Fatalf("Unable to push to rx pipe")
- }
-
- copy(d, []byte{
- 100, 0, 0, 0, // packet size
- 0, 0, 0, 0, // reserved
-
- 100, 0, 0, 0, 0, 0, 0, 0, // offset 1
- 10, 0, 0, 0, // size 1
- 0, 0, 0, 0, 0, 0, 0, 0, // user data 1
- 53, 4, 0, 0, 0, 0, 0, 0, // ID 1
-
- 200, 0, 0, 0, 0, 0, 0, 0, // offset 2
- 10, 0, 0, 0, // size 2
- 0, 0, 0, 0, 0, 0, 0, 0, // user data 2
- 75, 8, 0, 0, 0, 0, 0, 0, // ID 2
- })
-
- txp.Flush()
- if b, _ := q.Dequeue(nil); b != nil {
- t.Fatalf("Bad completion not ignored")
- }
-
- // Post a completion whose buffer sizes will cause a 32-bit overflow,
- // but adds up to the right number.
- d = txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer)
- if d == nil {
- t.Fatalf("Unable to push to rx pipe")
- }
-
- copy(d, []byte{
- 100, 0, 0, 0, // packet size
- 0, 0, 0, 0, // reserved
-
- 100, 0, 0, 0, 0, 0, 0, 0, // offset 1
- 255, 255, 255, 255, // size 1
- 0, 0, 0, 0, 0, 0, 0, 0, // user data 1
- 53, 4, 0, 0, 0, 0, 0, 0, // ID 1
-
- 200, 0, 0, 0, 0, 0, 0, 0, // offset 2
- 101, 0, 0, 0, // size 2
- 0, 0, 0, 0, 0, 0, 0, 0, // user data 2
- 75, 8, 0, 0, 0, 0, 0, 0, // ID 2
- })
-
- txp.Flush()
- if b, _ := q.Dequeue(nil); b != nil {
- t.Fatalf("Bad completion not ignored")
- }
-}
-
-func TestFillTxPipe(t *testing.T) {
- // Check that transmitting a new buffer when the buffer pipe is full
- // fails gracefully.
- pb1 := make([]byte, 104)
- pb2 := make([]byte, 104)
-
- var rxp pipe.Rx
- rxp.Init(pb1)
-
- var txp pipe.Tx
- txp.Init(pb2)
-
- var q Tx
- q.Init(pb1, pb2)
-
- // Transmit twice, which should fill the tx pipe.
- b := []TxBuffer{
- {nil, 100, 60},
- {nil, 200, 40},
- }
-
- b[0].Next = &b[1]
-
- const usedID = 1002
- const usedTotalSize = 100
- for i := uint64(0); i < 2; i++ {
- if !q.Enqueue(usedID+i, usedTotalSize, 2, &b[0]) {
- t.Fatalf("Failed to transmit buffer")
- }
- }
-
- // Transmit another packet now that the tx pipe is full.
- if q.Enqueue(usedID+2, usedTotalSize, 2, &b[0]) {
- t.Fatalf("Enqueue succeeded when tx pipe is full")
- }
-}
-
-func TestFillRxPipe(t *testing.T) {
- // Check that posting a new buffer when the buffer pipe is full fails
- // gracefully.
- pb1 := make([]byte, 100)
- pb2 := make([]byte, 100)
-
- var rxp pipe.Rx
- rxp.Init(pb1)
-
- var txp pipe.Tx
- txp.Init(pb2)
-
- var q Rx
- q.Init(pb1, pb2, nil)
-
- // Post a buffer twice, it should fill the tx pipe.
- b := []RxBuffer{
- {100, 60, 1077, 0},
- }
-
- for i := 0; i < 2; i++ {
- if !q.PostBuffers(b) {
- t.Fatalf("PostBuffers failed on non-full queue")
- }
- }
-
- // Post another buffer now that the tx pipe is full.
- if q.PostBuffers(b) {
- t.Fatalf("PostBuffers succeeded on full queue")
- }
-}
-
-func TestLotsOfTransmissions(t *testing.T) {
- // Make sure pipes are being properly flushed when transmitting packets.
- pb1 := make([]byte, 100)
- pb2 := make([]byte, 100)
-
- var rxp pipe.Rx
- rxp.Init(pb1)
-
- var txp pipe.Tx
- txp.Init(pb2)
-
- var q Tx
- q.Init(pb1, pb2)
-
- // Prepare packet with two buffers.
- b := []TxBuffer{
- {nil, 100, 60},
- {nil, 200, 40},
- }
-
- b[0].Next = &b[1]
-
- const usedID = 1002
- const usedTotalSize = 100
-
- // Post 100000 packets and completions.
- for i := 100000; i > 0; i-- {
- if !q.Enqueue(usedID, usedTotalSize, 2, &b[0]) {
- t.Fatalf("Enqueue failed on non-full queue")
- }
-
- if d := rxp.Pull(); d == nil {
- t.Fatalf("Tx pipe is empty after Enqueue")
- }
- rxp.Flush()
-
- d := txp.Push(8)
- if d == nil {
- t.Fatalf("Unable to write to rx pipe")
- }
- binary.LittleEndian.PutUint64(d, usedID)
- txp.Flush()
- if _, ok := q.CompletedPacket(); !ok {
- t.Fatalf("Completion not returned")
- }
- }
-}
-
-func TestLotsOfReceptions(t *testing.T) {
- // Make sure pipes are being properly flushed when receiving packets.
- pb1 := make([]byte, 100)
- pb2 := make([]byte, 100)
-
- var rxp pipe.Rx
- rxp.Init(pb1)
-
- var txp pipe.Tx
- txp.Init(pb2)
-
- var q Rx
- q.Init(pb1, pb2, nil)
-
- // Prepare for posting two buffers.
- b := []RxBuffer{
- {100, 60, 1077, 0},
- {200, 40, 2123, 0},
- }
-
- // Post 100000 buffers and completions.
- for i := 100000; i > 0; i-- {
- if !q.PostBuffers(b) {
- t.Fatalf("PostBuffers failed on non-full queue")
- }
-
- if d := rxp.Pull(); d == nil {
- t.Fatalf("Tx pipe is empty after PostBuffers")
- }
- rxp.Flush()
-
- if d := rxp.Pull(); d == nil {
- t.Fatalf("Tx pipe is empty after PostBuffers")
- }
- rxp.Flush()
-
- d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer)
- if d == nil {
- t.Fatalf("Unable to push to rx pipe")
- }
-
- copy(d, []byte{
- 100, 0, 0, 0, // packet size
- 0, 0, 0, 0, // reserved
-
- 100, 0, 0, 0, 0, 0, 0, 0, // offset 1
- 60, 0, 0, 0, // size 1
- 0, 0, 0, 0, 0, 0, 0, 0, // user data 1
- 53, 4, 0, 0, 0, 0, 0, 0, // ID 1
-
- 200, 0, 0, 0, 0, 0, 0, 0, // offset 2
- 40, 0, 0, 0, // size 2
- 0, 0, 0, 0, 0, 0, 0, 0, // user data 2
- 75, 8, 0, 0, 0, 0, 0, 0, // ID 2
- })
-
- txp.Flush()
-
- if _, n := q.Dequeue(nil); n == 0 {
- t.Fatalf("Dequeue failed when there is a completion")
- }
- }
-}
-
-func TestRxEnableNotification(t *testing.T) {
- // Check that enabling nofifications results in properly updated state.
- pb1 := make([]byte, 100)
- pb2 := make([]byte, 100)
-
- var state uint32
- var q Rx
- q.Init(pb1, pb2, &state)
-
- q.EnableNotification()
- if state != eventFDEnabled {
- t.Fatalf("Bad value in shared state: got %v, want %v", state, eventFDEnabled)
- }
-}
-
-func TestRxDisableNotification(t *testing.T) {
- // Check that disabling nofifications results in properly updated state.
- pb1 := make([]byte, 100)
- pb2 := make([]byte, 100)
-
- var state uint32
- var q Rx
- q.Init(pb1, pb2, &state)
-
- q.DisableNotification()
- if state != eventFDDisabled {
- t.Fatalf("Bad value in shared state: got %v, want %v", state, eventFDDisabled)
- }
-}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_state_autogen.go b/pkg/tcpip/link/sharedmem/sharedmem_state_autogen.go
new file mode 100644
index 000000000..86551c9f5
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/sharedmem_state_autogen.go
@@ -0,0 +1,6 @@
+// automatically generated by stateify.
+
+//go:build linux && linux
+// +build linux,linux
+
+package sharedmem
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
deleted file mode 100644
index d6d953085..000000000
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ /dev/null
@@ -1,815 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//go:build linux
-// +build linux
-
-package sharedmem
-
-import (
- "bytes"
- "io/ioutil"
- "math/rand"
- "os"
- "strings"
- "testing"
- "time"
-
- "golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe"
- "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-const (
- localLinkAddr = "\xde\xad\xbe\xef\x56\x78"
- remoteLinkAddr = "\xde\xad\xbe\xef\x12\x34"
-
- queueDataSize = 1024 * 1024
- queuePipeSize = 4096
-)
-
-type queueBuffers struct {
- data []byte
- rx pipe.Tx
- tx pipe.Rx
-}
-
-func initQueue(t *testing.T, q *queueBuffers, c *QueueConfig) {
- // Prepare tx pipe.
- b, err := getBuffer(c.TxPipeFD)
- if err != nil {
- t.Fatalf("getBuffer failed: %v", err)
- }
- q.tx.Init(b)
-
- // Prepare rx pipe.
- b, err = getBuffer(c.RxPipeFD)
- if err != nil {
- t.Fatalf("getBuffer failed: %v", err)
- }
- q.rx.Init(b)
-
- // Get data slice.
- q.data, err = getBuffer(c.DataFD)
- if err != nil {
- t.Fatalf("getBuffer failed: %v", err)
- }
-}
-
-func (q *queueBuffers) cleanup() {
- unix.Munmap(q.tx.Bytes())
- unix.Munmap(q.rx.Bytes())
- unix.Munmap(q.data)
-}
-
-type packetInfo struct {
- addr tcpip.LinkAddress
- proto tcpip.NetworkProtocolNumber
- data buffer.View
- linkHeader buffer.View
-}
-
-type testContext struct {
- t *testing.T
- ep *endpoint
- txCfg QueueConfig
- rxCfg QueueConfig
- txq queueBuffers
- rxq queueBuffers
-
- packetCh chan struct{}
- mu sync.Mutex
- packets []packetInfo
-}
-
-func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress) *testContext {
- var err error
- c := &testContext{
- t: t,
- packetCh: make(chan struct{}, 1000000),
- }
- c.txCfg = createQueueFDs(t, queueSizes{
- dataSize: queueDataSize,
- txPipeSize: queuePipeSize,
- rxPipeSize: queuePipeSize,
- sharedDataSize: 4096,
- })
-
- c.rxCfg = createQueueFDs(t, queueSizes{
- dataSize: queueDataSize,
- txPipeSize: queuePipeSize,
- rxPipeSize: queuePipeSize,
- sharedDataSize: 4096,
- })
-
- initQueue(t, &c.txq, &c.txCfg)
- initQueue(t, &c.rxq, &c.rxCfg)
-
- ep, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg)
- if err != nil {
- t.Fatalf("New failed: %v", err)
- }
-
- c.ep = ep.(*endpoint)
- c.ep.Attach(c)
-
- return c
-}
-
-func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- c.mu.Lock()
- c.packets = append(c.packets, packetInfo{
- addr: remoteLinkAddr,
- proto: proto,
- data: pkt.Data().AsRange().ToOwnedView(),
- })
- c.mu.Unlock()
-
- c.packetCh <- struct{}{}
-}
-
-func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- panic("unimplemented")
-}
-
-func (c *testContext) cleanup() {
- c.ep.Close()
- closeFDs(&c.txCfg)
- closeFDs(&c.rxCfg)
- c.txq.cleanup()
- c.rxq.cleanup()
-}
-
-func (c *testContext) waitForPackets(n int, to <-chan time.Time, errorStr string) {
- for i := 0; i < n; i++ {
- select {
- case <-c.packetCh:
- case <-to:
- c.t.Fatalf(errorStr)
- }
- }
-}
-
-func (c *testContext) pushRxCompletion(size uint32, bs []queue.RxBuffer) {
- b := c.rxq.rx.Push(queue.RxCompletionSize(len(bs)))
- queue.EncodeRxCompletion(b, size, 0)
- for i := range bs {
- queue.EncodeRxCompletionBuffer(b, i, queue.RxBuffer{
- Offset: bs[i].Offset,
- Size: bs[i].Size,
- ID: bs[i].ID,
- })
- }
-}
-
-func randomFill(b []byte) {
- for i := range b {
- b[i] = byte(rand.Intn(256))
- }
-}
-
-func shuffle(b []int) {
- for i := len(b) - 1; i >= 0; i-- {
- j := rand.Intn(i + 1)
- b[i], b[j] = b[j], b[i]
- }
-}
-
-func createFile(t *testing.T, size int64, initQueue bool) int {
- tmpDir, ok := os.LookupEnv("TEST_TMPDIR")
- if !ok {
- tmpDir = os.Getenv("TMPDIR")
- }
- f, err := ioutil.TempFile(tmpDir, "sharedmem_test")
- if err != nil {
- t.Fatalf("TempFile failed: %v", err)
- }
- defer f.Close()
- unix.Unlink(f.Name())
-
- if initQueue {
- // Write the "slot-free" flag in the initial queue.
- _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0)
- if err != nil {
- t.Fatalf("WriteAt failed: %v", err)
- }
- }
-
- fd, err := unix.Dup(int(f.Fd()))
- if err != nil {
- t.Fatalf("Dup failed: %v", err)
- }
-
- if err := unix.Ftruncate(fd, size); err != nil {
- unix.Close(fd)
- t.Fatalf("Ftruncate failed: %v", err)
- }
-
- return fd
-}
-
-func closeFDs(c *QueueConfig) {
- unix.Close(c.DataFD)
- unix.Close(c.EventFD)
- unix.Close(c.TxPipeFD)
- unix.Close(c.RxPipeFD)
- unix.Close(c.SharedDataFD)
-}
-
-type queueSizes struct {
- dataSize int64
- txPipeSize int64
- rxPipeSize int64
- sharedDataSize int64
-}
-
-func createQueueFDs(t *testing.T, s queueSizes) QueueConfig {
- fd, _, err := unix.RawSyscall(unix.SYS_EVENTFD2, 0, 0, 0)
- if err != 0 {
- t.Fatalf("eventfd failed: %v", error(err))
- }
-
- return QueueConfig{
- EventFD: int(fd),
- DataFD: createFile(t, s.dataSize, false),
- TxPipeFD: createFile(t, s.txPipeSize, true),
- RxPipeFD: createFile(t, s.rxPipeSize, true),
- SharedDataFD: createFile(t, s.sharedDataSize, false),
- }
-}
-
-// TestSimpleSend sends 1000 packets with random header and payload sizes,
-// then checks that the right payload is received on the shared memory queues.
-func TestSimpleSend(t *testing.T) {
- c := newTestContext(t, 20000, 1500, localLinkAddr)
- defer c.cleanup()
-
- // Prepare route.
- var r stack.RouteInfo
- r.RemoteLinkAddress = remoteLinkAddr
-
- for iters := 1000; iters > 0; iters-- {
- func() {
- hdrLen, dataLen := rand.Intn(10000), rand.Intn(10000)
-
- // Prepare and send packet.
- hdrBuf := buffer.NewView(hdrLen)
- randomFill(hdrBuf)
-
- data := buffer.NewView(dataLen)
- randomFill(data)
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: hdrLen + int(c.ep.MaxHeaderLength()),
- Data: data.ToVectorisedView(),
- })
- copy(pkt.NetworkHeader().Push(hdrLen), hdrBuf)
-
- proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
- if err := c.ep.WritePacket(r, proto, pkt); err != nil {
- t.Fatalf("WritePacket failed: %v", err)
- }
-
- // Receive packet.
- desc := c.txq.tx.Pull()
- pi := queue.DecodeTxPacketHeader(desc)
- if pi.Reserved != 0 {
- t.Fatalf("Reserved value is non-zero: 0x%x", pi.Reserved)
- }
- contents := make([]byte, 0, pi.Size)
- for i := 0; i < pi.BufferCount; i++ {
- bi := queue.DecodeTxBufferHeader(desc, i)
- contents = append(contents, c.txq.data[bi.Offset:][:bi.Size]...)
- }
- c.txq.tx.Flush()
-
- defer func() {
- // Tell the endpoint about the completion of the write.
- b := c.txq.rx.Push(8)
- queue.EncodeTxCompletion(b, pi.ID)
- c.txq.rx.Flush()
- }()
-
- // Check the ethernet header.
- ethTemplate := make(header.Ethernet, header.EthernetMinimumSize)
- ethTemplate.Encode(&header.EthernetFields{
- SrcAddr: localLinkAddr,
- DstAddr: remoteLinkAddr,
- Type: proto,
- })
- if got := contents[:header.EthernetMinimumSize]; !bytes.Equal(got, []byte(ethTemplate)) {
- t.Fatalf("Bad ethernet header in packet: got %x, want %x", got, ethTemplate)
- }
-
- // Compare contents skipping the ethernet header added by the
- // endpoint.
- merged := append(hdrBuf, data...)
- if uint32(len(contents)) < pi.Size {
- t.Fatalf("Sum of buffers is less than packet size: %v < %v", len(contents), pi.Size)
- }
- contents = contents[:pi.Size][header.EthernetMinimumSize:]
-
- if !bytes.Equal(contents, merged) {
- t.Fatalf("Buffers are different: got %x (%v bytes), want %x (%v bytes)", contents, len(contents), merged, len(merged))
- }
- }()
- }
-}
-
-// TestPreserveSrcAddressInSend calls WritePacket once with LocalLinkAddress
-// set in Route (using much of the same code as TestSimpleSend), then checks
-// that the encoded ethernet header received includes the correct SrcAddr.
-func TestPreserveSrcAddressInSend(t *testing.T) {
- c := newTestContext(t, 20000, 1500, localLinkAddr)
- defer c.cleanup()
-
- newLocalLinkAddress := tcpip.LinkAddress(strings.Repeat("0xFE", 6))
- // Set both remote and local link address in route.
- var r stack.RouteInfo
- r.LocalLinkAddress = newLocalLinkAddress
- r.RemoteLinkAddress = remoteLinkAddr
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- // WritePacket panics given a prependable with anything less than
- // the minimum size of the ethernet header.
- ReserveHeaderBytes: header.EthernetMinimumSize,
- })
-
- proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
- if err := c.ep.WritePacket(r, proto, pkt); err != nil {
- t.Fatalf("WritePacket failed: %v", err)
- }
-
- // Receive packet.
- desc := c.txq.tx.Pull()
- pi := queue.DecodeTxPacketHeader(desc)
- if pi.Reserved != 0 {
- t.Fatalf("Reserved value is non-zero: 0x%x", pi.Reserved)
- }
- contents := make([]byte, 0, pi.Size)
- for i := 0; i < pi.BufferCount; i++ {
- bi := queue.DecodeTxBufferHeader(desc, i)
- contents = append(contents, c.txq.data[bi.Offset:][:bi.Size]...)
- }
- c.txq.tx.Flush()
-
- defer func() {
- // Tell the endpoint about the completion of the write.
- b := c.txq.rx.Push(8)
- queue.EncodeTxCompletion(b, pi.ID)
- c.txq.rx.Flush()
- }()
-
- // Check that the ethernet header contains the expected SrcAddr.
- ethTemplate := make(header.Ethernet, header.EthernetMinimumSize)
- ethTemplate.Encode(&header.EthernetFields{
- SrcAddr: newLocalLinkAddress,
- DstAddr: remoteLinkAddr,
- Type: proto,
- })
- if got := contents[:header.EthernetMinimumSize]; !bytes.Equal(got, []byte(ethTemplate)) {
- t.Fatalf("Bad ethernet header in packet: got %x, want %x", got, ethTemplate)
- }
-}
-
-// TestFillTxQueue sends packets until the queue is full.
-func TestFillTxQueue(t *testing.T) {
- c := newTestContext(t, 20000, 1500, localLinkAddr)
- defer c.cleanup()
-
- // Prepare to send a packet.
- var r stack.RouteInfo
- r.RemoteLinkAddress = remoteLinkAddr
-
- buf := buffer.NewView(100)
-
- // Each packet is uses no more than 40 bytes, so write that many packets
- // until the tx queue if full.
- ids := make(map[uint64]struct{})
- for i := queuePipeSize / 40; i > 0; i-- {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
- Data: buf.ToVectorisedView(),
- })
-
- if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil {
- t.Fatalf("WritePacket failed unexpectedly: %v", err)
- }
-
- // Check that they have different IDs.
- desc := c.txq.tx.Pull()
- pi := queue.DecodeTxPacketHeader(desc)
- if _, ok := ids[pi.ID]; ok {
- t.Fatalf("ID (%v) reused", pi.ID)
- }
- ids[pi.ID] = struct{}{}
- }
-
- // Next attempt to write must fail.
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
- Data: buf.ToVectorisedView(),
- })
- err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt)
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{})
- }
-}
-
-// TestFillTxQueueAfterBadCompletion sends a bad completion, then sends packets
-// until the queue is full.
-func TestFillTxQueueAfterBadCompletion(t *testing.T) {
- c := newTestContext(t, 20000, 1500, localLinkAddr)
- defer c.cleanup()
-
- // Send a bad completion.
- queue.EncodeTxCompletion(c.txq.rx.Push(8), 1)
- c.txq.rx.Flush()
-
- // Prepare to send a packet.
- var r stack.RouteInfo
- r.RemoteLinkAddress = remoteLinkAddr
-
- buf := buffer.NewView(100)
-
- // Send two packets so that the id slice has at least two slots.
- for i := 2; i > 0; i-- {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
- Data: buf.ToVectorisedView(),
- })
- if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil {
- t.Fatalf("WritePacket failed unexpectedly: %v", err)
- }
- }
-
- // Complete the two writes twice.
- for i := 2; i > 0; i-- {
- pi := queue.DecodeTxPacketHeader(c.txq.tx.Pull())
-
- queue.EncodeTxCompletion(c.txq.rx.Push(8), pi.ID)
- queue.EncodeTxCompletion(c.txq.rx.Push(8), pi.ID)
- c.txq.rx.Flush()
- }
- c.txq.tx.Flush()
-
- // Each packet is uses no more than 40 bytes, so write that many packets
- // until the tx queue if full.
- ids := make(map[uint64]struct{})
- for i := queuePipeSize / 40; i > 0; i-- {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
- Data: buf.ToVectorisedView(),
- })
- if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil {
- t.Fatalf("WritePacket failed unexpectedly: %v", err)
- }
-
- // Check that they have different IDs.
- desc := c.txq.tx.Pull()
- pi := queue.DecodeTxPacketHeader(desc)
- if _, ok := ids[pi.ID]; ok {
- t.Fatalf("ID (%v) reused", pi.ID)
- }
- ids[pi.ID] = struct{}{}
- }
-
- // Next attempt to write must fail.
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
- Data: buf.ToVectorisedView(),
- })
- err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt)
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{})
- }
-}
-
-// TestFillTxMemory sends packets until the we run out of shared memory.
-func TestFillTxMemory(t *testing.T) {
- const bufferSize = 1500
- c := newTestContext(t, 20000, bufferSize, localLinkAddr)
- defer c.cleanup()
-
- // Prepare to send a packet.
- var r stack.RouteInfo
- r.RemoteLinkAddress = remoteLinkAddr
-
- buf := buffer.NewView(100)
-
- // Each packet is uses up one buffer, so write as many as possible until
- // we fill the memory.
- ids := make(map[uint64]struct{})
- for i := queueDataSize / bufferSize; i > 0; i-- {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
- Data: buf.ToVectorisedView(),
- })
- if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil {
- t.Fatalf("WritePacket failed unexpectedly: %v", err)
- }
-
- // Check that they have different IDs.
- desc := c.txq.tx.Pull()
- pi := queue.DecodeTxPacketHeader(desc)
- if _, ok := ids[pi.ID]; ok {
- t.Fatalf("ID (%v) reused", pi.ID)
- }
- ids[pi.ID] = struct{}{}
- c.txq.tx.Flush()
- }
-
- // Next attempt to write must fail.
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
- Data: buf.ToVectorisedView(),
- })
- err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt)
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{})
- }
-}
-
-// TestFillTxMemoryWithMultiBuffer sends packets until the we run out of
-// shared memory for a 2-buffer packet, but still with room for a 1-buffer
-// packet.
-func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
- const bufferSize = 1500
- c := newTestContext(t, 20000, bufferSize, localLinkAddr)
- defer c.cleanup()
-
- // Prepare to send a packet.
- var r stack.RouteInfo
- r.RemoteLinkAddress = remoteLinkAddr
-
- buf := buffer.NewView(100)
-
- // Each packet is uses up one buffer, so write as many as possible
- // until there is only one buffer left.
- for i := queueDataSize/bufferSize - 1; i > 0; i-- {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
- Data: buf.ToVectorisedView(),
- })
- if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil {
- t.Fatalf("WritePacket failed unexpectedly: %v", err)
- }
-
- // Pull the posted buffer.
- c.txq.tx.Pull()
- c.txq.tx.Flush()
- }
-
- // Attempt to write a two-buffer packet. It must fail.
- {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
- Data: buffer.NewView(bufferSize).ToVectorisedView(),
- })
- err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt)
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{})
- }
- }
-
- // Attempt to write the one-buffer packet again. It must succeed.
- {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
- Data: buf.ToVectorisedView(),
- })
- if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil {
- t.Fatalf("WritePacket failed unexpectedly: %v", err)
- }
- }
-}
-
-func pollPull(t *testing.T, p *pipe.Rx, to <-chan time.Time, errStr string) []byte {
- t.Helper()
-
- for {
- b := p.Pull()
- if b != nil {
- return b
- }
-
- select {
- case <-time.After(10 * time.Millisecond):
- case <-to:
- t.Fatal(errStr)
- }
- }
-}
-
-// TestSimpleReceive completes 1000 different receives with random payload and
-// random number of buffers. It checks that the contents match the expected
-// values.
-func TestSimpleReceive(t *testing.T) {
- const bufferSize = 1500
- c := newTestContext(t, 20000, bufferSize, localLinkAddr)
- defer c.cleanup()
-
- // Check that buffers have been posted.
- limit := c.ep.rx.q.PostedBuffersLimit()
- for i := uint64(0); i < limit; i++ {
- timeout := time.After(2 * time.Second)
- bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for all buffers to be posted"))
-
- if want := i * bufferSize; want != bi.Offset {
- t.Fatalf("Bad posted offset: got %v, want %v", bi.Offset, want)
- }
-
- if want := i; want != bi.ID {
- t.Fatalf("Bad posted ID: got %v, want %v", bi.ID, want)
- }
-
- if bufferSize != bi.Size {
- t.Fatalf("Bad posted bufferSize: got %v, want %v", bi.Size, bufferSize)
- }
- }
- c.rxq.tx.Flush()
-
- // Create a slice with the indices 0..limit-1.
- idx := make([]int, limit)
- for i := range idx {
- idx[i] = i
- }
-
- // Complete random packets 1000 times.
- for iters := 1000; iters > 0; iters-- {
- timeout := time.After(2 * time.Second)
- // Prepare a random packet.
- shuffle(idx)
- n := 1 + rand.Intn(10)
- bufs := make([]queue.RxBuffer, n)
- contents := make([]byte, bufferSize*n-rand.Intn(500))
- randomFill(contents)
- for i := range bufs {
- j := idx[i]
- bufs[i].Size = bufferSize
- bufs[i].Offset = uint64(bufferSize * j)
- bufs[i].ID = uint64(j)
-
- copy(c.rxq.data[bufs[i].Offset:][:bufferSize], contents[i*bufferSize:])
- }
-
- // Push completion.
- c.pushRxCompletion(uint32(len(contents)), bufs)
- c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
-
- // Wait for packet to be received, then check it.
- c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet")
- c.mu.Lock()
- rcvd := []byte(c.packets[0].data)
- c.packets = c.packets[:0]
- c.mu.Unlock()
-
- if contents := contents[header.EthernetMinimumSize:]; !bytes.Equal(contents, rcvd) {
- t.Fatalf("Unexpected buffer contents: got %x, want %x", rcvd, contents)
- }
-
- // Check that buffers have been reposted.
- for i := range bufs {
- bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffers to be reposted"))
- if bi != bufs[i] {
- t.Fatalf("Unexpected buffer reposted: got %x, want %x", bi, bufs[i])
- }
- }
- c.rxq.tx.Flush()
- }
-}
-
-// TestRxBuffersReposted tests that rx buffers get reposted after they have been
-// completed.
-func TestRxBuffersReposted(t *testing.T) {
- const bufferSize = 1500
- c := newTestContext(t, 20000, bufferSize, localLinkAddr)
- defer c.cleanup()
-
- // Receive all posted buffers.
- limit := c.ep.rx.q.PostedBuffersLimit()
- buffers := make([]queue.RxBuffer, 0, limit)
- for i := limit; i > 0; i-- {
- timeout := time.After(2 * time.Second)
- buffers = append(buffers, queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for all buffers")))
- }
- c.rxq.tx.Flush()
-
- // Check that all buffers are reposted when individually completed.
- for i := range buffers {
- timeout := time.After(2 * time.Second)
- // Complete the buffer.
- c.pushRxCompletion(buffers[i].Size, buffers[i:][:1])
- c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
-
- // Wait for it to be reposted.
- bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted"))
- if bi != buffers[i] {
- t.Fatalf("Different buffer posted: got %v, want %v", bi, buffers[i])
- }
- }
- c.rxq.tx.Flush()
-
- // Check that all buffers are reposted when completed in pairs.
- for i := 0; i < len(buffers)/2; i++ {
- timeout := time.After(2 * time.Second)
- // Complete with two buffers.
- c.pushRxCompletion(2*bufferSize, buffers[2*i:][:2])
- c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
-
- // Wait for them to be reposted.
- for j := 0; j < 2; j++ {
- bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted"))
- if bi != buffers[2*i+j] {
- t.Fatalf("Different buffer posted: got %v, want %v", bi, buffers[2*i+j])
- }
- }
- }
- c.rxq.tx.Flush()
-}
-
-// TestReceivePostingIsFull checks that the endpoint will properly handle the
-// case when a received buffer cannot be immediately reposted because it hasn't
-// been pulled from the tx pipe yet.
-func TestReceivePostingIsFull(t *testing.T) {
- const bufferSize = 1500
- c := newTestContext(t, 20000, bufferSize, localLinkAddr)
- defer c.cleanup()
-
- // Complete first posted buffer before flushing it from the tx pipe.
- first := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for first buffer to be posted"))
- c.pushRxCompletion(first.Size, []queue.RxBuffer{first})
- c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
-
- // Check that packet is received.
- c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet")
-
- // Complete another buffer.
- second := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for second buffer to be posted"))
- c.pushRxCompletion(second.Size, []queue.RxBuffer{second})
- c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
-
- // Check that no packet is received yet, as the worker is blocked trying
- // to repost.
- select {
- case <-time.After(500 * time.Millisecond):
- case <-c.packetCh:
- t.Fatalf("Unexpected packet received")
- }
-
- // Flush tx queue, which will allow the first buffer to be reposted,
- // and the second completion to be pulled.
- c.rxq.tx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
-
- // Check that second packet completes.
- c.waitForPackets(1, time.After(time.Second), "Timeout waiting for second completed packet")
-}
-
-// TestCloseWhileWaitingToPost closes the endpoint while it is waiting to
-// repost a buffer. Make sure it backs out.
-func TestCloseWhileWaitingToPost(t *testing.T) {
- const bufferSize = 1500
- c := newTestContext(t, 20000, bufferSize, localLinkAddr)
- cleaned := false
- defer func() {
- if !cleaned {
- c.cleanup()
- }
- }()
-
- // Complete first posted buffer before flushing it from the tx pipe.
- bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for initial buffer to be posted"))
- c.pushRxCompletion(bi.Size, []queue.RxBuffer{bi})
- c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
-
- // Wait for packet to be indicated.
- c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet")
-
- // Cleanup and wait for worker to complete.
- c.cleanup()
- cleaned = true
- c.ep.Wait()
-}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_unsafe_state_autogen.go b/pkg/tcpip/link/sharedmem/sharedmem_unsafe_state_autogen.go
new file mode 100644
index 000000000..ac3a66520
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/sharedmem_unsafe_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package sharedmem
diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD
deleted file mode 100644
index 4aac12a8c..000000000
--- a/pkg/tcpip/link/sniffer/BUILD
+++ /dev/null
@@ -1,21 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "sniffer",
- srcs = [
- "pcap.go",
- "sniffer.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/log",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/header/parse",
- "//pkg/tcpip/link/nested",
- "//pkg/tcpip/stack",
- ],
-)
diff --git a/pkg/tcpip/link/sniffer/sniffer_state_autogen.go b/pkg/tcpip/link/sniffer/sniffer_state_autogen.go
new file mode 100644
index 000000000..8d79defea
--- /dev/null
+++ b/pkg/tcpip/link/sniffer/sniffer_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package sniffer
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
deleted file mode 100644
index c3e4c3455..000000000
--- a/pkg/tcpip/link/tun/BUILD
+++ /dev/null
@@ -1,42 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-package(licenses = ["notice"])
-
-go_template_instance(
- name = "tun_endpoint_refs",
- out = "tun_endpoint_refs.go",
- package = "tun",
- prefix = "tunEndpoint",
- template = "//pkg/refsvfs2:refs_template",
- types = {
- "T": "tunEndpoint",
- },
-)
-
-go_library(
- name = "tun",
- srcs = [
- "device.go",
- "protocol.go",
- "tun_endpoint_refs.go",
- "tun_unsafe.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/abi/linux",
- "//pkg/context",
- "//pkg/errors/linuxerr",
- "//pkg/log",
- "//pkg/refs",
- "//pkg/refsvfs2",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/stack",
- "//pkg/waiter",
- "@org_golang_x_sys//unix:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/link/tun/tun_endpoint_refs.go b/pkg/tcpip/link/tun/tun_endpoint_refs.go
new file mode 100644
index 000000000..a3bee1c05
--- /dev/null
+++ b/pkg/tcpip/link/tun/tun_endpoint_refs.go
@@ -0,0 +1,140 @@
+package tun
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/refsvfs2"
+)
+
+// enableLogging indicates whether reference-related events should be logged (with
+// stack traces). This is false by default and should only be set to true for
+// debugging purposes, as it can generate an extremely large amount of output
+// and drastically degrade performance.
+const tunEndpointenableLogging = false
+
+// obj is used to customize logging. Note that we use a pointer to T so that
+// we do not copy the entire object when passed as a format parameter.
+var tunEndpointobj *tunEndpoint
+
+// Refs implements refs.RefCounter. It keeps a reference count using atomic
+// operations and calls the destructor when the count reaches zero.
+//
+// NOTE: Do not introduce additional fields to the Refs struct. It is used by
+// many filesystem objects, and we want to keep it as small as possible (i.e.,
+// the same size as using an int64 directly) to avoid taking up extra cache
+// space. In general, this template should not be extended at the cost of
+// performance. If it does not offer enough flexibility for a particular object
+// (example: b/187877947), we should implement the RefCounter/CheckedObject
+// interfaces manually.
+//
+// +stateify savable
+type tunEndpointRefs struct {
+ // refCount is composed of two fields:
+ //
+ // [32-bit speculative references]:[32-bit real references]
+ //
+ // Speculative references are used for TryIncRef, to avoid a CompareAndSwap
+ // loop. See IncRef, DecRef and TryIncRef for details of how these fields are
+ // used.
+ refCount int64
+}
+
+// InitRefs initializes r with one reference and, if enabled, activates leak
+// checking.
+func (r *tunEndpointRefs) InitRefs() {
+ atomic.StoreInt64(&r.refCount, 1)
+ refsvfs2.Register(r)
+}
+
+// RefType implements refsvfs2.CheckedObject.RefType.
+func (r *tunEndpointRefs) RefType() string {
+ return fmt.Sprintf("%T", tunEndpointobj)[1:]
+}
+
+// LeakMessage implements refsvfs2.CheckedObject.LeakMessage.
+func (r *tunEndpointRefs) LeakMessage() string {
+ return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs())
+}
+
+// LogRefs implements refsvfs2.CheckedObject.LogRefs.
+func (r *tunEndpointRefs) LogRefs() bool {
+ return tunEndpointenableLogging
+}
+
+// ReadRefs returns the current number of references. The returned count is
+// inherently racy and is unsafe to use without external synchronization.
+func (r *tunEndpointRefs) ReadRefs() int64 {
+ return atomic.LoadInt64(&r.refCount)
+}
+
+// IncRef implements refs.RefCounter.IncRef.
+//
+//go:nosplit
+func (r *tunEndpointRefs) IncRef() {
+ v := atomic.AddInt64(&r.refCount, 1)
+ if tunEndpointenableLogging {
+ refsvfs2.LogIncRef(r, v)
+ }
+ if v <= 1 {
+ panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType()))
+ }
+}
+
+// TryIncRef implements refs.TryRefCounter.TryIncRef.
+//
+// To do this safely without a loop, a speculative reference is first acquired
+// on the object. This allows multiple concurrent TryIncRef calls to distinguish
+// other TryIncRef calls from genuine references held.
+//
+//go:nosplit
+func (r *tunEndpointRefs) TryIncRef() bool {
+ const speculativeRef = 1 << 32
+ if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 {
+
+ atomic.AddInt64(&r.refCount, -speculativeRef)
+ return false
+ }
+
+ v := atomic.AddInt64(&r.refCount, -speculativeRef+1)
+ if tunEndpointenableLogging {
+ refsvfs2.LogTryIncRef(r, v)
+ }
+ return true
+}
+
+// DecRef implements refs.RefCounter.DecRef.
+//
+// Note that speculative references are counted here. Since they were added
+// prior to real references reaching zero, they will successfully convert to
+// real references. In other words, we see speculative references only in the
+// following case:
+//
+// A: TryIncRef [speculative increase => sees non-negative references]
+// B: DecRef [real decrease]
+// A: TryIncRef [transform speculative to real]
+//
+//go:nosplit
+func (r *tunEndpointRefs) DecRef(destroy func()) {
+ v := atomic.AddInt64(&r.refCount, -1)
+ if tunEndpointenableLogging {
+ refsvfs2.LogDecRef(r, v)
+ }
+ switch {
+ case v < 0:
+ panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType()))
+
+ case v == 0:
+ refsvfs2.Unregister(r)
+
+ if destroy != nil {
+ destroy()
+ }
+ }
+}
+
+func (r *tunEndpointRefs) afterLoad() {
+ if r.ReadRefs() > 0 {
+ refsvfs2.Register(r)
+ }
+}
diff --git a/pkg/tcpip/link/tun/tun_state_autogen.go b/pkg/tcpip/link/tun/tun_state_autogen.go
new file mode 100644
index 000000000..c5773cc11
--- /dev/null
+++ b/pkg/tcpip/link/tun/tun_state_autogen.go
@@ -0,0 +1,68 @@
+// automatically generated by stateify.
+
+package tun
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (d *Device) StateTypeName() string {
+ return "pkg/tcpip/link/tun.Device"
+}
+
+func (d *Device) StateFields() []string {
+ return []string{
+ "Queue",
+ "endpoint",
+ "notifyHandle",
+ "flags",
+ }
+}
+
+// +checklocksignore
+func (d *Device) StateSave(stateSinkObject state.Sink) {
+ d.beforeSave()
+ stateSinkObject.Save(0, &d.Queue)
+ stateSinkObject.Save(1, &d.endpoint)
+ stateSinkObject.Save(2, &d.notifyHandle)
+ stateSinkObject.Save(3, &d.flags)
+}
+
+func (d *Device) afterLoad() {}
+
+// +checklocksignore
+func (d *Device) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &d.Queue)
+ stateSourceObject.Load(1, &d.endpoint)
+ stateSourceObject.Load(2, &d.notifyHandle)
+ stateSourceObject.Load(3, &d.flags)
+}
+
+func (r *tunEndpointRefs) StateTypeName() string {
+ return "pkg/tcpip/link/tun.tunEndpointRefs"
+}
+
+func (r *tunEndpointRefs) StateFields() []string {
+ return []string{
+ "refCount",
+ }
+}
+
+func (r *tunEndpointRefs) beforeSave() {}
+
+// +checklocksignore
+func (r *tunEndpointRefs) StateSave(stateSinkObject state.Sink) {
+ r.beforeSave()
+ stateSinkObject.Save(0, &r.refCount)
+}
+
+// +checklocksignore
+func (r *tunEndpointRefs) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &r.refCount)
+ stateSourceObject.AfterLoad(r.afterLoad)
+}
+
+func init() {
+ state.Register((*Device)(nil))
+ state.Register((*tunEndpointRefs)(nil))
+}
diff --git a/pkg/tcpip/link/tun/tun_unsafe_state_autogen.go b/pkg/tcpip/link/tun/tun_unsafe_state_autogen.go
new file mode 100644
index 000000000..8d82ad324
--- /dev/null
+++ b/pkg/tcpip/link/tun/tun_unsafe_state_autogen.go
@@ -0,0 +1,6 @@
+// automatically generated by stateify.
+
+//go:build linux
+// +build linux
+
+package tun
diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD
deleted file mode 100644
index b8d417b7d..000000000
--- a/pkg/tcpip/link/waitable/BUILD
+++ /dev/null
@@ -1,30 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "waitable",
- srcs = [
- "waitable.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/header",
- "//pkg/tcpip/stack",
- ],
-)
-
-go_test(
- name = "waitable_test",
- srcs = [
- "waitable_test.go",
- ],
- library = ":waitable",
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/header",
- "//pkg/tcpip/stack",
- ],
-)
diff --git a/pkg/tcpip/link/waitable/waitable_state_autogen.go b/pkg/tcpip/link/waitable/waitable_state_autogen.go
new file mode 100644
index 000000000..059424fa0
--- /dev/null
+++ b/pkg/tcpip/link/waitable/waitable_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package waitable
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
deleted file mode 100644
index b0e4237bd..000000000
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ /dev/null
@@ -1,187 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package waitable
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-type countedEndpoint struct {
- dispatchCount int
- writeCount int
- attachCount int
-
- mtu uint32
- capabilities stack.LinkEndpointCapabilities
- hdrLen uint16
- linkAddr tcpip.LinkAddress
-
- dispatcher stack.NetworkDispatcher
-}
-
-func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- e.dispatchCount++
-}
-
-func (e *countedEndpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- panic("unimplemented")
-}
-
-func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
- e.attachCount++
- e.dispatcher = dispatcher
-}
-
-// IsAttached implements stack.LinkEndpoint.IsAttached.
-func (e *countedEndpoint) IsAttached() bool {
- return e.dispatcher != nil
-}
-
-func (e *countedEndpoint) MTU() uint32 {
- return e.mtu
-}
-
-func (e *countedEndpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.capabilities
-}
-
-func (e *countedEndpoint) MaxHeaderLength() uint16 {
- return e.hdrLen
-}
-
-func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress {
- return e.linkAddr
-}
-
-func (e *countedEndpoint) WritePacket(stack.RouteInfo, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
- e.writeCount++
- return nil
-}
-
-// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- e.writeCount += pkts.Len()
- return pkts.Len(), nil
-}
-
-// WriteRawPacket implements stack.LinkEndpoint.
-func (*countedEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error {
- return &tcpip.ErrNotSupported{}
-}
-
-// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
-func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType {
- panic("unimplemented")
-}
-
-// Wait implements stack.LinkEndpoint.Wait.
-func (*countedEndpoint) Wait() {}
-
-// AddHeader implements stack.LinkEndpoint.AddHeader.
-func (e *countedEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- panic("unimplemented")
-}
-
-func TestWaitWrite(t *testing.T) {
- ep := &countedEndpoint{}
- wep := New(ep)
-
- // Write and check that it goes through.
- wep.WritePacket(stack.RouteInfo{}, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
- if want := 1; ep.writeCount != want {
- t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
- }
-
- // Wait on dispatches, then try to write. It must go through.
- wep.WaitDispatch()
- wep.WritePacket(stack.RouteInfo{}, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
- if want := 2; ep.writeCount != want {
- t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
- }
-
- // Wait on writes, then try to write. It must not go through.
- wep.WaitWrite()
- wep.WritePacket(stack.RouteInfo{}, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
- if want := 2; ep.writeCount != want {
- t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
- }
-}
-
-func TestWaitDispatch(t *testing.T) {
- ep := &countedEndpoint{}
- wep := New(ep)
-
- // Check that attach happens.
- wep.Attach(ep)
- if want := 1; ep.attachCount != want {
- t.Fatalf("Unexpected attachCount: got=%v, want=%v", ep.attachCount, want)
- }
-
- // Dispatch and check that it goes through.
- ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
- if want := 1; ep.dispatchCount != want {
- t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
- }
-
- // Wait on writes, then try to dispatch. It must go through.
- wep.WaitWrite()
- ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
- if want := 2; ep.dispatchCount != want {
- t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
- }
-
- // Wait on dispatches, then try to dispatch. It must not go through.
- wep.WaitDispatch()
- ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
- if want := 2; ep.dispatchCount != want {
- t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
- }
-}
-
-func TestOtherMethods(t *testing.T) {
- const (
- mtu = 0xdead
- capabilities = 0xbeef
- hdrLen = 0x1234
- linkAddr = "test address"
- )
- ep := &countedEndpoint{
- mtu: mtu,
- capabilities: capabilities,
- hdrLen: hdrLen,
- linkAddr: linkAddr,
- }
- wep := New(ep)
-
- if v := wep.MTU(); v != mtu {
- t.Fatalf("Unexpected mtu: got=%v, want=%v", v, mtu)
- }
-
- if v := wep.Capabilities(); v != capabilities {
- t.Fatalf("Unexpected capabilities: got=%v, want=%v", v, capabilities)
- }
-
- if v := wep.MaxHeaderLength(); v != hdrLen {
- t.Fatalf("Unexpected MaxHeaderLength: got=%v, want=%v", v, hdrLen)
- }
-
- if v := wep.LinkAddress(); v != linkAddr {
- t.Fatalf("Unexpected LinkAddress: got=%q, want=%q", v, linkAddr)
- }
-}
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
deleted file mode 100644
index c0179104a..000000000
--- a/pkg/tcpip/network/BUILD
+++ /dev/null
@@ -1,32 +0,0 @@
-load("//tools:defs.bzl", "go_test")
-
-package(licenses = ["notice"])
-
-go_test(
- name = "ip_test",
- size = "small",
- srcs = [
- "ip_test.go",
- "multicast_group_test.go",
- ],
- deps = [
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/faketime",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/loopback",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/icmp",
- "//pkg/tcpip/transport/raw",
- "//pkg/tcpip/transport/tcp",
- "//pkg/tcpip/transport/udp",
- "//pkg/waiter",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
deleted file mode 100644
index 6fa1aee18..000000000
--- a/pkg/tcpip/network/arp/BUILD
+++ /dev/null
@@ -1,53 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "arp",
- srcs = [
- "arp.go",
- "stats.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/header/parse",
- "//pkg/tcpip/network/internal/ip",
- "//pkg/tcpip/stack",
- ],
-)
-
-go_test(
- name = "arp_test",
- size = "small",
- srcs = ["arp_test.go"],
- deps = [
- ":arp",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/faketime",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/sniffer",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/testutil",
- "@com_github_google_go_cmp//cmp:go_default_library",
- "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
- ],
-)
-
-go_test(
- name = "stats_test",
- size = "small",
- srcs = ["stats_test.go"],
- library = ":arp",
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/testutil",
- ],
-)
diff --git a/pkg/tcpip/network/arp/arp_state_autogen.go b/pkg/tcpip/network/arp/arp_state_autogen.go
new file mode 100644
index 000000000..5cd8535e3
--- /dev/null
+++ b/pkg/tcpip/network/arp/arp_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package arp
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
deleted file mode 100644
index 061cc35ae..000000000
--- a/pkg/tcpip/network/arp/arp_test.go
+++ /dev/null
@@ -1,688 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package arp_test
-
-import (
- "fmt"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "github.com/google/go-cmp/cmp/cmpopts"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.dev/gvisor/pkg/tcpip/network/arp"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
-)
-
-const (
- nicID = 1
-
- stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
- remoteLinkAddr = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06")
-)
-
-var (
- stackAddr = testutil.MustParse4("10.0.0.1")
- remoteAddr = testutil.MustParse4("10.0.0.2")
- unknownAddr = testutil.MustParse4("10.0.0.3")
-)
-
-type eventType uint8
-
-const (
- entryAdded eventType = iota
- entryChanged
- entryRemoved
-)
-
-func (t eventType) String() string {
- switch t {
- case entryAdded:
- return "add"
- case entryChanged:
- return "change"
- case entryRemoved:
- return "remove"
- default:
- return fmt.Sprintf("unknown (%d)", t)
- }
-}
-
-type eventInfo struct {
- eventType eventType
- nicID tcpip.NICID
- entry stack.NeighborEntry
-}
-
-func (e eventInfo) String() string {
- return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry)
-}
-
-// arpDispatcher implements NUDDispatcher to validate the dispatching of
-// events upon certain NUD state machine events.
-type arpDispatcher struct {
- // C is where events are queued
- C chan eventInfo
-}
-
-var _ stack.NUDDispatcher = (*arpDispatcher)(nil)
-
-func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) {
- e := eventInfo{
- eventType: entryAdded,
- nicID: nicID,
- entry: entry,
- }
- d.C <- e
-}
-
-func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) {
- e := eventInfo{
- eventType: entryChanged,
- nicID: nicID,
- entry: entry,
- }
- d.C <- e
-}
-
-func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) {
- e := eventInfo{
- eventType: entryRemoved,
- nicID: nicID,
- entry: entry,
- }
- d.C <- e
-}
-
-func (d *arpDispatcher) nextEvent() (eventInfo, bool) {
- select {
- case event := <-d.C:
- return event, true
- default:
- return eventInfo{}, false
- }
-}
-
-type testContext struct {
- s *stack.Stack
- linkEP *channel.Endpoint
- nudDisp arpDispatcher
-}
-
-func makeTestContext(t *testing.T, eventDepth int, packetDepth int) testContext {
- t.Helper()
-
- tc := testContext{
- nudDisp: arpDispatcher{
- C: make(chan eventInfo, eventDepth),
- },
- }
-
- tc.s = stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol},
- NUDDisp: &tc.nudDisp,
- Clock: &faketime.NullClock{},
- })
-
- tc.linkEP = channel.New(packetDepth, header.IPv4MinimumMTU, stackLinkAddr)
- tc.linkEP.LinkEPCapabilities |= stack.CapabilityResolutionRequired
-
- wep := stack.LinkEndpoint(tc.linkEP)
- if testing.Verbose() {
- wep = sniffer.New(wep)
- }
- if err := tc.s.CreateNIC(nicID, wep); err != nil {
- t.Fatalf("CreateNIC failed: %s", err)
- }
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: stackAddr.WithPrefix(),
- }
- if err := tc.s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
-
- tc.s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv4EmptySubnet,
- NIC: nicID,
- }})
-
- return tc
-}
-
-func (c *testContext) cleanup() {
- c.linkEP.Close()
-}
-
-func TestMalformedPacket(t *testing.T) {
- c := makeTestContext(t, 0, 0)
- defer c.cleanup()
-
- v := make(buffer.View, header.ARPSize)
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: v.ToVectorisedView(),
- })
-
- c.linkEP.InjectInbound(arp.ProtocolNumber, pkt)
-
- if got := c.s.Stats().ARP.PacketsReceived.Value(); got != 1 {
- t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got)
- }
- if got := c.s.Stats().ARP.MalformedPacketsReceived.Value(); got != 1 {
- t.Errorf("got c.s.Stats().ARP.MalformedPacketsReceived.Value() = %d, want = 1", got)
- }
-}
-
-func TestDisabledEndpoint(t *testing.T) {
- c := makeTestContext(t, 0, 0)
- defer c.cleanup()
-
- ep, err := c.s.GetNetworkEndpoint(nicID, header.ARPProtocolNumber)
- if err != nil {
- t.Fatalf("GetNetworkEndpoint(%d, header.ARPProtocolNumber) failed: %s", nicID, err)
- }
- ep.Disable()
-
- v := make(buffer.View, header.ARPSize)
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: v.ToVectorisedView(),
- })
-
- c.linkEP.InjectInbound(arp.ProtocolNumber, pkt)
-
- if got := c.s.Stats().ARP.PacketsReceived.Value(); got != 1 {
- t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got)
- }
- if got := c.s.Stats().ARP.DisabledPacketsReceived.Value(); got != 1 {
- t.Errorf("got c.s.Stats().ARP.DisabledPacketsReceived.Value() = %d, want = 1", got)
- }
-}
-
-func TestDirectReply(t *testing.T) {
- c := makeTestContext(t, 0, 0)
- defer c.cleanup()
-
- const senderMAC = "\x01\x02\x03\x04\x05\x06"
- const senderIPv4 = "\x0a\x00\x00\x02"
-
- v := make(buffer.View, header.ARPSize)
- h := header.ARP(v)
- h.SetIPv4OverEthernet()
- h.SetOp(header.ARPReply)
-
- copy(h.HardwareAddressSender(), senderMAC)
- copy(h.ProtocolAddressSender(), senderIPv4)
- copy(h.HardwareAddressTarget(), stackLinkAddr)
- copy(h.ProtocolAddressTarget(), stackAddr)
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: v.ToVectorisedView(),
- })
-
- c.linkEP.InjectInbound(arp.ProtocolNumber, pkt)
-
- if got := c.s.Stats().ARP.PacketsReceived.Value(); got != 1 {
- t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got)
- }
- if got := c.s.Stats().ARP.RepliesReceived.Value(); got != 1 {
- t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = 1", got)
- }
-}
-
-func TestDirectRequest(t *testing.T) {
- c := makeTestContext(t, 1, 1)
- defer c.cleanup()
-
- tests := []struct {
- name string
- senderAddr tcpip.Address
- senderLinkAddr tcpip.LinkAddress
- targetAddr tcpip.Address
- isValid bool
- }{
- {
- name: "Loopback",
- senderAddr: stackAddr,
- senderLinkAddr: stackLinkAddr,
- targetAddr: stackAddr,
- isValid: true,
- },
- {
- name: "Remote",
- senderAddr: remoteAddr,
- senderLinkAddr: remoteLinkAddr,
- targetAddr: stackAddr,
- isValid: true,
- },
- {
- name: "RemoteInvalidTarget",
- senderAddr: remoteAddr,
- senderLinkAddr: remoteLinkAddr,
- targetAddr: unknownAddr,
- isValid: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- packetsRecv := c.s.Stats().ARP.PacketsReceived.Value()
- requestsRecv := c.s.Stats().ARP.RequestsReceived.Value()
- requestsRecvUnknownAddr := c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value()
- outgoingReplies := c.s.Stats().ARP.OutgoingRepliesSent.Value()
-
- // Inject an incoming ARP request.
- v := make(buffer.View, header.ARPSize)
- h := header.ARP(v)
- h.SetIPv4OverEthernet()
- h.SetOp(header.ARPRequest)
- copy(h.HardwareAddressSender(), test.senderLinkAddr)
- copy(h.ProtocolAddressSender(), test.senderAddr)
- copy(h.ProtocolAddressTarget(), test.targetAddr)
- c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: v.ToVectorisedView(),
- }))
-
- if got, want := c.s.Stats().ARP.PacketsReceived.Value(), packetsRecv+1; got != want {
- t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, want)
- }
- if got, want := c.s.Stats().ARP.RequestsReceived.Value(), requestsRecv+1; got != want {
- t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, want)
- }
-
- if !test.isValid {
- // No packets should be sent after receiving an invalid ARP request.
- // There is no need to perform a blocking read here, since packets are
- // sent in the same function that handles ARP requests.
- if pkt, ok := c.linkEP.Read(); ok {
- t.Errorf("unexpected packet sent with network protocol number %d", pkt.Proto)
- }
- if got, want := c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value(), requestsRecvUnknownAddr+1; got != want {
- t.Errorf("got c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value() = %d, want = %d", got, want)
- }
- if got, want := c.s.Stats().ARP.OutgoingRepliesSent.Value(), outgoingReplies; got != want {
- t.Errorf("got c.s.Stats().ARP.OutgoingRepliesSent.Value() = %d, want = %d", got, want)
- }
-
- return
- }
-
- if got, want := c.s.Stats().ARP.OutgoingRepliesSent.Value(), outgoingReplies+1; got != want {
- t.Errorf("got c.s.Stats().ARP.OutgoingRepliesSent.Value() = %d, want = %d", got, want)
- }
-
- // Verify an ARP response was sent.
- pi, ok := c.linkEP.Read()
- if !ok {
- t.Fatal("expected ARP response to be sent, got none")
- }
-
- if pi.Proto != arp.ProtocolNumber {
- t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto)
- }
- rep := header.ARP(pi.Pkt.NetworkHeader().View())
- if !rep.IsValid() {
- t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep)
- }
- if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want {
- t.Errorf("got HardwareAddressSender() = %s, want = %s", got, want)
- }
- if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want {
- t.Errorf("got ProtocolAddressSender() = %s, want = %s", got, want)
- }
- if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress(h.HardwareAddressSender()); got != want {
- t.Errorf("got HardwareAddressTarget() = %s, want = %s", got, want)
- }
- if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want {
- t.Errorf("got ProtocolAddressTarget() = %s, want = %s", got, want)
- }
-
- // Verify the sender was saved in the neighbor cache.
- if got, ok := c.nudDisp.nextEvent(); ok {
- want := eventInfo{
- eventType: entryAdded,
- nicID: nicID,
- entry: stack.NeighborEntry{
- Addr: test.senderAddr,
- LinkAddr: test.senderLinkAddr,
- State: stack.Stale,
- },
- }
- if diff := cmp.Diff(want, got, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAt")); diff != "" {
- t.Errorf("got invalid event (-want +got):\n%s", diff)
- }
- } else {
- t.Fatal("event didn't arrive")
- }
-
- neighbors, err := c.s.Neighbors(nicID, ipv4.ProtocolNumber)
- if err != nil {
- t.Fatalf("c.s.Neighbors(%d, %d): %s", nicID, ipv4.ProtocolNumber, err)
- }
-
- neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry)
- for _, n := range neighbors {
- if existing, ok := neighborByAddr[n.Addr]; ok {
- if diff := cmp.Diff(existing, n); diff != "" {
- t.Fatalf("duplicate neighbor entry found (-existing +got):\n%s", diff)
- }
- t.Fatalf("exact neighbor entry duplicate found for addr=%s", n.Addr)
- }
- neighborByAddr[n.Addr] = n
- }
-
- neigh, ok := neighborByAddr[test.senderAddr]
- if !ok {
- t.Fatalf("expected neighbor entry with Addr = %s", test.senderAddr)
- }
- if got, want := neigh.LinkAddr, test.senderLinkAddr; got != want {
- t.Errorf("got neighbor LinkAddr = %s, want = %s", got, want)
- }
- if got, want := neigh.State, stack.Stale; got != want {
- t.Errorf("got neighbor State = %s, want = %s", got, want)
- }
-
- // No more events should be dispatched
- for {
- event, ok := c.nudDisp.nextEvent()
- if !ok {
- break
- }
- t.Errorf("unexpected %s", event)
- }
- })
- }
-}
-
-var _ stack.LinkEndpoint = (*testLinkEndpoint)(nil)
-
-type testLinkEndpoint struct {
- stack.LinkEndpoint
-
- writeErr tcpip.Error
-}
-
-func (t *testLinkEndpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- if t.writeErr != nil {
- return t.writeErr
- }
-
- return t.LinkEndpoint.WritePacket(r, protocol, pkt)
-}
-
-func TestLinkAddressRequest(t *testing.T) {
- const nicID = 1
-
- testAddr := tcpip.Address([]byte{1, 2, 3, 4})
-
- tests := []struct {
- name string
- nicAddr tcpip.Address
- localAddr tcpip.Address
- remoteLinkAddr tcpip.LinkAddress
- linkErr tcpip.Error
- expectedErr tcpip.Error
- expectedLocalAddr tcpip.Address
- expectedRemoteLinkAddr tcpip.LinkAddress
- expectedRequestsSent uint64
- expectedRequestBadLocalAddressErrors uint64
- expectedRequestInterfaceHasNoLocalAddressErrors uint64
- expectedRequestDroppedErrors uint64
- }{
- {
- name: "Unicast",
- nicAddr: stackAddr,
- localAddr: stackAddr,
- remoteLinkAddr: remoteLinkAddr,
- expectedLocalAddr: stackAddr,
- expectedRemoteLinkAddr: remoteLinkAddr,
- expectedRequestsSent: 1,
- expectedRequestBadLocalAddressErrors: 0,
- expectedRequestInterfaceHasNoLocalAddressErrors: 0,
- expectedRequestDroppedErrors: 0,
- },
- {
- name: "Multicast",
- nicAddr: stackAddr,
- localAddr: stackAddr,
- remoteLinkAddr: "",
- expectedLocalAddr: stackAddr,
- expectedRemoteLinkAddr: header.EthernetBroadcastAddress,
- expectedRequestsSent: 1,
- expectedRequestBadLocalAddressErrors: 0,
- expectedRequestInterfaceHasNoLocalAddressErrors: 0,
- expectedRequestDroppedErrors: 0,
- },
- {
- name: "Unicast with unspecified source",
- nicAddr: stackAddr,
- localAddr: "",
- remoteLinkAddr: remoteLinkAddr,
- expectedLocalAddr: stackAddr,
- expectedRemoteLinkAddr: remoteLinkAddr,
- expectedRequestsSent: 1,
- expectedRequestBadLocalAddressErrors: 0,
- expectedRequestInterfaceHasNoLocalAddressErrors: 0,
- expectedRequestDroppedErrors: 0,
- },
- {
- name: "Multicast with unspecified source",
- nicAddr: stackAddr,
- localAddr: "",
- remoteLinkAddr: "",
- expectedLocalAddr: stackAddr,
- expectedRemoteLinkAddr: header.EthernetBroadcastAddress,
- expectedRequestsSent: 1,
- expectedRequestBadLocalAddressErrors: 0,
- expectedRequestInterfaceHasNoLocalAddressErrors: 0,
- expectedRequestDroppedErrors: 0,
- },
- {
- name: "Unicast with unassigned address",
- nicAddr: stackAddr,
- localAddr: testAddr,
- remoteLinkAddr: remoteLinkAddr,
- expectedErr: &tcpip.ErrBadLocalAddress{},
- expectedRequestsSent: 0,
- expectedRequestBadLocalAddressErrors: 1,
- expectedRequestInterfaceHasNoLocalAddressErrors: 0,
- expectedRequestDroppedErrors: 0,
- },
- {
- name: "Multicast with unassigned address",
- nicAddr: stackAddr,
- localAddr: testAddr,
- remoteLinkAddr: "",
- expectedErr: &tcpip.ErrBadLocalAddress{},
- expectedRequestsSent: 0,
- expectedRequestBadLocalAddressErrors: 1,
- expectedRequestInterfaceHasNoLocalAddressErrors: 0,
- expectedRequestDroppedErrors: 0,
- },
- {
- name: "Unicast with no local address available",
- nicAddr: "",
- localAddr: "",
- remoteLinkAddr: remoteLinkAddr,
- expectedErr: &tcpip.ErrNetworkUnreachable{},
- expectedRequestsSent: 0,
- expectedRequestBadLocalAddressErrors: 0,
- expectedRequestInterfaceHasNoLocalAddressErrors: 1,
- expectedRequestDroppedErrors: 0,
- },
- {
- name: "Multicast with no local address available",
- nicAddr: "",
- localAddr: "",
- remoteLinkAddr: "",
- expectedErr: &tcpip.ErrNetworkUnreachable{},
- expectedRequestsSent: 0,
- expectedRequestBadLocalAddressErrors: 0,
- expectedRequestInterfaceHasNoLocalAddressErrors: 1,
- expectedRequestDroppedErrors: 0,
- },
- {
- name: "Link error",
- nicAddr: stackAddr,
- localAddr: stackAddr,
- remoteLinkAddr: remoteLinkAddr,
- linkErr: &tcpip.ErrInvalidEndpointState{},
- expectedErr: &tcpip.ErrInvalidEndpointState{},
- expectedRequestsSent: 0,
- expectedRequestBadLocalAddressErrors: 0,
- expectedRequestInterfaceHasNoLocalAddressErrors: 0,
- expectedRequestDroppedErrors: 1,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
- })
- linkEP := channel.New(1, header.IPv4MinimumMTU, stackLinkAddr)
- if err := s.CreateNIC(nicID, &testLinkEndpoint{LinkEndpoint: linkEP, writeErr: test.linkErr}); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
-
- ep, err := s.GetNetworkEndpoint(nicID, arp.ProtocolNumber)
- if err != nil {
- t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, arp.ProtocolNumber, err)
- }
- linkRes, ok := ep.(stack.LinkAddressResolver)
- if !ok {
- t.Fatalf("expected %T to implement stack.LinkAddressResolver", ep)
- }
-
- if len(test.nicAddr) != 0 {
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: test.nicAddr.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- }
-
- {
- err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr)
- if diff := cmp.Diff(test.expectedErr, err); diff != "" {
- t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", remoteAddr, test.localAddr, test.remoteLinkAddr, diff)
- }
- }
-
- if got := s.Stats().ARP.OutgoingRequestsSent.Value(); got != test.expectedRequestsSent {
- t.Errorf("got s.Stats().ARP.OutgoingRequestsSent.Value() = %d, want = %d", got, test.expectedRequestsSent)
- }
- if got := s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value(); got != test.expectedRequestInterfaceHasNoLocalAddressErrors {
- t.Errorf("got s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value() = %d, want = %d", got, test.expectedRequestInterfaceHasNoLocalAddressErrors)
- }
- if got := s.Stats().ARP.OutgoingRequestBadLocalAddressErrors.Value(); got != test.expectedRequestBadLocalAddressErrors {
- t.Errorf("got s.Stats().ARP.OutgoingRequestBadLocalAddressErrors.Value() = %d, want = %d", got, test.expectedRequestBadLocalAddressErrors)
- }
- if got := s.Stats().ARP.OutgoingRequestsDropped.Value(); got != test.expectedRequestDroppedErrors {
- t.Errorf("got s.Stats().ARP.OutgoingRequestsDropped.Value() = %d, want = %d", got, test.expectedRequestDroppedErrors)
- }
-
- if test.expectedErr != nil {
- return
- }
-
- pkt, ok := linkEP.Read()
- if !ok {
- t.Fatal("expected to send a link address request")
- }
-
- if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr {
- t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr)
- }
-
- rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
- if got := rep.Op(); got != header.ARPRequest {
- t.Errorf("got Op = %d, want = %d", got, header.ARPRequest)
- }
- if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr {
- t.Errorf("got HardwareAddressSender = %s, want = %s", got, stackLinkAddr)
- }
- if got := tcpip.Address(rep.ProtocolAddressSender()); got != test.expectedLocalAddr {
- t.Errorf("got ProtocolAddressSender = %s, want = %s", got, test.expectedLocalAddr)
- }
- if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"); got != want {
- t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want)
- }
- if got := tcpip.Address(rep.ProtocolAddressTarget()); got != remoteAddr {
- t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, remoteAddr)
- }
- })
- }
-}
-
-func TestDADARPRequestPacket(t *testing.T) {
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocolWithOptions(arp.Options{
- DADConfigs: stack.DADConfigurations{
- DupAddrDetectTransmits: 1,
- },
- }), ipv4.NewProtocol},
- Clock: clock,
- })
- e := channel.New(1, header.IPv4MinimumMTU, stackLinkAddr)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
-
- if res, err := s.CheckDuplicateAddress(nicID, header.IPv4ProtocolNumber, remoteAddr, func(stack.DADResult) {}); err != nil {
- t.Fatalf("s.CheckDuplicateAddress(%d, %d, %s, _): %s", nicID, header.IPv4ProtocolNumber, remoteAddr, err)
- } else if res != stack.DADStarting {
- t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, header.IPv4ProtocolNumber, remoteAddr, res, stack.DADStarting)
- }
-
- clock.RunImmediatelyScheduledJobs()
- pkt, ok := e.Read()
- if !ok {
- t.Fatal("expected to send an ARP request")
- }
-
- if pkt.Route.RemoteLinkAddress != header.EthernetBroadcastAddress {
- t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, header.EthernetBroadcastAddress)
- }
-
- req := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
- if !req.IsValid() {
- t.Errorf("got req.IsValid() = false, want = true")
- }
- if got := req.Op(); got != header.ARPRequest {
- t.Errorf("got req.Op() = %d, want = %d", got, header.ARPRequest)
- }
- if got := tcpip.LinkAddress(req.HardwareAddressSender()); got != stackLinkAddr {
- t.Errorf("got req.HardwareAddressSender() = %s, want = %s", got, stackLinkAddr)
- }
- if got := tcpip.Address(req.ProtocolAddressSender()); got != header.IPv4Any {
- t.Errorf("got req.ProtocolAddressSender() = %s, want = %s", got, header.IPv4Any)
- }
- if got, want := tcpip.LinkAddress(req.HardwareAddressTarget()), tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"); got != want {
- t.Errorf("got req.HardwareAddressTarget() = %s, want = %s", got, want)
- }
- if got := tcpip.Address(req.ProtocolAddressTarget()); got != remoteAddr {
- t.Errorf("got req.ProtocolAddressTarget() = %s, want = %s", got, remoteAddr)
- }
-}
diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go
deleted file mode 100644
index 0df39ae81..000000000
--- a/pkg/tcpip/network/arp/stats_test.go
+++ /dev/null
@@ -1,51 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package arp
-
-import (
- "reflect"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
-)
-
-var _ stack.NetworkInterface = (*testInterface)(nil)
-
-type testInterface struct {
- stack.NetworkInterface
- nicID tcpip.NICID
-}
-
-func (t *testInterface) ID() tcpip.NICID {
- return t.nicID
-}
-
-func TestMultiCounterStatsInitialization(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
- var nic testInterface
- ep := proto.NewEndpoint(&nic, nil).(*endpoint)
- // At this point, the Stack's stats and the NetworkEndpoint's stats are
- // expected to be bound by a MultiCounterStat.
- refStack := s.Stats()
- refEP := ep.stats.localStats
- if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.arp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.ARP).Elem(), reflect.ValueOf(&refStack.ARP).Elem()}); err != nil {
- t.Error(err)
- }
-}
diff --git a/pkg/tcpip/network/hash/BUILD b/pkg/tcpip/network/hash/BUILD
deleted file mode 100644
index 872165866..000000000
--- a/pkg/tcpip/network/hash/BUILD
+++ /dev/null
@@ -1,13 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "hash",
- srcs = ["hash.go"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/rand",
- "//pkg/tcpip/header",
- ],
-)
diff --git a/pkg/tcpip/network/hash/hash_state_autogen.go b/pkg/tcpip/network/hash/hash_state_autogen.go
new file mode 100644
index 000000000..9467fe298
--- /dev/null
+++ b/pkg/tcpip/network/hash/hash_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package hash
diff --git a/pkg/tcpip/network/internal/fragmentation/BUILD b/pkg/tcpip/network/internal/fragmentation/BUILD
deleted file mode 100644
index 274f09092..000000000
--- a/pkg/tcpip/network/internal/fragmentation/BUILD
+++ /dev/null
@@ -1,54 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-package(licenses = ["notice"])
-
-go_template_instance(
- name = "reassembler_list",
- out = "reassembler_list.go",
- package = "fragmentation",
- prefix = "reassembler",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*reassembler",
- "Linker": "*reassembler",
- },
-)
-
-go_library(
- name = "fragmentation",
- srcs = [
- "fragmentation.go",
- "reassembler.go",
- "reassembler_list.go",
- ],
- visibility = [
- "//pkg/tcpip/network/ipv4:__pkg__",
- "//pkg/tcpip/network/ipv6:__pkg__",
- ],
- deps = [
- "//pkg/log",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/stack",
- ],
-)
-
-go_test(
- name = "fragmentation_test",
- size = "small",
- srcs = [
- "fragmentation_test.go",
- "reassembler_test.go",
- ],
- library = ":fragmentation",
- deps = [
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/faketime",
- "//pkg/tcpip/network/internal/testutil",
- "//pkg/tcpip/stack",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_state_autogen.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_state_autogen.go
new file mode 100644
index 000000000..21c5774e9
--- /dev/null
+++ b/pkg/tcpip/network/internal/fragmentation/fragmentation_state_autogen.go
@@ -0,0 +1,68 @@
+// automatically generated by stateify.
+
+package fragmentation
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (l *reassemblerList) StateTypeName() string {
+ return "pkg/tcpip/network/internal/fragmentation.reassemblerList"
+}
+
+func (l *reassemblerList) StateFields() []string {
+ return []string{
+ "head",
+ "tail",
+ }
+}
+
+func (l *reassemblerList) beforeSave() {}
+
+// +checklocksignore
+func (l *reassemblerList) StateSave(stateSinkObject state.Sink) {
+ l.beforeSave()
+ stateSinkObject.Save(0, &l.head)
+ stateSinkObject.Save(1, &l.tail)
+}
+
+func (l *reassemblerList) afterLoad() {}
+
+// +checklocksignore
+func (l *reassemblerList) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &l.head)
+ stateSourceObject.Load(1, &l.tail)
+}
+
+func (e *reassemblerEntry) StateTypeName() string {
+ return "pkg/tcpip/network/internal/fragmentation.reassemblerEntry"
+}
+
+func (e *reassemblerEntry) StateFields() []string {
+ return []string{
+ "next",
+ "prev",
+ }
+}
+
+func (e *reassemblerEntry) beforeSave() {}
+
+// +checklocksignore
+func (e *reassemblerEntry) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.next)
+ stateSinkObject.Save(1, &e.prev)
+}
+
+func (e *reassemblerEntry) afterLoad() {}
+
+// +checklocksignore
+func (e *reassemblerEntry) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.next)
+ stateSourceObject.Load(1, &e.prev)
+}
+
+func init() {
+ state.Register((*reassemblerList)(nil))
+ state.Register((*reassemblerEntry)(nil))
+}
diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go
deleted file mode 100644
index dadfc28cc..000000000
--- a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go
+++ /dev/null
@@ -1,648 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package fragmentation
-
-import (
- "errors"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-// reassembleTimeout is dummy timeout used for testing, where the clock never
-// advances.
-const reassembleTimeout = 1
-
-// vv is a helper to build VectorisedView from different strings.
-func vv(size int, pieces ...string) buffer.VectorisedView {
- views := make([]buffer.View, len(pieces))
- for i, p := range pieces {
- views[i] = []byte(p)
- }
-
- return buffer.NewVectorisedView(size, views)
-}
-
-func pkt(size int, pieces ...string) *stack.PacketBuffer {
- return stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv(size, pieces...),
- })
-}
-
-type processInput struct {
- id FragmentID
- first uint16
- last uint16
- more bool
- proto uint8
- pkt *stack.PacketBuffer
-}
-
-type processOutput struct {
- vv buffer.VectorisedView
- proto uint8
- done bool
-}
-
-var processTestCases = []struct {
- comment string
- in []processInput
- out []processOutput
-}{
- {
- comment: "One ID",
- in: []processInput{
- {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")},
- {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")},
- },
- out: []processOutput{
- {vv: buffer.VectorisedView{}, done: false},
- {vv: vv(4, "01", "23"), done: true},
- },
- },
- {
- comment: "Next Header protocol mismatch",
- in: []processInput{
- {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, pkt: pkt(2, "01")},
- {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, pkt: pkt(2, "23")},
- },
- out: []processOutput{
- {vv: buffer.VectorisedView{}, done: false},
- {vv: vv(4, "01", "23"), proto: 6, done: true},
- },
- },
- {
- comment: "Two IDs",
- in: []processInput{
- {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")},
- {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, pkt: pkt(2, "ab")},
- {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, pkt: pkt(2, "cd")},
- {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")},
- },
- out: []processOutput{
- {vv: buffer.VectorisedView{}, done: false},
- {vv: buffer.VectorisedView{}, done: false},
- {vv: vv(4, "ab", "cd"), done: true},
- {vv: vv(4, "01", "23"), done: true},
- },
- },
-}
-
-func TestFragmentationProcess(t *testing.T) {
- for _, c := range processTestCases {
- t.Run(c.comment, func(t *testing.T) {
- f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}, nil)
- firstFragmentProto := c.in[0].proto
- for i, in := range c.in {
- resPkt, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt)
- if err != nil {
- t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s",
- in.id, in.first, in.last, in.more, in.proto, in.pkt, err)
- }
- if done != c.out[i].done {
- t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)",
- in.id, in.first, in.last, in.more, in.proto, done, c.out[i].done)
- }
- if c.out[i].done {
- if diff := cmp.Diff(c.out[i].vv.ToOwnedView(), resPkt.Data().AsRange().ToOwnedView()); diff != "" {
- t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) result mismatch (-want, +got):\n%s",
- in.id, in.first, in.last, in.more, in.proto, in.pkt, diff)
- }
- if firstFragmentProto != proto {
- t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, %d, _, _), want = (_, %d, _, _)",
- in.id, in.first, in.last, in.more, in.proto, proto, firstFragmentProto)
- }
- if _, ok := f.reassemblers[in.id]; ok {
- t.Errorf("Process(%d) did not remove buffer from reassemblers", i)
- }
- for n := f.rList.Front(); n != nil; n = n.Next() {
- if n.id == in.id {
- t.Errorf("Process(%d) did not remove buffer from rList", i)
- }
- }
- }
- }
- })
- }
-}
-
-func TestReassemblingTimeout(t *testing.T) {
- const (
- reassemblyTimeout = time.Millisecond
- protocol = 0xff
- )
-
- type fragment struct {
- first uint16
- last uint16
- more bool
- data string
- }
-
- type event struct {
- // name is a nickname of this event.
- name string
-
- // clockAdvance is a duration to advance the clock. The clock advances
- // before a fragment specified in the fragment field is processed.
- clockAdvance time.Duration
-
- // fragment is a fragment to process. This can be nil if there is no
- // fragment to process.
- fragment *fragment
-
- // expectDone is true if the fragmentation instance should report the
- // reassembly is done after the fragment is processd.
- expectDone bool
-
- // memSizeAfterEvent is the expected memory size of the fragmentation
- // instance after the event.
- memSizeAfterEvent int
- }
-
- memSizeOfFrags := func(frags ...*fragment) int {
- var size int
- for _, frag := range frags {
- size += pkt(len(frag.data), frag.data).MemSize()
- }
- return size
- }
-
- half1 := &fragment{first: 0, last: 0, more: true, data: "0"}
- half2 := &fragment{first: 1, last: 1, more: false, data: "1"}
-
- tests := []struct {
- name string
- events []event
- }{
- {
- name: "half1 and half2 are reassembled successfully",
- events: []event{
- {
- name: "half1",
- fragment: half1,
- expectDone: false,
- memSizeAfterEvent: memSizeOfFrags(half1),
- },
- {
- name: "half2",
- fragment: half2,
- expectDone: true,
- memSizeAfterEvent: 0,
- },
- },
- },
- {
- name: "half1 timeout, half2 timeout",
- events: []event{
- {
- name: "half1",
- fragment: half1,
- expectDone: false,
- memSizeAfterEvent: memSizeOfFrags(half1),
- },
- {
- name: "half1 just before reassembly timeout",
- clockAdvance: reassemblyTimeout - 1,
- memSizeAfterEvent: memSizeOfFrags(half1),
- },
- {
- name: "half1 reassembly timeout",
- clockAdvance: 1,
- memSizeAfterEvent: 0,
- },
- {
- name: "half2",
- fragment: half2,
- expectDone: false,
- memSizeAfterEvent: memSizeOfFrags(half2),
- },
- {
- name: "half2 just before reassembly timeout",
- clockAdvance: reassemblyTimeout - 1,
- memSizeAfterEvent: memSizeOfFrags(half2),
- },
- {
- name: "half2 reassembly timeout",
- clockAdvance: 1,
- memSizeAfterEvent: 0,
- },
- },
- },
- }
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock, nil)
- for _, event := range test.events {
- clock.Advance(event.clockAdvance)
- if frag := event.fragment; frag != nil {
- _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, pkt(len(frag.data), frag.data))
- if err != nil {
- t.Fatalf("%s: f.Process failed: %s", event.name, err)
- }
- if done != event.expectDone {
- t.Fatalf("%s: got done = %t, want = %t", event.name, done, event.expectDone)
- }
- }
- if got, want := f.memSize, event.memSizeAfterEvent; got != want {
- t.Errorf("%s: got f.memSize = %d, want = %d", event.name, got, want)
- }
- }
- })
- }
-}
-
-func TestMemoryLimits(t *testing.T) {
- lowLimit := pkt(1, "0").MemSize()
- highLimit := 3 * lowLimit // Allow at most 3 such packets.
- f := NewFragmentation(minBlockSize, highLimit, lowLimit, reassembleTimeout, &faketime.NullClock{}, nil)
- // Send first fragment with id = 0.
- if _, _, _, err := f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil {
- t.Fatal(err)
- }
- // Send first fragment with id = 1.
- if _, _, _, err := f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")); err != nil {
- t.Fatal(err)
- }
- // Send first fragment with id = 2.
- if _, _, _, err := f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")); err != nil {
- t.Fatal(err)
- }
-
- // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be
- // evicted.
- if _, _, _, err := f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")); err != nil {
- t.Fatal(err)
- }
-
- if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok {
- t.Errorf("Memory limits are not respected: id=0 has not been evicted.")
- }
- if _, ok := f.reassemblers[FragmentID{ID: 1}]; ok {
- t.Errorf("Memory limits are not respected: id=1 has not been evicted.")
- }
- if _, ok := f.reassemblers[FragmentID{ID: 3}]; !ok {
- t.Errorf("Implementation of memory limits is wrong: id=3 is not present.")
- }
-}
-
-func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
- memSize := pkt(1, "0").MemSize()
- f := NewFragmentation(minBlockSize, memSize, 0, reassembleTimeout, &faketime.NullClock{}, nil)
- // Send first fragment with id = 0.
- if _, _, _, err := f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil {
- t.Fatal(err)
- }
- // Send the same packet again.
- if _, _, _, err := f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil {
- t.Fatal(err)
- }
-
- if got, want := f.memSize, memSize; got != want {
- t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want)
- }
-}
-
-func TestErrors(t *testing.T) {
- tests := []struct {
- name string
- blockSize uint16
- first uint16
- last uint16
- more bool
- data string
- err error
- }{
- {
- name: "exact block size without more",
- blockSize: 2,
- first: 2,
- last: 3,
- more: false,
- data: "01",
- },
- {
- name: "exact block size with more",
- blockSize: 2,
- first: 2,
- last: 3,
- more: true,
- data: "01",
- },
- {
- name: "exact block size with more and extra data",
- blockSize: 2,
- first: 2,
- last: 3,
- more: true,
- data: "012",
- err: ErrInvalidArgs,
- },
- {
- name: "exact block size with more and too little data",
- blockSize: 2,
- first: 2,
- last: 3,
- more: true,
- data: "0",
- err: ErrInvalidArgs,
- },
- {
- name: "not exact block size with more",
- blockSize: 2,
- first: 2,
- last: 2,
- more: true,
- data: "0",
- err: ErrInvalidArgs,
- },
- {
- name: "not exact block size without more",
- blockSize: 2,
- first: 2,
- last: 2,
- more: false,
- data: "0",
- },
- {
- name: "first not a multiple of block size",
- blockSize: 2,
- first: 3,
- last: 4,
- more: true,
- data: "01",
- err: ErrInvalidArgs,
- },
- {
- name: "first more than last",
- blockSize: 2,
- first: 4,
- last: 3,
- more: true,
- data: "01",
- err: ErrInvalidArgs,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, nil)
- _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, pkt(len(test.data), test.data))
- if !errors.Is(err, test.err) {
- t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err)
- }
- if done {
- t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, true, _), want = (_, _, false, _)", test.first, test.last, test.more, test.data)
- }
- })
- }
-}
-
-type fragmentInfo struct {
- remaining int
- copied int
- offset int
- more bool
-}
-
-func TestPacketFragmenter(t *testing.T) {
- const (
- reserve = 60
- proto = 0
- )
-
- tests := []struct {
- name string
- fragmentPayloadLen uint32
- transportHeaderLen int
- payloadSize int
- wantFragments []fragmentInfo
- }{
- {
- name: "Packet exactly fits in MTU",
- fragmentPayloadLen: 1280,
- transportHeaderLen: 0,
- payloadSize: 1280,
- wantFragments: []fragmentInfo{
- {remaining: 0, copied: 1280, offset: 0, more: false},
- },
- },
- {
- name: "Packet exactly does not fit in MTU",
- fragmentPayloadLen: 1000,
- transportHeaderLen: 0,
- payloadSize: 1001,
- wantFragments: []fragmentInfo{
- {remaining: 1, copied: 1000, offset: 0, more: true},
- {remaining: 0, copied: 1, offset: 1000, more: false},
- },
- },
- {
- name: "Packet has a transport header",
- fragmentPayloadLen: 560,
- transportHeaderLen: 40,
- payloadSize: 560,
- wantFragments: []fragmentInfo{
- {remaining: 1, copied: 560, offset: 0, more: true},
- {remaining: 0, copied: 40, offset: 560, more: false},
- },
- },
- {
- name: "Packet has a huge transport header",
- fragmentPayloadLen: 500,
- transportHeaderLen: 1300,
- payloadSize: 500,
- wantFragments: []fragmentInfo{
- {remaining: 3, copied: 500, offset: 0, more: true},
- {remaining: 2, copied: 500, offset: 500, more: true},
- {remaining: 1, copied: 500, offset: 1000, more: true},
- {remaining: 0, copied: 300, offset: 1500, more: false},
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- pkt := testutil.MakeRandPkt(test.transportHeaderLen, reserve, []int{test.payloadSize}, proto)
- originalPayload := stack.PayloadSince(pkt.TransportHeader())
- var reassembledPayload buffer.VectorisedView
- pf := MakePacketFragmenter(pkt, test.fragmentPayloadLen, reserve)
- for i := 0; ; i++ {
- fragPkt, offset, copied, more := pf.BuildNextFragment()
- wantFragment := test.wantFragments[i]
- if got := pf.RemainingFragmentCount(); got != wantFragment.remaining {
- t.Errorf("(fragment #%d) got pf.RemainingFragmentCount() = %d, want = %d", i, got, wantFragment.remaining)
- }
- if copied != wantFragment.copied {
- t.Errorf("(fragment #%d) got copied = %d, want = %d", i, copied, wantFragment.copied)
- }
- if offset != wantFragment.offset {
- t.Errorf("(fragment #%d) got offset = %d, want = %d", i, offset, wantFragment.offset)
- }
- if more != wantFragment.more {
- t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more)
- }
- if got := uint32(fragPkt.Size()); got > test.fragmentPayloadLen {
- t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.fragmentPayloadLen)
- }
- if got := fragPkt.AvailableHeaderBytes(); got != reserve {
- t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve)
- }
- if got := fragPkt.TransportHeader().View().Size(); got != 0 {
- t.Errorf("(fragment #%d) got fragPkt.TransportHeader().View().Size() = %d, want = 0", i, got)
- }
- reassembledPayload.AppendViews(fragPkt.Data().Views())
- if !more {
- if i != len(test.wantFragments)-1 {
- t.Errorf("got fragment count = %d, want = %d", i, len(test.wantFragments)-1)
- }
- break
- }
- }
- if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload); diff != "" {
- t.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
-
-type testTimeoutHandler struct {
- pkt *stack.PacketBuffer
-}
-
-func (h *testTimeoutHandler) OnReassemblyTimeout(pkt *stack.PacketBuffer) {
- h.pkt = pkt
-}
-
-func TestTimeoutHandler(t *testing.T) {
- const (
- proto = 99
- )
-
- pk1 := pkt(1, "1")
- pk2 := pkt(1, "2")
-
- type processParam struct {
- first uint16
- last uint16
- more bool
- pkt *stack.PacketBuffer
- }
-
- tests := []struct {
- name string
- params []processParam
- wantError bool
- wantPkt *stack.PacketBuffer
- }{
- {
- name: "onTimeout runs",
- params: []processParam{
- {
- first: 0,
- last: 0,
- more: true,
- pkt: pk1,
- },
- },
- wantError: false,
- wantPkt: pk1,
- },
- {
- name: "no first fragment",
- params: []processParam{
- {
- first: 1,
- last: 1,
- more: true,
- pkt: pk1,
- },
- },
- wantError: false,
- wantPkt: nil,
- },
- {
- name: "second pkt is ignored",
- params: []processParam{
- {
- first: 0,
- last: 0,
- more: true,
- pkt: pk1,
- },
- {
- first: 0,
- last: 0,
- more: true,
- pkt: pk2,
- },
- },
- wantError: false,
- wantPkt: pk1,
- },
- {
- name: "invalid args - first is greater than last",
- params: []processParam{
- {
- first: 1,
- last: 0,
- more: true,
- pkt: pk1,
- },
- },
- wantError: true,
- wantPkt: nil,
- },
- }
-
- id := FragmentID{ID: 0}
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- handler := &testTimeoutHandler{pkt: nil}
-
- f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, handler)
-
- for _, p := range test.params {
- if _, _, _, err := f.Process(id, p.first, p.last, p.more, proto, p.pkt); err != nil && !test.wantError {
- t.Errorf("f.Process error = %s", err)
- }
- }
- if !test.wantError {
- r, ok := f.reassemblers[id]
- if !ok {
- t.Fatal("Reassembler not found")
- }
- f.release(r, true)
- }
- switch {
- case handler.pkt != nil && test.wantPkt == nil:
- t.Errorf("got handler.pkt = not nil (pkt.Data = %x), want = nil", handler.pkt.Data().AsRange().ToOwnedView())
- case handler.pkt == nil && test.wantPkt != nil:
- t.Errorf("got handler.pkt = nil, want = not nil (pkt.Data = %x)", test.wantPkt.Data().AsRange().ToOwnedView())
- case handler.pkt != nil && test.wantPkt != nil:
- if diff := cmp.Diff(test.wantPkt.Data().AsRange().ToOwnedView(), handler.pkt.Data().AsRange().ToOwnedView()); diff != "" {
- t.Errorf("pkt.Data mismatch (-want, +got):\n%s", diff)
- }
- }
- })
- }
-}
diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler_list.go b/pkg/tcpip/network/internal/fragmentation/reassembler_list.go
new file mode 100644
index 000000000..673bb11b0
--- /dev/null
+++ b/pkg/tcpip/network/internal/fragmentation/reassembler_list.go
@@ -0,0 +1,221 @@
+package fragmentation
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type reassemblerElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (reassemblerElementMapper) linkerFor(elem *reassembler) *reassembler { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type reassemblerList struct {
+ head *reassembler
+ tail *reassembler
+}
+
+// Reset resets list l to the empty state.
+func (l *reassemblerList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+//
+//go:nosplit
+func (l *reassemblerList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+//
+//go:nosplit
+func (l *reassemblerList) Front() *reassembler {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+//
+//go:nosplit
+func (l *reassemblerList) Back() *reassembler {
+ return l.tail
+}
+
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+//
+//go:nosplit
+func (l *reassemblerList) Len() (count int) {
+ for e := l.Front(); e != nil; e = (reassemblerElementMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
+// PushFront inserts the element e at the front of list l.
+//
+//go:nosplit
+func (l *reassemblerList) PushFront(e *reassembler) {
+ linker := reassemblerElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
+ if l.head != nil {
+ reassemblerElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+//
+//go:nosplit
+func (l *reassemblerList) PushBack(e *reassembler) {
+ linker := reassemblerElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
+ if l.tail != nil {
+ reassemblerElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+//
+//go:nosplit
+func (l *reassemblerList) PushBackList(m *reassemblerList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ reassemblerElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ reassemblerElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+//
+//go:nosplit
+func (l *reassemblerList) InsertAfter(b, e *reassembler) {
+ bLinker := reassemblerElementMapper{}.linkerFor(b)
+ eLinker := reassemblerElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
+
+ if a != nil {
+ reassemblerElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+//
+//go:nosplit
+func (l *reassemblerList) InsertBefore(a, e *reassembler) {
+ aLinker := reassemblerElementMapper{}.linkerFor(a)
+ eLinker := reassemblerElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
+
+ if b != nil {
+ reassemblerElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+//
+//go:nosplit
+func (l *reassemblerList) Remove(e *reassembler) {
+ linker := reassemblerElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
+
+ if prev != nil {
+ reassemblerElementMapper{}.linkerFor(prev).SetNext(next)
+ } else if l.head == e {
+ l.head = next
+ }
+
+ if next != nil {
+ reassemblerElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else if l.tail == e {
+ l.tail = prev
+ }
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type reassemblerEntry struct {
+ next *reassembler
+ prev *reassembler
+}
+
+// Next returns the entry that follows e in the list.
+//
+//go:nosplit
+func (e *reassemblerEntry) Next() *reassembler {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *reassemblerEntry) Prev() *reassembler {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+//
+//go:nosplit
+func (e *reassemblerEntry) SetNext(elem *reassembler) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *reassemblerEntry) SetPrev(elem *reassembler) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler_test.go b/pkg/tcpip/network/internal/fragmentation/reassembler_test.go
deleted file mode 100644
index cfd9f00ef..000000000
--- a/pkg/tcpip/network/internal/fragmentation/reassembler_test.go
+++ /dev/null
@@ -1,233 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package fragmentation
-
-import (
- "bytes"
- "math"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-type processParams struct {
- first uint16
- last uint16
- more bool
- pkt *stack.PacketBuffer
- wantDone bool
- wantError error
-}
-
-func TestReassemblerProcess(t *testing.T) {
- const proto = 99
-
- v := func(size int) buffer.View {
- payload := buffer.NewView(size)
- for i := 1; i < size; i++ {
- payload[i] = uint8(i) * 3
- }
- return payload
- }
-
- pkt := func(sizes ...int) *stack.PacketBuffer {
- var vv buffer.VectorisedView
- for _, size := range sizes {
- vv.AppendView(v(size))
- }
- return stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- })
- }
-
- var tests = []struct {
- name string
- params []processParams
- want []hole
- wantPkt *stack.PacketBuffer
- }{
- {
- name: "No fragments",
- params: nil,
- want: []hole{{first: 0, last: math.MaxUint16, filled: false, final: true}},
- },
- {
- name: "One fragment at beginning",
- params: []processParams{{first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}},
- want: []hole{
- {first: 0, last: 1, filled: true, final: false, pkt: pkt(2)},
- {first: 2, last: math.MaxUint16, filled: false, final: true},
- },
- },
- {
- name: "One fragment in the middle",
- params: []processParams{{first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}},
- want: []hole{
- {first: 1, last: 2, filled: true, final: false, pkt: pkt(2)},
- {first: 0, last: 0, filled: false, final: false},
- {first: 3, last: math.MaxUint16, filled: false, final: true},
- },
- },
- {
- name: "One fragment at the end",
- params: []processParams{{first: 1, last: 2, more: false, pkt: pkt(2), wantDone: false, wantError: nil}},
- want: []hole{
- {first: 1, last: 2, filled: true, final: true, pkt: pkt(2)},
- {first: 0, last: 0, filled: false},
- },
- },
- {
- name: "One fragment completing a packet",
- params: []processParams{{first: 0, last: 1, more: false, pkt: pkt(2), wantDone: true, wantError: nil}},
- want: []hole{
- {first: 0, last: 1, filled: true, final: true},
- },
- wantPkt: pkt(2),
- },
- {
- name: "Two fragments completing a packet",
- params: []processParams{
- {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
- {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
- },
- want: []hole{
- {first: 0, last: 1, filled: true, final: false},
- {first: 2, last: 3, filled: true, final: true},
- },
- wantPkt: pkt(2, 2),
- },
- {
- name: "Two fragments completing a packet with a duplicate",
- params: []processParams{
- {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
- {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
- {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
- },
- want: []hole{
- {first: 0, last: 1, filled: true, final: false},
- {first: 2, last: 3, filled: true, final: true},
- },
- wantPkt: pkt(2, 2),
- },
- {
- name: "Two fragments completing a packet with a partial duplicate",
- params: []processParams{
- {first: 0, last: 3, more: true, pkt: pkt(4), wantDone: false, wantError: nil},
- {first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil},
- {first: 4, last: 5, more: false, pkt: pkt(2), wantDone: true, wantError: nil},
- },
- want: []hole{
- {first: 0, last: 3, filled: true, final: false},
- {first: 4, last: 5, filled: true, final: true},
- },
- wantPkt: pkt(4, 2),
- },
- {
- name: "Two overlapping fragments",
- params: []processParams{
- {first: 0, last: 10, more: true, pkt: pkt(11), wantDone: false, wantError: nil},
- {first: 5, last: 15, more: false, pkt: pkt(11), wantDone: false, wantError: ErrFragmentOverlap},
- },
- want: []hole{
- {first: 0, last: 10, filled: true, final: false, pkt: pkt(11)},
- {first: 11, last: math.MaxUint16, filled: false, final: true},
- },
- },
- {
- name: "Two final fragments with different ends",
- params: []processParams{
- {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil},
- {first: 0, last: 9, more: false, pkt: pkt(10), wantDone: false, wantError: ErrFragmentConflict},
- },
- want: []hole{
- {first: 10, last: 14, filled: true, final: true, pkt: pkt(5)},
- {first: 0, last: 9, filled: false, final: false},
- },
- },
- {
- name: "Two final fragments - duplicate",
- params: []processParams{
- {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil},
- {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil},
- },
- want: []hole{
- {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)},
- {first: 0, last: 4, filled: false, final: false},
- },
- },
- {
- name: "Two final fragments - duplicate, with different ends",
- params: []processParams{
- {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil},
- {first: 10, last: 13, more: false, pkt: pkt(4), wantDone: false, wantError: ErrFragmentConflict},
- },
- want: []hole{
- {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)},
- {first: 0, last: 4, filled: false, final: false},
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- r := newReassembler(FragmentID{}, &faketime.NullClock{})
- var resPkt *stack.PacketBuffer
- var isDone bool
- for _, param := range test.params {
- pkt, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt)
- if done != param.wantDone || err != param.wantError {
- t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError)
- }
- if done {
- resPkt = pkt
- isDone = true
- }
- }
-
- ignorePkt := func(a, b *stack.PacketBuffer) bool { return true }
- cmpPktData := func(a, b *stack.PacketBuffer) bool {
- if a == nil || b == nil {
- return a == b
- }
- return bytes.Equal(a.Data().AsRange().ToOwnedView(), b.Data().AsRange().ToOwnedView())
- }
-
- if isDone {
- if diff := cmp.Diff(
- test.want, r.holes,
- cmp.AllowUnexported(hole{}),
- // Do not compare pkt in hole. Data will be altered.
- cmp.Comparer(ignorePkt),
- ); diff != "" {
- t.Errorf("r.holes mismatch (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(test.wantPkt, resPkt, cmp.Comparer(cmpPktData)); diff != "" {
- t.Errorf("Reassembled pkt mismatch (-want +got):\n%s", diff)
- }
- } else {
- if diff := cmp.Diff(
- test.want, r.holes,
- cmp.AllowUnexported(hole{}),
- cmp.Comparer(cmpPktData),
- ); diff != "" {
- t.Errorf("r.holes mismatch (-want +got):\n%s", diff)
- }
- }
- })
- }
-}
diff --git a/pkg/tcpip/network/internal/ip/BUILD b/pkg/tcpip/network/internal/ip/BUILD
deleted file mode 100644
index fd944ce99..000000000
--- a/pkg/tcpip/network/internal/ip/BUILD
+++ /dev/null
@@ -1,40 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "ip",
- srcs = [
- "duplicate_address_detection.go",
- "errors.go",
- "generic_multicast_protocol.go",
- "stats.go",
- ],
- visibility = [
- "//pkg/tcpip/network/arp:__pkg__",
- "//pkg/tcpip/network/ipv4:__pkg__",
- "//pkg/tcpip/network/ipv6:__pkg__",
- ],
- deps = [
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/stack",
- ],
-)
-
-go_test(
- name = "ip_x_test",
- size = "small",
- srcs = [
- "duplicate_address_detection_test.go",
- "generic_multicast_protocol_test.go",
- ],
- deps = [
- ":ip",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/faketime",
- "//pkg/tcpip/stack",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go
deleted file mode 100644
index 24687cf06..000000000
--- a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go
+++ /dev/null
@@ -1,381 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ip_test
-
-import (
- "bytes"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-type mockDADProtocol struct {
- t *testing.T
-
- mu struct {
- sync.Mutex
-
- dad ip.DAD
- sentNonces map[tcpip.Address][][]byte
- }
-}
-
-func (m *mockDADProtocol) init(t *testing.T, c stack.DADConfigurations, opts ip.DADOptions) {
- m.mu.Lock()
- defer m.mu.Unlock()
-
- m.t = t
- opts.Protocol = m
- m.mu.dad.Init(&m.mu, c, opts)
- m.initLocked()
-}
-
-func (m *mockDADProtocol) initLocked() {
- m.mu.sentNonces = make(map[tcpip.Address][][]byte)
-}
-
-func (m *mockDADProtocol) SendDADMessage(addr tcpip.Address, nonce []byte) tcpip.Error {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.mu.sentNonces[addr] = append(m.mu.sentNonces[addr], nonce)
- return nil
-}
-
-func (m *mockDADProtocol) check(addrs []tcpip.Address) string {
- sentNonces := make(map[tcpip.Address][][]byte)
- for _, a := range addrs {
- sentNonces[a] = append(sentNonces[a], nil)
- }
-
- return m.checkWithNonce(sentNonces)
-}
-
-func (m *mockDADProtocol) checkWithNonce(expectedSentNonces map[tcpip.Address][][]byte) string {
- m.mu.Lock()
- defer m.mu.Unlock()
-
- diff := cmp.Diff(expectedSentNonces, m.mu.sentNonces)
- m.initLocked()
- return diff
-}
-
-func (m *mockDADProtocol) checkDuplicateAddress(addr tcpip.Address, h stack.DADCompletionHandler) stack.DADCheckAddressDisposition {
- m.mu.Lock()
- defer m.mu.Unlock()
- return m.mu.dad.CheckDuplicateAddressLocked(addr, h)
-}
-
-func (m *mockDADProtocol) stop(addr tcpip.Address, reason stack.DADResult) {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.mu.dad.StopLocked(addr, reason)
-}
-
-func (m *mockDADProtocol) extendIfNonceEqual(addr tcpip.Address, nonce []byte) ip.ExtendIfNonceEqualLockedDisposition {
- m.mu.Lock()
- defer m.mu.Unlock()
- return m.mu.dad.ExtendIfNonceEqualLocked(addr, nonce)
-}
-
-func (m *mockDADProtocol) setConfigs(c stack.DADConfigurations) {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.mu.dad.SetConfigsLocked(c)
-}
-
-const (
- addr1 = tcpip.Address("\x01")
- addr2 = tcpip.Address("\x02")
- addr3 = tcpip.Address("\x03")
- addr4 = tcpip.Address("\x04")
-)
-
-type dadResult struct {
- Addr tcpip.Address
- R stack.DADResult
-}
-
-func handler(ch chan<- dadResult, a tcpip.Address) func(stack.DADResult) {
- return func(r stack.DADResult) {
- ch <- dadResult{Addr: a, R: r}
- }
-}
-
-func TestDADCheckDuplicateAddress(t *testing.T) {
- var dad mockDADProtocol
- clock := faketime.NewManualClock()
- dad.init(t, stack.DADConfigurations{}, ip.DADOptions{
- Clock: clock,
- })
-
- ch := make(chan dadResult, 2)
-
- // DAD should initially be disabled.
- if res := dad.checkDuplicateAddress(addr1, handler(nil, "")); res != stack.DADDisabled {
- t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADDisabled)
- }
- // Wait for any initially fired timers to complete.
- clock.RunImmediatelyScheduledJobs()
- if diff := dad.check(nil); diff != "" {
- t.Errorf("dad check mismatch (-want +got):\n%s", diff)
- }
-
- // Enable and request DAD.
- dadConfigs1 := stack.DADConfigurations{
- DupAddrDetectTransmits: 1,
- RetransmitTimer: time.Second,
- }
- dad.setConfigs(dadConfigs1)
- if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting {
- t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting)
- }
- clock.RunImmediatelyScheduledJobs()
- if diff := dad.check([]tcpip.Address{addr1}); diff != "" {
- t.Errorf("dad check mismatch (-want +got):\n%s", diff)
- }
- // The second request for DAD on the same address should use the original
- // request since it has not completed yet.
- if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADAlreadyRunning {
- t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADAlreadyRunning)
- }
- clock.RunImmediatelyScheduledJobs()
- if diff := dad.check(nil); diff != "" {
- t.Errorf("dad check mismatch (-want +got):\n%s", diff)
- }
-
- dadConfigs2 := stack.DADConfigurations{
- DupAddrDetectTransmits: 2,
- RetransmitTimer: time.Second,
- }
- dad.setConfigs(dadConfigs2)
- // A new address should start a new DAD process.
- if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting {
- t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting)
- }
- clock.RunImmediatelyScheduledJobs()
- if diff := dad.check([]tcpip.Address{addr2}); diff != "" {
- t.Errorf("dad check mismatch (-want +got):\n%s", diff)
- }
-
- // Make sure DAD for addr1 only resolves after the expected timeout.
- const delta = time.Nanosecond
- dadConfig1Duration := time.Duration(dadConfigs1.DupAddrDetectTransmits) * dadConfigs1.RetransmitTimer
- clock.Advance(dadConfig1Duration - delta)
- select {
- case r := <-ch:
- t.Fatalf("unexpectedly got a DAD result before the expected timeout of %s; r = %#v", dadConfig1Duration, r)
- default:
- }
- clock.Advance(delta)
- for i := 0; i < 2; i++ {
- if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" {
- t.Errorf("(i=%d) dad result mismatch (-want +got):\n%s", i, diff)
- }
- }
-
- // Make sure DAD for addr2 only resolves after the expected timeout.
- dadConfig2Duration := time.Duration(dadConfigs2.DupAddrDetectTransmits) * dadConfigs2.RetransmitTimer
- clock.Advance(dadConfig2Duration - dadConfig1Duration - delta)
- select {
- case r := <-ch:
- t.Fatalf("unexpectedly got a DAD result before the expected timeout of %s; r = %#v", dadConfig2Duration, r)
- default:
- }
- clock.Advance(delta)
- if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADSucceeded{}}, <-ch); diff != "" {
- t.Errorf("dad result mismatch (-want +got):\n%s", diff)
- }
-
- // Should be able to restart DAD for addr2 after it resolved.
- if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting {
- t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting)
- }
- clock.RunImmediatelyScheduledJobs()
- if diff := dad.check([]tcpip.Address{addr2, addr2}); diff != "" {
- t.Errorf("dad check mismatch (-want +got):\n%s", diff)
- }
- clock.Advance(dadConfig2Duration)
- if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADSucceeded{}}, <-ch); diff != "" {
- t.Errorf("dad result mismatch (-want +got):\n%s", diff)
- }
-
- // Should not have anymore results.
- select {
- case r := <-ch:
- t.Fatalf("unexpectedly got an extra DAD result; r = %#v", r)
- default:
- }
-}
-
-func TestDADStop(t *testing.T) {
- var dad mockDADProtocol
- clock := faketime.NewManualClock()
- dadConfigs := stack.DADConfigurations{
- DupAddrDetectTransmits: 1,
- RetransmitTimer: time.Second,
- }
- dad.init(t, dadConfigs, ip.DADOptions{
- Clock: clock,
- })
-
- ch := make(chan dadResult, 1)
-
- if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting {
- t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting)
- }
- if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting {
- t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting)
- }
- if res := dad.checkDuplicateAddress(addr3, handler(ch, addr3)); res != stack.DADStarting {
- t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting)
- }
- clock.RunImmediatelyScheduledJobs()
- if diff := dad.check([]tcpip.Address{addr1, addr2, addr3}); diff != "" {
- t.Errorf("dad check mismatch (-want +got):\n%s", diff)
- }
-
- dad.stop(addr1, &stack.DADAborted{})
- if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADAborted{}}, <-ch); diff != "" {
- t.Errorf("dad result mismatch (-want +got):\n%s", diff)
- }
-
- dad.stop(addr2, &stack.DADDupAddrDetected{})
- if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADDupAddrDetected{}}, <-ch); diff != "" {
- t.Errorf("dad result mismatch (-want +got):\n%s", diff)
- }
-
- dadResolutionDuration := time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer
- clock.Advance(dadResolutionDuration)
- if diff := cmp.Diff(dadResult{Addr: addr3, R: &stack.DADSucceeded{}}, <-ch); diff != "" {
- t.Errorf("dad result mismatch (-want +got):\n%s", diff)
- }
-
- // Should be able to restart DAD for an address we stopped DAD on.
- if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting {
- t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting)
- }
- clock.RunImmediatelyScheduledJobs()
- if diff := dad.check([]tcpip.Address{addr1}); diff != "" {
- t.Errorf("dad check mismatch (-want +got):\n%s", diff)
- }
- clock.Advance(dadResolutionDuration)
- if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" {
- t.Errorf("dad result mismatch (-want +got):\n%s", diff)
- }
-
- // Should not have anymore updates.
- select {
- case r := <-ch:
- t.Fatalf("unexpectedly got an extra DAD result; r = %#v", r)
- default:
- }
-}
-
-func TestNonce(t *testing.T) {
- const (
- nonceSize = 2
-
- extendRequestAttempts = 2
-
- dupAddrDetectTransmits = 2
- extendTransmits = 5
- )
-
- var secureRNGBytes [nonceSize * (dupAddrDetectTransmits + extendTransmits)]byte
- for i := range secureRNGBytes {
- secureRNGBytes[i] = byte(i)
- }
-
- tests := []struct {
- name string
- mockedReceivedNonce []byte
- expectedResults [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition
- expectedTransmits int
- }{
- {
- name: "not matching",
- mockedReceivedNonce: []byte{0, 0},
- expectedResults: [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition{ip.NonceNotEqual, ip.NonceNotEqual},
- expectedTransmits: dupAddrDetectTransmits,
- },
- {
- name: "matching nonce",
- mockedReceivedNonce: secureRNGBytes[:nonceSize],
- expectedResults: [extendRequestAttempts]ip.ExtendIfNonceEqualLockedDisposition{ip.Extended, ip.AlreadyExtended},
- expectedTransmits: dupAddrDetectTransmits + extendTransmits,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- var dad mockDADProtocol
- clock := faketime.NewManualClock()
- dadConfigs := stack.DADConfigurations{
- DupAddrDetectTransmits: dupAddrDetectTransmits,
- RetransmitTimer: time.Second,
- }
-
- var secureRNG bytes.Reader
- secureRNG.Reset(secureRNGBytes[:])
- dad.init(t, dadConfigs, ip.DADOptions{
- Clock: clock,
- SecureRNG: &secureRNG,
- NonceSize: nonceSize,
- ExtendDADTransmits: extendTransmits,
- })
-
- ch := make(chan dadResult, 1)
- if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting {
- t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting)
- }
-
- clock.RunImmediatelyScheduledJobs()
- for i, want := range test.expectedResults {
- if got := dad.extendIfNonceEqual(addr1, test.mockedReceivedNonce); got != want {
- t.Errorf("(i=%d) got dad.extendIfNonceEqual(%s, _) = %d, want = %d", i, addr1, got, want)
- }
- }
-
- for i := 0; i < test.expectedTransmits; i++ {
- if diff := dad.checkWithNonce(map[tcpip.Address][][]byte{
- addr1: {
- secureRNGBytes[nonceSize*i:][:nonceSize],
- },
- }); diff != "" {
- t.Errorf("(i=%d) dad check mismatch (-want +got):\n%s", i, diff)
- }
-
- clock.Advance(dadConfigs.RetransmitTimer)
- }
-
- if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" {
- t.Errorf("dad result mismatch (-want +got):\n%s", diff)
- }
-
- // Should not have anymore updates.
- select {
- case r := <-ch:
- t.Fatalf("unexpectedly got an extra DAD result; r = %#v", r)
- default:
- }
- })
- }
-}
diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
deleted file mode 100644
index 1261ad414..000000000
--- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
+++ /dev/null
@@ -1,808 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ip_test
-
-import (
- "math/rand"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip"
-)
-
-const maxUnsolicitedReportDelay = time.Second
-
-var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil)
-
-type mockMulticastGroupProtocolProtectedFields struct {
- sync.RWMutex
-
- genericMulticastGroup ip.GenericMulticastProtocolState
- sendReportGroupAddrCount map[tcpip.Address]int
- sendLeaveGroupAddrCount map[tcpip.Address]int
- makeQueuePackets bool
- disabled bool
-}
-
-type mockMulticastGroupProtocol struct {
- t *testing.T
-
- skipProtocolAddress tcpip.Address
-
- mu mockMulticastGroupProtocolProtectedFields
-}
-
-func (m *mockMulticastGroupProtocol) init(opts ip.GenericMulticastProtocolOptions) {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.initLocked()
- opts.Protocol = m
- m.mu.genericMulticastGroup.Init(&m.mu.RWMutex, opts)
-}
-
-func (m *mockMulticastGroupProtocol) initLocked() {
- m.mu.sendReportGroupAddrCount = make(map[tcpip.Address]int)
- m.mu.sendLeaveGroupAddrCount = make(map[tcpip.Address]int)
-}
-
-func (m *mockMulticastGroupProtocol) setEnabled(v bool) {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.mu.disabled = !v
-}
-
-func (m *mockMulticastGroupProtocol) setQueuePackets(v bool) {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.mu.makeQueuePackets = v
-}
-
-func (m *mockMulticastGroupProtocol) joinGroup(addr tcpip.Address) {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.mu.genericMulticastGroup.JoinGroupLocked(addr)
-}
-
-func (m *mockMulticastGroupProtocol) leaveGroup(addr tcpip.Address) bool {
- m.mu.Lock()
- defer m.mu.Unlock()
- return m.mu.genericMulticastGroup.LeaveGroupLocked(addr)
-}
-
-func (m *mockMulticastGroupProtocol) handleReport(addr tcpip.Address) {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.mu.genericMulticastGroup.HandleReportLocked(addr)
-}
-
-func (m *mockMulticastGroupProtocol) handleQuery(addr tcpip.Address, maxRespTime time.Duration) {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.mu.genericMulticastGroup.HandleQueryLocked(addr, maxRespTime)
-}
-
-func (m *mockMulticastGroupProtocol) isLocallyJoined(addr tcpip.Address) bool {
- m.mu.RLock()
- defer m.mu.RUnlock()
- return m.mu.genericMulticastGroup.IsLocallyJoinedRLocked(addr)
-}
-
-func (m *mockMulticastGroupProtocol) makeAllNonMember() {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.mu.genericMulticastGroup.MakeAllNonMemberLocked()
-}
-
-func (m *mockMulticastGroupProtocol) initializeGroups() {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.mu.genericMulticastGroup.InitializeGroupsLocked()
-}
-
-func (m *mockMulticastGroupProtocol) sendQueuedReports() {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.mu.genericMulticastGroup.SendQueuedReportsLocked()
-}
-
-// Enabled implements ip.MulticastGroupProtocol.
-//
-// Precondition: m.mu must be read locked.
-func (m *mockMulticastGroupProtocol) Enabled() bool {
- if m.mu.TryLock() {
- m.mu.Unlock() // +checklocksforce: TryLock.
- m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled")
- }
-
- return !m.mu.disabled
-}
-
-// SendReport implements ip.MulticastGroupProtocol.
-//
-// Precondition: m.mu must be locked.
-func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) {
- if m.mu.TryLock() {
- m.mu.Unlock() // +checklocksforce: TryLock.
- m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
- }
- if m.mu.TryRLock() {
- m.mu.RUnlock() // +checklocksforce: TryLock.
- m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
- }
-
- m.mu.sendReportGroupAddrCount[groupAddress]++
- return !m.mu.makeQueuePackets, nil
-}
-
-// SendLeave implements ip.MulticastGroupProtocol.
-//
-// Precondition: m.mu must be locked.
-func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip.Error {
- if m.mu.TryLock() {
- m.mu.Unlock() // +checklocksforce: TryLock.
- m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
- }
- if m.mu.TryRLock() {
- m.mu.RUnlock() // +checklocksforce: TryLock.
- m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
- }
-
- m.mu.sendLeaveGroupAddrCount[groupAddress]++
- return nil
-}
-
-// ShouldPerformProtocol implements ip.MulticastGroupProtocol.
-func (m *mockMulticastGroupProtocol) ShouldPerformProtocol(groupAddress tcpip.Address) bool {
- return groupAddress != m.skipProtocolAddress
-}
-
-func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string {
- m.mu.Lock()
- defer m.mu.Unlock()
-
- sendReportGroupAddrCount := make(map[tcpip.Address]int)
- for _, a := range sendReportGroupAddresses {
- sendReportGroupAddrCount[a] = 1
- }
-
- sendLeaveGroupAddrCount := make(map[tcpip.Address]int)
- for _, a := range sendLeaveGroupAddresses {
- sendLeaveGroupAddrCount[a] = 1
- }
-
- diff := cmp.Diff(
- &mockMulticastGroupProtocol{
- mu: mockMulticastGroupProtocolProtectedFields{
- sendReportGroupAddrCount: sendReportGroupAddrCount,
- sendLeaveGroupAddrCount: sendLeaveGroupAddrCount,
- },
- },
- m,
- cmp.AllowUnexported(mockMulticastGroupProtocol{}),
- cmp.AllowUnexported(mockMulticastGroupProtocolProtectedFields{}),
- // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t
- cmp.FilterPath(
- func(p cmp.Path) bool {
- switch p.Last().String() {
- case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup", ".skipProtocolAddress":
- return true
- default:
- return false
- }
- },
- cmp.Ignore(),
- ),
- )
- m.initLocked()
- return diff
-}
-
-func TestJoinGroup(t *testing.T) {
- tests := []struct {
- name string
- addr tcpip.Address
- shouldSendReports bool
- }{
- {
- name: "Normal group",
- addr: addr1,
- shouldSendReports: true,
- },
- {
- name: "All-nodes group",
- addr: addr2,
- shouldSendReports: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2}
- clock := faketime.NewManualClock()
-
- mgp.init(ip.GenericMulticastProtocolOptions{
- Rand: rand.New(rand.NewSource(0)),
- Clock: clock,
- MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- })
-
- // Joining a group should send a report immediately and another after
- // a random interval between 0 and the maximum unsolicited report delay.
- mgp.joinGroup(test.addr)
- if test.shouldSendReports {
- if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Generic multicast protocol timers are expected to take the job mutex.
- clock.Advance(maxUnsolicitedReportDelay)
- if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- }
-
- // Should have no more messages to send.
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
-
-func TestLeaveGroup(t *testing.T) {
- tests := []struct {
- name string
- addr tcpip.Address
- shouldSendMessages bool
- }{
- {
- name: "Normal group",
- addr: addr1,
- shouldSendMessages: true,
- },
- {
- name: "All-nodes group",
- addr: addr2,
- shouldSendMessages: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2}
- clock := faketime.NewManualClock()
-
- mgp.init(ip.GenericMulticastProtocolOptions{
- Rand: rand.New(rand.NewSource(1)),
- Clock: clock,
- MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- })
-
- mgp.joinGroup(test.addr)
- if test.shouldSendMessages {
- if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- }
-
- // Leaving a group should send a leave report immediately and cancel any
- // delayed reports.
- {
-
- if !mgp.leaveGroup(test.addr) {
- t.Fatalf("got mgp.leaveGroup(%s) = false, want = true", test.addr)
- }
- }
- if test.shouldSendMessages {
- if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- }
-
- // Should have no more messages to send.
- //
- // Generic multicast protocol timers are expected to take the job mutex.
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
-
-func TestHandleReport(t *testing.T) {
- tests := []struct {
- name string
- reportAddr tcpip.Address
- expectReportsFor []tcpip.Address
- }{
- {
- name: "Unpecified empty",
- reportAddr: "",
- expectReportsFor: []tcpip.Address{addr1, addr2},
- },
- {
- name: "Unpecified any",
- reportAddr: "\x00",
- expectReportsFor: []tcpip.Address{addr1, addr2},
- },
- {
- name: "Specified",
- reportAddr: addr1,
- expectReportsFor: []tcpip.Address{addr2},
- },
- {
- name: "Specified all-nodes",
- reportAddr: addr3,
- expectReportsFor: []tcpip.Address{addr1, addr2},
- },
- {
- name: "Specified other",
- reportAddr: addr4,
- expectReportsFor: []tcpip.Address{addr1, addr2},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3}
- clock := faketime.NewManualClock()
-
- mgp.init(ip.GenericMulticastProtocolOptions{
- Rand: rand.New(rand.NewSource(2)),
- Clock: clock,
- MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- })
-
- mgp.joinGroup(addr1)
- if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- mgp.joinGroup(addr2)
- if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- mgp.joinGroup(addr3)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Receiving a report for a group we have a timer scheduled for should
- // cancel our delayed report timer for the group.
- mgp.handleReport(test.reportAddr)
- if len(test.expectReportsFor) != 0 {
- // Generic multicast protocol timers are expected to take the job mutex.
- clock.Advance(maxUnsolicitedReportDelay)
- if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- }
-
- // Should have no more messages to send.
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
-
-func TestHandleQuery(t *testing.T) {
- tests := []struct {
- name string
- queryAddr tcpip.Address
- maxDelay time.Duration
- expectQueriedReportsFor []tcpip.Address
- expectDelayedReportsFor []tcpip.Address
- }{
- {
- name: "Unpecified empty",
- queryAddr: "",
- maxDelay: 0,
- expectQueriedReportsFor: []tcpip.Address{addr1, addr2},
- expectDelayedReportsFor: nil,
- },
- {
- name: "Unpecified any",
- queryAddr: "\x00",
- maxDelay: 1,
- expectQueriedReportsFor: []tcpip.Address{addr1, addr2},
- expectDelayedReportsFor: nil,
- },
- {
- name: "Specified",
- queryAddr: addr1,
- maxDelay: 2,
- expectQueriedReportsFor: []tcpip.Address{addr1},
- expectDelayedReportsFor: []tcpip.Address{addr2},
- },
- {
- name: "Specified all-nodes",
- queryAddr: addr3,
- maxDelay: 3,
- expectQueriedReportsFor: nil,
- expectDelayedReportsFor: []tcpip.Address{addr1, addr2},
- },
- {
- name: "Specified other",
- queryAddr: addr4,
- maxDelay: 4,
- expectQueriedReportsFor: nil,
- expectDelayedReportsFor: []tcpip.Address{addr1, addr2},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3}
- clock := faketime.NewManualClock()
-
- mgp.init(ip.GenericMulticastProtocolOptions{
- Rand: rand.New(rand.NewSource(3)),
- Clock: clock,
- MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- })
-
- mgp.joinGroup(addr1)
- if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- mgp.joinGroup(addr2)
- if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- mgp.joinGroup(addr3)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Receiving a query should make us reschedule our delayed report timer
- // to some time within the new max response delay.
- mgp.handleQuery(test.queryAddr, test.maxDelay)
- clock.Advance(test.maxDelay)
- if diff := mgp.check(test.expectQueriedReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // The groups that were not affected by the query should still send a
- // report after the max unsolicited report delay.
- clock.Advance(maxUnsolicitedReportDelay)
- if diff := mgp.check(test.expectDelayedReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Should have no more messages to send.
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
-
-func TestJoinCount(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t}
- clock := faketime.NewManualClock()
-
- mgp.init(ip.GenericMulticastProtocolOptions{
- Rand: rand.New(rand.NewSource(4)),
- Clock: clock,
- MaxUnsolicitedReportDelay: time.Second,
- })
-
- // Set the join count to 2 for a group.
- mgp.joinGroup(addr1)
- if !mgp.isLocallyJoined(addr1) {
- t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
- }
- // Only the first join should trigger a report to be sent.
- if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- mgp.joinGroup(addr1)
- if !mgp.isLocallyJoined(addr1) {
- t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
- }
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Group should still be considered joined after leaving once.
- if !mgp.leaveGroup(addr1) {
- t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1)
- }
- if !mgp.isLocallyJoined(addr1) {
- t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
- }
- // A leave report should only be sent once the join count reaches 0.
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Leaving once more should actually remove us from the group.
- if !mgp.leaveGroup(addr1) {
- t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1)
- }
- if mgp.isLocallyJoined(addr1) {
- t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1)
- }
- if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Group should no longer be joined so we should not have anything to
- // leave.
- if mgp.leaveGroup(addr1) {
- t.Errorf("got mgp.leaveGroup(%s) = true, want = false", addr1)
- }
- if mgp.isLocallyJoined(addr1) {
- t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1)
- }
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Should have no more messages to send.
- //
- // Generic multicast protocol timers are expected to take the job mutex.
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-}
-
-func TestMakeAllNonMemberAndInitialize(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3}
- clock := faketime.NewManualClock()
-
- mgp.init(ip.GenericMulticastProtocolOptions{
- Rand: rand.New(rand.NewSource(3)),
- Clock: clock,
- MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- })
-
- mgp.joinGroup(addr1)
- if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- mgp.joinGroup(addr2)
- if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- mgp.joinGroup(addr3)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Should send the leave reports for each but still consider them locally
- // joined.
- mgp.makeAllNonMember()
- if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- // Generic multicast protocol timers are expected to take the job mutex.
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- for _, group := range []tcpip.Address{addr1, addr2, addr3} {
- if !mgp.isLocallyJoined(group) {
- t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", group)
- }
- }
-
- // Should send the initial set of unsolcited reports.
- mgp.initializeGroups()
- if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- clock.Advance(maxUnsolicitedReportDelay)
- if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Should have no more messages to send.
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-}
-
-// TestGroupStateNonMember tests that groups do not send packets when in the
-// non-member state, but are still considered locally joined.
-func TestGroupStateNonMember(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t}
- clock := faketime.NewManualClock()
-
- mgp.init(ip.GenericMulticastProtocolOptions{
- Rand: rand.New(rand.NewSource(3)),
- Clock: clock,
- MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- })
- mgp.setEnabled(false)
-
- // Joining groups should not send any reports.
- mgp.joinGroup(addr1)
- if !mgp.isLocallyJoined(addr1) {
- t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1)
- }
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- mgp.joinGroup(addr2)
- if !mgp.isLocallyJoined(addr1) {
- t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr2)
- }
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Receiving a query should not send any reports.
- mgp.handleQuery(addr1, time.Nanosecond)
- // Generic multicast protocol timers are expected to take the job mutex.
- clock.Advance(time.Nanosecond)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Leaving groups should not send any leave messages.
- if !mgp.leaveGroup(addr1) {
- t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr2)
- }
- if mgp.isLocallyJoined(addr1) {
- t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr2)
- }
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-}
-
-func TestQueuedPackets(t *testing.T) {
- clock := faketime.NewManualClock()
- mgp := mockMulticastGroupProtocol{t: t}
- mgp.init(ip.GenericMulticastProtocolOptions{
- Rand: rand.New(rand.NewSource(4)),
- Clock: clock,
- MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- })
-
- // Joining should trigger a SendReport, but mgp should report that we did not
- // send the packet.
- mgp.setQueuePackets(true)
- mgp.joinGroup(addr1)
- if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // The delayed report timer should have been cancelled since we did not send
- // the initial report earlier.
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Mock being able to successfully send the report.
- mgp.setQueuePackets(false)
- mgp.sendQueuedReports()
- if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // The delayed report (sent after the initial report) should now be sent.
- clock.Advance(maxUnsolicitedReportDelay)
- if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Should not have anything else to send (we should be idle).
- mgp.sendQueuedReports()
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Receive a query but mock being unable to send reports again.
- mgp.setQueuePackets(true)
- mgp.handleQuery(addr1, time.Nanosecond)
- clock.Advance(time.Nanosecond)
- if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Mock being able to send reports again - we should have a packet queued to
- // send.
- mgp.setQueuePackets(false)
- mgp.sendQueuedReports()
- if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Should not have anything else to send.
- mgp.sendQueuedReports()
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Receive a query again, but mock being unable to send reports.
- mgp.setQueuePackets(true)
- mgp.handleQuery(addr1, time.Nanosecond)
- clock.Advance(time.Nanosecond)
- if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Receiving a report should should transition us into the idle member state,
- // even if we had a packet queued. We should no longer have any packets to
- // send.
- mgp.handleReport(addr1)
- mgp.sendQueuedReports()
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // When we fail to send the initial set of reports, incoming reports should
- // not affect a newly joined group's reports from being sent.
- mgp.setQueuePackets(true)
- mgp.joinGroup(addr2)
- if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- mgp.handleReport(addr2)
- // Attempting to send queued reports while still unable to send reports should
- // not change the host state.
- mgp.sendQueuedReports()
- if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- // Mock being able to successfully send the report.
- mgp.setQueuePackets(false)
- mgp.sendQueuedReports()
- if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
- // The delayed report (sent after the initial report) should now be sent.
- clock.Advance(maxUnsolicitedReportDelay)
- if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-
- // Should not have anything else to send.
- mgp.sendQueuedReports()
- clock.Advance(time.Hour)
- if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
- }
-}
diff --git a/pkg/tcpip/network/internal/ip/ip_state_autogen.go b/pkg/tcpip/network/internal/ip/ip_state_autogen.go
new file mode 100644
index 000000000..360922bfe
--- /dev/null
+++ b/pkg/tcpip/network/internal/ip/ip_state_autogen.go
@@ -0,0 +1,32 @@
+// automatically generated by stateify.
+
+package ip
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (e *ErrMessageTooLong) StateTypeName() string {
+ return "pkg/tcpip/network/internal/ip.ErrMessageTooLong"
+}
+
+func (e *ErrMessageTooLong) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrMessageTooLong) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrMessageTooLong) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrMessageTooLong) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrMessageTooLong) StateLoad(stateSourceObject state.Source) {
+}
+
+func init() {
+ state.Register((*ErrMessageTooLong)(nil))
+}
diff --git a/pkg/tcpip/network/internal/testutil/BUILD b/pkg/tcpip/network/internal/testutil/BUILD
deleted file mode 100644
index a180e5c75..000000000
--- a/pkg/tcpip/network/internal/testutil/BUILD
+++ /dev/null
@@ -1,21 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "testutil",
- srcs = ["testutil.go"],
- visibility = [
- "//pkg/tcpip/network/arp:__pkg__",
- "//pkg/tcpip/network/internal/fragmentation:__pkg__",
- "//pkg/tcpip/network/ipv4:__pkg__",
- "//pkg/tcpip/network/ipv6:__pkg__",
- "//pkg/tcpip/tests/integration:__pkg__",
- ],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/stack",
- ],
-)
diff --git a/pkg/tcpip/network/internal/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go
deleted file mode 100644
index 4d4d98caf..000000000
--- a/pkg/tcpip/network/internal/testutil/testutil.go
+++ /dev/null
@@ -1,134 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package testutil defines types and functions used to test Network Layer
-// functionality such as IP fragmentation.
-package testutil
-
-import (
- "fmt"
- "math/rand"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-// MockLinkEndpoint is an endpoint used for testing, it stores packets written
-// to it and can mock errors.
-type MockLinkEndpoint struct {
- // WrittenPackets is where packets written to the endpoint are stored.
- WrittenPackets []*stack.PacketBuffer
-
- mtu uint32
- err tcpip.Error
- allowPackets int
-}
-
-// NewMockLinkEndpoint creates a new MockLinkEndpoint.
-//
-// err is the error that will be returned once allowPackets packets are written
-// to the endpoint.
-func NewMockLinkEndpoint(mtu uint32, err tcpip.Error, allowPackets int) *MockLinkEndpoint {
- return &MockLinkEndpoint{
- mtu: mtu,
- err: err,
- allowPackets: allowPackets,
- }
-}
-
-// MTU implements LinkEndpoint.MTU.
-func (ep *MockLinkEndpoint) MTU() uint32 { return ep.mtu }
-
-// Capabilities implements LinkEndpoint.Capabilities.
-func (*MockLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return 0 }
-
-// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
-func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 }
-
-// LinkAddress implements LinkEndpoint.LinkAddress.
-func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
-
-// WritePacket implements LinkEndpoint.WritePacket.
-func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- if ep.allowPackets == 0 {
- return ep.err
- }
- ep.allowPackets--
- ep.WrittenPackets = append(ep.WrittenPackets, pkt)
- return nil
-}
-
-// WritePackets implements LinkEndpoint.WritePackets.
-func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- var n int
-
- for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- if err := ep.WritePacket(r, protocol, pkt); err != nil {
- return n, err
- }
- n++
- }
-
- return n, nil
-}
-
-// Attach implements LinkEndpoint.Attach.
-func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {}
-
-// IsAttached implements LinkEndpoint.IsAttached.
-func (*MockLinkEndpoint) IsAttached() bool { return false }
-
-// Wait implements LinkEndpoint.Wait.
-func (*MockLinkEndpoint) Wait() {}
-
-// ARPHardwareType implements LinkEndpoint.ARPHardwareType.
-func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone }
-
-// AddHeader implements LinkEndpoint.AddHeader.
-func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
-}
-
-// WriteRawPacket implements stack.LinkEndpoint.
-func (*MockLinkEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error {
- return &tcpip.ErrNotSupported{}
-}
-
-// MakeRandPkt generates a randomized packet. transportHeaderLength indicates
-// how many random bytes will be copied in the Transport Header.
-// extraHeaderReserveLength indicates how much extra space will be reserved for
-// the other headers. The payload is made from Views of the sizes listed in
-// viewSizes.
-func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSizes []int, proto tcpip.NetworkProtocolNumber) *stack.PacketBuffer {
- var views buffer.VectorisedView
-
- for _, s := range viewSizes {
- newView := buffer.NewView(s)
- if _, err := rand.Read(newView); err != nil {
- panic(fmt.Sprintf("rand.Read: %s", err))
- }
- views.AppendView(newView)
- }
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: transportHeaderLength + extraHeaderReserveLength,
- Data: views,
- })
- pkt.NetworkProtocolNumber = proto
- if _, err := rand.Read(pkt.TransportHeader().Push(transportHeaderLength)); err != nil {
- panic(fmt.Sprintf("rand.Read: %s", err))
- }
- return pkt
-}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
deleted file mode 100644
index 87f650661..000000000
--- a/pkg/tcpip/network/ip_test.go
+++ /dev/null
@@ -1,2147 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ip_test
-
-import (
- "bytes"
- "fmt"
- "strings"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/raw"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const nicID = 1
-
-var (
- localIPv4Addr = testutil.MustParse4("10.0.0.1")
- remoteIPv4Addr = testutil.MustParse4("10.0.0.2")
- ipv4SubnetAddr = testutil.MustParse4("10.0.0.0")
- ipv4SubnetMask = testutil.MustParse4("255.255.255.0")
- ipv4Gateway = testutil.MustParse4("10.0.0.3")
- localIPv6Addr = testutil.MustParse6("a00::1")
- remoteIPv6Addr = testutil.MustParse6("a00::2")
- ipv6SubnetAddr = testutil.MustParse6("a00::")
- ipv6SubnetMask = testutil.MustParse6("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ff00")
- ipv6Gateway = testutil.MustParse6("a00::3")
-)
-
-var localIPv4AddrWithPrefix = tcpip.AddressWithPrefix{
- Address: localIPv4Addr,
- PrefixLen: 24,
-}
-
-var localIPv6AddrWithPrefix = tcpip.AddressWithPrefix{
- Address: localIPv6Addr,
- PrefixLen: 120,
-}
-
-type transportError struct {
- origin tcpip.SockErrOrigin
- typ uint8
- code uint8
- info uint32
- kind stack.TransportErrorKind
-}
-
-// testObject implements two interfaces: LinkEndpoint and TransportDispatcher.
-// The former is used to pretend that it's a link endpoint so that we can
-// inspect packets written by the network endpoints. The latter is used to
-// pretend that it's the network stack so that it can inspect incoming packets
-// that have been handled by the network endpoints.
-//
-// Packets are checked by comparing their fields/values against the expected
-// values stored in the test object itself.
-type testObject struct {
- t *testing.T
- protocol tcpip.TransportProtocolNumber
- contents []byte
- srcAddr tcpip.Address
- dstAddr tcpip.Address
- v4 bool
- transErr transportError
-
- dataCalls int
- controlCalls int
- rawCalls int
-}
-
-// checkValues verifies that the transport protocol, data contents, src & dst
-// addresses of a packet match what's expected. If any field doesn't match, the
-// test fails.
-func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, v buffer.View, srcAddr, dstAddr tcpip.Address) {
- if protocol != t.protocol {
- t.t.Errorf("protocol = %v, want %v", protocol, t.protocol)
- }
-
- if srcAddr != t.srcAddr {
- t.t.Errorf("srcAddr = %v, want %v", srcAddr, t.srcAddr)
- }
-
- if dstAddr != t.dstAddr {
- t.t.Errorf("dstAddr = %v, want %v", dstAddr, t.dstAddr)
- }
-
- if len(v) != len(t.contents) {
- t.t.Fatalf("len(payload) = %v, want %v", len(v), len(t.contents))
- }
-
- for i := range t.contents {
- if t.contents[i] != v[i] {
- t.t.Fatalf("payload[%v] = %v, want %v", i, v[i], t.contents[i])
- }
- }
-}
-
-// DeliverTransportPacket is called by network endpoints after parsing incoming
-// packets. This is used by the test object to verify that the results of the
-// parsing are expected.
-func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition {
- netHdr := pkt.Network()
- t.checkValues(protocol, pkt.Data().AsRange().ToOwnedView(), netHdr.SourceAddress(), netHdr.DestinationAddress())
- t.dataCalls++
- return stack.TransportPacketHandled
-}
-
-// DeliverTransportError is called by network endpoints after parsing
-// incoming control (ICMP) packets. This is used by the test object to verify
-// that the results of the parsing are expected.
-func (t *testObject) DeliverTransportError(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr stack.TransportError, pkt *stack.PacketBuffer) {
- t.checkValues(trans, pkt.Data().AsRange().ToOwnedView(), remote, local)
- if diff := cmp.Diff(
- t.transErr,
- transportError{
- origin: transErr.Origin(),
- typ: transErr.Type(),
- code: transErr.Code(),
- info: transErr.Info(),
- kind: transErr.Kind(),
- },
- cmp.AllowUnexported(transportError{}),
- ); diff != "" {
- t.t.Errorf("transport error mismatch (-want +got):\n%s", diff)
- }
- t.controlCalls++
-}
-
-func (t *testObject) DeliverRawPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) {
- t.rawCalls++
-}
-
-// Attach is only implemented to satisfy the LinkEndpoint interface.
-func (*testObject) Attach(stack.NetworkDispatcher) {}
-
-// IsAttached implements stack.LinkEndpoint.IsAttached.
-func (*testObject) IsAttached() bool {
- return true
-}
-
-// MTU implements stack.LinkEndpoint.MTU. It just returns a constant that
-// matches the linux loopback MTU.
-func (*testObject) MTU() uint32 {
- return 65536
-}
-
-// Capabilities implements stack.LinkEndpoint.Capabilities.
-func (*testObject) Capabilities() stack.LinkEndpointCapabilities {
- return 0
-}
-
-// MaxHeaderLength is only implemented to satisfy the LinkEndpoint interface.
-func (*testObject) MaxHeaderLength() uint16 {
- return 0
-}
-
-// LinkAddress returns the link address of this endpoint.
-func (*testObject) LinkAddress() tcpip.LinkAddress {
- return ""
-}
-
-// Wait implements stack.LinkEndpoint.Wait.
-func (*testObject) Wait() {}
-
-// WritePacket is called by network endpoints after producing a packet and
-// writing it to the link endpoint. This is used by the test object to verify
-// that the produced packet is as expected.
-func (t *testObject) WritePacket(_ *stack.Route, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- var prot tcpip.TransportProtocolNumber
- var srcAddr tcpip.Address
- var dstAddr tcpip.Address
-
- if t.v4 {
- h := header.IPv4(pkt.NetworkHeader().View())
- prot = tcpip.TransportProtocolNumber(h.Protocol())
- srcAddr = h.SourceAddress()
- dstAddr = h.DestinationAddress()
-
- } else {
- h := header.IPv6(pkt.NetworkHeader().View())
- prot = tcpip.TransportProtocolNumber(h.NextHeader())
- srcAddr = h.SourceAddress()
- dstAddr = h.DestinationAddress()
- }
- t.checkValues(prot, pkt.Data().AsRange().ToOwnedView(), srcAddr, dstAddr)
- return nil
-}
-
-// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (*testObject) WritePackets(_ *stack.Route, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- panic("not implemented")
-}
-
-// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
-func (*testObject) ARPHardwareType() header.ARPHardwareType {
- panic("not implemented")
-}
-
-// AddHeader implements stack.LinkEndpoint.AddHeader.
-func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- panic("not implemented")
-}
-
-func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
- })
- s.CreateNIC(nicID, loopback.New())
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: local.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- return nil, err
- }
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv4EmptySubnet,
- Gateway: ipv4Gateway,
- NIC: 1,
- }})
-
- return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */)
-}
-
-func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
- })
- s.CreateNIC(nicID, loopback.New())
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: local.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- return nil, err
- }
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv6EmptySubnet,
- Gateway: ipv6Gateway,
- NIC: 1,
- }})
-
- return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */)
-}
-
-func buildDummyStackWithLinkEndpoint(t *testing.T, mtu uint32) (*stack.Stack, *channel.Endpoint) {
- t.Helper()
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
- })
- e := channel.New(1, mtu, "")
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix}
- if err := s.AddProtocolAddress(nicID, v4Addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, v4Addr, err)
- }
-
- v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix}
- if err := s.AddProtocolAddress(nicID, v6Addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, v6Addr, err)
- }
-
- return s, e
-}
-
-func buildDummyStack(t *testing.T) *stack.Stack {
- t.Helper()
-
- s, _ := buildDummyStackWithLinkEndpoint(t, header.IPv6MinimumMTU)
- return s
-}
-
-var _ stack.NetworkInterface = (*testInterface)(nil)
-
-type testInterface struct {
- testObject
-
- mu struct {
- sync.RWMutex
- disabled bool
- }
-}
-
-func (*testInterface) ID() tcpip.NICID {
- return nicID
-}
-
-func (*testInterface) IsLoopback() bool {
- return false
-}
-
-func (*testInterface) Name() string {
- return ""
-}
-
-func (t *testInterface) Enabled() bool {
- t.mu.RLock()
- defer t.mu.RUnlock()
- return !t.mu.disabled
-}
-
-func (*testInterface) Promiscuous() bool {
- return false
-}
-
-func (*testInterface) Spoofing() bool {
- return false
-}
-
-func (t *testInterface) setEnabled(v bool) {
- t.mu.Lock()
- defer t.mu.Unlock()
- t.mu.disabled = !v
-}
-
-func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
- return &tcpip.ErrNotSupported{}
-}
-
-func (*testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error {
- return nil
-}
-
-func (*testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) tcpip.Error {
- return nil
-}
-
-func (*testInterface) PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) {
- return tcpip.AddressWithPrefix{}, nil
-}
-
-func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool {
- return false
-}
-
-func TestSourceAddressValidation(t *testing.T) {
- rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) {
- totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
- hdr := buffer.NewPrependable(totalLen)
- pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
- pkt.SetType(header.ICMPv4Echo)
- pkt.SetCode(0)
- pkt.SetChecksum(0)
- pkt.SetChecksum(^header.Checksum(pkt, 0))
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(totalLen),
- Protocol: uint8(icmp.ProtocolNumber4),
- TTL: ipv4.DefaultTTL,
- SrcAddr: src,
- DstAddr: localIPv4Addr,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- }
-
- rxIPv6ICMP := func(e *channel.Endpoint, src tcpip.Address) {
- totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
- hdr := buffer.NewPrependable(totalLen)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
- pkt.SetType(header.ICMPv6EchoRequest)
- pkt.SetCode(0)
- pkt.SetChecksum(0)
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: src,
- Dst: localIPv6Addr,
- }))
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- TransportProtocol: icmp.ProtocolNumber6,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: localIPv6Addr,
- })
- e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- }
-
- tests := []struct {
- name string
- srcAddress tcpip.Address
- rxICMP func(*channel.Endpoint, tcpip.Address)
- valid bool
- }{
- {
- name: "IPv4 valid",
- srcAddress: "\x01\x02\x03\x04",
- rxICMP: rxIPv4ICMP,
- valid: true,
- },
- {
- name: "IPv6 valid",
- srcAddress: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10",
- rxICMP: rxIPv6ICMP,
- valid: true,
- },
- {
- name: "IPv4 unspecified",
- srcAddress: header.IPv4Any,
- rxICMP: rxIPv4ICMP,
- valid: true,
- },
- {
- name: "IPv6 unspecified",
- srcAddress: header.IPv4Any,
- rxICMP: rxIPv6ICMP,
- valid: true,
- },
- {
- name: "IPv4 multicast",
- srcAddress: "\xe0\x00\x00\x01",
- rxICMP: rxIPv4ICMP,
- valid: false,
- },
- {
- name: "IPv6 multicast",
- srcAddress: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- rxICMP: rxIPv6ICMP,
- valid: false,
- },
- {
- name: "IPv4 broadcast",
- srcAddress: header.IPv4Broadcast,
- rxICMP: rxIPv4ICMP,
- valid: false,
- },
- {
- name: "IPv4 subnet broadcast",
- srcAddress: func() tcpip.Address {
- subnet := localIPv4AddrWithPrefix.Subnet()
- return subnet.Broadcast()
- }(),
- rxICMP: rxIPv4ICMP,
- valid: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s, e := buildDummyStackWithLinkEndpoint(t, header.IPv6MinimumMTU)
- test.rxICMP(e, test.srcAddress)
-
- var wantValid uint64
- if test.valid {
- wantValid = 1
- }
-
- if got, want := s.Stats().IP.InvalidSourceAddressesReceived.Value(), 1-wantValid; got != want {
- t.Errorf("got s.Stats().IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want)
- }
- if got := s.Stats().IP.PacketsDelivered.Value(); got != wantValid {
- t.Errorf("got s.Stats().IP.PacketsDelivered.Value() = %d, want = %d", got, wantValid)
- }
- })
- }
-}
-
-func TestEnableWhenNICDisabled(t *testing.T) {
- tests := []struct {
- name string
- protocolFactory stack.NetworkProtocolFactory
- protoNum tcpip.NetworkProtocolNumber
- }{
- {
- name: "IPv4",
- protocolFactory: ipv4.NewProtocol,
- protoNum: ipv4.ProtocolNumber,
- },
- {
- name: "IPv6",
- protocolFactory: ipv6.NewProtocol,
- protoNum: ipv6.ProtocolNumber,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- var nic testInterface
- nic.setEnabled(false)
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{test.protocolFactory},
- })
- p := s.NetworkProtocolInstance(test.protoNum)
-
- // We pass nil for all parameters except the NetworkInterface and Stack
- // since Enable only depends on these.
- ep := p.NewEndpoint(&nic, nil)
-
- // The endpoint should initially be disabled, regardless the NIC's enabled
- // status.
- if ep.Enabled() {
- t.Fatal("got ep.Enabled() = true, want = false")
- }
- nic.setEnabled(true)
- if ep.Enabled() {
- t.Fatal("got ep.Enabled() = true, want = false")
- }
-
- // Attempting to enable the endpoint while the NIC is disabled should
- // fail.
- nic.setEnabled(false)
- err := ep.Enable()
- if _, ok := err.(*tcpip.ErrNotPermitted); !ok {
- t.Fatalf("got ep.Enable() = %s, want = %s", err, &tcpip.ErrNotPermitted{})
- }
- // ep should consider the NIC's enabled status when determining its own
- // enabled status so we "enable" the NIC to read just the endpoint's
- // enabled status.
- nic.setEnabled(true)
- if ep.Enabled() {
- t.Fatal("got ep.Enabled() = true, want = false")
- }
-
- // Enabling the interface after the NIC has been enabled should succeed.
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
- if !ep.Enabled() {
- t.Fatal("got ep.Enabled() = false, want = true")
- }
-
- // ep should consider the NIC's enabled status when determining its own
- // enabled status.
- nic.setEnabled(false)
- if ep.Enabled() {
- t.Fatal("got ep.Enabled() = true, want = false")
- }
-
- // Disabling the endpoint when the NIC is enabled should make the endpoint
- // disabled.
- nic.setEnabled(true)
- ep.Disable()
- if ep.Enabled() {
- t.Fatal("got ep.Enabled() = true, want = false")
- }
- })
- }
-}
-
-func TestIPv4Send(t *testing.T) {
- s := buildDummyStack(t)
- proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
- nic := testInterface{
- testObject: testObject{
- t: t,
- v4: true,
- },
- }
- ep := proto.NewEndpoint(&nic, nil)
- defer ep.Close()
-
- // Allocate and initialize the payload view.
- payload := buffer.NewView(100)
- for i := 0; i < len(payload); i++ {
- payload[i] = uint8(i)
- }
-
- // Setup the packet buffer.
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(ep.MaxHeaderLength()),
- Data: payload.ToVectorisedView(),
- })
-
- // Issue the write.
- nic.testObject.protocol = 123
- nic.testObject.srcAddr = localIPv4Addr
- nic.testObject.dstAddr = remoteIPv4Addr
- nic.testObject.contents = payload
-
- r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
- if err != nil {
- t.Fatalf("could not find route: %v", err)
- }
- if err := ep.WritePacket(r, stack.NetworkHeaderParams{
- Protocol: 123,
- TTL: 123,
- TOS: stack.DefaultTOS,
- }, pkt); err != nil {
- t.Fatalf("WritePacket failed: %v", err)
- }
-}
-
-func TestReceive(t *testing.T) {
- tests := []struct {
- name string
- protoFactory stack.NetworkProtocolFactory
- protoNum tcpip.NetworkProtocolNumber
- v4 bool
- epAddr tcpip.AddressWithPrefix
- handlePacket func(*testing.T, stack.NetworkEndpoint, *testInterface)
- }{
- {
- name: "IPv4",
- protoFactory: ipv4.NewProtocol,
- protoNum: ipv4.ProtocolNumber,
- v4: true,
- epAddr: localIPv4Addr.WithPrefix(),
- handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) {
- const totalLen = header.IPv4MinimumSize + 30 /* payload length */
-
- view := buffer.NewView(totalLen)
- ip := header.IPv4(view)
- ip.Encode(&header.IPv4Fields{
- TotalLength: totalLen,
- TTL: ipv4.DefaultTTL,
- Protocol: 10,
- SrcAddr: remoteIPv4Addr,
- DstAddr: localIPv4Addr,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- // Make payload be non-zero.
- for i := header.IPv4MinimumSize; i < len(view); i++ {
- view[i] = uint8(i)
- }
-
- // Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- nic.testObject.protocol = 10
- nic.testObject.srcAddr = remoteIPv4Addr
- nic.testObject.dstAddr = localIPv4Addr
- nic.testObject.contents = view[header.IPv4MinimumSize:totalLen]
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: view.ToVectorisedView(),
- })
- ep.HandlePacket(pkt)
- },
- },
- {
- name: "IPv6",
- protoFactory: ipv6.NewProtocol,
- protoNum: ipv6.ProtocolNumber,
- v4: false,
- epAddr: localIPv6Addr.WithPrefix(),
- handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) {
- const payloadLen = 30
- view := buffer.NewView(header.IPv6MinimumSize + payloadLen)
- ip := header.IPv6(view)
- ip.Encode(&header.IPv6Fields{
- PayloadLength: payloadLen,
- TransportProtocol: 10,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: remoteIPv6Addr,
- DstAddr: localIPv6Addr,
- })
-
- // Make payload be non-zero.
- for i := header.IPv6MinimumSize; i < len(view); i++ {
- view[i] = uint8(i)
- }
-
- // Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
- nic.testObject.protocol = 10
- nic.testObject.srcAddr = remoteIPv6Addr
- nic.testObject.dstAddr = localIPv6Addr
- nic.testObject.contents = view[header.IPv6MinimumSize:][:payloadLen]
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: view.ToVectorisedView(),
- })
- ep.HandlePacket(pkt)
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory},
- })
- nic := testInterface{
- testObject: testObject{
- t: t,
- v4: test.v4,
- },
- }
- ep := s.NetworkProtocolInstance(test.protoNum).NewEndpoint(&nic, &nic.testObject)
- defer ep.Close()
-
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
-
- addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
- if !ok {
- t.Fatalf("expected network endpoint with number = %d to implement stack.AddressableEndpoint", test.protoNum)
- }
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", test.epAddr, err)
- } else {
- ep.DecRef()
- }
-
- stat := s.Stats().IP.PacketsReceived
- if got := stat.Value(); got != 0 {
- t.Fatalf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 0", got)
- }
- test.handlePacket(t, ep, &nic)
- if nic.testObject.dataCalls != 1 {
- t.Errorf("Bad number of data calls: got %d, want 1", nic.testObject.dataCalls)
- }
- if nic.testObject.rawCalls != 1 {
- t.Errorf("Bad number of raw calls: got %d, want 1", nic.testObject.rawCalls)
- }
- if got := stat.Value(); got != 1 {
- t.Errorf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 1", got)
- }
- })
- }
-}
-
-func TestIPv4ReceiveControl(t *testing.T) {
- const (
- mtu = 0xbeef - header.IPv4MinimumSize
- dataLen = 8
- )
-
- cases := []struct {
- name string
- expectedCount int
- fragmentOffset uint16
- code header.ICMPv4Code
- transErr transportError
- trunc int
- }{
- {
- name: "FragmentationNeeded",
- expectedCount: 1,
- fragmentOffset: 0,
- code: header.ICMPv4FragmentationNeeded,
- transErr: transportError{
- origin: tcpip.SockExtErrorOriginICMP,
- typ: uint8(header.ICMPv4DstUnreachable),
- code: uint8(header.ICMPv4FragmentationNeeded),
- info: mtu,
- kind: stack.PacketTooBigTransportError,
- },
- trunc: 0,
- },
- {
- name: "Truncated (missing IPv4 header)",
- expectedCount: 0,
- fragmentOffset: 0,
- code: header.ICMPv4FragmentationNeeded,
- trunc: header.IPv4MinimumSize + header.ICMPv4MinimumSize,
- },
- {
- name: "Truncated (partial offending packet's IP header)",
- expectedCount: 0,
- fragmentOffset: 0,
- code: header.ICMPv4FragmentationNeeded,
- trunc: header.IPv4MinimumSize + header.ICMPv4MinimumSize + header.IPv4MinimumSize - 1,
- },
- {
- name: "Truncated (partial offending packet's data)",
- expectedCount: 0,
- fragmentOffset: 0,
- code: header.ICMPv4FragmentationNeeded,
- trunc: header.ICMPv4MinimumSize + header.ICMPv4MinimumSize + header.IPv4MinimumSize + dataLen - 1,
- },
- {
- name: "Port unreachable",
- expectedCount: 1,
- fragmentOffset: 0,
- code: header.ICMPv4PortUnreachable,
- transErr: transportError{
- origin: tcpip.SockExtErrorOriginICMP,
- typ: uint8(header.ICMPv4DstUnreachable),
- code: uint8(header.ICMPv4PortUnreachable),
- kind: stack.DestinationPortUnreachableTransportError,
- },
- trunc: 0,
- },
- {
- name: "Non-zero fragment offset",
- expectedCount: 0,
- fragmentOffset: 100,
- code: header.ICMPv4PortUnreachable,
- trunc: 0,
- },
- {
- name: "Zero-length packet",
- expectedCount: 0,
- fragmentOffset: 100,
- code: header.ICMPv4PortUnreachable,
- trunc: 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + dataLen,
- },
- }
- for _, c := range cases {
- t.Run(c.name, func(t *testing.T) {
- s := buildDummyStack(t)
- proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
- nic := testInterface{
- testObject: testObject{
- t: t,
- },
- }
- ep := proto.NewEndpoint(&nic, &nic.testObject)
- defer ep.Close()
-
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
-
- const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize
- view := buffer.NewView(dataOffset + dataLen)
-
- // Create the outer IPv4 header.
- ip := header.IPv4(view)
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(len(view) - c.trunc),
- TTL: 20,
- Protocol: uint8(header.ICMPv4ProtocolNumber),
- SrcAddr: "\x0a\x00\x00\xbb",
- DstAddr: localIPv4Addr,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- // Create the ICMP header.
- icmp := header.ICMPv4(view[header.IPv4MinimumSize:])
- icmp.SetType(header.ICMPv4DstUnreachable)
- icmp.SetCode(c.code)
- icmp.SetIdent(0xdead)
- icmp.SetSequence(0xbeef)
-
- // Create the inner IPv4 header.
- ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:])
- ip.Encode(&header.IPv4Fields{
- TotalLength: 100,
- TTL: 20,
- Protocol: 10,
- FragmentOffset: c.fragmentOffset,
- SrcAddr: localIPv4Addr,
- DstAddr: remoteIPv4Addr,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- // Make payload be non-zero.
- for i := dataOffset; i < len(view); i++ {
- view[i] = uint8(i)
- }
-
- icmp.SetChecksum(0)
- checksum := ^header.Checksum(icmp, 0 /* initial */)
- icmp.SetChecksum(checksum)
-
- // Give packet to IPv4 endpoint, dispatcher will validate that
- // it's ok.
- nic.testObject.protocol = 10
- nic.testObject.srcAddr = remoteIPv4Addr
- nic.testObject.dstAddr = localIPv4Addr
- nic.testObject.contents = view[dataOffset:]
- nic.testObject.transErr = c.transErr
-
- addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
- if !ok {
- t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint")
- }
- addr := localIPv4Addr.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
- } else {
- ep.DecRef()
- }
-
- pkt := truncatedPacket(view, c.trunc, header.IPv4MinimumSize)
- ep.HandlePacket(pkt)
- if want := c.expectedCount; nic.testObject.controlCalls != want {
- t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
- }
- })
- }
-}
-
-func TestIPv4FragmentationReceive(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- })
- proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
- nic := testInterface{
- testObject: testObject{
- t: t,
- v4: true,
- },
- }
- ep := proto.NewEndpoint(&nic, &nic.testObject)
- defer ep.Close()
-
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
-
- totalLen := header.IPv4MinimumSize + 24
-
- frag1 := buffer.NewView(totalLen)
- ip1 := header.IPv4(frag1)
- ip1.Encode(&header.IPv4Fields{
- TotalLength: uint16(totalLen),
- TTL: 20,
- Protocol: 10,
- FragmentOffset: 0,
- Flags: header.IPv4FlagMoreFragments,
- SrcAddr: remoteIPv4Addr,
- DstAddr: localIPv4Addr,
- })
- ip1.SetChecksum(^ip1.CalculateChecksum())
-
- // Make payload be non-zero.
- for i := header.IPv4MinimumSize; i < totalLen; i++ {
- frag1[i] = uint8(i)
- }
-
- frag2 := buffer.NewView(totalLen)
- ip2 := header.IPv4(frag2)
- ip2.Encode(&header.IPv4Fields{
- TotalLength: uint16(totalLen),
- TTL: 20,
- Protocol: 10,
- FragmentOffset: 24,
- SrcAddr: remoteIPv4Addr,
- DstAddr: localIPv4Addr,
- })
- ip2.SetChecksum(^ip2.CalculateChecksum())
-
- // Make payload be non-zero.
- for i := header.IPv4MinimumSize; i < totalLen; i++ {
- frag2[i] = uint8(i)
- }
-
- // Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- nic.testObject.protocol = 10
- nic.testObject.srcAddr = remoteIPv4Addr
- nic.testObject.dstAddr = localIPv4Addr
- nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
-
- // Send first segment.
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: frag1.ToVectorisedView(),
- })
-
- addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
- if !ok {
- t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint")
- }
- addr := localIPv4Addr.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
- } else {
- ep.DecRef()
- }
-
- ep.HandlePacket(pkt)
- if nic.testObject.dataCalls != 0 {
- t.Fatalf("Bad number of data calls: got %d, want 0", nic.testObject.dataCalls)
- }
- if nic.testObject.rawCalls != 0 {
- t.Errorf("Bad number of raw calls: got %d, want 0", nic.testObject.rawCalls)
- }
-
- // Send second segment.
- pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: frag2.ToVectorisedView(),
- })
- ep.HandlePacket(pkt)
- if nic.testObject.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %d, want 1", nic.testObject.dataCalls)
- }
- if nic.testObject.rawCalls != 1 {
- t.Errorf("Bad number of raw calls: got %d, want 1", nic.testObject.rawCalls)
- }
-}
-
-func TestIPv6Send(t *testing.T) {
- s := buildDummyStack(t)
- proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
- nic := testInterface{
- testObject: testObject{
- t: t,
- },
- }
- ep := proto.NewEndpoint(&nic, nil)
- defer ep.Close()
-
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
-
- // Allocate and initialize the payload view.
- payload := buffer.NewView(100)
- for i := 0; i < len(payload); i++ {
- payload[i] = uint8(i)
- }
-
- // Setup the packet buffer.
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(ep.MaxHeaderLength()),
- Data: payload.ToVectorisedView(),
- })
-
- // Issue the write.
- nic.testObject.protocol = 123
- nic.testObject.srcAddr = localIPv6Addr
- nic.testObject.dstAddr = remoteIPv6Addr
- nic.testObject.contents = payload
-
- r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr)
- if err != nil {
- t.Fatalf("could not find route: %v", err)
- }
- if err := ep.WritePacket(r, stack.NetworkHeaderParams{
- Protocol: 123,
- TTL: 123,
- TOS: stack.DefaultTOS,
- }, pkt); err != nil {
- t.Fatalf("WritePacket failed: %v", err)
- }
-}
-
-func TestIPv6ReceiveControl(t *testing.T) {
- const (
- mtu = 0xffff
- outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa"
- dataLen = 8
- )
-
- newUint16 := func(v uint16) *uint16 { return &v }
-
- portUnreachableTransErr := transportError{
- origin: tcpip.SockExtErrorOriginICMP6,
- typ: uint8(header.ICMPv6DstUnreachable),
- code: uint8(header.ICMPv6PortUnreachable),
- kind: stack.DestinationPortUnreachableTransportError,
- }
-
- cases := []struct {
- name string
- expectedCount int
- fragmentOffset *uint16
- typ header.ICMPv6Type
- code header.ICMPv6Code
- transErr transportError
- trunc int
- }{
- {
- name: "PacketTooBig",
- expectedCount: 1,
- fragmentOffset: nil,
- typ: header.ICMPv6PacketTooBig,
- code: header.ICMPv6UnusedCode,
- transErr: transportError{
- origin: tcpip.SockExtErrorOriginICMP6,
- typ: uint8(header.ICMPv6PacketTooBig),
- code: uint8(header.ICMPv6UnusedCode),
- info: mtu,
- kind: stack.PacketTooBigTransportError,
- },
- trunc: 0,
- },
- {
- name: "Truncated (missing offending packet's IPv6 header)",
- expectedCount: 0,
- fragmentOffset: nil,
- typ: header.ICMPv6PacketTooBig,
- code: header.ICMPv6UnusedCode,
- trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize,
- },
- {
- name: "Truncated PacketTooBig (partial offending packet's IPv6 header)",
- expectedCount: 0,
- fragmentOffset: nil,
- typ: header.ICMPv6PacketTooBig,
- code: header.ICMPv6UnusedCode,
- trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize - 1,
- },
- {
- name: "Truncated (partial offending packet's data)",
- expectedCount: 0,
- fragmentOffset: nil,
- typ: header.ICMPv6PacketTooBig,
- code: header.ICMPv6UnusedCode,
- trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + dataLen - 1,
- },
- {
- name: "Port unreachable",
- expectedCount: 1,
- fragmentOffset: nil,
- typ: header.ICMPv6DstUnreachable,
- code: header.ICMPv6PortUnreachable,
- transErr: portUnreachableTransErr,
- trunc: 0,
- },
- {
- name: "Truncated DstPortUnreachable (partial offending packet's IP header)",
- expectedCount: 0,
- fragmentOffset: nil,
- typ: header.ICMPv6DstUnreachable,
- code: header.ICMPv6PortUnreachable,
- trunc: header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + header.IPv6MinimumSize - 1,
- },
- {
- name: "DstPortUnreachable for Fragmented, zero offset",
- expectedCount: 1,
- fragmentOffset: newUint16(0),
- typ: header.ICMPv6DstUnreachable,
- code: header.ICMPv6PortUnreachable,
- transErr: portUnreachableTransErr,
- trunc: 0,
- },
- {
- name: "DstPortUnreachable for Non-zero fragment offset",
- expectedCount: 0,
- fragmentOffset: newUint16(100),
- typ: header.ICMPv6DstUnreachable,
- code: header.ICMPv6PortUnreachable,
- transErr: portUnreachableTransErr,
- trunc: 0,
- },
- {
- name: "Zero-length packet",
- expectedCount: 0,
- fragmentOffset: nil,
- typ: header.ICMPv6DstUnreachable,
- code: header.ICMPv6PortUnreachable,
- trunc: 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + dataLen,
- },
- }
- for _, c := range cases {
- t.Run(c.name, func(t *testing.T) {
- s := buildDummyStack(t)
- proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
- nic := testInterface{
- testObject: testObject{
- t: t,
- },
- }
- ep := proto.NewEndpoint(&nic, &nic.testObject)
- defer ep.Close()
-
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
-
- dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize
- if c.fragmentOffset != nil {
- dataOffset += header.IPv6FragmentHeaderSize
- }
- view := buffer.NewView(dataOffset + dataLen)
-
- // Create the outer IPv6 header.
- ip := header.IPv6(view)
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: 20,
- SrcAddr: outerSrcAddr,
- DstAddr: localIPv6Addr,
- })
-
- // Create the ICMP header.
- icmp := header.ICMPv6(view[header.IPv6MinimumSize:])
- icmp.SetType(c.typ)
- icmp.SetCode(c.code)
- icmp.SetIdent(0xdead)
- icmp.SetSequence(0xbeef)
-
- var extHdrs header.IPv6ExtHdrSerializer
- // Build the fragmentation header if needed.
- if c.fragmentOffset != nil {
- extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{
- FragmentOffset: *c.fragmentOffset,
- M: true,
- Identification: 0x12345678,
- })
- }
-
- // Create the inner IPv6 header.
- ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:])
- ip.Encode(&header.IPv6Fields{
- PayloadLength: 100,
- TransportProtocol: 10,
- HopLimit: 20,
- SrcAddr: localIPv6Addr,
- DstAddr: remoteIPv6Addr,
- ExtensionHeaders: extHdrs,
- })
-
- // Make payload be non-zero.
- for i := dataOffset; i < len(view); i++ {
- view[i] = uint8(i)
- }
-
- // Give packet to IPv6 endpoint, dispatcher will validate that
- // it's ok.
- nic.testObject.protocol = 10
- nic.testObject.srcAddr = remoteIPv6Addr
- nic.testObject.dstAddr = localIPv6Addr
- nic.testObject.contents = view[dataOffset:]
- nic.testObject.transErr = c.transErr
-
- // Set ICMPv6 checksum.
- icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmp,
- Src: outerSrcAddr,
- Dst: localIPv6Addr,
- }))
-
- addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
- if !ok {
- t.Fatal("expected IPv6 network endpoint to implement stack.AddressableEndpoint")
- }
- addr := localIPv6Addr.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
- } else {
- ep.DecRef()
- }
- pkt := truncatedPacket(view, c.trunc, header.IPv6MinimumSize)
- ep.HandlePacket(pkt)
- if want := c.expectedCount; nic.testObject.controlCalls != want {
- t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
- }
- })
- }
-}
-
-// truncatedPacket returns a PacketBuffer based on a truncated view. If view,
-// after truncation, is large enough to hold a network header, it makes part of
-// view the packet's NetworkHeader and the rest its Data. Otherwise all of view
-// becomes Data.
-func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer {
- v := view[:len(view)-trunc]
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: v.ToVectorisedView(),
- })
- return pkt
-}
-
-func TestWriteHeaderIncludedPacket(t *testing.T) {
- const (
- nicID = 1
- transportProto = 5
-
- dataLen = 4
- )
-
- dataBuf := [dataLen]byte{1, 2, 3, 4}
- data := dataBuf[:]
-
- ipv4Options := header.IPv4OptionsSerializer{
- &header.IPv4SerializableListEndOption{},
- &header.IPv4SerializableNOPOption{},
- &header.IPv4SerializableListEndOption{},
- &header.IPv4SerializableNOPOption{},
- }
-
- expectOptions := header.IPv4Options{
- byte(header.IPv4OptionListEndType),
- byte(header.IPv4OptionNOPType),
- byte(header.IPv4OptionListEndType),
- byte(header.IPv4OptionNOPType),
- }
-
- ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4}
- ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:]
-
- var ipv6PayloadWithExtHdrBuf [dataLen + header.IPv6FragmentExtHdrLength]byte
- ipv6PayloadWithExtHdr := ipv6PayloadWithExtHdrBuf[:]
- if n := copy(ipv6PayloadWithExtHdr, ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) {
- t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr))
- }
- if n := copy(ipv6PayloadWithExtHdr[header.IPv6FragmentExtHdrLength:], data); n != len(data) {
- t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
- }
-
- tests := []struct {
- name string
- protoFactory stack.NetworkProtocolFactory
- protoNum tcpip.NetworkProtocolNumber
- nicAddr tcpip.AddressWithPrefix
- remoteAddr tcpip.Address
- pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView
- checker func(*testing.T, *stack.PacketBuffer, tcpip.Address)
- expectedErr tcpip.Error
- }{
- {
- name: "IPv4",
- protoFactory: ipv4.NewProtocol,
- protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4AddrWithPrefix,
- remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- totalLen := header.IPv4MinimumSize + len(data)
- hdr := buffer.NewPrependable(totalLen)
- if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
- t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
- }
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- Protocol: transportProto,
- TTL: ipv4.DefaultTTL,
- SrcAddr: src,
- DstAddr: remoteIPv4Addr,
- })
- return hdr.View().ToVectorisedView()
- },
- checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
- if src == header.IPv4Any {
- src = localIPv4Addr
- }
-
- netHdr := pkt.NetworkHeader()
-
- if len(netHdr.View()) != header.IPv4MinimumSize {
- t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize)
- }
-
- checker.IPv4(t, stack.PayloadSince(netHdr),
- checker.SrcAddr(src),
- checker.DstAddr(remoteIPv4Addr),
- checker.IPv4HeaderLength(header.IPv4MinimumSize),
- checker.IPFullLength(uint16(header.IPv4MinimumSize+len(data))),
- checker.IPPayload(data),
- )
- },
- },
- {
- name: "IPv4 with IHL too small",
- protoFactory: ipv4.NewProtocol,
- protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4AddrWithPrefix,
- remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- totalLen := header.IPv4MinimumSize + len(data)
- hdr := buffer.NewPrependable(totalLen)
- if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
- t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
- }
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- Protocol: transportProto,
- TTL: ipv4.DefaultTTL,
- SrcAddr: src,
- DstAddr: remoteIPv4Addr,
- })
- ip.SetHeaderLength(header.IPv4MinimumSize - 1)
- return hdr.View().ToVectorisedView()
- },
- expectedErr: &tcpip.ErrMalformedHeader{},
- },
- {
- name: "IPv4 too small",
- protoFactory: ipv4.NewProtocol,
- protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4AddrWithPrefix,
- remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- Protocol: transportProto,
- TTL: ipv4.DefaultTTL,
- SrcAddr: src,
- DstAddr: remoteIPv4Addr,
- })
- return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
- },
- expectedErr: &tcpip.ErrMalformedHeader{},
- },
- {
- name: "IPv4 minimum size",
- protoFactory: ipv4.NewProtocol,
- protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4AddrWithPrefix,
- remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- Protocol: transportProto,
- TTL: ipv4.DefaultTTL,
- SrcAddr: src,
- DstAddr: remoteIPv4Addr,
- })
- return buffer.View(ip).ToVectorisedView()
- },
- checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
- if src == header.IPv4Any {
- src = localIPv4Addr
- }
-
- netHdr := pkt.NetworkHeader()
-
- if len(netHdr.View()) != header.IPv4MinimumSize {
- t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize)
- }
-
- checker.IPv4(t, stack.PayloadSince(netHdr),
- checker.SrcAddr(src),
- checker.DstAddr(remoteIPv4Addr),
- checker.IPv4HeaderLength(header.IPv4MinimumSize),
- checker.IPFullLength(header.IPv4MinimumSize),
- checker.IPPayload(nil),
- )
- },
- },
- {
- name: "IPv4 with options",
- protoFactory: ipv4.NewProtocol,
- protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4AddrWithPrefix,
- remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
- totalLen := ipHdrLen + len(data)
- hdr := buffer.NewPrependable(totalLen)
- if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
- t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
- }
- ip := header.IPv4(hdr.Prepend(ipHdrLen))
- ip.Encode(&header.IPv4Fields{
- Protocol: transportProto,
- TTL: ipv4.DefaultTTL,
- SrcAddr: src,
- DstAddr: remoteIPv4Addr,
- Options: ipv4Options,
- })
- return hdr.View().ToVectorisedView()
- },
- checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
- if src == header.IPv4Any {
- src = localIPv4Addr
- }
-
- netHdr := pkt.NetworkHeader()
-
- hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
- if len(netHdr.View()) != hdrLen {
- t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen)
- }
-
- checker.IPv4(t, stack.PayloadSince(netHdr),
- checker.SrcAddr(src),
- checker.DstAddr(remoteIPv4Addr),
- checker.IPv4HeaderLength(hdrLen),
- checker.IPFullLength(uint16(hdrLen+len(data))),
- checker.IPv4Options(expectOptions),
- checker.IPPayload(data),
- )
- },
- },
- {
- name: "IPv4 with options and data across views",
- protoFactory: ipv4.NewProtocol,
- protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4AddrWithPrefix,
- remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length()))
- ip.Encode(&header.IPv4Fields{
- Protocol: transportProto,
- TTL: ipv4.DefaultTTL,
- SrcAddr: src,
- DstAddr: remoteIPv4Addr,
- Options: ipv4Options,
- })
- vv := buffer.View(ip).ToVectorisedView()
- vv.AppendView(data)
- return vv
- },
- checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
- if src == header.IPv4Any {
- src = localIPv4Addr
- }
-
- netHdr := pkt.NetworkHeader()
-
- hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
- if len(netHdr.View()) != hdrLen {
- t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen)
- }
-
- checker.IPv4(t, stack.PayloadSince(netHdr),
- checker.SrcAddr(src),
- checker.DstAddr(remoteIPv4Addr),
- checker.IPv4HeaderLength(hdrLen),
- checker.IPFullLength(uint16(hdrLen+len(data))),
- checker.IPv4Options(expectOptions),
- checker.IPPayload(data),
- )
- },
- },
- {
- name: "IPv6",
- protoFactory: ipv6.NewProtocol,
- protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6AddrWithPrefix,
- remoteAddr: remoteIPv6Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- totalLen := header.IPv6MinimumSize + len(data)
- hdr := buffer.NewPrependable(totalLen)
- if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
- t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
- }
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- TransportProtocol: transportProto,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: remoteIPv6Addr,
- })
- return hdr.View().ToVectorisedView()
- },
- checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
- if src == header.IPv6Any {
- src = localIPv6Addr
- }
-
- netHdr := pkt.NetworkHeader()
-
- if len(netHdr.View()) != header.IPv6MinimumSize {
- t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize)
- }
-
- checker.IPv6(t, stack.PayloadSince(netHdr),
- checker.SrcAddr(src),
- checker.DstAddr(remoteIPv6Addr),
- checker.IPFullLength(uint16(header.IPv6MinimumSize+len(data))),
- checker.IPPayload(data),
- )
- },
- },
- {
- name: "IPv6 with extension header",
- protoFactory: ipv6.NewProtocol,
- protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6AddrWithPrefix,
- remoteAddr: remoteIPv6Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data)
- hdr := buffer.NewPrependable(totalLen)
- if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
- t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
- }
- if n := copy(hdr.Prepend(len(ipv6FragmentExtHdr)), ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) {
- t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr))
- }
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- // NB: we're lying about transport protocol here to verify the raw
- // fragment header bytes.
- TransportProtocol: tcpip.TransportProtocolNumber(header.IPv6FragmentExtHdrIdentifier),
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: remoteIPv6Addr,
- })
- return hdr.View().ToVectorisedView()
- },
- checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
- if src == header.IPv6Any {
- src = localIPv6Addr
- }
-
- netHdr := pkt.NetworkHeader()
-
- if want := header.IPv6MinimumSize + len(ipv6FragmentExtHdr); len(netHdr.View()) != want {
- t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), want)
- }
-
- checker.IPv6(t, stack.PayloadSince(netHdr),
- checker.SrcAddr(src),
- checker.DstAddr(remoteIPv6Addr),
- checker.IPFullLength(uint16(header.IPv6MinimumSize+len(ipv6PayloadWithExtHdr))),
- checker.IPPayload(ipv6PayloadWithExtHdr),
- )
- },
- },
- {
- name: "IPv6 minimum size",
- protoFactory: ipv6.NewProtocol,
- protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6AddrWithPrefix,
- remoteAddr: remoteIPv6Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- TransportProtocol: transportProto,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: remoteIPv6Addr,
- })
- return buffer.View(ip).ToVectorisedView()
- },
- checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
- if src == header.IPv6Any {
- src = localIPv6Addr
- }
-
- netHdr := pkt.NetworkHeader()
-
- if len(netHdr.View()) != header.IPv6MinimumSize {
- t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize)
- }
-
- checker.IPv6(t, stack.PayloadSince(netHdr),
- checker.SrcAddr(src),
- checker.DstAddr(remoteIPv6Addr),
- checker.IPFullLength(header.IPv6MinimumSize),
- checker.IPPayload(nil),
- )
- },
- },
- {
- name: "IPv6 too small",
- protoFactory: ipv6.NewProtocol,
- protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6AddrWithPrefix,
- remoteAddr: remoteIPv6Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
- ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- TransportProtocol: transportProto,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: remoteIPv4Addr,
- })
- return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
- },
- expectedErr: &tcpip.ErrMalformedHeader{},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- subTests := []struct {
- name string
- srcAddr tcpip.Address
- }{
- {
- name: "unspecified source",
- srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr.Address))),
- },
- {
- name: "random source",
- srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr.Address))),
- },
- }
-
- for _, subTest := range subTests {
- t.Run(subTest.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory},
- })
- e := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: test.protoNum,
- AddressWithPrefix: test.nicAddr,
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
-
- s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}})
-
- r, err := s.FindRoute(nicID, test.nicAddr.Address, test.remoteAddr, test.protoNum, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr.Address, test.protoNum, err)
- }
- defer r.Release()
-
- {
- err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: test.pktGen(t, subTest.srcAddr),
- }))
- if diff := cmp.Diff(test.expectedErr, err); diff != "" {
- t.Fatalf("unexpected error from r.WriteHeaderIncludedPacket(_), (-want, +got):\n%s", diff)
- }
- }
-
- if test.expectedErr != nil {
- return
- }
-
- pkt, ok := e.Read()
- if !ok {
- t.Fatal("expected a packet to be written")
- }
- test.checker(t, pkt.Pkt, subTest.srcAddr)
- })
- }
- })
- }
-}
-
-// Test that the included data in an ICMP error packet conforms to the
-// requirements of RFC 972, RFC 4443 section 2.4 and RFC 1812 Section 4.3.2.3
-func TestICMPInclusionSize(t *testing.T) {
- const (
- replyHeaderLength4 = header.IPv4MinimumSize + header.IPv4MinimumSize + header.ICMPv4MinimumSize
- replyHeaderLength6 = header.IPv6MinimumSize + header.IPv6MinimumSize + header.ICMPv6MinimumSize
- targetSize4 = header.IPv4MinimumProcessableDatagramSize
- targetSize6 = header.IPv6MinimumMTU
- // A protocol number that will cause an error response.
- reservedProtocol = 254
- )
-
- // IPv4 function to create a IP packet and send it to the stack.
- // The packet should generate an error response. We can do that by using an
- // unknown transport protocol (254).
- rxIPv4Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) buffer.View {
- totalLen := header.IPv4MinimumSize + len(payload)
- hdr := buffer.NewPrependable(header.IPv4MinimumSize)
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(totalLen),
- Protocol: reservedProtocol,
- TTL: ipv4.DefaultTTL,
- SrcAddr: src,
- DstAddr: localIPv4Addr,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
- vv := hdr.View().ToVectorisedView()
- vv.AppendView(buffer.View(payload))
- // Take a copy before InjectInbound takes ownership of vv
- // as vv may be changed during the call.
- v := vv.ToView()
- e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- }))
- return v
- }
-
- // IPv6 function to create a packet and send it to the stack.
- // The packet should be errant in a way that causes the stack to send an
- // ICMP error response and have enough data to allow the testing of the
- // inclusion of the errant packet. Use `unknown next header' to generate
- // the error.
- rxIPv6Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) buffer.View {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize)
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(payload)),
- TransportProtocol: reservedProtocol,
- HopLimit: ipv6.DefaultTTL,
- SrcAddr: src,
- DstAddr: localIPv6Addr,
- })
- vv := hdr.View().ToVectorisedView()
- vv.AppendView(buffer.View(payload))
- // Take a copy before InjectInbound takes ownership of vv
- // as vv may be changed during the call.
- v := vv.ToView()
-
- e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- }))
- return v
- }
-
- v4Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload buffer.View) {
- // We already know the entire packet is the right size so we can use its
- // length to calculate the right payload size to check.
- expectedPayloadLength := pkt.Size() - header.IPv4MinimumSize - header.ICMPv4MinimumSize
- checker.IPv4(t, stack.PayloadSince(pkt.NetworkHeader()),
- checker.SrcAddr(localIPv4Addr),
- checker.DstAddr(remoteIPv4Addr),
- checker.IPv4HeaderLength(header.IPv4MinimumSize),
- checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+expectedPayloadLength)),
- checker.ICMPv4(
- checker.ICMPv4Checksum(),
- checker.ICMPv4Type(header.ICMPv4DstUnreachable),
- checker.ICMPv4Code(header.ICMPv4ProtoUnreachable),
- checker.ICMPv4Payload(payload[:expectedPayloadLength]),
- ),
- )
- }
-
- v6Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload buffer.View) {
- // We already know the entire packet is the right size so we can use its
- // length to calculate the right payload size to check.
- expectedPayloadLength := pkt.Size() - header.IPv6MinimumSize - header.ICMPv6MinimumSize
- checker.IPv6(t, stack.PayloadSince(pkt.NetworkHeader()),
- checker.SrcAddr(localIPv6Addr),
- checker.DstAddr(remoteIPv6Addr),
- checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectedPayloadLength)),
- checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6ParamProblem),
- checker.ICMPv6Code(header.ICMPv6UnknownHeader),
- checker.ICMPv6Payload(payload[:expectedPayloadLength]),
- ),
- )
- }
- tests := []struct {
- name string
- srcAddress tcpip.Address
- injector func(*channel.Endpoint, tcpip.Address, []byte) buffer.View
- checker func(*testing.T, *stack.PacketBuffer, buffer.View)
- payloadLength int // Not including IP header.
- linkMTU uint32 // Largest IP packet that the link can send as payload.
- replyLength int // Total size of IP/ICMP packet expected back.
- }{
- {
- name: "IPv4 exact match",
- srcAddress: remoteIPv4Addr,
- injector: rxIPv4Bad,
- checker: v4Checker,
- payloadLength: targetSize4 - replyHeaderLength4,
- linkMTU: targetSize4,
- replyLength: targetSize4,
- },
- {
- name: "IPv4 larger MTU",
- srcAddress: remoteIPv4Addr,
- injector: rxIPv4Bad,
- checker: v4Checker,
- payloadLength: targetSize4,
- linkMTU: targetSize4 + 1000,
- replyLength: targetSize4,
- },
- {
- name: "IPv4 smaller MTU",
- srcAddress: remoteIPv4Addr,
- injector: rxIPv4Bad,
- checker: v4Checker,
- payloadLength: targetSize4,
- linkMTU: targetSize4 - 50,
- replyLength: targetSize4 - 50,
- },
- {
- name: "IPv4 payload exceeds",
- srcAddress: remoteIPv4Addr,
- injector: rxIPv4Bad,
- checker: v4Checker,
- payloadLength: targetSize4 + 10,
- linkMTU: targetSize4,
- replyLength: targetSize4,
- },
- {
- name: "IPv4 1 byte less",
- srcAddress: remoteIPv4Addr,
- injector: rxIPv4Bad,
- checker: v4Checker,
- payloadLength: targetSize4 - replyHeaderLength4 - 1,
- linkMTU: targetSize4,
- replyLength: targetSize4 - 1,
- },
- {
- name: "IPv4 No payload",
- srcAddress: remoteIPv4Addr,
- injector: rxIPv4Bad,
- checker: v4Checker,
- payloadLength: 0,
- linkMTU: targetSize4,
- replyLength: replyHeaderLength4,
- },
- {
- name: "IPv6 exact match",
- srcAddress: remoteIPv6Addr,
- injector: rxIPv6Bad,
- checker: v6Checker,
- payloadLength: targetSize6 - replyHeaderLength6,
- linkMTU: targetSize6,
- replyLength: targetSize6,
- },
- {
- name: "IPv6 larger MTU",
- srcAddress: remoteIPv6Addr,
- injector: rxIPv6Bad,
- checker: v6Checker,
- payloadLength: targetSize6,
- linkMTU: targetSize6 + 400,
- replyLength: targetSize6,
- },
- // NB. No "smaller MTU" test here as less than 1280 is not permitted
- // in IPv6.
- {
- name: "IPv6 payload exceeds",
- srcAddress: remoteIPv6Addr,
- injector: rxIPv6Bad,
- checker: v6Checker,
- payloadLength: targetSize6,
- linkMTU: targetSize6,
- replyLength: targetSize6,
- },
- {
- name: "IPv6 1 byte less",
- srcAddress: remoteIPv6Addr,
- injector: rxIPv6Bad,
- checker: v6Checker,
- payloadLength: targetSize6 - replyHeaderLength6 - 1,
- linkMTU: targetSize6,
- replyLength: targetSize6 - 1,
- },
- {
- name: "IPv6 no payload",
- srcAddress: remoteIPv6Addr,
- injector: rxIPv6Bad,
- checker: v6Checker,
- payloadLength: 0,
- linkMTU: targetSize6,
- replyLength: replyHeaderLength6,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s, e := buildDummyStackWithLinkEndpoint(t, test.linkMTU)
- // Allocate and initialize the payload view.
- payload := buffer.NewView(test.payloadLength)
- for i := 0; i < len(payload); i++ {
- payload[i] = uint8(i)
- }
- // Default routes for IPv4&6 so ICMP can find a route to the remote
- // node when attempting to send the ICMP error Reply.
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: nicID,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: nicID,
- },
- })
- v := test.injector(e, test.srcAddress, payload)
- pkt, ok := e.Read()
- if !ok {
- t.Fatal("expected a packet to be written")
- }
- if got, want := pkt.Pkt.Size(), test.replyLength; got != want {
- t.Fatalf("got %d bytes of icmp error packet, want %d", got, want)
- }
- test.checker(t, pkt.Pkt, v)
- })
- }
-}
-
-func TestJoinLeaveAllRoutersGroup(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- protoFactory stack.NetworkProtocolFactory
- allRoutersAddr tcpip.Address
- }{
- {
- name: "IPv4",
- netProto: ipv4.ProtocolNumber,
- protoFactory: ipv4.NewProtocol,
- allRoutersAddr: header.IPv4AllRoutersGroup,
- },
- {
- name: "IPv6 Interface Local",
- netProto: ipv6.ProtocolNumber,
- protoFactory: ipv6.NewProtocol,
- allRoutersAddr: header.IPv6AllRoutersInterfaceLocalMulticastAddress,
- },
- {
- name: "IPv6 Link Local",
- netProto: ipv6.ProtocolNumber,
- protoFactory: ipv6.NewProtocol,
- allRoutersAddr: header.IPv6AllRoutersLinkLocalMulticastAddress,
- },
- {
- name: "IPv6 Site Local",
- netProto: ipv6.ProtocolNumber,
- protoFactory: ipv6.NewProtocol,
- allRoutersAddr: header.IPv6AllRoutersSiteLocalMulticastAddress,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- for _, nicDisabled := range [...]bool{true, false} {
- t.Run(fmt.Sprintf("NIC Disabled = %t", nicDisabled), func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
- })
- opts := stack.NICOptions{Disabled: nicDisabled}
- if err := s.CreateNICWithOptions(nicID, channel.New(0, 0, ""), opts); err != nil {
- t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, opts, err)
- }
-
- if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil {
- t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err)
- } else if got {
- t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr)
- }
-
- if err := s.SetForwardingDefaultAndAllNICs(test.netProto, true); err != nil {
- t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", test.netProto, err)
- }
- if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil {
- t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err)
- } else if !got {
- t.Fatalf("got s.IsInGroup(%d, %s) = false, want = true", nicID, test.allRoutersAddr)
- }
-
- if err := s.SetForwardingDefaultAndAllNICs(test.netProto, false); err != nil {
- t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, false): %s", test.netProto, err)
- }
- if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil {
- t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err)
- } else if got {
- t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr)
- }
- })
- }
- })
- }
-}
-
-func TestSetNICIDBeforeDeliveringToRawEndpoint(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- proto tcpip.NetworkProtocolNumber
- addr tcpip.AddressWithPrefix
- payloadOffset int
- }{
- {
- name: "IPv4",
- proto: header.IPv4ProtocolNumber,
- addr: localIPv4AddrWithPrefix,
- payloadOffset: header.IPv4MinimumSize,
- },
- {
- name: "IPv6",
- proto: header.IPv6ProtocolNumber,
- addr: localIPv6AddrWithPrefix,
- payloadOffset: 0,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{
- ipv4.NewProtocol,
- ipv6.NewProtocol,
- },
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- RawFactory: raw.EndpointFactory{},
- })
- if err := s.CreateNIC(nicID, loopback.New()); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: test.proto,
- AddressWithPrefix: test.addr,
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: test.addr.Subnet(),
- NIC: nicID,
- },
- })
-
- var wq waiter.Queue
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- ep, err := s.NewRawEndpoint(udp.ProtocolNumber, test.proto, &wq, true /* associated */)
- if err != nil {
- t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err)
- }
- defer ep.Close()
-
- writeOpts := tcpip.WriteOptions{
- To: &tcpip.FullAddress{
- Addr: test.addr.Address,
- },
- }
- data := []byte{1, 2, 3, 4}
- var r bytes.Reader
- r.Reset(data)
- if n, err := ep.Write(&r, writeOpts); err != nil {
- t.Fatalf("ep.Write(_, _): %s", err)
- } else if want := int64(len(data)); n != want {
- t.Fatalf("got ep.Write(_, _) = (%d, nil), want = (%d, nil)", n, want)
- }
-
- // Wait for the endpoint to become readable.
- <-ch
-
- var w bytes.Buffer
- rr, err := ep.Read(&w, tcpip.ReadOptions{
- NeedRemoteAddr: true,
- })
- if err != nil {
- t.Fatalf("ep.Read(...): %s", err)
- }
- if diff := cmp.Diff(data, w.Bytes()[test.payloadOffset:]); diff != "" {
- t.Errorf("payload mismatch (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(tcpip.FullAddress{Addr: test.addr.Address, NIC: nicID}, rr.RemoteAddr); diff != "" {
- t.Errorf("remote addr mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
deleted file mode 100644
index 2257f728e..000000000
--- a/pkg/tcpip/network/ipv4/BUILD
+++ /dev/null
@@ -1,67 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "ipv4",
- srcs = [
- "icmp.go",
- "igmp.go",
- "ipv4.go",
- "stats.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/header/parse",
- "//pkg/tcpip/network/hash",
- "//pkg/tcpip/network/internal/fragmentation",
- "//pkg/tcpip/network/internal/ip",
- "//pkg/tcpip/stack",
- ],
-)
-
-go_test(
- name = "ipv4_test",
- size = "small",
- srcs = [
- "igmp_test.go",
- "ipv4_test.go",
- ],
- deps = [
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/faketime",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/sniffer",
- "//pkg/tcpip/network/arp",
- "//pkg/tcpip/network/internal/testutil",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/icmp",
- "//pkg/tcpip/transport/raw",
- "//pkg/tcpip/transport/tcp",
- "//pkg/tcpip/transport/udp",
- "//pkg/waiter",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
-
-go_test(
- name = "stats_test",
- size = "small",
- srcs = ["stats_test.go"],
- library = ":ipv4",
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/testutil",
- ],
-)
diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go
deleted file mode 100644
index c6576fcbc..000000000
--- a/pkg/tcpip/network/ipv4/igmp_test.go
+++ /dev/null
@@ -1,401 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ipv4_test
-
-import (
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
-)
-
-const (
- linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- nicID = 1
- defaultTTL = 1
- defaultPrefixLength = 24
-)
-
-var (
- stackAddr = testutil.MustParse4("10.0.0.1")
- remoteAddr = testutil.MustParse4("10.0.0.2")
- multicastAddr = testutil.MustParse4("224.0.0.3")
-)
-
-// validateIgmpPacket checks that a passed PacketInfo is an IPv4 IGMP packet
-// sent to the provided address with the passed fields set. Raises a t.Error if
-// any field does not match.
-func validateIgmpPacket(t *testing.T, p channel.PacketInfo, igmpType header.IGMPType, maxRespTime byte, srcAddr, dstAddr, groupAddress tcpip.Address) {
- t.Helper()
-
- payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
- checker.IPv4(t, payload,
- checker.SrcAddr(srcAddr),
- checker.DstAddr(dstAddr),
- // TTL for an IGMP message must be 1 as per RFC 2236 section 2.
- checker.TTL(1),
- checker.IPv4RouterAlert(),
- checker.IGMP(
- checker.IGMPType(igmpType),
- checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)),
- checker.IGMPGroupAddress(groupAddress),
- ),
- )
-}
-
-func createStack(t *testing.T, igmpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) {
- t.Helper()
-
- // Create an endpoint of queue size 1, since no more than 1 packets are ever
- // queued in the tests in this file.
- e := channel.New(1, 1280, linkAddr)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocolWithOptions(ipv4.Options{
- IGMP: ipv4.IGMPOptions{
- Enabled: igmpEnabled,
- },
- })},
- Clock: clock,
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- return e, s, clock
-}
-
-func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, maxRespTime byte, ttl uint8, srcAddr, dstAddr, groupAddress tcpip.Address, hasRouterAlertOption bool) {
- var options header.IPv4OptionsSerializer
- if hasRouterAlertOption {
- options = header.IPv4OptionsSerializer{
- &header.IPv4SerializableRouterAlertOption{},
- }
- }
- buf := buffer.NewView(header.IPv4MinimumSize + int(options.Length()) + header.IGMPQueryMinimumSize)
-
- ip := header.IPv4(buf)
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(len(buf)),
- TTL: ttl,
- Protocol: uint8(header.IGMPProtocolNumber),
- SrcAddr: srcAddr,
- DstAddr: dstAddr,
- Options: options,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- igmp := header.IGMP(ip.Payload())
- igmp.SetType(igmpType)
- igmp.SetMaxRespTime(maxRespTime)
- igmp.SetGroupAddress(groupAddress)
- igmp.SetChecksum(header.IGMPCalculateChecksum(igmp))
-
- e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-}
-
-// TestIGMPV1Present tests the node's ability to fallback to V1 when a V1
-// router is detected. V1 present status is expected to be reset when the NIC
-// cycles.
-func TestIGMPV1Present(t *testing.T) {
- e, s, clock := createStack(t, true)
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength},
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
-
- if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
- t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err)
- }
-
- // This NIC will send an IGMPv2 report immediately, before this test can get
- // the IGMPv1 General Membership Query in.
- {
- p, ok := e.Read()
- if !ok {
- t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
- }
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got)
- }
- validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr)
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Inject an IGMPv1 General Membership Query which is identical to a standard
- // membership query except the Max Response Time is set to 0, which will tell
- // the stack that this is a router using IGMPv1. Send it to the all systems
- // group which is the only group this host belongs to.
- createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, 0, defaultTTL, remoteAddr, stackAddr, header.IPv4AllSystems, true /* hasRouterAlertOption */)
- if got := s.Stats().IGMP.PacketsReceived.MembershipQuery.Value(); got != 1 {
- t.Fatalf("got Membership Queries received = %d, want = 1", got)
- }
-
- // Before advancing the clock, verify that this host has not sent a
- // V1MembershipReport yet.
- if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 0 {
- t.Fatalf("got V1MembershipReport messages sent = %d, want = 0", got)
- }
-
- // Verify the solicited Membership Report is sent. Now that this NIC has seen
- // an IGMPv1 query, it should send an IGMPv1 Membership Report.
- if p, ok := e.Read(); ok {
- t.Fatalf("sent unexpected packet, expected V1MembershipReport only after advancing the clock = %+v", p.Pkt)
- }
- clock.Advance(ipv4.UnsolicitedReportIntervalMax)
- {
- p, ok := e.Read()
- if !ok {
- t.Fatal("unable to Read IGMP packet, expected V1MembershipReport")
- }
- if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 1 {
- t.Fatalf("got V1MembershipReport messages sent = %d, want = 1", got)
- }
- validateIgmpPacket(t, p, header.IGMPv1MembershipReport, 0, stackAddr, multicastAddr, multicastAddr)
- }
-
- // Cycling the interface should reset the V1 present flag.
- if err := s.DisableNIC(nicID); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
- }
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
- }
- {
- p, ok := e.Read()
- if !ok {
- t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
- }
- if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 2 {
- t.Fatalf("got V2MembershipReport messages sent = %d, want = 2", got)
- }
- validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr)
- }
-}
-
-func TestSendQueuedIGMPReports(t *testing.T) {
- e, s, clock := createStack(t, true)
-
- // Joining a group without an assigned address should queue IGMP packets; none
- // should be sent without an assigned address.
- if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
- t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv4.ProtocolNumber, nicID, multicastAddr, err)
- }
- reportStat := s.Stats().IGMP.PacketsSent.V2MembershipReport
- if got := reportStat.Value(); got != 0 {
- t.Errorf("got reportStat.Value() = %d, want = 0", got)
- }
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Fatalf("got unexpected packet = %#v", p)
- }
-
- // The initial set of IGMP reports that were queued should be sent once an
- // address is assigned.
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: stackAddr,
- PrefixLen: defaultPrefixLength,
- },
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- if got := reportStat.Value(); got != 1 {
- t.Errorf("got reportStat.Value() = %d, want = 1", got)
- }
- if p, ok := e.Read(); !ok {
- t.Error("expected to send an IGMP membership report")
- } else {
- validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr)
- }
- if t.Failed() {
- t.FailNow()
- }
- clock.Advance(ipv4.UnsolicitedReportIntervalMax)
- if got := reportStat.Value(); got != 2 {
- t.Errorf("got reportStat.Value() = %d, want = 2", got)
- }
- if p, ok := e.Read(); !ok {
- t.Error("expected to send an IGMP membership report")
- } else {
- validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr)
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Should have no more packets to send after the initial set of unsolicited
- // reports.
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Fatalf("got unexpected packet = %#v", p)
- }
-}
-
-func TestIGMPPacketValidation(t *testing.T) {
- tests := []struct {
- name string
- messageType header.IGMPType
- stackAddresses []tcpip.AddressWithPrefix
- srcAddr tcpip.Address
- includeRouterAlertOption bool
- ttl uint8
- expectValidIGMP bool
- getMessageTypeStatValue func(tcpip.Stats) uint64
- }{
- {
- name: "valid",
- messageType: header.IGMPLeaveGroup,
- includeRouterAlertOption: true,
- stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
- srcAddr: remoteAddr,
- ttl: 1,
- expectValidIGMP: true,
- getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.LeaveGroup.Value() },
- },
- {
- name: "bad ttl",
- messageType: header.IGMPv1MembershipReport,
- includeRouterAlertOption: true,
- stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
- srcAddr: remoteAddr,
- ttl: 2,
- expectValidIGMP: false,
- getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V1MembershipReport.Value() },
- },
- {
- name: "missing router alert ip option",
- messageType: header.IGMPv2MembershipReport,
- includeRouterAlertOption: false,
- stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
- srcAddr: remoteAddr,
- ttl: 1,
- expectValidIGMP: false,
- getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() },
- },
- {
- name: "igmp leave group and src ip does not belong to nic subnet",
- messageType: header.IGMPLeaveGroup,
- includeRouterAlertOption: true,
- stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
- srcAddr: testutil.MustParse4("10.0.1.2"),
- ttl: 1,
- expectValidIGMP: false,
- getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.LeaveGroup.Value() },
- },
- {
- name: "igmp query and src ip does not belong to nic subnet",
- messageType: header.IGMPMembershipQuery,
- includeRouterAlertOption: true,
- stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
- srcAddr: testutil.MustParse4("10.0.1.2"),
- ttl: 1,
- expectValidIGMP: true,
- getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.MembershipQuery.Value() },
- },
- {
- name: "igmp report v1 and src ip does not belong to nic subnet",
- messageType: header.IGMPv1MembershipReport,
- includeRouterAlertOption: true,
- stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
- srcAddr: testutil.MustParse4("10.0.1.2"),
- ttl: 1,
- expectValidIGMP: false,
- getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V1MembershipReport.Value() },
- },
- {
- name: "igmp report v2 and src ip does not belong to nic subnet",
- messageType: header.IGMPv2MembershipReport,
- includeRouterAlertOption: true,
- stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
- srcAddr: testutil.MustParse4("10.0.1.2"),
- ttl: 1,
- expectValidIGMP: false,
- getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() },
- },
- {
- name: "src ip belongs to the subnet of the nic's second address",
- messageType: header.IGMPv2MembershipReport,
- includeRouterAlertOption: true,
- stackAddresses: []tcpip.AddressWithPrefix{
- {Address: testutil.MustParse4("10.0.15.1"), PrefixLen: 24},
- {Address: stackAddr, PrefixLen: 24},
- },
- srcAddr: remoteAddr,
- ttl: 1,
- expectValidIGMP: true,
- getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e, s, _ := createStack(t, true)
- for _, address := range test.stackAddresses {
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: address,
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- }
- stats := s.Stats()
- // Verify that every relevant stats is zero'd before we send a packet.
- if got := test.getMessageTypeStatValue(s.Stats()); got != 0 {
- t.Errorf("got test.getMessageTypeStatValue(s.Stats()) = %d, want = 0", got)
- }
- if got := stats.IGMP.PacketsReceived.Invalid.Value(); got != 0 {
- t.Errorf("got stats.IGMP.PacketsReceived.Invalid.Value() = %d, want = 0", got)
- }
- if got := stats.IP.PacketsDelivered.Value(); got != 0 {
- t.Fatalf("got stats.IP.PacketsDelivered.Value() = %d, want = 0", got)
- }
- createAndInjectIGMPPacket(e, test.messageType, 0, test.ttl, test.srcAddr, header.IPv4AllSystems, header.IPv4AllSystems, test.includeRouterAlertOption)
- // We always expect the packet to pass IP validation.
- if got := stats.IP.PacketsDelivered.Value(); got != 1 {
- t.Fatalf("got stats.IP.PacketsDelivered.Value() = %d, want = 1", got)
- }
- // Even when the IGMP-specific validation checks fail, we expect the
- // corresponding IGMP counter to be incremented.
- if got := test.getMessageTypeStatValue(s.Stats()); got != 1 {
- t.Errorf("got test.getMessageTypeStatValue(s.Stats()) = %d, want = 1", got)
- }
- var expectedInvalidCount uint64
- if !test.expectValidIGMP {
- expectedInvalidCount = 1
- }
- if got := stats.IGMP.PacketsReceived.Invalid.Value(); got != expectedInvalidCount {
- t.Errorf("got stats.IGMP.PacketsReceived.Invalid.Value() = %d, want = %d", got, expectedInvalidCount)
- }
- })
- }
-}
diff --git a/pkg/tcpip/network/ipv4/ipv4_state_autogen.go b/pkg/tcpip/network/ipv4/ipv4_state_autogen.go
new file mode 100644
index 000000000..19b672251
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/ipv4_state_autogen.go
@@ -0,0 +1,113 @@
+// automatically generated by stateify.
+
+package ipv4
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (i *icmpv4DestinationUnreachableSockError) StateTypeName() string {
+ return "pkg/tcpip/network/ipv4.icmpv4DestinationUnreachableSockError"
+}
+
+func (i *icmpv4DestinationUnreachableSockError) StateFields() []string {
+ return []string{}
+}
+
+func (i *icmpv4DestinationUnreachableSockError) beforeSave() {}
+
+// +checklocksignore
+func (i *icmpv4DestinationUnreachableSockError) StateSave(stateSinkObject state.Sink) {
+ i.beforeSave()
+}
+
+func (i *icmpv4DestinationUnreachableSockError) afterLoad() {}
+
+// +checklocksignore
+func (i *icmpv4DestinationUnreachableSockError) StateLoad(stateSourceObject state.Source) {
+}
+
+func (i *icmpv4DestinationHostUnreachableSockError) StateTypeName() string {
+ return "pkg/tcpip/network/ipv4.icmpv4DestinationHostUnreachableSockError"
+}
+
+func (i *icmpv4DestinationHostUnreachableSockError) StateFields() []string {
+ return []string{
+ "icmpv4DestinationUnreachableSockError",
+ }
+}
+
+func (i *icmpv4DestinationHostUnreachableSockError) beforeSave() {}
+
+// +checklocksignore
+func (i *icmpv4DestinationHostUnreachableSockError) StateSave(stateSinkObject state.Sink) {
+ i.beforeSave()
+ stateSinkObject.Save(0, &i.icmpv4DestinationUnreachableSockError)
+}
+
+func (i *icmpv4DestinationHostUnreachableSockError) afterLoad() {}
+
+// +checklocksignore
+func (i *icmpv4DestinationHostUnreachableSockError) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &i.icmpv4DestinationUnreachableSockError)
+}
+
+func (i *icmpv4DestinationPortUnreachableSockError) StateTypeName() string {
+ return "pkg/tcpip/network/ipv4.icmpv4DestinationPortUnreachableSockError"
+}
+
+func (i *icmpv4DestinationPortUnreachableSockError) StateFields() []string {
+ return []string{
+ "icmpv4DestinationUnreachableSockError",
+ }
+}
+
+func (i *icmpv4DestinationPortUnreachableSockError) beforeSave() {}
+
+// +checklocksignore
+func (i *icmpv4DestinationPortUnreachableSockError) StateSave(stateSinkObject state.Sink) {
+ i.beforeSave()
+ stateSinkObject.Save(0, &i.icmpv4DestinationUnreachableSockError)
+}
+
+func (i *icmpv4DestinationPortUnreachableSockError) afterLoad() {}
+
+// +checklocksignore
+func (i *icmpv4DestinationPortUnreachableSockError) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &i.icmpv4DestinationUnreachableSockError)
+}
+
+func (e *icmpv4FragmentationNeededSockError) StateTypeName() string {
+ return "pkg/tcpip/network/ipv4.icmpv4FragmentationNeededSockError"
+}
+
+func (e *icmpv4FragmentationNeededSockError) StateFields() []string {
+ return []string{
+ "icmpv4DestinationUnreachableSockError",
+ "mtu",
+ }
+}
+
+func (e *icmpv4FragmentationNeededSockError) beforeSave() {}
+
+// +checklocksignore
+func (e *icmpv4FragmentationNeededSockError) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.icmpv4DestinationUnreachableSockError)
+ stateSinkObject.Save(1, &e.mtu)
+}
+
+func (e *icmpv4FragmentationNeededSockError) afterLoad() {}
+
+// +checklocksignore
+func (e *icmpv4FragmentationNeededSockError) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.icmpv4DestinationUnreachableSockError)
+ stateSourceObject.Load(1, &e.mtu)
+}
+
+func init() {
+ state.Register((*icmpv4DestinationUnreachableSockError)(nil))
+ state.Register((*icmpv4DestinationHostUnreachableSockError)(nil))
+ state.Register((*icmpv4DestinationPortUnreachableSockError)(nil))
+ state.Register((*icmpv4FragmentationNeededSockError)(nil))
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
deleted file mode 100644
index e7b5b3ea2..000000000
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ /dev/null
@@ -1,3375 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ipv4_test
-
-import (
- "bytes"
- "encoding/hex"
- "fmt"
- "io/ioutil"
- "math"
- "net"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.dev/gvisor/pkg/tcpip/network/arp"
- iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/raw"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- extraHeaderReserve = 50
- defaultMTU = 65536
-)
-
-func TestExcludeBroadcast(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
-
- ep := stack.LinkEndpoint(channel.New(256, defaultMTU, ""))
- if testing.Verbose() {
- ep = sniffer.New(ep)
- }
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv4EmptySubnet,
- NIC: 1,
- }})
-
- randomAddr := tcpip.FullAddress{NIC: 1, Addr: "\x0a\x00\x00\x01", Port: 53}
-
- var wq waiter.Queue
- t.Run("WithoutPrimaryAddress", func(t *testing.T) {
- ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatal(err)
- }
- defer ep.Close()
-
- // Cannot connect using a broadcast address as the source.
- {
- err := ep.Connect(randomAddr)
- if _, ok := err.(*tcpip.ErrNoRoute); !ok {
- t.Errorf("got ep.Connect(...) = %v, want = %v", err, &tcpip.ErrNoRoute{})
- }
- }
-
- // However, we can bind to a broadcast address to listen.
- if err := ep.Bind(tcpip.FullAddress{Addr: header.IPv4Broadcast, Port: 53, NIC: 1}); err != nil {
- t.Errorf("Bind failed: %v", err)
- }
- })
-
- t.Run("WithPrimaryAddress", func(t *testing.T) {
- ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatal(err)
- }
- defer ep.Close()
-
- // Add a valid primary endpoint address, now we can connect.
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.Address("\x0a\x00\x00\x02").WithPrefix(),
- }
- if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
- }
- if err := ep.Connect(randomAddr); err != nil {
- t.Errorf("Connect failed: %v", err)
- }
- })
-}
-
-func TestForwarding(t *testing.T) {
- const (
- incomingNICID = 1
- outgoingNICID = 2
- randomSequence = 123
- randomIdent = 42
- randomTimeOffset = 0x10203040
- )
-
- incomingIPv4Addr := tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()),
- PrefixLen: 8,
- }
- outgoingIPv4Addr := tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()),
- PrefixLen: 8,
- }
- outgoingLinkAddr := tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
- remoteIPv4Addr1 := testutil.MustParse4("10.0.0.2")
- remoteIPv4Addr2 := testutil.MustParse4("11.0.0.2")
- unreachableIPv4Addr := testutil.MustParse4("12.0.0.2")
- multicastIPv4Addr := testutil.MustParse4("225.0.0.0")
- linkLocalIPv4Addr := testutil.MustParse4("169.254.0.0")
-
- tests := []struct {
- name string
- TTL uint8
- sourceAddr tcpip.Address
- destAddr tcpip.Address
- expectErrorICMP bool
- ipFlags uint8
- mtu uint32
- payloadLength int
- options header.IPv4Options
- forwardedOptions header.IPv4Options
- icmpType header.ICMPv4Type
- icmpCode header.ICMPv4Code
- expectPacketUnrouteableError bool
- expectLinkLocalSourceError bool
- expectLinkLocalDestError bool
- expectPacketForwarded bool
- expectedFragmentsForwarded []fragmentInfo
- }{
- {
- name: "TTL of zero",
- TTL: 0,
- sourceAddr: remoteIPv4Addr1,
- destAddr: remoteIPv4Addr2,
- expectErrorICMP: true,
- icmpType: header.ICMPv4TimeExceeded,
- icmpCode: header.ICMPv4TTLExceeded,
- mtu: ipv4.MaxTotalSize,
- },
- {
- name: "TTL of one",
- TTL: 1,
- sourceAddr: remoteIPv4Addr1,
- destAddr: remoteIPv4Addr2,
- expectPacketForwarded: true,
- mtu: ipv4.MaxTotalSize,
- },
- {
- name: "TTL of two",
- TTL: 2,
- sourceAddr: remoteIPv4Addr1,
- destAddr: remoteIPv4Addr2,
- expectPacketForwarded: true,
- mtu: ipv4.MaxTotalSize,
- },
- {
- name: "Max TTL",
- TTL: math.MaxUint8,
- sourceAddr: remoteIPv4Addr1,
- destAddr: remoteIPv4Addr2,
- expectPacketForwarded: true,
- mtu: ipv4.MaxTotalSize,
- },
- {
- name: "four EOL options",
- TTL: 2,
- sourceAddr: remoteIPv4Addr1,
- destAddr: remoteIPv4Addr2,
- expectPacketForwarded: true,
- mtu: ipv4.MaxTotalSize,
- options: header.IPv4Options{0, 0, 0, 0},
- forwardedOptions: header.IPv4Options{0, 0, 0, 0},
- },
- {
- name: "TS type 1 full",
- TTL: 2,
- sourceAddr: remoteIPv4Addr1,
- destAddr: remoteIPv4Addr2,
- mtu: ipv4.MaxTotalSize,
- options: header.IPv4Options{
- 68, 12, 13, 0xF1,
- 192, 168, 1, 12,
- 1, 2, 3, 4,
- },
- expectErrorICMP: true,
- icmpType: header.ICMPv4ParamProblem,
- icmpCode: header.ICMPv4UnusedCode,
- },
- {
- name: "TS type 0",
- TTL: 2,
- sourceAddr: remoteIPv4Addr1,
- destAddr: remoteIPv4Addr2,
- mtu: ipv4.MaxTotalSize,
- options: header.IPv4Options{
- 68, 24, 21, 0x00,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 0, 0, 0, 0,
- },
- forwardedOptions: header.IPv4Options{
- 68, 24, 25, 0x00,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock
- },
- expectPacketForwarded: true,
- },
- {
- name: "end of options list",
- TTL: 2,
- sourceAddr: remoteIPv4Addr1,
- destAddr: remoteIPv4Addr2,
- mtu: ipv4.MaxTotalSize,
- options: header.IPv4Options{
- 68, 12, 13, 0x11,
- 192, 168, 1, 12,
- 1, 2, 3, 4,
- 0, 10, 3, 99, // EOL followed by junk
- 1, 2, 3, 4,
- },
- forwardedOptions: header.IPv4Options{
- 68, 12, 13, 0x21,
- 192, 168, 1, 12,
- 1, 2, 3, 4,
- 0, // End of Options hides following bytes.
- 0, 0, 0, // 7 bytes unknown option removed.
- 0, 0, 0, 0,
- },
- expectPacketForwarded: true,
- },
- {
- name: "Network unreachable",
- TTL: 2,
- sourceAddr: remoteIPv4Addr1,
- destAddr: unreachableIPv4Addr,
- expectErrorICMP: true,
- mtu: ipv4.MaxTotalSize,
- icmpType: header.ICMPv4DstUnreachable,
- icmpCode: header.ICMPv4NetUnreachable,
- expectPacketUnrouteableError: true,
- },
- {
- name: "Multicast destination",
- TTL: 2,
- destAddr: multicastIPv4Addr,
- expectPacketUnrouteableError: true,
- },
- {
- name: "Link local destination",
- TTL: 2,
- sourceAddr: remoteIPv4Addr1,
- destAddr: linkLocalIPv4Addr,
- expectLinkLocalDestError: true,
- },
- {
- name: "Link local source",
- TTL: 2,
- sourceAddr: linkLocalIPv4Addr,
- destAddr: remoteIPv4Addr2,
- expectLinkLocalSourceError: true,
- },
- {
- name: "Fragmentation needed and DF set",
- TTL: 2,
- sourceAddr: remoteIPv4Addr1,
- destAddr: remoteIPv4Addr2,
- ipFlags: header.IPv4FlagDontFragment,
- // We've picked this MTU because it is:
- //
- // 1) Greater than the minimum MTU that IPv4 hosts are required to process
- // (576 bytes). As per RFC 1812, Section 4.3.2.3:
- //
- // The ICMP datagram SHOULD contain as much of the original datagram as
- // possible without the length of the ICMP datagram exceeding 576 bytes.
- //
- // Therefore, setting an MTU greater than 576 bytes ensures that we can fit a
- // complete ICMP packet on the incoming endpoint (and make assertions about
- // it).
- //
- // 2) Less than `ipv4.MaxTotalSize`, which lets us build an IPv4 packet whose
- // size exceeds the MTU.
- mtu: 1000,
- payloadLength: 1004,
- expectErrorICMP: true,
- icmpType: header.ICMPv4DstUnreachable,
- icmpCode: header.ICMPv4FragmentationNeeded,
- },
- {
- name: "Fragmentation needed and DF not set",
- TTL: 2,
- sourceAddr: remoteIPv4Addr1,
- destAddr: remoteIPv4Addr2,
- mtu: 1000,
- payloadLength: 1004,
- expectPacketForwarded: true,
- // Combined, these fragments have length of 1012 octets, which is equal to
- // the length of the payload (1004 octets), plus the length of the ICMP
- // header (8 octets).
- expectedFragmentsForwarded: []fragmentInfo{
- // The first fragment has a length of the greatest multiple of 8 which is
- // less than or equal to to `mtu - header.IPv4MinimumSize`.
- {offset: 0, payloadSize: uint16(976), more: true},
- // The next fragment holds the rest of the packet.
- {offset: uint16(976), payloadSize: 36, more: false},
- },
- },
- }
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
- Clock: clock,
- })
-
- // Advance the clock by some unimportant amount to make
- // it give a more recognisable signature than 00,00,00,00.
- clock.Advance(time.Millisecond * randomTimeOffset)
-
- // We expect at most a single packet in response to our ICMP Echo Request.
- incomingEndpoint := channel.New(1, test.mtu, "")
- if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err)
- }
- incomingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: incomingIPv4Addr}
- if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingIPv4ProtoAddr, err)
- }
-
- expectedEmittedPacketCount := 1
- if len(test.expectedFragmentsForwarded) > expectedEmittedPacketCount {
- expectedEmittedPacketCount = len(test.expectedFragmentsForwarded)
- }
- outgoingEndpoint := channel.New(expectedEmittedPacketCount, test.mtu, outgoingLinkAddr)
- if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err)
- }
- outgoingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: outgoingIPv4Addr}
- if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingIPv4ProtoAddr, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: incomingIPv4Addr.Subnet(),
- NIC: incomingNICID,
- },
- {
- Destination: outgoingIPv4Addr.Subnet(),
- NIC: outgoingNICID,
- },
- })
-
- if err := s.SetForwardingDefaultAndAllNICs(header.IPv4ProtocolNumber, true); err != nil {
- t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", header.IPv4ProtocolNumber, err)
- }
-
- ipHeaderLength := header.IPv4MinimumSize + len(test.options)
- if ipHeaderLength > header.IPv4MaximumHeaderSize {
- t.Fatalf("got ipHeaderLength = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
- }
- icmpHeaderLength := header.ICMPv4MinimumSize
- totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength
- hdr := buffer.NewPrependable(totalLength)
- hdr.Prepend(test.payloadLength)
- icmpH := header.ICMPv4(hdr.Prepend(icmpHeaderLength))
- icmpH.SetIdent(randomIdent)
- icmpH.SetSequence(randomSequence)
- icmpH.SetType(header.ICMPv4Echo)
- icmpH.SetCode(header.ICMPv4UnusedCode)
- icmpH.SetChecksum(0)
- icmpH.SetChecksum(^header.Checksum(icmpH, 0))
- ip := header.IPv4(hdr.Prepend(ipHeaderLength))
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(totalLength),
- Protocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: test.TTL,
- SrcAddr: test.sourceAddr,
- DstAddr: test.destAddr,
- Flags: test.ipFlags,
- })
- if len(test.options) != 0 {
- ip.SetHeaderLength(uint8(ipHeaderLength))
- // Copy options manually. We do not use Encode for options so we can
- // verify malformed options with handcrafted payloads.
- if want, got := copy(ip.Options(), test.options), len(test.options); want != got {
- t.Fatalf("got copy(ip.Options(), test.options) = %d, want = %d", got, want)
- }
- }
- ip.SetChecksum(0)
- ip.SetChecksum(^ip.CalculateChecksum())
- requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- })
- requestPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
- incomingEndpoint.InjectInbound(header.IPv4ProtocolNumber, requestPkt)
-
- reply, ok := incomingEndpoint.Read()
-
- if test.expectErrorICMP {
- if !ok {
- t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType)
- }
-
- // We expect the ICMP packet to contain as much of the original packet as
- // possible up to a limit of 576 bytes, split between payload, IP header,
- // and ICMP header.
- expectedICMPPayloadLength := func() int {
- maxICMPPacketLength := header.IPv4MinimumProcessableDatagramSize
- maxICMPPayloadLength := maxICMPPacketLength - icmpHeaderLength - ipHeaderLength
- if len(hdr.View()) > maxICMPPayloadLength {
- return maxICMPPayloadLength
- }
- return len(hdr.View())
- }
-
- checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
- checker.SrcAddr(incomingIPv4Addr.Address),
- checker.DstAddr(test.sourceAddr),
- checker.TTL(ipv4.DefaultTTL),
- checker.ICMPv4(
- checker.ICMPv4Checksum(),
- checker.ICMPv4Type(test.icmpType),
- checker.ICMPv4Code(test.icmpCode),
- checker.ICMPv4Payload(hdr.View()[:expectedICMPPayloadLength()]),
- ),
- )
- } else if ok {
- t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", reply)
- }
-
- if test.expectPacketForwarded {
- if len(test.expectedFragmentsForwarded) != 0 {
- var fragmentedPackets []*stack.PacketBuffer
- for i := 0; i < len(test.expectedFragmentsForwarded); i++ {
- reply, ok = outgoingEndpoint.Read()
- if !ok {
- t.Fatal("expected ICMP Echo fragment through outgoing NIC")
- }
- fragmentedPackets = append(fragmentedPackets, reply.Pkt)
- }
-
- // The forwarded packet's TTL will have been decremented.
- ipHeader := header.IPv4(requestPkt.NetworkHeader().View())
- ipHeader.SetTTL(ipHeader.TTL() - 1)
-
- // Forwarded packets have available header bytes equalling the sum of the
- // maximum IP header size and the maximum size allocated for link layer
- // headers. In this case, no size is allocated for link layer headers.
- expectedAvailableHeaderBytes := header.IPv4MaximumHeaderSize
- if err := compareFragments(fragmentedPackets, requestPkt, test.mtu, test.expectedFragmentsForwarded, header.ICMPv4ProtocolNumber, true /* withIPHeader */, expectedAvailableHeaderBytes); err != nil {
- t.Error(err)
- }
- } else {
- reply, ok = outgoingEndpoint.Read()
- if !ok {
- t.Fatal("expected ICMP Echo packet through outgoing NIC")
- }
-
- checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
- checker.SrcAddr(test.sourceAddr),
- checker.DstAddr(test.destAddr),
- checker.TTL(test.TTL-1),
- checker.IPv4Options(test.forwardedOptions),
- checker.ICMPv4(
- checker.ICMPv4Checksum(),
- checker.ICMPv4Type(header.ICMPv4Echo),
- checker.ICMPv4Code(header.ICMPv4UnusedCode),
- checker.ICMPv4Payload(nil),
- ),
- )
- }
- } else {
- if reply, ok = outgoingEndpoint.Read(); ok {
- t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply)
- }
- }
- boolToInt := func(val bool) uint64 {
- if val {
- return 1
- }
- return 0
- }
-
- if got, want := s.Stats().IP.Forwarding.LinkLocalSource.Value(), boolToInt(test.expectLinkLocalSourceError); got != want {
- t.Errorf("got s.Stats().IP.Forwarding.LinkLocalSource.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.Stats().IP.Forwarding.LinkLocalDestination.Value(), boolToInt(test.expectLinkLocalDestError); got != want {
- t.Errorf("got s.Stats().IP.Forwarding.LinkLocalDestination.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), boolToInt(test.icmpType == header.ICMPv4ParamProblem); got != want {
- t.Errorf("got s.Stats().IP.MalformedPacketsReceived.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.Stats().IP.Forwarding.ExhaustedTTL.Value(), boolToInt(test.TTL <= 0); got != want {
- t.Errorf("got s.Stats().IP.Forwarding.ExhaustedTTL.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.Stats().IP.Forwarding.Unrouteable.Value(), boolToInt(test.expectPacketUnrouteableError); got != want {
- t.Errorf("got s.Stats().IP.Forwarding.Unrouteable.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.Stats().IP.Forwarding.Errors.Value(), boolToInt(!test.expectPacketForwarded); got != want {
- t.Errorf("got s.Stats().IP.Forwarding.Errors.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.Stats().IP.Forwarding.PacketTooBig.Value(), boolToInt(test.icmpCode == header.ICMPv4FragmentationNeeded); got != want {
- t.Errorf("got s.Stats().IP.Forwarding.PacketTooBig.Value() = %d, want = %d", got, want)
- }
- })
- }
-}
-
-// TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and
-// checks the response.
-func TestIPv4Sanity(t *testing.T) {
- const (
- ttl = 255
- nicID = 1
- randomSequence = 123
- randomIdent = 42
- // In some cases Linux sets the error pointer to the start of the option
- // (offset 0) instead of the actual wrong value, which is the length byte
- // (offset 1). For compatibility we must do the same. Use this constant
- // to indicate where this happens.
- pointerOffsetForInvalidLength = 0
- randomTimeOffset = 0x10203040
- )
- var (
- ipv4Addr = tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()),
- PrefixLen: 24,
- }
- remoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4())
- )
-
- tests := []struct {
- name string
- headerLength uint8 // value of 0 means "use correct size"
- badHeaderChecksum bool
- maxTotalLength uint16
- transportProtocol uint8
- TTL uint8
- options header.IPv4Options
- replyOptions header.IPv4Options // reply should look like this
- shouldFail bool
- expectErrorICMP bool
- ICMPType header.ICMPv4Type
- ICMPCode header.ICMPv4Code
- paramProblemPointer uint8
- }{
- {
- name: "valid no options",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- },
- {
- name: "bad header checksum",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- badHeaderChecksum: true,
- shouldFail: true,
- },
- // The TTL tests check that we are not rejecting an incoming packet
- // with a zero or one TTL, which has been a point of confusion in the
- // past as RFC 791 says: "If this field contains the value zero, then the
- // datagram must be destroyed". However RFC 1122 section 3.2.1.7 clarifies
- // for the case of the destination host, stating as follows.
- //
- // A host MUST NOT send a datagram with a Time-to-Live (TTL)
- // value of zero.
- //
- // A host MUST NOT discard a datagram just because it was
- // received with TTL less than 2.
- {
- name: "zero TTL",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: 0,
- },
- {
- name: "one TTL",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: 1,
- },
- {
- name: "End options",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{0, 0, 0, 0},
- replyOptions: header.IPv4Options{0, 0, 0, 0},
- },
- {
- name: "NOP options",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{1, 1, 1, 1},
- replyOptions: header.IPv4Options{1, 1, 1, 1},
- },
- {
- name: "NOP and End options",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{1, 1, 0, 0},
- replyOptions: header.IPv4Options{1, 1, 0, 0},
- },
- {
- name: "bad header length",
- headerLength: header.IPv4MinimumSize - 1,
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- shouldFail: true,
- },
- {
- name: "bad total length (0)",
- maxTotalLength: 0,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- shouldFail: true,
- },
- {
- name: "bad total length (ip - 1)",
- maxTotalLength: uint16(header.IPv4MinimumSize - 1),
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- shouldFail: true,
- },
- {
- name: "bad total length (ip + icmp - 1)",
- maxTotalLength: uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize - 1),
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- shouldFail: true,
- },
- {
- name: "bad protocol",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: 99,
- TTL: ttl,
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4DstUnreachable,
- ICMPCode: header.ICMPv4ProtoUnreachable,
- },
- {
- name: "timestamp option overflow",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 12, 13, 0x11,
- 192, 168, 1, 12,
- 1, 2, 3, 4,
- },
- replyOptions: header.IPv4Options{
- 68, 12, 13, 0x21,
- 192, 168, 1, 12,
- 1, 2, 3, 4,
- },
- },
- {
- name: "timestamp option overflow full",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 12, 13, 0xF1,
- // ^ Counter full (15/0xF)
- 192, 168, 1, 12,
- 1, 2, 3, 4,
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + 3,
- replyOptions: header.IPv4Options{},
- },
- {
- name: "unknown option",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{10, 4, 9, 0},
- // ^^
- // The unknown option should be stripped out of the reply.
- replyOptions: header.IPv4Options{},
- },
- {
- name: "bad option - no length",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 1, 1, 1, 68,
- // ^-start of timestamp.. but no length..
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + 3,
- },
- {
- name: "bad option - length 0",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 0, 9, 0,
- // ^
- 1, 2, 3, 4,
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength,
- },
- {
- name: "bad option - length 1",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 1, 9, 0,
- // ^
- 1, 2, 3, 4,
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength,
- },
- {
- name: "bad option - length big",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 9, 9, 0,
- // ^
- // There are only 8 bytes allocated to options so 9 bytes of timestamp
- // space is not possible. (Second byte)
- 1, 2, 3, 4,
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength,
- },
- {
- // This tests for some linux compatible behaviour.
- // The ICMP pointer returned is 22 for Linux but the
- // error is actually in spot 21.
- name: "bad option - length bad",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- // Timestamps are in multiples of 4 or 8 but never 7.
- // The option space should be padded out.
- options: header.IPv4Options{
- 68, 7, 5, 0,
- // ^ ^ Linux points here which is wrong.
- // | Not a multiple of 4
- 1, 2, 3, 0,
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset,
- },
- {
- name: "multiple type 0 with room",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 24, 21, 0x00,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 0, 0, 0, 0,
- },
- replyOptions: header.IPv4Options{
- 68, 24, 25, 0x00,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock
- },
- },
- {
- // The timestamp area is full so add to the overflow count.
- name: "multiple type 1 timestamps",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 20, 21, 0x11,
- // ^
- 192, 168, 1, 12,
- 1, 2, 3, 4,
- 192, 168, 1, 13,
- 5, 6, 7, 8,
- },
- // Overflow count is the top nibble of the 4th byte.
- replyOptions: header.IPv4Options{
- 68, 20, 21, 0x21,
- // ^
- 192, 168, 1, 12,
- 1, 2, 3, 4,
- 192, 168, 1, 13,
- 5, 6, 7, 8,
- },
- },
- {
- name: "multiple type 1 timestamps with room",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 28, 21, 0x01,
- 192, 168, 1, 12,
- 1, 2, 3, 4,
- 192, 168, 1, 13,
- 5, 6, 7, 8,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- },
- replyOptions: header.IPv4Options{
- 68, 28, 29, 0x01,
- 192, 168, 1, 12,
- 1, 2, 3, 4,
- 192, 168, 1, 13,
- 5, 6, 7, 8,
- 192, 168, 1, 58, // New IP Address.
- 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock
- },
- },
- {
- // Timestamp pointer uses one based counting so 0 is invalid.
- name: "timestamp pointer invalid",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 8, 0, 0x00,
- // ^ 0 instead of 5 or more.
- 0, 0, 0, 0,
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + 2,
- },
- {
- // Timestamp pointer cannot be less than 5. It must point past the header
- // which is 4 bytes. (1 based counting)
- name: "timestamp pointer too small by 1",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 8, header.IPv4OptionTimestampHdrLength, 0x00,
- // ^ header is 4 bytes, so 4 should fail.
- 0, 0, 0, 0,
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset,
- },
- {
- name: "valid timestamp pointer",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 8, header.IPv4OptionTimestampHdrLength + 1, 0x00,
- // ^ header is 4 bytes, so 5 should succeed.
- 0, 0, 0, 0,
- },
- replyOptions: header.IPv4Options{
- 68, 8, 9, 0x00,
- 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock
- },
- },
- {
- // Needs 8 bytes for a type 1 timestamp but there are only 4 free.
- name: "bad timer element alignment",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 20, 17, 0x01,
- // ^^ ^^ 20 byte area, next free spot at 17.
- 192, 168, 1, 12,
- 1, 2, 3, 4,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset,
- },
- // End of option list with illegal option after it, which should be ignored.
- {
- name: "end of options list",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 12, 13, 0x11,
- 192, 168, 1, 12,
- 1, 2, 3, 4,
- 0, 10, 3, 99, // EOL followed by junk
- },
- replyOptions: header.IPv4Options{
- 68, 12, 13, 0x21,
- 192, 168, 1, 12,
- 1, 2, 3, 4,
- 0, // End of Options hides following bytes.
- 0, 0, 0, // 3 bytes unknown option removed.
- },
- },
- {
- // Timestamp with a size much too small.
- name: "timestamp truncated",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 68, 1, 0, 0,
- // ^ Smallest possible is 8. Linux points at the 68.
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength,
- },
- {
- name: "single record route with room",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 7, 7, 4, // 3 byte header
- 0, 0, 0, 0,
- 0,
- },
- replyOptions: header.IPv4Options{
- 7, 7, 8, // 3 byte header
- 192, 168, 1, 58, // New IP Address.
- 0, // padding to multiple of 4 bytes.
- },
- },
- {
- name: "multiple record route with room",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 7, 23, 20, // 3 byte header
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 0, 0, 0, 0,
- 0,
- },
- replyOptions: header.IPv4Options{
- 7, 23, 24,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 192, 168, 1, 58, // New IP Address.
- 0, // padding to multiple of 4 bytes.
- },
- },
- {
- name: "single record route with no room",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 7, 7, 8, // 3 byte header
- 1, 2, 3, 4,
- 0,
- },
- replyOptions: header.IPv4Options{
- 7, 7, 8, // 3 byte header
- 1, 2, 3, 4,
- 0, // padding to multiple of 4 bytes.
- },
- },
- {
- // Unlike timestamp, this should just succeed.
- name: "multiple record route with no room",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 7, 23, 24, // 3 byte header
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 17, 18, 19, 20,
- 0,
- },
- replyOptions: header.IPv4Options{
- 7, 23, 24,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 17, 18, 19, 20,
- 0, // padding to multiple of 4 bytes.
- },
- },
- {
- // Pointer uses one based counting so 0 is invalid.
- name: "record route pointer zero",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 7, 8, 0, // 3 byte header
- 0, 0, 0, 0,
- 0,
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset,
- },
- {
- // Pointer must be 4 or more as it must point past the 3 byte header
- // using 1 based counting. 3 should fail.
- name: "record route pointer too small by 1",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 7, 8, header.IPv4OptionRecordRouteHdrLength, // 3 byte header
- 0, 0, 0, 0,
- 0,
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset,
- },
- {
- // Pointer must be 4 or more as it must point past the 3 byte header
- // using 1 based counting. Check 4 passes. (Duplicates "single
- // record route with room")
- name: "valid record route pointer",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 7, 7, header.IPv4OptionRecordRouteHdrLength + 1, // 3 byte header
- 0, 0, 0, 0,
- 0,
- },
- replyOptions: header.IPv4Options{
- 7, 7, 8, // 3 byte header
- 192, 168, 1, 58, // New IP Address.
- 0, // padding to multiple of 4 bytes.
- },
- },
- {
- // Confirm Linux bug for bug compatibility.
- // Linux returns slot 22 but the error is in slot 21.
- name: "multiple record route with not enough room",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 7, 8, 8, // 3 byte header
- // ^ ^ Linux points here. We must too.
- // | Not enough room. 1 byte free, need 4.
- 1, 2, 3, 4,
- 0,
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset,
- },
- {
- name: "duplicate record route",
- maxTotalLength: ipv4.MaxTotalSize,
- transportProtocol: uint8(header.ICMPv4ProtocolNumber),
- TTL: ttl,
- options: header.IPv4Options{
- 7, 7, 8, // 3 byte header
- 1, 2, 3, 4,
- 7, 7, 8, // 3 byte header
- 1, 2, 3, 4,
- 0, 0, // pad
- },
- shouldFail: true,
- expectErrorICMP: true,
- ICMPType: header.ICMPv4ParamProblem,
- ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + 7,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
- Clock: clock,
- })
- // Advance the clock by some unimportant amount to make
- // it give a more recognisable signature than 00,00,00,00.
- clock.Advance(time.Millisecond * randomTimeOffset)
-
- // We expect at most a single packet in response to our ICMP Echo Request.
- e := channel.New(1, ipv4.MaxTotalSize, "")
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
- ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr}
- if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtoAddr, err)
- }
-
- // Default routes for IPv4 so ICMP can find a route to the remote
- // node when attempting to send the ICMP Echo Reply.
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: nicID,
- },
- })
-
- if len(test.options)%4 != 0 {
- t.Fatalf("options must be aligned to 32 bits, invalid test options: %x (len=%d)", test.options, len(test.options))
- }
- ipHeaderLength := header.IPv4MinimumSize + len(test.options)
- if ipHeaderLength > header.IPv4MaximumHeaderSize {
- t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
- }
- totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize)
- hdr := buffer.NewPrependable(int(totalLen))
- icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
-
- // Specify ident/seq to make sure we get the same in the response.
- icmpH.SetIdent(randomIdent)
- icmpH.SetSequence(randomSequence)
- icmpH.SetType(header.ICMPv4Echo)
- icmpH.SetCode(header.ICMPv4UnusedCode)
- icmpH.SetChecksum(0)
- icmpH.SetChecksum(^header.Checksum(icmpH, 0))
- ip := header.IPv4(hdr.Prepend(ipHeaderLength))
- if test.maxTotalLength < totalLen {
- totalLen = test.maxTotalLength
- }
- ip.Encode(&header.IPv4Fields{
- TotalLength: totalLen,
- Protocol: test.transportProtocol,
- TTL: test.TTL,
- SrcAddr: remoteIPv4Addr,
- DstAddr: ipv4Addr.Address,
- })
- if test.headerLength != 0 {
- ip.SetHeaderLength(test.headerLength)
- } else {
- // Set the calculated header length, since we may manually add options.
- ip.SetHeaderLength(uint8(ipHeaderLength))
- }
- if len(test.options) != 0 {
- // Copy options manually. We do not use Encode for options so we can
- // verify malformed options with handcrafted payloads.
- if want, got := copy(ip.Options(), test.options), len(test.options); want != got {
- t.Fatalf("got copy(ip.Options(), test.options) = %d, want = %d", got, want)
- }
- }
- ip.SetChecksum(0)
- ipHeaderChecksum := ip.CalculateChecksum()
- if test.badHeaderChecksum {
- ipHeaderChecksum += 42
- }
- ip.SetChecksum(^ipHeaderChecksum)
- requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- })
- e.InjectInbound(header.IPv4ProtocolNumber, requestPkt)
- reply, ok := e.Read()
- if !ok {
- if test.shouldFail {
- if test.expectErrorICMP {
- t.Fatalf("ICMP error response (type %d, code %d) missing", test.ICMPType, test.ICMPCode)
- }
- return // Expected silent failure.
- }
- t.Fatal("expected ICMP echo reply missing")
- }
-
- // We didn't expect a packet. Register our surprise but carry on to
- // provide more information about what we got.
- if test.shouldFail && !test.expectErrorICMP {
- t.Error("unexpected packet response")
- }
-
- // Check the route that brought the packet to us.
- if reply.Route.LocalAddress != ipv4Addr.Address {
- t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address)
- }
- if reply.Route.RemoteAddress != remoteIPv4Addr {
- t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr)
- }
-
- // Make sure it's all in one buffer for checker.
- replyIPHeader := header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader()))
-
- // At this stage we only know it's probably an IP+ICMP header so verify
- // that much.
- checker.IPv4(t, replyIPHeader,
- checker.SrcAddr(ipv4Addr.Address),
- checker.DstAddr(remoteIPv4Addr),
- checker.ICMPv4(
- checker.ICMPv4Checksum(),
- ),
- )
-
- // Don't proceed any further if the checker found problems.
- if t.Failed() {
- t.FailNow()
- }
-
- // OK it's ICMP. We can safely look at the type now.
- replyICMPHeader := header.ICMPv4(replyIPHeader.Payload())
- switch replyICMPHeader.Type() {
- case header.ICMPv4ParamProblem:
- if !test.shouldFail {
- t.Fatalf("got Parameter Problem with pointer %d, wanted Echo Reply", replyICMPHeader.Pointer())
- }
- if !test.expectErrorICMP {
- t.Fatalf("got Parameter Problem with pointer %d, wanted no response", replyICMPHeader.Pointer())
- }
- checker.IPv4(t, replyIPHeader,
- checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())),
- checker.IPv4HeaderLength(header.IPv4MinimumSize),
- checker.ICMPv4(
- checker.ICMPv4Type(test.ICMPType),
- checker.ICMPv4Code(test.ICMPCode),
- checker.ICMPv4Pointer(test.paramProblemPointer),
- checker.ICMPv4Payload(hdr.View()),
- ),
- )
- return
- case header.ICMPv4DstUnreachable:
- if !test.shouldFail {
- t.Fatalf("got ICMP error packet type %d, code %d, wanted Echo Reply",
- header.ICMPv4DstUnreachable, replyICMPHeader.Code())
- }
- if !test.expectErrorICMP {
- t.Fatalf("got ICMP error packet type %d, code %d, wanted no response",
- header.ICMPv4DstUnreachable, replyICMPHeader.Code())
- }
- checker.IPv4(t, replyIPHeader,
- checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())),
- checker.IPv4HeaderLength(header.IPv4MinimumSize),
- checker.ICMPv4(
- checker.ICMPv4Type(test.ICMPType),
- checker.ICMPv4Code(test.ICMPCode),
- checker.ICMPv4Payload(hdr.View()),
- ),
- )
- return
- case header.ICMPv4EchoReply:
- if test.shouldFail {
- if !test.expectErrorICMP {
- t.Error("got Echo Reply packet, want no response")
- } else {
- t.Errorf("got Echo Reply, want ICMP error type %d, code %d", test.ICMPType, test.ICMPCode)
- }
- }
- // If the IP options change size then the packet will change size, so
- // some IP header fields will need to be adjusted for the checks.
- sizeChange := len(test.replyOptions) - len(test.options)
-
- checker.IPv4(t, replyIPHeader,
- checker.IPv4HeaderLength(ipHeaderLength+sizeChange),
- checker.IPv4Options(test.replyOptions),
- checker.IPFullLength(uint16(requestPkt.Size()+sizeChange)),
- checker.ICMPv4(
- checker.ICMPv4Checksum(),
- checker.ICMPv4Code(header.ICMPv4UnusedCode),
- checker.ICMPv4Seq(randomSequence),
- checker.ICMPv4Ident(randomIdent),
- ),
- )
- default:
- t.Fatalf("unexpected ICMP response, got type %d, want = %d, %d or %d",
- replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable, header.ICMPv4ParamProblem)
- }
- })
- }
-}
-
-// compareFragments compares the contents of a set of fragmented packets against
-// the contents of a source packet.
-//
-// If withIPHeader is set to true, we will validate the fragmented packets' IP
-// headers against the source packet's IP header. If set to false, we validate
-// the fragmented packets' IP headers against each other.
-func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber, withIPHeader bool, expectedAvailableHeaderBytes int) error {
- // Make a complete array of the sourcePacket packet.
- var source header.IPv4
- vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views())
-
- // If the packet to be fragmented contains an IPv4 header, use that header for
- // validating fragment headers. Else, use the header of the first fragment.
- if withIPHeader {
- source = header.IPv4(vv.ToView())
- } else {
- source = header.IPv4(packets[0].NetworkHeader().View())
- source = append(source, vv.ToView()...)
- }
-
- // Make a copy of the IP header, which will be modified in some fields to make
- // an expected header.
- sourceCopy := header.IPv4(append(buffer.View(nil), source[:source.HeaderLength()]...))
- sourceCopy.SetChecksum(0)
- sourceCopy.SetFlagsFragmentOffset(0, 0)
- sourceCopy.SetTotalLength(0)
- // Build up an array of the bytes sent.
- var reassembledPayload buffer.VectorisedView
- for i, packet := range packets {
- // Confirm that the packet is valid.
- allBytes := buffer.NewVectorisedView(packet.Size(), packet.Views())
- fragmentIPHeader := header.IPv4(allBytes.ToView())
- if !fragmentIPHeader.IsValid(len(fragmentIPHeader)) {
- return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeader))
- }
- if got := len(fragmentIPHeader); got > int(mtu) {
- return fmt.Errorf("fragment #%d: got len(fragmentIPHeader) = %d, want <= %d", i, got, mtu)
- }
- if got := fragmentIPHeader.TransportProtocol(); got != proto {
- return fmt.Errorf("fragment #%d: got fragmentIPHeader.TransportProtocol() = %d, want = %d", i, got, uint8(proto))
- }
- if got, want := packet.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber; got != want {
- return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, got, want)
- }
- if got := packet.AvailableHeaderBytes(); got != expectedAvailableHeaderBytes {
- return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, expectedAvailableHeaderBytes)
- }
- if got, want := fragmentIPHeader.CalculateChecksum(), uint16(0xffff); got != want {
- return fmt.Errorf("fragment #%d: got ip.CalculateChecksum() = %#x, want = %#x", i, got, want)
- }
- if wantFragments[i].more {
- sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, wantFragments[i].offset)
- } else {
- sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, wantFragments[i].offset)
- }
- reassembledPayload.AppendView(packet.TransportHeader().View())
- reassembledPayload.AppendView(packet.Data().AsRange().ToOwnedView())
- // Clear out the checksum and length from the ip because we can't compare
- // it.
- sourceCopy.SetTotalLength(wantFragments[i].payloadSize + header.IPv4MinimumSize)
- sourceCopy.SetChecksum(0)
- sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum())
-
- // If we are validating against the original IP header, we should exclude the
- // ID field, which will only be set fo fragmented packets.
- if withIPHeader {
- fragmentIPHeader.SetID(0)
- fragmentIPHeader.SetChecksum(0)
- fragmentIPHeader.SetChecksum(^fragmentIPHeader.CalculateChecksum())
- }
- if diff := cmp.Diff(fragmentIPHeader[:fragmentIPHeader.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]); diff != "" {
- return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff)
- }
- }
-
- expected := buffer.View(source[source.HeaderLength():])
- if diff := cmp.Diff(expected, reassembledPayload.ToView()); diff != "" {
- return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
- }
-
- return nil
-}
-
-type fragmentInfo struct {
- offset uint16
- more bool
- payloadSize uint16
-}
-
-var fragmentationTests = []struct {
- description string
- mtu uint32
- transportHeaderLength int
- payloadSize int
- wantFragments []fragmentInfo
-}{
- {
- description: "No fragmentation",
- mtu: 1280,
- transportHeaderLength: 0,
- payloadSize: 1000,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 1000, more: false},
- },
- },
- {
- description: "Fragmented",
- mtu: 1280,
- transportHeaderLength: 0,
- payloadSize: 2000,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 1256, more: true},
- {offset: 1256, payloadSize: 744, more: false},
- },
- },
- {
- description: "Fragmented with the minimum mtu",
- mtu: header.IPv4MinimumMTU,
- transportHeaderLength: 0,
- payloadSize: 100,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 48, more: true},
- {offset: 48, payloadSize: 48, more: true},
- {offset: 96, payloadSize: 4, more: false},
- },
- },
- {
- description: "Fragmented with mtu not a multiple of 8",
- mtu: header.IPv4MinimumMTU + 1,
- transportHeaderLength: 0,
- payloadSize: 100,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 48, more: true},
- {offset: 48, payloadSize: 48, more: true},
- {offset: 96, payloadSize: 4, more: false},
- },
- },
- {
- description: "No fragmentation with big header",
- mtu: 2000,
- transportHeaderLength: 100,
- payloadSize: 1000,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 1100, more: false},
- },
- },
- {
- description: "Fragmented with big header",
- mtu: 1280,
- transportHeaderLength: 100,
- payloadSize: 1200,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 1256, more: true},
- {offset: 1256, payloadSize: 44, more: false},
- },
- },
- {
- description: "Fragmented with MTU smaller than header",
- mtu: 300,
- transportHeaderLength: 1000,
- payloadSize: 500,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 280, more: true},
- {offset: 280, payloadSize: 280, more: true},
- {offset: 560, payloadSize: 280, more: true},
- {offset: 840, payloadSize: 280, more: true},
- {offset: 1120, payloadSize: 280, more: true},
- {offset: 1400, payloadSize: 100, more: false},
- },
- },
-}
-
-func TestFragmentationWritePacket(t *testing.T) {
- const ttl = 42
-
- for _, ft := range fragmentationTests {
- t.Run(ft.description, func(t *testing.T) {
- ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
- r := buildRoute(t, ep)
- pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
- source := pkt.Clone()
- err := r.WritePacket(stack.NetworkHeaderParams{
- Protocol: tcp.ProtocolNumber,
- TTL: ttl,
- TOS: stack.DefaultTOS,
- }, pkt)
- if err != nil {
- t.Fatalf("r.WritePacket(...): %s", err)
- }
- if got := len(ep.WrittenPackets); got != len(ft.wantFragments) {
- t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments))
- }
- if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) {
- t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments))
- }
- if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
- t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
- }
- if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber, false /* withIPHeader */, extraHeaderReserve); err != nil {
- t.Error(err)
- }
- })
- }
-}
-
-func TestFragmentationWritePackets(t *testing.T) {
- const ttl = 42
- writePacketsTests := []struct {
- description string
- insertBefore int
- insertAfter int
- }{
- {
- description: "Single packet",
- insertBefore: 0,
- insertAfter: 0,
- },
- {
- description: "With packet before",
- insertBefore: 1,
- insertAfter: 0,
- },
- {
- description: "With packet after",
- insertBefore: 0,
- insertAfter: 1,
- },
- {
- description: "With packet before and after",
- insertBefore: 1,
- insertAfter: 1,
- },
- }
- tinyPacket := iptestutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber)
-
- for _, test := range writePacketsTests {
- t.Run(test.description, func(t *testing.T) {
- for _, ft := range fragmentationTests {
- t.Run(ft.description, func(t *testing.T) {
- var pkts stack.PacketBufferList
- for i := 0; i < test.insertBefore; i++ {
- pkts.PushBack(tinyPacket.Clone())
- }
- pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
- pkts.PushBack(pkt.Clone())
- for i := 0; i < test.insertAfter; i++ {
- pkts.PushBack(tinyPacket.Clone())
- }
-
- ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
- r := buildRoute(t, ep)
-
- wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter
- n, err := r.WritePackets(pkts, stack.NetworkHeaderParams{
- Protocol: tcp.ProtocolNumber,
- TTL: ttl,
- TOS: stack.DefaultTOS,
- })
- if err != nil {
- t.Errorf("got WritePackets(_, _, _) = (_, %s), want = (_, nil)", err)
- }
- if n != wantTotalPackets {
- t.Errorf("got WritePackets(_, _, _) = (%d, _), want = (%d, _)", n, wantTotalPackets)
- }
- if got := len(ep.WrittenPackets); got != wantTotalPackets {
- t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets)
- }
- if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets {
- t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets)
- }
- if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != 0 {
- t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
- }
-
- if wantTotalPackets == 0 {
- return
- }
-
- fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore]
- if err := compareFragments(fragments, pkt, ft.mtu, ft.wantFragments, tcp.ProtocolNumber, false /* withIPHeader */, extraHeaderReserve); err != nil {
- t.Error(err)
- }
- })
- }
- })
- }
-}
-
-// TestFragmentationErrors checks that errors are returned from WritePacket
-// correctly.
-func TestFragmentationErrors(t *testing.T) {
- const ttl = 42
-
- tests := []struct {
- description string
- mtu uint32
- transportHeaderLength int
- payloadSize int
- allowPackets int
- outgoingErrors int
- mockError tcpip.Error
- wantError tcpip.Error
- }{
- {
- description: "No frag",
- mtu: 2000,
- payloadSize: 1000,
- transportHeaderLength: 0,
- allowPackets: 0,
- outgoingErrors: 1,
- mockError: &tcpip.ErrAborted{},
- wantError: &tcpip.ErrAborted{},
- },
- {
- description: "Error on first frag",
- mtu: 500,
- payloadSize: 1000,
- transportHeaderLength: 0,
- allowPackets: 0,
- outgoingErrors: 3,
- mockError: &tcpip.ErrAborted{},
- wantError: &tcpip.ErrAborted{},
- },
- {
- description: "Error on second frag",
- mtu: 500,
- payloadSize: 1000,
- transportHeaderLength: 0,
- allowPackets: 1,
- outgoingErrors: 2,
- mockError: &tcpip.ErrAborted{},
- wantError: &tcpip.ErrAborted{},
- },
- {
- description: "Error on first frag MTU smaller than header",
- mtu: 500,
- transportHeaderLength: 1000,
- payloadSize: 500,
- allowPackets: 0,
- outgoingErrors: 4,
- mockError: &tcpip.ErrAborted{},
- wantError: &tcpip.ErrAborted{},
- },
- {
- description: "Error when MTU is smaller than IPv4 minimum MTU",
- mtu: header.IPv4MinimumMTU - 1,
- transportHeaderLength: 0,
- payloadSize: 500,
- allowPackets: 0,
- outgoingErrors: 1,
- mockError: nil,
- wantError: &tcpip.ErrInvalidEndpointState{},
- },
- }
-
- for _, ft := range tests {
- t.Run(ft.description, func(t *testing.T) {
- pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
- ep := iptestutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
- r := buildRoute(t, ep)
- err := r.WritePacket(stack.NetworkHeaderParams{
- Protocol: tcp.ProtocolNumber,
- TTL: ttl,
- TOS: stack.DefaultTOS,
- }, pkt)
- if diff := cmp.Diff(ft.wantError, err); diff != "" {
- t.Fatalf("unexpected error from r.WritePacket(_, _, _), (-want, +got):\n%s", diff)
- }
- if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets {
- t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets)
- }
- if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors {
- t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors)
- }
- })
- }
-}
-
-func TestInvalidFragments(t *testing.T) {
- const (
- nicID = 1
- linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- addr1 = tcpip.Address("\x0a\x00\x00\x01")
- addr2 = tcpip.Address("\x0a\x00\x00\x02")
- tos = 0
- ident = 1
- ttl = 48
- protocol = 6
- )
-
- payloadGen := func(payloadLen int) []byte {
- payload := make([]byte, payloadLen)
- for i := 0; i < len(payload); i++ {
- payload[i] = 0x30
- }
- return payload
- }
-
- type fragmentData struct {
- ipv4fields header.IPv4Fields
- // 0 means insert the correct IHL. Non 0 means override the correct IHL.
- overrideIHL int // For 0 use 1 as it is an int and will be divided by 4.
- payload []byte
- autoChecksum bool // If true, the Checksum field will be overwritten.
- }
-
- tests := []struct {
- name string
- fragments []fragmentData
- wantMalformedIPPackets uint64
- wantMalformedFragments uint64
- }{
- {
- name: "IHL and TotalLength zero, FragmentOffset non-zero",
- fragments: []fragmentData{
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: 0,
- ID: ident,
- Flags: header.IPv4FlagDontFragment | header.IPv4FlagMoreFragments,
- FragmentOffset: 59776,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- overrideIHL: 1, // See note above.
- payload: payloadGen(12),
- autoChecksum: true,
- },
- },
- wantMalformedIPPackets: 1,
- wantMalformedFragments: 0,
- },
- {
- name: "IHL and TotalLength zero, FragmentOffset zero",
- fragments: []fragmentData{
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: 0,
- ID: ident,
- Flags: header.IPv4FlagMoreFragments,
- FragmentOffset: 0,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- overrideIHL: 1, // See note above.
- payload: payloadGen(12),
- autoChecksum: true,
- },
- },
- wantMalformedIPPackets: 1,
- wantMalformedFragments: 0,
- },
- {
- // Payload 17 octets and Fragment offset 65520
- // Leading to the fragment end to be past 65536.
- name: "fragment ends past 65536",
- fragments: []fragmentData{
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize + 17,
- ID: ident,
- Flags: 0,
- FragmentOffset: 65520,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: payloadGen(17),
- autoChecksum: true,
- },
- },
- wantMalformedIPPackets: 1,
- wantMalformedFragments: 1,
- },
- {
- // Payload 16 octets and fragment offset 65520
- // Leading to the fragment end to be exactly 65536.
- name: "fragment ends exactly at 65536",
- fragments: []fragmentData{
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize + 16,
- ID: ident,
- Flags: 0,
- FragmentOffset: 65520,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: payloadGen(16),
- autoChecksum: true,
- },
- },
- wantMalformedIPPackets: 0,
- wantMalformedFragments: 0,
- },
- {
- name: "IHL less than IPv4 minimum size",
- fragments: []fragmentData{
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize + 28,
- ID: ident,
- Flags: 0,
- FragmentOffset: 1944,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: payloadGen(28),
- overrideIHL: header.IPv4MinimumSize - 12,
- autoChecksum: true,
- },
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize - 12,
- ID: ident,
- Flags: header.IPv4FlagMoreFragments,
- FragmentOffset: 0,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: payloadGen(28),
- overrideIHL: header.IPv4MinimumSize - 12,
- autoChecksum: true,
- },
- },
- wantMalformedIPPackets: 2,
- wantMalformedFragments: 0,
- },
- {
- name: "fragment with short TotalLength and extra payload",
- fragments: []fragmentData{
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize + 28,
- ID: ident,
- Flags: 0,
- FragmentOffset: 28816,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: payloadGen(28),
- overrideIHL: header.IPv4MinimumSize + 4,
- autoChecksum: true,
- },
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize + 4,
- ID: ident,
- Flags: header.IPv4FlagMoreFragments,
- FragmentOffset: 0,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: payloadGen(28),
- overrideIHL: header.IPv4MinimumSize + 4,
- autoChecksum: true,
- },
- },
- wantMalformedIPPackets: 1,
- wantMalformedFragments: 1,
- },
- {
- name: "multiple fragments with More Fragments flag set to false",
- fragments: []fragmentData{
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize + 8,
- ID: ident,
- Flags: 0,
- FragmentOffset: 128,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: payloadGen(8),
- autoChecksum: true,
- },
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize + 8,
- ID: ident,
- Flags: 0,
- FragmentOffset: 8,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: payloadGen(8),
- autoChecksum: true,
- },
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize + 8,
- ID: ident,
- Flags: header.IPv4FlagMoreFragments,
- FragmentOffset: 0,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: payloadGen(8),
- autoChecksum: true,
- },
- },
- wantMalformedIPPackets: 1,
- wantMalformedFragments: 1,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{
- ipv4.NewProtocol,
- },
- })
- e := channel.New(0, 1500, linkAddr)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: addr2.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
-
- for _, f := range test.fragments {
- pktSize := header.IPv4MinimumSize + len(f.payload)
- hdr := buffer.NewPrependable(pktSize)
-
- ip := header.IPv4(hdr.Prepend(pktSize))
- ip.Encode(&f.ipv4fields)
- if want, got := len(f.payload), copy(ip[header.IPv4MinimumSize:], f.payload); want != got {
- t.Fatalf("copied %d bytes, expected %d bytes.", got, want)
- }
- // Encode sets this up correctly. If we want a different value for
- // testing then we need to overwrite the good value.
- if f.overrideIHL != 0 {
- ip.SetHeaderLength(uint8(f.overrideIHL))
- // If we are asked to add options (type not specified) then pad
- // with 0 (EOL). RFC 791 page 23 says "The padding is zero".
- for i := header.IPv4MinimumSize; i < f.overrideIHL; i++ {
- ip[i] = byte(header.IPv4OptionListEndType)
- }
- }
-
- if f.autoChecksum {
- ip.SetChecksum(0)
- ip.SetChecksum(^ip.CalculateChecksum())
- }
-
- vv := hdr.View().ToVectorisedView()
- e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- }))
- }
-
- if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want {
- t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want)
- }
- if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want {
- t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want)
- }
- })
- }
-}
-
-func TestFragmentReassemblyTimeout(t *testing.T) {
- const (
- nicID = 1
- linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- addr1 = tcpip.Address("\x0a\x00\x00\x01")
- addr2 = tcpip.Address("\x0a\x00\x00\x02")
- tos = 0
- ident = 1
- ttl = 48
- protocol = 99
- data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT"
- )
-
- type fragmentData struct {
- ipv4fields header.IPv4Fields
- payload []byte
- }
-
- tests := []struct {
- name string
- fragments []fragmentData
- expectICMP bool
- }{
- {
- name: "first fragment only",
- fragments: []fragmentData{
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize + 16,
- ID: ident,
- Flags: header.IPv4FlagMoreFragments,
- FragmentOffset: 0,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: []byte(data)[:16],
- },
- },
- expectICMP: true,
- },
- {
- name: "two first fragments",
- fragments: []fragmentData{
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize + 16,
- ID: ident,
- Flags: header.IPv4FlagMoreFragments,
- FragmentOffset: 0,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: []byte(data)[:16],
- },
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize + 16,
- ID: ident,
- Flags: header.IPv4FlagMoreFragments,
- FragmentOffset: 0,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: []byte(data)[:16],
- },
- },
- expectICMP: true,
- },
- {
- name: "second fragment only",
- fragments: []fragmentData{
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16),
- ID: ident,
- Flags: 0,
- FragmentOffset: 8,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: []byte(data)[16:],
- },
- },
- expectICMP: false,
- },
- {
- name: "two fragments with a gap",
- fragments: []fragmentData{
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize + 8,
- ID: ident,
- Flags: header.IPv4FlagMoreFragments,
- FragmentOffset: 0,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: []byte(data)[:8],
- },
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16),
- ID: ident,
- Flags: 0,
- FragmentOffset: 16,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: []byte(data)[16:],
- },
- },
- expectICMP: true,
- },
- {
- name: "two fragments with a gap in reverse order",
- fragments: []fragmentData{
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16),
- ID: ident,
- Flags: 0,
- FragmentOffset: 16,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: []byte(data)[16:],
- },
- {
- ipv4fields: header.IPv4Fields{
- TOS: tos,
- TotalLength: header.IPv4MinimumSize + 8,
- ID: ident,
- Flags: header.IPv4FlagMoreFragments,
- FragmentOffset: 0,
- TTL: ttl,
- Protocol: protocol,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- payload: []byte(data)[:8],
- },
- },
- expectICMP: true,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{
- ipv4.NewProtocol,
- },
- Clock: clock,
- })
- e := channel.New(1, 1500, linkAddr)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: addr2.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv4EmptySubnet,
- NIC: nicID,
- }})
-
- var firstFragmentSent buffer.View
- for _, f := range test.fragments {
- pktSize := header.IPv4MinimumSize
- hdr := buffer.NewPrependable(pktSize)
-
- ip := header.IPv4(hdr.Prepend(pktSize))
- ip.Encode(&f.ipv4fields)
-
- ip.SetChecksum(0)
- ip.SetChecksum(^ip.CalculateChecksum())
-
- vv := hdr.View().ToVectorisedView()
- vv.AppendView(f.payload)
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- })
-
- if firstFragmentSent == nil && ip.FragmentOffset() == 0 {
- firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader())
- }
-
- e.InjectInbound(header.IPv4ProtocolNumber, pkt)
- }
-
- clock.Advance(ipv4.ReassembleTimeout)
-
- reply, ok := e.Read()
- if !test.expectICMP {
- if ok {
- t.Fatalf("unexpected ICMP error message received: %#v", reply)
- }
- return
- }
- if !ok {
- t.Fatal("expected ICMP error message missing")
- }
- if firstFragmentSent == nil {
- t.Fatalf("unexpected ICMP error message received: %#v", reply)
- }
-
- checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
- checker.SrcAddr(addr2),
- checker.DstAddr(addr1),
- checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+firstFragmentSent.Size())),
- checker.IPv4HeaderLength(header.IPv4MinimumSize),
- checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4TimeExceeded),
- checker.ICMPv4Code(header.ICMPv4ReassemblyTimeout),
- checker.ICMPv4Checksum(),
- checker.ICMPv4Payload(firstFragmentSent),
- ),
- )
- })
- }
-}
-
-// TestReceiveFragments feeds fragments in through the incoming packet path to
-// test reassembly
-func TestReceiveFragments(t *testing.T) {
- const (
- nicID = 1
-
- addr1 = tcpip.Address("\x0c\xa8\x00\x01") // 192.168.0.1
- addr2 = tcpip.Address("\x0c\xa8\x00\x02") // 192.168.0.2
- addr3 = tcpip.Address("\x0c\xa8\x00\x03") // 192.168.0.3
- )
-
- // Build and return a UDP header containing payload.
- udpGen := func(payloadLen int, multiplier uint8, src, dst tcpip.Address) buffer.View {
- payload := buffer.NewView(payloadLen)
- for i := 0; i < len(payload); i++ {
- payload[i] = uint8(i) * multiplier
- }
-
- udpLength := header.UDPMinimumSize + len(payload)
-
- hdr := buffer.NewPrependable(udpLength)
- u := header.UDP(hdr.Prepend(udpLength))
- u.Encode(&header.UDPFields{
- SrcPort: 5555,
- DstPort: 80,
- Length: uint16(udpLength),
- })
- copy(u.Payload(), payload)
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(udpLength))
- sum = header.Checksum(payload, sum)
- u.SetChecksum(^u.CalculateChecksum(sum))
- return hdr.View()
- }
-
- // UDP header plus a payload of 0..256
- ipv4Payload1Addr1ToAddr2 := udpGen(256, 1, addr1, addr2)
- udpPayload1Addr1ToAddr2 := ipv4Payload1Addr1ToAddr2[header.UDPMinimumSize:]
- ipv4Payload1Addr3ToAddr2 := udpGen(256, 1, addr3, addr2)
- udpPayload1Addr3ToAddr2 := ipv4Payload1Addr3ToAddr2[header.UDPMinimumSize:]
- // UDP header plus a payload of 0..256 in increments of 2.
- ipv4Payload2Addr1ToAddr2 := udpGen(128, 2, addr1, addr2)
- udpPayload2Addr1ToAddr2 := ipv4Payload2Addr1ToAddr2[header.UDPMinimumSize:]
- // UDP header plus a payload of 0..256 in increments of 3.
- // Used to test cases where the fragment blocks are not a multiple of
- // the fragment block size of 8 (RFC 791 section 3.1 page 14).
- ipv4Payload3Addr1ToAddr2 := udpGen(127, 3, addr1, addr2)
- udpPayload3Addr1ToAddr2 := ipv4Payload3Addr1ToAddr2[header.UDPMinimumSize:]
- // Used to test the max reassembled IPv4 payload length.
- ipv4Payload4Addr1ToAddr2 := udpGen(header.UDPMaximumSize-header.UDPMinimumSize, 4, addr1, addr2)
- udpPayload4Addr1ToAddr2 := ipv4Payload4Addr1ToAddr2[header.UDPMinimumSize:]
-
- type fragmentData struct {
- srcAddr tcpip.Address
- dstAddr tcpip.Address
- id uint16
- flags uint8
- fragmentOffset uint16
- payload buffer.View
- }
-
- tests := []struct {
- name string
- fragments []fragmentData
- expectedPayloads [][]byte
- }{
- {
- name: "No fragmentation",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: 0,
- fragmentOffset: 0,
- payload: ipv4Payload1Addr1ToAddr2,
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
- },
- {
- name: "No fragmentation with size not a multiple of fragment block size",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: 0,
- fragmentOffset: 0,
- payload: ipv4Payload3Addr1ToAddr2,
- },
- },
- expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2},
- },
- {
- name: "More fragments without payload",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload1Addr1ToAddr2,
- },
- },
- expectedPayloads: nil,
- },
- {
- name: "Non-zero fragment offset without payload",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: 0,
- fragmentOffset: 8,
- payload: ipv4Payload1Addr1ToAddr2,
- },
- },
- expectedPayloads: nil,
- },
- {
- name: "Two fragments",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload1Addr1ToAddr2[:64],
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: 0,
- fragmentOffset: 64,
- payload: ipv4Payload1Addr1ToAddr2[64:],
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
- },
- {
- name: "Two fragments out of order",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: 0,
- fragmentOffset: 64,
- payload: ipv4Payload1Addr1ToAddr2[64:],
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload1Addr1ToAddr2[:64],
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
- },
- {
- name: "Two fragments with last fragment size not a multiple of fragment block size",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload3Addr1ToAddr2[:64],
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: 0,
- fragmentOffset: 64,
- payload: ipv4Payload3Addr1ToAddr2[64:],
- },
- },
- expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2},
- },
- {
- name: "Two fragments with first fragment size not a multiple of fragment block size",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload3Addr1ToAddr2[:63],
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: 0,
- fragmentOffset: 63,
- payload: ipv4Payload3Addr1ToAddr2[63:],
- },
- },
- expectedPayloads: nil,
- },
- {
- name: "Second fragment has MoreFlags set",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload1Addr1ToAddr2[:64],
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 64,
- payload: ipv4Payload1Addr1ToAddr2[64:],
- },
- },
- expectedPayloads: nil,
- },
- {
- name: "Two fragments with different IDs",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload1Addr1ToAddr2[:64],
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 2,
- flags: 0,
- fragmentOffset: 64,
- payload: ipv4Payload1Addr1ToAddr2[64:],
- },
- },
- expectedPayloads: nil,
- },
- {
- name: "Two interleaved fragmented packets",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload1Addr1ToAddr2[:64],
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 2,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload2Addr1ToAddr2[:64],
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: 0,
- fragmentOffset: 64,
- payload: ipv4Payload1Addr1ToAddr2[64:],
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 2,
- flags: 0,
- fragmentOffset: 64,
- payload: ipv4Payload2Addr1ToAddr2[64:],
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload2Addr1ToAddr2},
- },
- {
- name: "Two interleaved fragmented packets from different sources but with same ID",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload1Addr1ToAddr2[:64],
- },
- {
- srcAddr: addr3,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload1Addr3ToAddr2[:32],
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: 0,
- fragmentOffset: 64,
- payload: ipv4Payload1Addr1ToAddr2[64:],
- },
- {
- srcAddr: addr3,
- dstAddr: addr2,
- id: 1,
- flags: 0,
- fragmentOffset: 32,
- payload: ipv4Payload1Addr3ToAddr2[32:],
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload1Addr3ToAddr2},
- },
- {
- name: "Fragment without followup",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload1Addr1ToAddr2[:64],
- },
- },
- expectedPayloads: nil,
- },
- {
- name: "Two fragments reassembled into a maximum UDP packet",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload4Addr1ToAddr2[:65512],
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: 0,
- fragmentOffset: 65512,
- payload: ipv4Payload4Addr1ToAddr2[65512:],
- },
- },
- expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2},
- },
- {
- name: "Two fragments with MF flag reassembled into a maximum UDP packet",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 0,
- payload: ipv4Payload4Addr1ToAddr2[:65512],
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- id: 1,
- flags: header.IPv4FlagMoreFragments,
- fragmentOffset: 65512,
- payload: ipv4Payload4Addr1ToAddr2[65512:],
- },
- },
- expectedPayloads: nil,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- // Setup a stack and endpoint.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- RawFactory: raw.EndpointFactory{},
- })
- e := channel.New(0, 1280, "\xf0\x00")
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: addr2.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
-
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
- defer close(ch)
- ep, err := s.NewEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, header.IPv4ProtocolNumber, err)
- }
- defer ep.Close()
-
- bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80}
- if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("Bind(%+v): %s", bindAddr, err)
- }
-
- // Bring up a raw endpoint so we can examine network headers.
- epRaw, err := s.NewRawEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq, true /* associated */)
- if err != nil {
- t.Fatalf("NewRawEndpoint(%d, %d, _, true): %s", udp.ProtocolNumber, header.IPv4ProtocolNumber, err)
- }
- defer epRaw.Close()
-
- // Prepare and send the fragments.
- for _, frag := range test.fragments {
- hdr := buffer.NewPrependable(header.IPv4MinimumSize)
-
- // Serialize IPv4 fixed header.
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- TotalLength: header.IPv4MinimumSize + uint16(len(frag.payload)),
- ID: frag.id,
- Flags: frag.flags,
- FragmentOffset: frag.fragmentOffset,
- TTL: 64,
- Protocol: uint8(header.UDPProtocolNumber),
- SrcAddr: frag.srcAddr,
- DstAddr: frag.dstAddr,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- vv := hdr.View().ToVectorisedView()
- vv.AppendView(frag.payload)
-
- e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- }))
- }
-
- if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want {
- t.Errorf("got UDP Rx Packets = %d, want = %d", got, want)
- }
-
- for i, expectedPayload := range test.expectedPayloads {
- // Check UDP payload delivered by UDP endpoint.
- var buf bytes.Buffer
- result, err := ep.Read(&buf, tcpip.ReadOptions{})
- if err != nil {
- t.Fatalf("(i=%d) ep.Read: %s", i, err)
- }
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: len(expectedPayload),
- Total: len(expectedPayload),
- }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
- t.Errorf("(i=%d) ep.Read: unexpected result (-want +got):\n%s", i, diff)
- }
- if diff := cmp.Diff(expectedPayload, buf.Bytes()); diff != "" {
- t.Errorf("(i=%d) ep.Read: UDP payload mismatch (-want +got):\n%s", i, diff)
- }
-
- // Check IPv4 header in packet delivered by raw endpoint.
- buf.Reset()
- result, err = epRaw.Read(&buf, tcpip.ReadOptions{})
- if err != nil {
- t.Fatalf("(i=%d) epRaw.Read: %s", i, err)
- }
- // Reassambly does not take care of checksum. Here we write our own
- // check routine instead of using checker.IPv4.
- ip := header.IPv4(buf.Bytes())
- for _, check := range []checker.NetworkChecker{
- checker.FragmentFlags(0),
- checker.FragmentOffset(0),
- checker.IPFullLength(uint16(header.IPv4MinimumSize + header.UDPMinimumSize + len(expectedPayload))),
- } {
- check(t, []header.Network{ip})
- }
- }
-
- res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{})
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("(last) got Read = (%#v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{})
- }
- })
- }
-}
-
-func TestWriteStats(t *testing.T) {
- const nPackets = 3
-
- tests := []struct {
- name string
- setup func(*testing.T, *stack.Stack)
- allowPackets int
- expectSent int
- expectOutputDropped int
- expectPostroutingDropped int
- expectWritten int
- }{
- {
- name: "Accept all",
- // No setup needed, tables accept everything by default.
- setup: func(*testing.T, *stack.Stack) {},
- allowPackets: math.MaxInt32,
- expectSent: nPackets,
- expectOutputDropped: 0,
- expectPostroutingDropped: 0,
- expectWritten: nPackets,
- }, {
- name: "Accept all with error",
- // No setup needed, tables accept everything by default.
- setup: func(*testing.T, *stack.Stack) {},
- allowPackets: nPackets - 1,
- expectSent: nPackets - 1,
- expectOutputDropped: 0,
- expectPostroutingDropped: 0,
- expectWritten: nPackets - 1,
- }, {
- name: "Drop all with Output chain",
- setup: func(t *testing.T, stk *stack.Stack) {
- // Install Output DROP rule.
- ipt := stk.IPTables()
- filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Output]
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("failed to replace table: %s", err)
- }
- },
- allowPackets: math.MaxInt32,
- expectSent: 0,
- expectOutputDropped: nPackets,
- expectPostroutingDropped: 0,
- expectWritten: nPackets,
- }, {
- name: "Drop all with Postrouting chain",
- setup: func(t *testing.T, stk *stack.Stack) {
- ipt := stk.IPTables()
- filter := ipt.GetTable(stack.NATID, false /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Postrouting]
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("failed to replace table: %s", err)
- }
- },
- allowPackets: math.MaxInt32,
- expectSent: 0,
- expectOutputDropped: 0,
- expectPostroutingDropped: nPackets,
- expectWritten: nPackets,
- }, {
- name: "Drop some with Output chain",
- setup: func(t *testing.T, stk *stack.Stack) {
- // Install Output DROP rule that matches only 1
- // of the 3 packets.
- ipt := stk.IPTables()
- filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
- // We'll match and DROP the last packet.
- ruleIdx := filter.BuiltinChains[stack.Output]
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
- // Make sure the next rule is ACCEPT.
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("failed to replace table: %s", err)
- }
- },
- allowPackets: math.MaxInt32,
- expectSent: nPackets - 1,
- expectOutputDropped: 1,
- expectPostroutingDropped: 0,
- expectWritten: nPackets,
- }, {
- name: "Drop some with Postrouting chain",
- setup: func(t *testing.T, stk *stack.Stack) {
- // Install Postrouting DROP rule that matches only 1
- // of the 3 packets.
- ipt := stk.IPTables()
- filter := ipt.GetTable(stack.NATID, false /* ipv6 */)
- // We'll match and DROP the last packet.
- ruleIdx := filter.BuiltinChains[stack.Postrouting]
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
- // Make sure the next rule is ACCEPT.
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("failed to replace table: %s", err)
- }
- },
- allowPackets: math.MaxInt32,
- expectSent: nPackets - 1,
- expectOutputDropped: 0,
- expectPostroutingDropped: 1,
- expectWritten: nPackets,
- },
- }
-
- // Parameterize the tests to run with both WritePacket and WritePackets.
- writers := []struct {
- name string
- writePackets func(*stack.Route, stack.PacketBufferList) (int, tcpip.Error)
- }{
- {
- name: "WritePacket",
- writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) {
- nWritten := 0
- for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- if err := rt.WritePacket(stack.NetworkHeaderParams{}, pkt); err != nil {
- return nWritten, err
- }
- nWritten++
- }
- return nWritten, nil
- },
- }, {
- name: "WritePackets",
- writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) {
- return rt.WritePackets(pkts, stack.NetworkHeaderParams{})
- },
- },
- }
-
- for _, writer := range writers {
- t.Run(writer.name, func(t *testing.T) {
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- ep := iptestutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets)
- rt := buildRoute(t, ep)
-
- var pkts stack.PacketBufferList
- for i := 0; i < nPackets; i++ {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()),
- Data: buffer.NewView(0).ToVectorisedView(),
- })
- pkt.TransportHeader().Push(header.UDPMinimumSize)
- pkts.PushBack(pkt)
- }
-
- test.setup(t, rt.Stack())
-
- nWritten, _ := writer.writePackets(rt, pkts)
-
- if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
- t.Errorf("got rt.Stats().IP.PacketsSent.Value() = %d, want = %d", got, test.expectSent)
- }
- if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectOutputDropped {
- t.Errorf("got rt.Stats().IP.IPTablesOutputDropped.Value() = %d, want = %d", got, test.expectOutputDropped)
- }
- if got := int(rt.Stats().IP.IPTablesPostroutingDropped.Value()); got != test.expectPostroutingDropped {
- t.Errorf("got rt.Stats().IP.IPTablesPostroutingDropped.Value() = %d, want = %d", got, test.expectPostroutingDropped)
- }
- if nWritten != test.expectWritten {
- t.Errorf("got nWritten = %d, want = %d", nWritten, test.expectWritten)
- }
- })
- }
- })
- }
-}
-
-func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- })
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatalf("CreateNIC(1, _) failed: %s", err)
- }
- const (
- src = tcpip.Address("\x10\x00\x00\x01")
- dst = tcpip.Address("\x10\x00\x00\x02")
- )
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: src.WithPrefix(),
- }
- if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
- }
- {
- mask := tcpip.AddressMask(header.IPv4Broadcast)
- subnet, err := tcpip.NewSubnet(dst, mask)
- if err != nil {
- t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err)
- }
- s.SetRouteTable([]tcpip.Route{{
- Destination: subnet,
- NIC: 1,
- }})
- }
- rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s", src, dst, ipv4.ProtocolNumber, err)
- }
- return rt
-}
-
-// limitedMatcher is an iptables matcher that matches after a certain number of
-// packets are checked against it.
-type limitedMatcher struct {
- limit int
-}
-
-// Name implements Matcher.Name.
-func (*limitedMatcher) Name() string {
- return "limitedMatcher"
-}
-
-// Match implements Matcher.Match.
-func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) (bool, bool) {
- if lm.limit == 0 {
- return true, false
- }
- lm.limit--
- return false, false
-}
-
-func TestPacketQueuing(t *testing.T) {
- const nicID = 1
-
- var (
- host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
- host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
-
- host1IPv4Addr = tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
- PrefixLen: 24,
- },
- }
- host2IPv4Addr = tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
- PrefixLen: 8,
- },
- }
- )
-
- tests := []struct {
- name string
- rxPkt func(*channel.Endpoint)
- checkResp func(*testing.T, *channel.Endpoint)
- }{
- {
- name: "ICMP Error",
- rxPkt: func(e *channel.Endpoint) {
- hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.UDPMinimumSize)
- u := header.UDP(hdr.Prepend(header.UDPMinimumSize))
- u.Encode(&header.UDPFields{
- SrcPort: 5555,
- DstPort: 80,
- Length: header.UDPMinimumSize,
- })
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv4Addr.AddressWithPrefix.Address, host1IPv4Addr.AddressWithPrefix.Address, header.UDPMinimumSize)
- sum = header.Checksum(nil, sum)
- u.SetChecksum(^u.CalculateChecksum(sum))
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- TotalLength: header.IPv4MinimumSize + header.UDPMinimumSize,
- TTL: ipv4.DefaultTTL,
- Protocol: uint8(udp.ProtocolNumber),
- SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
- DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
- e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- },
- checkResp: func(t *testing.T, e *channel.Endpoint) {
- p, ok := e.Read()
- if !ok {
- t.Fatalf("timed out waiting for packet")
- }
- if p.Proto != header.IPv4ProtocolNumber {
- t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber)
- }
- if p.Route.RemoteLinkAddress != host2NICLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
- }
- checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
- checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
- checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4DstUnreachable),
- checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
- },
- },
-
- {
- name: "Ping",
- rxPkt: func(e *channel.Endpoint) {
- totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
- hdr := buffer.NewPrependable(totalLen)
- pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
- pkt.SetType(header.ICMPv4Echo)
- pkt.SetCode(0)
- pkt.SetChecksum(0)
- pkt.SetChecksum(^header.Checksum(pkt, 0))
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(totalLen),
- Protocol: uint8(icmp.ProtocolNumber4),
- TTL: ipv4.DefaultTTL,
- SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
- DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
- e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- },
- checkResp: func(t *testing.T, e *channel.Endpoint) {
- p, ok := e.Read()
- if !ok {
- t.Fatalf("timed out waiting for packet")
- }
- if p.Proto != header.IPv4ProtocolNumber {
- t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber)
- }
- if p.Route.RemoteLinkAddress != host2NICLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
- }
- checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
- checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
- checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4EchoReply),
- checker.ICMPv4Code(header.ICMPv4UnusedCode)))
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e := channel.New(1, defaultMTU, host1NICLinkAddr)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- Clock: clock,
- })
-
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
- if err := s.AddProtocolAddress(nicID, host1IPv4Addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv4Addr, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
- NIC: nicID,
- },
- })
-
- // Receive a packet to trigger link resolution before a response is sent.
- test.rxPkt(e)
-
- // Wait for a ARP request since link address resolution should be
- // performed.
- {
- clock.RunImmediatelyScheduledJobs()
- p, ok := e.Read()
- if !ok {
- t.Fatalf("timed out waiting for packet")
- }
- if p.Proto != arp.ProtocolNumber {
- t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber)
- }
- if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress)
- }
- rep := header.ARP(p.Pkt.NetworkHeader().View())
- if got := rep.Op(); got != header.ARPRequest {
- t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest)
- }
- if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != host1NICLinkAddr {
- t.Errorf("got HardwareAddressSender = %s, want = %s", got, host1NICLinkAddr)
- }
- if got := tcpip.Address(rep.ProtocolAddressSender()); got != host1IPv4Addr.AddressWithPrefix.Address {
- t.Errorf("got ProtocolAddressSender = %s, want = %s", got, host1IPv4Addr.AddressWithPrefix.Address)
- }
- if got := tcpip.Address(rep.ProtocolAddressTarget()); got != host2IPv4Addr.AddressWithPrefix.Address {
- t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, host2IPv4Addr.AddressWithPrefix.Address)
- }
- }
-
- // Send an ARP reply to complete link address resolution.
- {
- hdr := buffer.View(make([]byte, header.ARPSize))
- packet := header.ARP(hdr)
- packet.SetIPv4OverEthernet()
- packet.SetOp(header.ARPReply)
- copy(packet.HardwareAddressSender(), host2NICLinkAddr)
- copy(packet.ProtocolAddressSender(), host2IPv4Addr.AddressWithPrefix.Address)
- copy(packet.HardwareAddressTarget(), host1NICLinkAddr)
- copy(packet.ProtocolAddressTarget(), host1IPv4Addr.AddressWithPrefix.Address)
- e.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.ToVectorisedView(),
- }))
- }
-
- // Expect the response now that the link address has resolved.
- clock.RunImmediatelyScheduledJobs()
- test.checkResp(t, e)
-
- // Since link resolution was already performed, it shouldn't be performed
- // again.
- test.rxPkt(e)
- test.checkResp(t, e)
- })
- }
-}
-
-// TestCloseLocking test that lock ordering is followed when closing an
-// endpoint.
-func TestCloseLocking(t *testing.T) {
- const (
- nicID1 = 1
- nicID2 = 2
-
- iterations = 1000
- )
-
- var (
- src = testutil.MustParse4("16.0.0.1")
- dst = testutil.MustParse4("16.0.0.2")
- )
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
-
- // Perform NAT so that the endpoint tries to search for a sibling endpoint
- // which ends up taking the protocol and endpoint lock (in that order).
- table := stack.Table{
- Rules: []stack.Rule{
- {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
- {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
- {Target: &stack.RedirectTarget{Port: 5, NetworkProtocol: header.IPv4ProtocolNumber}},
- {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
- {Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
- },
- BuiltinChains: [stack.NumHooks]int{
- stack.Prerouting: 0,
- stack.Input: 1,
- stack.Forward: stack.HookUnset,
- stack.Output: 2,
- stack.Postrouting: 3,
- },
- Underflows: [stack.NumHooks]int{
- stack.Prerouting: 0,
- stack.Input: 1,
- stack.Forward: stack.HookUnset,
- stack.Output: 2,
- stack.Postrouting: 3,
- },
- }
- if err := s.IPTables().ReplaceTable(stack.NATID, table, false /* ipv6 */); err != nil {
- t.Fatalf("s.IPTables().ReplaceTable(...): %s", err)
- }
-
- e := channel.New(0, defaultMTU, "")
- if err := s.CreateNIC(nicID1, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
- }
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: src.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID1, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr, err)
- }
-
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv4EmptySubnet,
- NIC: nicID1,
- }})
-
- var wq waiter.Queue
- ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatal(err)
- }
- defer ep.Close()
-
- addr := tcpip.FullAddress{NIC: nicID1, Addr: dst, Port: 53}
- if err := ep.Connect(addr); err != nil {
- t.Errorf("ep.Connect(%#v): %s", addr, err)
- }
-
- var wg sync.WaitGroup
- defer wg.Wait()
-
- // Writing packets should trigger NAT which requires the stack to search the
- // protocol for network endpoints with the destination address.
- //
- // Creating and removing interfaces should modify the protocol and endpoint
- // which requires taking the locks of each.
- //
- // We expect the protocol > endpoint lock ordering to be followed here.
- wg.Add(2)
- go func() {
- defer wg.Done()
-
- data := []byte{1, 2, 3, 4}
-
- for i := 0; i < iterations; i++ {
- var r bytes.Reader
- r.Reset(data)
- if n, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Errorf("ep.Write(_, _): %s", err)
- return
- } else if want := int64(len(data)); n != want {
- t.Errorf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want)
- return
- }
- }
- }()
- go func() {
- defer wg.Done()
-
- for i := 0; i < iterations; i++ {
- if err := s.CreateNIC(nicID2, stack.LinkEndpoint(channel.New(0, defaultMTU, ""))); err != nil {
- t.Errorf("CreateNIC(%d, _): %s", nicID2, err)
- return
- }
- if err := s.RemoveNIC(nicID2); err != nil {
- t.Errorf("RemoveNIC(%d): %s", nicID2, err)
- return
- }
- }
- }()
-}
diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go
deleted file mode 100644
index d1f9e3cf5..000000000
--- a/pkg/tcpip/network/ipv4/stats_test.go
+++ /dev/null
@@ -1,99 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ipv4
-
-import (
- "reflect"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
-)
-
-var _ stack.NetworkInterface = (*testInterface)(nil)
-
-type testInterface struct {
- stack.NetworkInterface
- nicID tcpip.NICID
-}
-
-func (t *testInterface) ID() tcpip.NICID {
- return t.nicID
-}
-
-func knownNICIDs(proto *protocol) []tcpip.NICID {
- var nicIDs []tcpip.NICID
-
- for k := range proto.mu.eps {
- nicIDs = append(nicIDs, k)
- }
-
- return nicIDs
-}
-
-func TestClearEndpointFromProtocolOnClose(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
- nic := testInterface{nicID: 1}
- ep := proto.NewEndpoint(&nic, nil).(*endpoint)
- var nicIDs []tcpip.NICID
-
- proto.mu.Lock()
- foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()]
- nicIDs = knownNICIDs(proto)
- proto.mu.Unlock()
-
- if !hasEndpointBeforeClose {
- t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs)
- }
- if foundEP != ep {
- t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID())
- }
-
- ep.Close()
-
- proto.mu.Lock()
- _, hasEP := proto.mu.eps[nic.ID()]
- nicIDs = knownNICIDs(proto)
- proto.mu.Unlock()
- if hasEP {
- t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs)
- }
-}
-
-func TestMultiCounterStatsInitialization(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
- var nic testInterface
- ep := proto.NewEndpoint(&nic, nil).(*endpoint)
- // At this point, the Stack's stats and the NetworkEndpoint's stats are
- // expected to be bound by a MultiCounterStat.
- refStack := s.Stats()
- refEP := ep.stats.localStats
- if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.ip).Elem(), []reflect.Value{reflect.ValueOf(&refEP.IP).Elem(), reflect.ValueOf(&refStack.IP).Elem()}); err != nil {
- t.Error(err)
- }
- if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.icmp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.ICMP).Elem(), reflect.ValueOf(&refStack.ICMP.V4).Elem()}); err != nil {
- t.Error(err)
- }
- if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.igmp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.IGMP).Elem(), reflect.ValueOf(&refStack.IGMP).Elem()}); err != nil {
- t.Error(err)
- }
-}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
deleted file mode 100644
index f99cbf8f3..000000000
--- a/pkg/tcpip/network/ipv6/BUILD
+++ /dev/null
@@ -1,72 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "ipv6",
- srcs = [
- "dhcpv6configurationfromndpra_string.go",
- "icmp.go",
- "ipv6.go",
- "mld.go",
- "ndp.go",
- "stats.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/header/parse",
- "//pkg/tcpip/network/hash",
- "//pkg/tcpip/network/internal/fragmentation",
- "//pkg/tcpip/network/internal/ip",
- "//pkg/tcpip/stack",
- ],
-)
-
-go_test(
- name = "ipv6_test",
- size = "small",
- srcs = [
- "icmp_test.go",
- "ipv6_test.go",
- "ndp_test.go",
- ],
- library = ":ipv6",
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/faketime",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/sniffer",
- "//pkg/tcpip/network/internal/testutil",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/icmp",
- "//pkg/tcpip/transport/tcp",
- "//pkg/tcpip/transport/udp",
- "//pkg/waiter",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
-
-go_test(
- name = "ipv6_x_test",
- size = "small",
- srcs = ["mld_test.go"],
- deps = [
- ":ipv6",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/faketime",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/testutil",
- ],
-)
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
deleted file mode 100644
index 3b4c235fa..000000000
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ /dev/null
@@ -1,1758 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ipv6
-
-import (
- "bytes"
- "net"
- "reflect"
- "strings"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- nicID = 1
-
- linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
-
- defaultChannelSize = 1
- defaultMTU = 65536
-
- arbitraryHopLimit = 42
-)
-
-var (
- lladdr0 = header.LinkLocalAddr(linkAddr0)
- lladdr1 = header.LinkLocalAddr(linkAddr1)
-)
-
-type stubLinkEndpoint struct {
- stack.LinkEndpoint
-}
-
-func (*stubLinkEndpoint) MTU() uint32 {
- return defaultMTU
-}
-
-func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
- // Indicate that resolution for link layer addresses is required to send
- // packets over this link. This is needed so the NIC knows to allocate a
- // neighbor table.
- return stack.CapabilityResolutionRequired
-}
-
-func (*stubLinkEndpoint) MaxHeaderLength() uint16 {
- return 0
-}
-
-func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress {
- return ""
-}
-
-func (*stubLinkEndpoint) WritePacket(stack.RouteInfo, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
- return nil
-}
-
-func (*stubLinkEndpoint) Attach(stack.NetworkDispatcher) {}
-
-type stubDispatcher struct {
- stack.TransportDispatcher
-}
-
-func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition {
- return stack.TransportPacketHandled
-}
-
-func (*stubDispatcher) DeliverRawPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) {
- // No-op.
-}
-
-var _ stack.NetworkInterface = (*testInterface)(nil)
-
-type testInterface struct {
- stack.LinkEndpoint
-
- probeCount int
- confirmationCount int
-
- nicID tcpip.NICID
-}
-
-func (*testInterface) ID() tcpip.NICID {
- return nicID
-}
-
-func (*testInterface) IsLoopback() bool {
- return false
-}
-
-func (*testInterface) Name() string {
- return ""
-}
-
-func (*testInterface) Enabled() bool {
- return true
-}
-
-func (*testInterface) Promiscuous() bool {
- return false
-}
-
-func (*testInterface) Spoofing() bool {
- return false
-}
-
-func (t *testInterface) WritePacket(r *stack.Route, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- return t.LinkEndpoint.WritePacket(r.Fields(), protocol, pkt)
-}
-
-func (t *testInterface) WritePackets(r *stack.Route, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- return t.LinkEndpoint.WritePackets(r.Fields(), pkts, protocol)
-}
-
-func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- var r stack.RouteInfo
- r.NetProto = protocol
- r.RemoteLinkAddress = remoteLinkAddr
- return t.LinkEndpoint.WritePacket(r, protocol, pkt)
-}
-
-func (t *testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error {
- t.probeCount++
- return nil
-}
-
-func (t *testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) tcpip.Error {
- t.confirmationCount++
- return nil
-}
-
-func (*testInterface) PrimaryAddress(tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, tcpip.Error) {
- return tcpip.AddressWithPrefix{}, nil
-}
-
-func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool {
- return false
-}
-
-func handleICMPInIPv6(ep stack.NetworkEndpoint, src, dst tcpip.Address, icmp header.ICMPv6, hopLimit uint8, includeRouterAlert bool) {
- var extensionHeaders header.IPv6ExtHdrSerializer
- if includeRouterAlert {
- extensionHeaders = header.IPv6ExtHdrSerializer{
- header.IPv6SerializableHopByHopExtHdr{
- &header.IPv6RouterAlertOption{Value: header.IPv6RouterAlertMLD},
- },
- }
- }
- ip := buffer.NewView(header.IPv6MinimumSize + extensionHeaders.Length())
- header.IPv6(ip).Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: hopLimit,
- SrcAddr: src,
- DstAddr: dst,
- ExtensionHeaders: extensionHeaders,
- })
-
- vv := ip.ToVectorisedView()
- vv.AppendView(buffer.View(icmp))
- ep.HandlePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- }))
-}
-
-func TestICMPCounts(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
- })
- if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_, _) = %s", err)
- }
- {
- subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet,
- NIC: nicID,
- }},
- )
- }
-
- netProto := s.NetworkProtocolInstance(ProtocolNumber)
- if netProto == nil {
- t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
- }
- ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{})
- defer ep.Close()
-
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
-
- addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
- if !ok {
- t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
- }
- addr := lladdr0.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
- } else {
- ep.DecRef()
- }
-
- var tllData [header.NDPLinkLayerAddressSize]byte
- header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(linkAddr1),
- })
-
- types := []struct {
- typ header.ICMPv6Type
- hopLimit uint8
- includeRouterAlert bool
- size int
- extraData []byte
- }{
- {
- typ: header.ICMPv6DstUnreachable,
- hopLimit: arbitraryHopLimit,
- size: header.ICMPv6DstUnreachableMinimumSize,
- },
- {
- typ: header.ICMPv6PacketTooBig,
- hopLimit: arbitraryHopLimit,
- size: header.ICMPv6PacketTooBigMinimumSize,
- },
- {
- typ: header.ICMPv6TimeExceeded,
- hopLimit: arbitraryHopLimit,
- size: header.ICMPv6MinimumSize,
- },
- {
- typ: header.ICMPv6ParamProblem,
- hopLimit: arbitraryHopLimit,
- size: header.ICMPv6MinimumSize,
- },
- {
- typ: header.ICMPv6EchoRequest,
- hopLimit: arbitraryHopLimit,
- size: header.ICMPv6EchoMinimumSize,
- },
- {
- typ: header.ICMPv6EchoReply,
- hopLimit: arbitraryHopLimit,
- size: header.ICMPv6EchoMinimumSize,
- },
- {
- typ: header.ICMPv6RouterSolicit,
- hopLimit: header.NDPHopLimit,
- size: header.ICMPv6MinimumSize,
- },
- {
- typ: header.ICMPv6RouterAdvert,
- hopLimit: header.NDPHopLimit,
- size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
- },
- {
- typ: header.ICMPv6NeighborSolicit,
- hopLimit: header.NDPHopLimit,
- size: header.ICMPv6NeighborSolicitMinimumSize,
- },
- {
- typ: header.ICMPv6NeighborAdvert,
- hopLimit: header.NDPHopLimit,
- size: header.ICMPv6NeighborAdvertMinimumSize,
- extraData: tllData[:],
- },
- {
- typ: header.ICMPv6RedirectMsg,
- hopLimit: header.NDPHopLimit,
- size: header.ICMPv6MinimumSize,
- },
- {
- typ: header.ICMPv6MulticastListenerQuery,
- hopLimit: header.MLDHopLimit,
- includeRouterAlert: true,
- size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
- },
- {
- typ: header.ICMPv6MulticastListenerReport,
- hopLimit: header.MLDHopLimit,
- includeRouterAlert: true,
- size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
- },
- {
- typ: header.ICMPv6MulticastListenerDone,
- hopLimit: header.MLDHopLimit,
- includeRouterAlert: true,
- size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
- },
- {
- typ: 255, /* Unrecognized */
- size: 50,
- },
- }
-
- for _, typ := range types {
- icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
- copy(icmp[typ.size:], typ.extraData)
- icmp.SetType(typ.typ)
- icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmp[:typ.size],
- Src: lladdr0,
- Dst: lladdr1,
- PayloadCsum: header.Checksum(typ.extraData, 0 /* initial */),
- PayloadLen: len(typ.extraData),
- }))
- handleICMPInIPv6(ep, lladdr1, lladdr0, icmp, typ.hopLimit, typ.includeRouterAlert)
- }
-
- // Construct an empty ICMP packet so that
- // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
- handleICMPInIPv6(ep, lladdr1, lladdr0, header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)), arbitraryHopLimit, false)
-
- icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived
- visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
- if got, want := s.Value(), uint64(1); got != want {
- t.Errorf("got %s = %d, want = %d", name, got, want)
- }
- })
- if t.Failed() {
- t.Logf("stats:\n%+v", s.Stats())
- }
-}
-
-func visitStats(v reflect.Value, f func(string, *tcpip.StatCounter)) {
- t := v.Type()
- for i := 0; i < v.NumField(); i++ {
- v := v.Field(i)
- if s, ok := v.Interface().(*tcpip.StatCounter); ok {
- f(t.Field(i).Name, s)
- } else {
- visitStats(v, f)
- }
- }
-}
-
-type testContext struct {
- s0 *stack.Stack
- s1 *stack.Stack
-
- linkEP0 *channel.Endpoint
- linkEP1 *channel.Endpoint
-
- clock *faketime.ManualClock
-}
-
-type endpointWithResolutionCapability struct {
- stack.LinkEndpoint
-}
-
-func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapabilities {
- return e.LinkEndpoint.Capabilities() | stack.CapabilityResolutionRequired
-}
-
-func newTestContext(t *testing.T) *testContext {
- clock := faketime.NewManualClock()
- c := &testContext{
- s0: stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
- Clock: clock,
- }),
- s1: stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
- Clock: clock,
- }),
- clock: clock,
- }
-
- c.linkEP0 = channel.New(defaultChannelSize, defaultMTU, linkAddr0)
-
- wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0})
- if testing.Verbose() {
- wrappedEP0 = sniffer.New(wrappedEP0)
- }
- if err := c.s0.CreateNIC(nicID, wrappedEP0); err != nil {
- t.Fatalf("CreateNIC s0: %v", err)
- }
- llProtocolAddr0 := tcpip.ProtocolAddress{
- Protocol: ProtocolNumber,
- AddressWithPrefix: lladdr0.WithPrefix(),
- }
- if err := c.s0.AddProtocolAddress(nicID, llProtocolAddr0, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, llProtocolAddr0, err)
- }
-
- c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1)
- wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1})
- if err := c.s1.CreateNIC(nicID, wrappedEP1); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
- llProtocolAddr1 := tcpip.ProtocolAddress{
- Protocol: ProtocolNumber,
- AddressWithPrefix: lladdr1.WithPrefix(),
- }
- if err := c.s1.AddProtocolAddress(nicID, llProtocolAddr1, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, llProtocolAddr1, err)
- }
-
- subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
- if err != nil {
- t.Fatal(err)
- }
- c.s0.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet0,
- NIC: nicID,
- }},
- )
- subnet1, err := tcpip.NewSubnet(lladdr0, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0))))
- if err != nil {
- t.Fatal(err)
- }
- c.s1.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet1,
- NIC: nicID,
- }},
- )
-
- t.Cleanup(func() {
- if err := c.s0.RemoveNIC(nicID); err != nil {
- t.Errorf("c.s0.RemoveNIC(%d): %s", nicID, err)
- }
- if err := c.s1.RemoveNIC(nicID); err != nil {
- t.Errorf("c.s1.RemoveNIC(%d): %s", nicID, err)
- }
-
- c.linkEP0.Close()
- c.linkEP1.Close()
- })
-
- return c
-}
-
-type routeArgs struct {
- src, dst *channel.Endpoint
- typ header.ICMPv6Type
- remoteLinkAddr tcpip.LinkAddress
-}
-
-func routeICMPv6Packet(t *testing.T, clock *faketime.ManualClock, args routeArgs, fn func(*testing.T, header.ICMPv6)) {
- t.Helper()
-
- clock.RunImmediatelyScheduledJobs()
- pi, ok := args.src.Read()
- if !ok {
- t.Fatal("packet didn't arrive")
- }
-
- {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buffer.NewVectorisedView(pi.Pkt.Size(), pi.Pkt.Views()),
- })
- args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), pkt)
- }
-
- if pi.Proto != ProtocolNumber {
- t.Errorf("unexpected protocol number %d", pi.Proto)
- return
- }
-
- if len(args.remoteLinkAddr) != 0 && pi.Route.RemoteLinkAddress != args.remoteLinkAddr {
- t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr)
- }
-
- // Pull the full payload since network header. Needed for header.IPv6 to
- // extract its payload.
- ipv6 := header.IPv6(stack.PayloadSince(pi.Pkt.NetworkHeader()))
- transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader())
- if transProto != header.ICMPv6ProtocolNumber {
- t.Errorf("unexpected transport protocol number %d", transProto)
- return
- }
- icmpv6 := header.ICMPv6(ipv6.Payload())
- if got, want := icmpv6.Type(), args.typ; got != want {
- t.Errorf("got ICMPv6 type = %d, want = %d", got, want)
- return
- }
- if fn != nil {
- fn(t, icmpv6)
- }
-}
-
-func TestLinkResolution(t *testing.T) {
- c := newTestContext(t)
-
- r, err := c.s0.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
- }
- defer r.Release()
-
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
- pkt.SetType(header.ICMPv6EchoRequest)
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: r.LocalAddress(),
- Dst: r.RemoteAddress(),
- }))
-
- // We can't send our payload directly over the route because that
- // doesn't provoke NDP discovery.
- var wq waiter.Queue
- ep, err := c.s0.NewEndpoint(header.ICMPv6ProtocolNumber, ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint(_) = (_, %s), want = (_, nil)", err)
- }
-
- {
- var r bytes.Reader
- r.Reset(hdr.View())
- if _, err := ep.Write(&r, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}); err != nil {
- t.Fatalf("ep.Write(_): %s", err)
- }
- }
- for _, args := range []routeArgs{
- {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))},
- {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert},
- } {
- routeICMPv6Packet(t, c.clock, args, func(t *testing.T, icmpv6 header.ICMPv6) {
- if got, want := tcpip.Address(icmpv6[8:][:16]), lladdr1; got != want {
- t.Errorf("%d: got target = %s, want = %s", icmpv6.Type(), got, want)
- }
- })
- }
-
- for _, args := range []routeArgs{
- {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6EchoRequest},
- {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6EchoReply},
- } {
- routeICMPv6Packet(t, c.clock, args, nil)
- }
-}
-
-func TestICMPChecksumValidationSimple(t *testing.T) {
- var tllData [header.NDPLinkLayerAddressSize]byte
- header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(linkAddr1),
- })
-
- types := []struct {
- name string
- typ header.ICMPv6Type
- size int
- extraData []byte
- statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
- routerOnly bool
- }{
- {
- name: "DstUnreachable",
- typ: header.ICMPv6DstUnreachable,
- size: header.ICMPv6DstUnreachableMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.DstUnreachable
- },
- },
- {
- name: "PacketTooBig",
- typ: header.ICMPv6PacketTooBig,
- size: header.ICMPv6PacketTooBigMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.PacketTooBig
- },
- },
- {
- name: "TimeExceeded",
- typ: header.ICMPv6TimeExceeded,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.TimeExceeded
- },
- },
- {
- name: "ParamProblem",
- typ: header.ICMPv6ParamProblem,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.ParamProblem
- },
- },
- {
- name: "EchoRequest",
- typ: header.ICMPv6EchoRequest,
- size: header.ICMPv6EchoMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.EchoRequest
- },
- },
- {
- name: "EchoReply",
- typ: header.ICMPv6EchoReply,
- size: header.ICMPv6EchoMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.EchoReply
- },
- },
- {
- name: "RouterSolicit",
- typ: header.ICMPv6RouterSolicit,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterSolicit
- },
- // Hosts MUST silently discard any received Router Solicitation messages.
- routerOnly: true,
- },
- {
- name: "RouterAdvert",
- typ: header.ICMPv6RouterAdvert,
- size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterAdvert
- },
- },
- {
- name: "NeighborSolicit",
- typ: header.ICMPv6NeighborSolicit,
- size: header.ICMPv6NeighborSolicitMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborSolicit
- },
- },
- {
- name: "NeighborAdvert",
- typ: header.ICMPv6NeighborAdvert,
- size: header.ICMPv6NeighborAdvertMinimumSize,
- extraData: tllData[:],
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborAdvert
- },
- },
- {
- name: "RedirectMsg",
- typ: header.ICMPv6RedirectMsg,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RedirectMsg
- },
- },
- }
-
- for _, typ := range types {
- for _, isRouter := range []bool{false, true} {
- name := typ.name
- if isRouter {
- name += " (Router)"
- }
- t.Run(name, func(t *testing.T) {
- e := channel.New(0, 1280, linkAddr0)
-
- // Indicate that resolution for link layer addresses is required to
- // send packets over this link. This is needed so the NIC knows to
- // allocate a neighbor table.
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- if isRouter {
- if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil {
- t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err)
- }
- }
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(_, _) = %s", err)
- }
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ProtocolNumber,
- AddressWithPrefix: lladdr0.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- {
- subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet,
- NIC: nicID,
- }},
- )
- }
-
- handleIPv6Payload := func(checksum bool) {
- icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
- copy(icmp[typ.size:], typ.extraData)
- icmp.SetType(typ.typ)
- if checksum {
- icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmp,
- Src: lladdr1,
- Dst: lladdr0,
- }))
- }
- ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}),
- })
- e.InjectInbound(ProtocolNumber, pkt)
- }
-
- stats := s.Stats().ICMP.V6.PacketsReceived
- invalid := stats.Invalid
- routerOnly := stats.RouterOnlyPacketsDroppedByHost
- typStat := typ.statCounter(stats)
-
- // Initial stat counts should be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- if got := routerOnly.Value(); got != 0 {
- t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
- }
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
- }
-
- // Without setting checksum, the incoming packet should
- // be invalid.
- handleIPv6Payload(false)
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- // Router only count should not have increased.
- if got := routerOnly.Value(); got != 0 {
- t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
- }
- // Rx count of type typ.typ should not have increased.
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
- }
-
- // When checksum is set, it should be received.
- handleIPv6Payload(true)
- if got := typStat.Value(); got != 1 {
- t.Fatalf("got %s = %d, want = 1", typ.name, got)
- }
- // Invalid count should not have increased again.
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- if !isRouter && typ.routerOnly {
- // Router only count should have increased.
- if got := routerOnly.Value(); got != 1 {
- t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 1", got)
- }
- }
- })
- }
- }
-}
-
-func TestICMPChecksumValidationWithPayload(t *testing.T) {
- const simpleBodySize = 64
- simpleBody := func(view buffer.View) {
- for i := 0; i < simpleBodySize; i++ {
- view[i] = uint8(i)
- }
- }
-
- const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize
- errorICMPBody := func(view buffer.View) {
- ip := header.IPv6(view)
- ip.Encode(&header.IPv6Fields{
- PayloadLength: simpleBodySize,
- TransportProtocol: 10,
- HopLimit: 20,
- SrcAddr: lladdr0,
- DstAddr: lladdr1,
- })
- simpleBody(view[header.IPv6MinimumSize:])
- }
-
- types := []struct {
- name string
- typ header.ICMPv6Type
- size int
- statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
- payloadSize int
- payload func(buffer.View)
- }{
- {
- "DstUnreachable",
- header.ICMPv6DstUnreachable,
- header.ICMPv6DstUnreachableMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.DstUnreachable
- },
- errorICMPBodySize,
- errorICMPBody,
- },
- {
- "PacketTooBig",
- header.ICMPv6PacketTooBig,
- header.ICMPv6PacketTooBigMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.PacketTooBig
- },
- errorICMPBodySize,
- errorICMPBody,
- },
- {
- "TimeExceeded",
- header.ICMPv6TimeExceeded,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.TimeExceeded
- },
- errorICMPBodySize,
- errorICMPBody,
- },
- {
- "ParamProblem",
- header.ICMPv6ParamProblem,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.ParamProblem
- },
- errorICMPBodySize,
- errorICMPBody,
- },
- {
- "EchoRequest",
- header.ICMPv6EchoRequest,
- header.ICMPv6EchoMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.EchoRequest
- },
- simpleBodySize,
- simpleBody,
- },
- {
- "EchoReply",
- header.ICMPv6EchoReply,
- header.ICMPv6EchoMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.EchoReply
- },
- simpleBodySize,
- simpleBody,
- },
- }
-
- for _, typ := range types {
- t.Run(typ.name, func(t *testing.T) {
- e := channel.New(10, 1280, linkAddr0)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(_, _) = %s", err)
- }
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ProtocolNumber,
- AddressWithPrefix: lladdr0.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- {
- subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet,
- NIC: nicID,
- }},
- )
- }
-
- handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) {
- icmpSize := size + payloadSize
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
- icmpHdr := header.ICMPv6(hdr.Prepend(icmpSize))
- icmpHdr.SetType(typ)
- payloadFn(icmpHdr.Payload())
-
- if checksum {
- icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmpHdr,
- Src: lladdr1,
- Dst: lladdr0,
- }))
- }
-
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(icmpSize),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- })
- e.InjectInbound(ProtocolNumber, pkt)
- }
-
- stats := s.Stats().ICMP.V6.PacketsReceived
- invalid := stats.Invalid
- typStat := typ.statCounter(stats)
-
- // Initial stat counts should be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got = %d, want = 0", got)
- }
-
- // Without setting checksum, the incoming packet should
- // be invalid.
- handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false)
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- // Rx count of type typ.typ should not have increased.
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got = %d, want = 0", got)
- }
-
- // When checksum is set, it should be received.
- handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true)
- if got := typStat.Value(); got != 1 {
- t.Fatalf("got = %d, want = 0", got)
- }
- // Invalid count should not have increased again.
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- })
- }
-}
-
-func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
- const simpleBodySize = 64
- simpleBody := func(view buffer.View) {
- for i := 0; i < simpleBodySize; i++ {
- view[i] = uint8(i)
- }
- }
-
- const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize
- errorICMPBody := func(view buffer.View) {
- ip := header.IPv6(view)
- ip.Encode(&header.IPv6Fields{
- PayloadLength: simpleBodySize,
- TransportProtocol: 10,
- HopLimit: 20,
- SrcAddr: lladdr0,
- DstAddr: lladdr1,
- })
- simpleBody(view[header.IPv6MinimumSize:])
- }
-
- types := []struct {
- name string
- typ header.ICMPv6Type
- size int
- statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
- payloadSize int
- payload func(buffer.View)
- }{
- {
- "DstUnreachable",
- header.ICMPv6DstUnreachable,
- header.ICMPv6DstUnreachableMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.DstUnreachable
- },
- errorICMPBodySize,
- errorICMPBody,
- },
- {
- "PacketTooBig",
- header.ICMPv6PacketTooBig,
- header.ICMPv6PacketTooBigMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.PacketTooBig
- },
- errorICMPBodySize,
- errorICMPBody,
- },
- {
- "TimeExceeded",
- header.ICMPv6TimeExceeded,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.TimeExceeded
- },
- errorICMPBodySize,
- errorICMPBody,
- },
- {
- "ParamProblem",
- header.ICMPv6ParamProblem,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.ParamProblem
- },
- errorICMPBodySize,
- errorICMPBody,
- },
- {
- "EchoRequest",
- header.ICMPv6EchoRequest,
- header.ICMPv6EchoMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.EchoRequest
- },
- simpleBodySize,
- simpleBody,
- },
- {
- "EchoReply",
- header.ICMPv6EchoReply,
- header.ICMPv6EchoMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.EchoReply
- },
- simpleBodySize,
- simpleBody,
- },
- }
-
- for _, typ := range types {
- t.Run(typ.name, func(t *testing.T) {
- e := channel.New(10, 1280, linkAddr0)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ProtocolNumber,
- AddressWithPrefix: lladdr0.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- {
- subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet,
- NIC: nicID,
- }},
- )
- }
-
- handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + size)
- icmpHdr := header.ICMPv6(hdr.Prepend(size))
- icmpHdr.SetType(typ)
-
- payload := buffer.NewView(payloadSize)
- payloadFn(payload)
-
- if checksum {
- icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmpHdr,
- Src: lladdr1,
- Dst: lladdr0,
- PayloadCsum: header.Checksum(payload, 0 /* initial */),
- PayloadLen: len(payload),
- }))
- }
-
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(size + payloadSize),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}),
- })
- e.InjectInbound(ProtocolNumber, pkt)
- }
-
- stats := s.Stats().ICMP.V6.PacketsReceived
- invalid := stats.Invalid
- typStat := typ.statCounter(stats)
-
- // Initial stat counts should be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got = %d, want = 0", got)
- }
-
- // Without setting checksum, the incoming packet should
- // be invalid.
- handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false)
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- // Rx count of type typ.typ should not have increased.
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got = %d, want = 0", got)
- }
-
- // When checksum is set, it should be received.
- handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true)
- if got := typStat.Value(); got != 1 {
- t.Fatalf("got = %d, want = 0", got)
- }
- // Invalid count should not have increased again.
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- })
- }
-}
-
-func TestLinkAddressRequest(t *testing.T) {
- const nicID = 1
-
- snaddr := header.SolicitedNodeAddr(lladdr0)
- mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr)
-
- tests := []struct {
- name string
- nicAddr tcpip.Address
- localAddr tcpip.Address
- remoteLinkAddr tcpip.LinkAddress
-
- expectedErr tcpip.Error
- expectedRemoteAddr tcpip.Address
- expectedRemoteLinkAddr tcpip.LinkAddress
- }{
- {
- name: "Unicast",
- nicAddr: lladdr1,
- localAddr: lladdr1,
- remoteLinkAddr: linkAddr1,
- expectedRemoteAddr: lladdr0,
- expectedRemoteLinkAddr: linkAddr1,
- },
- {
- name: "Multicast",
- nicAddr: lladdr1,
- localAddr: lladdr1,
- remoteLinkAddr: "",
- expectedRemoteAddr: snaddr,
- expectedRemoteLinkAddr: mcaddr,
- },
- {
- name: "Unicast with unspecified source",
- nicAddr: lladdr1,
- remoteLinkAddr: linkAddr1,
- expectedRemoteAddr: lladdr0,
- expectedRemoteLinkAddr: linkAddr1,
- },
- {
- name: "Multicast with unspecified source",
- nicAddr: lladdr1,
- remoteLinkAddr: "",
- expectedRemoteAddr: snaddr,
- expectedRemoteLinkAddr: mcaddr,
- },
- {
- name: "Unicast with unassigned address",
- localAddr: lladdr1,
- remoteLinkAddr: linkAddr1,
- expectedErr: &tcpip.ErrBadLocalAddress{},
- },
- {
- name: "Multicast with unassigned address",
- localAddr: lladdr1,
- remoteLinkAddr: "",
- expectedErr: &tcpip.ErrBadLocalAddress{},
- },
- {
- name: "Unicast with no local address available",
- remoteLinkAddr: linkAddr1,
- expectedErr: &tcpip.ErrNetworkUnreachable{},
- },
- {
- name: "Multicast with no local address available",
- remoteLinkAddr: "",
- expectedErr: &tcpip.ErrNetworkUnreachable{},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
-
- linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0)
- if err := s.CreateNIC(nicID, linkEP); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
-
- ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber)
- if err != nil {
- t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ProtocolNumber, err)
- }
- linkRes, ok := ep.(stack.LinkAddressResolver)
- if !ok {
- t.Fatalf("expected %T to implement stack.LinkAddressResolver", ep)
- }
-
- if len(test.nicAddr) != 0 {
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ProtocolNumber,
- AddressWithPrefix: test.nicAddr.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- }
-
- {
- err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr)
- if diff := cmp.Diff(test.expectedErr, err); diff != "" {
- t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", lladdr0, test.localAddr, test.remoteLinkAddr, diff)
- }
- }
-
- if test.expectedErr != nil {
- return
- }
-
- pkt, ok := linkEP.Read()
- if !ok {
- t.Fatal("expected to send a link address request")
- }
-
- var want stack.RouteInfo
- want.NetProto = ProtocolNumber
- want.RemoteLinkAddress = test.expectedRemoteLinkAddr
- if diff := cmp.Diff(want, pkt.Route, cmp.AllowUnexported(want)); diff != "" {
- t.Errorf("route info mismatch (-want +got):\n%s", diff)
- }
- checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()),
- checker.SrcAddr(lladdr1),
- checker.DstAddr(test.expectedRemoteAddr),
- checker.TTL(header.NDPHopLimit),
- checker.NDPNS(
- checker.NDPNSTargetAddress(lladdr0),
- checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(linkAddr0)}),
- ))
- })
- }
-}
-
-func TestPacketQueing(t *testing.T) {
- const nicID = 1
-
- var (
- host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
- host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
-
- host1IPv6Addr = tcpip.ProtocolAddress{
- Protocol: ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("a::1").To16()),
- PrefixLen: 64,
- },
- }
- host2IPv6Addr = tcpip.ProtocolAddress{
- Protocol: ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("a::2").To16()),
- PrefixLen: 64,
- },
- }
- )
-
- tests := []struct {
- name string
- rxPkt func(*channel.Endpoint)
- checkResp func(*testing.T, *channel.Endpoint)
- }{
- {
- name: "ICMP Error",
- rxPkt: func(e *channel.Endpoint) {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize)
- u := header.UDP(hdr.Prepend(header.UDPMinimumSize))
- u.Encode(&header.UDPFields{
- SrcPort: 5555,
- DstPort: 80,
- Length: header.UDPMinimumSize,
- })
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize)
- sum = header.Checksum(nil, sum)
- u.SetChecksum(^u.CalculateChecksum(sum))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: udp.ProtocolNumber,
- HopLimit: DefaultTTL,
- SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
- DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
- })
- e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- },
- checkResp: func(t *testing.T, e *channel.Endpoint) {
- p, ok := e.Read()
- if !ok {
- t.Fatalf("timed out waiting for packet")
- }
- if p.Proto != ProtocolNumber {
- t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber)
- }
- if p.Route.RemoteLinkAddress != host2NICLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
- }
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
- checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address),
- checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6DstUnreachable),
- checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
- },
- },
-
- {
- name: "Ping",
- rxPkt: func(e *channel.Endpoint) {
- totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
- hdr := buffer.NewPrependable(totalLen)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
- pkt.SetType(header.ICMPv6EchoRequest)
- pkt.SetCode(0)
- pkt.SetChecksum(0)
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: host2IPv6Addr.AddressWithPrefix.Address,
- Dst: host1IPv6Addr.AddressWithPrefix.Address,
- }))
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- TransportProtocol: icmp.ProtocolNumber6,
- HopLimit: DefaultTTL,
- SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
- DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
- })
- e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- },
- checkResp: func(t *testing.T, e *channel.Endpoint) {
- p, ok := e.Read()
- if !ok {
- t.Fatalf("timed out waiting for packet")
- }
- if p.Proto != ProtocolNumber {
- t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber)
- }
- if p.Route.RemoteLinkAddress != host2NICLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
- }
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
- checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address),
- checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6EchoReply),
- checker.ICMPv6Code(header.ICMPv6UnusedCode)))
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
-
- e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- Clock: clock,
- })
-
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
- if err := s.AddProtocolAddress(nicID, host1IPv6Addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv6Addr, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
- NIC: nicID,
- },
- })
-
- // Receive a packet to trigger link resolution before a response is sent.
- test.rxPkt(e)
-
- // Wait for a neighbor solicitation since link address resolution should
- // be performed.
- {
- clock.RunImmediatelyScheduledJobs()
- p, ok := e.Read()
- if !ok {
- t.Fatalf("timed out waiting for packet")
- }
- if p.Proto != ProtocolNumber {
- t.Errorf("got Proto = %d, want = %d", p.Proto, ProtocolNumber)
- }
- snmc := header.SolicitedNodeAddr(host2IPv6Addr.AddressWithPrefix.Address)
- if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want)
- }
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
- checker.DstAddr(snmc),
- checker.TTL(header.NDPHopLimit),
- checker.NDPNS(
- checker.NDPNSTargetAddress(host2IPv6Addr.AddressWithPrefix.Address),
- checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(host1NICLinkAddr)}),
- ))
- }
-
- // Send a neighbor advertisement to complete link address resolution.
- {
- naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
- pkt := header.ICMPv6(hdr.Prepend(naSize))
- pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.MessageBody())
- na.SetSolicitedFlag(true)
- na.SetOverrideFlag(true)
- na.SetTargetAddress(host2IPv6Addr.AddressWithPrefix.Address)
- na.Options().Serialize(header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(host2NICLinkAddr),
- })
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: host2IPv6Addr.AddressWithPrefix.Address,
- Dst: host1IPv6Addr.AddressWithPrefix.Address,
- }))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: icmp.ProtocolNumber6,
- HopLimit: header.NDPHopLimit,
- SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
- DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
- })
- e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- }
-
- // Expect the response now that the link address has resolved.
- clock.RunImmediatelyScheduledJobs()
- test.checkResp(t, e)
-
- // Since link resolution was already performed, it shouldn't be performed
- // again.
- test.rxPkt(e)
- test.checkResp(t, e)
- })
- }
-}
-
-func TestCallsToNeighborCache(t *testing.T) {
- tests := []struct {
- name string
- createPacket func() header.ICMPv6
- multicast bool
- source tcpip.Address
- destination tcpip.Address
- wantProbeCount int
- wantConfirmationCount int
- }{
- {
- name: "Unicast Neighbor Solicitation without source link-layer address option",
- createPacket: func() header.ICMPv6 {
- nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
- icmp := header.ICMPv6(buffer.NewView(nsSize))
- icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.MessageBody())
- ns.SetTargetAddress(lladdr0)
- return icmp
- },
- source: lladdr1,
- destination: lladdr0,
- // "The source link-layer address option SHOULD be included in unicast
- // solicitations." - RFC 4861 section 4.3
- //
- // A Neighbor Advertisement needs to be sent in response, but the
- // Neighbor Cache shouldn't be updated since we have no useful
- // information about the sender.
- wantProbeCount: 0,
- },
- {
- name: "Unicast Neighbor Solicitation with source link-layer address option",
- createPacket: func() header.ICMPv6 {
- nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
- icmp := header.ICMPv6(buffer.NewView(nsSize))
- icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.MessageBody())
- ns.SetTargetAddress(lladdr0)
- ns.Options().Serialize(header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(linkAddr1),
- })
- return icmp
- },
- source: lladdr1,
- destination: lladdr0,
- wantProbeCount: 1,
- },
- {
- name: "Multicast Neighbor Solicitation without source link-layer address option",
- createPacket: func() header.ICMPv6 {
- nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
- icmp := header.ICMPv6(buffer.NewView(nsSize))
- icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.MessageBody())
- ns.SetTargetAddress(lladdr0)
- return icmp
- },
- source: lladdr1,
- destination: header.SolicitedNodeAddr(lladdr0),
- // "The source link-layer address option MUST be included in multicast
- // solicitations." - RFC 4861 section 4.3
- wantProbeCount: 0,
- },
- {
- name: "Multicast Neighbor Solicitation with source link-layer address option",
- createPacket: func() header.ICMPv6 {
- nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
- icmp := header.ICMPv6(buffer.NewView(nsSize))
- icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.MessageBody())
- ns.SetTargetAddress(lladdr0)
- ns.Options().Serialize(header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(linkAddr1),
- })
- return icmp
- },
- source: lladdr1,
- destination: header.SolicitedNodeAddr(lladdr0),
- wantProbeCount: 1,
- },
- {
- name: "Unicast Neighbor Advertisement without target link-layer address option",
- createPacket: func() header.ICMPv6 {
- naSize := header.ICMPv6NeighborAdvertMinimumSize
- icmp := header.ICMPv6(buffer.NewView(naSize))
- icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.MessageBody())
- na.SetSolicitedFlag(true)
- na.SetOverrideFlag(false)
- na.SetTargetAddress(lladdr1)
- return icmp
- },
- source: lladdr1,
- destination: lladdr0,
- // "When responding to unicast solicitations, the target link-layer
- // address option can be omitted since the sender of the solicitation has
- // the correct link-layer address; otherwise, it would not be able to
- // send the unicast solicitation in the first place."
- // - RFC 4861 section 4.4
- wantConfirmationCount: 1,
- },
- {
- name: "Unicast Neighbor Advertisement with target link-layer address option",
- createPacket: func() header.ICMPv6 {
- naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
- icmp := header.ICMPv6(buffer.NewView(naSize))
- icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.MessageBody())
- na.SetSolicitedFlag(true)
- na.SetOverrideFlag(false)
- na.SetTargetAddress(lladdr1)
- na.Options().Serialize(header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(linkAddr1),
- })
- return icmp
- },
- source: lladdr1,
- destination: lladdr0,
- wantConfirmationCount: 1,
- },
- {
- name: "Multicast Neighbor Advertisement without target link-layer address option",
- createPacket: func() header.ICMPv6 {
- naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
- icmp := header.ICMPv6(buffer.NewView(naSize))
- icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.MessageBody())
- na.SetSolicitedFlag(false)
- na.SetOverrideFlag(false)
- na.SetTargetAddress(lladdr1)
- return icmp
- },
- source: lladdr1,
- destination: header.IPv6AllNodesMulticastAddress,
- // "Target link-layer address MUST be included for multicast solicitations
- // in order to avoid infinite Neighbor Solicitation "recursion" when the
- // peer node does not have a cache entry to return a Neighbor
- // Advertisements message." - RFC 4861 section 4.4
- wantConfirmationCount: 0,
- },
- {
- name: "Multicast Neighbor Advertisement with target link-layer address option",
- createPacket: func() header.ICMPv6 {
- naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
- icmp := header.ICMPv6(buffer.NewView(naSize))
- icmp.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(icmp.MessageBody())
- na.SetSolicitedFlag(false)
- na.SetOverrideFlag(false)
- na.SetTargetAddress(lladdr1)
- na.Options().Serialize(header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(linkAddr1),
- })
- return icmp
- },
- source: lladdr1,
- destination: header.IPv6AllNodesMulticastAddress,
- wantConfirmationCount: 1,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
- })
- {
- if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_, _) = %s", err)
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ProtocolNumber,
- AddressWithPrefix: lladdr0.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- }
- {
- subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet,
- NIC: nicID,
- }},
- )
- }
-
- netProto := s.NetworkProtocolInstance(ProtocolNumber)
- if netProto == nil {
- t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
- }
-
- testInterface := testInterface{LinkEndpoint: channel.New(0, header.IPv6MinimumMTU, linkAddr0)}
- ep := netProto.NewEndpoint(&testInterface, &stubDispatcher{})
- defer ep.Close()
-
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
-
- addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
- if !ok {
- t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
- }
- addr := lladdr0.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
- } else {
- ep.DecRef()
- }
-
- icmp := test.createPacket()
- icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmp,
- Src: test.source,
- Dst: test.destination,
- }))
- handleICMPInIPv6(ep, test.source, test.destination, icmp, header.NDPHopLimit, false)
-
- // Confirm the endpoint calls the correct NUDHandler method.
- if testInterface.probeCount != test.wantProbeCount {
- t.Errorf("got testInterface.probeCount = %d, want = %d", testInterface.probeCount, test.wantProbeCount)
- }
- if testInterface.confirmationCount != test.wantConfirmationCount {
- t.Errorf("got testInterface.confirmationCount = %d, want = %d", testInterface.confirmationCount, test.wantConfirmationCount)
- }
- })
- }
-}
diff --git a/pkg/tcpip/network/ipv6/ipv6_state_autogen.go b/pkg/tcpip/network/ipv6/ipv6_state_autogen.go
new file mode 100644
index 000000000..13d427822
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/ipv6_state_autogen.go
@@ -0,0 +1,136 @@
+// automatically generated by stateify.
+
+package ipv6
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (i *icmpv6DestinationUnreachableSockError) StateTypeName() string {
+ return "pkg/tcpip/network/ipv6.icmpv6DestinationUnreachableSockError"
+}
+
+func (i *icmpv6DestinationUnreachableSockError) StateFields() []string {
+ return []string{}
+}
+
+func (i *icmpv6DestinationUnreachableSockError) beforeSave() {}
+
+// +checklocksignore
+func (i *icmpv6DestinationUnreachableSockError) StateSave(stateSinkObject state.Sink) {
+ i.beforeSave()
+}
+
+func (i *icmpv6DestinationUnreachableSockError) afterLoad() {}
+
+// +checklocksignore
+func (i *icmpv6DestinationUnreachableSockError) StateLoad(stateSourceObject state.Source) {
+}
+
+func (i *icmpv6DestinationNetworkUnreachableSockError) StateTypeName() string {
+ return "pkg/tcpip/network/ipv6.icmpv6DestinationNetworkUnreachableSockError"
+}
+
+func (i *icmpv6DestinationNetworkUnreachableSockError) StateFields() []string {
+ return []string{
+ "icmpv6DestinationUnreachableSockError",
+ }
+}
+
+func (i *icmpv6DestinationNetworkUnreachableSockError) beforeSave() {}
+
+// +checklocksignore
+func (i *icmpv6DestinationNetworkUnreachableSockError) StateSave(stateSinkObject state.Sink) {
+ i.beforeSave()
+ stateSinkObject.Save(0, &i.icmpv6DestinationUnreachableSockError)
+}
+
+func (i *icmpv6DestinationNetworkUnreachableSockError) afterLoad() {}
+
+// +checklocksignore
+func (i *icmpv6DestinationNetworkUnreachableSockError) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &i.icmpv6DestinationUnreachableSockError)
+}
+
+func (i *icmpv6DestinationPortUnreachableSockError) StateTypeName() string {
+ return "pkg/tcpip/network/ipv6.icmpv6DestinationPortUnreachableSockError"
+}
+
+func (i *icmpv6DestinationPortUnreachableSockError) StateFields() []string {
+ return []string{
+ "icmpv6DestinationUnreachableSockError",
+ }
+}
+
+func (i *icmpv6DestinationPortUnreachableSockError) beforeSave() {}
+
+// +checklocksignore
+func (i *icmpv6DestinationPortUnreachableSockError) StateSave(stateSinkObject state.Sink) {
+ i.beforeSave()
+ stateSinkObject.Save(0, &i.icmpv6DestinationUnreachableSockError)
+}
+
+func (i *icmpv6DestinationPortUnreachableSockError) afterLoad() {}
+
+// +checklocksignore
+func (i *icmpv6DestinationPortUnreachableSockError) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &i.icmpv6DestinationUnreachableSockError)
+}
+
+func (i *icmpv6DestinationAddressUnreachableSockError) StateTypeName() string {
+ return "pkg/tcpip/network/ipv6.icmpv6DestinationAddressUnreachableSockError"
+}
+
+func (i *icmpv6DestinationAddressUnreachableSockError) StateFields() []string {
+ return []string{
+ "icmpv6DestinationUnreachableSockError",
+ }
+}
+
+func (i *icmpv6DestinationAddressUnreachableSockError) beforeSave() {}
+
+// +checklocksignore
+func (i *icmpv6DestinationAddressUnreachableSockError) StateSave(stateSinkObject state.Sink) {
+ i.beforeSave()
+ stateSinkObject.Save(0, &i.icmpv6DestinationUnreachableSockError)
+}
+
+func (i *icmpv6DestinationAddressUnreachableSockError) afterLoad() {}
+
+// +checklocksignore
+func (i *icmpv6DestinationAddressUnreachableSockError) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &i.icmpv6DestinationUnreachableSockError)
+}
+
+func (e *icmpv6PacketTooBigSockError) StateTypeName() string {
+ return "pkg/tcpip/network/ipv6.icmpv6PacketTooBigSockError"
+}
+
+func (e *icmpv6PacketTooBigSockError) StateFields() []string {
+ return []string{
+ "mtu",
+ }
+}
+
+func (e *icmpv6PacketTooBigSockError) beforeSave() {}
+
+// +checklocksignore
+func (e *icmpv6PacketTooBigSockError) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.mtu)
+}
+
+func (e *icmpv6PacketTooBigSockError) afterLoad() {}
+
+// +checklocksignore
+func (e *icmpv6PacketTooBigSockError) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.mtu)
+}
+
+func init() {
+ state.Register((*icmpv6DestinationUnreachableSockError)(nil))
+ state.Register((*icmpv6DestinationNetworkUnreachableSockError)(nil))
+ state.Register((*icmpv6DestinationPortUnreachableSockError)(nil))
+ state.Register((*icmpv6DestinationAddressUnreachableSockError)(nil))
+ state.Register((*icmpv6PacketTooBigSockError)(nil))
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
deleted file mode 100644
index 0735ebb23..000000000
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ /dev/null
@@ -1,3523 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ipv6
-
-import (
- "bytes"
- "encoding/hex"
- "fmt"
- "io/ioutil"
- "math"
- "net"
- "reflect"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- // The least significant 3 bytes are the same as addr2 so both addr2 and
- // addr3 will have the same solicited-node address.
- addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02")
- addr4 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03")
-
- // Tests use the extension header identifier values as uint8 instead of
- // header.IPv6ExtensionHeaderIdentifier.
- hopByHopExtHdrID = uint8(header.IPv6HopByHopOptionsExtHdrIdentifier)
- routingExtHdrID = uint8(header.IPv6RoutingExtHdrIdentifier)
- fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier)
- destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier)
- noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier)
- unknownHdrID = uint8(header.IPv6UnknownExtHdrIdentifier)
-
- extraHeaderReserve = 50
-)
-
-// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the
-// expected Neighbor Advertisement received count after receiving the packet.
-func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) {
- t.Helper()
-
- // Receive ICMP packet.
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertMinimumSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertMinimumSize))
- pkt.SetType(header.ICMPv6NeighborAdvert)
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: src,
- Dst: dst,
- }))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: 255,
- SrcAddr: src,
- DstAddr: dst,
- })
-
- e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
-
- stats := s.Stats().ICMP.V6.PacketsReceived
-
- if got := stats.NeighborAdvert.Value(); got != want {
- t.Fatalf("got NeighborAdvert = %d, want = %d", got, want)
- }
-}
-
-// testReceiveUDP tests receiving a UDP packet from src to dst. want is the
-// expected UDP received count after receiving the packet.
-func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) {
- t.Helper()
-
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
- defer close(ch)
-
- ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
- defer ep.Close()
-
- if err := ep.Bind(tcpip.FullAddress{Addr: dst, Port: 80}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
- }
-
- // Receive UDP Packet.
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize)
- u := header.UDP(hdr.Prepend(header.UDPMinimumSize))
- u.Encode(&header.UDPFields{
- SrcPort: 5555,
- DstPort: 80,
- Length: header.UDPMinimumSize,
- })
-
- // UDP pseudo-header checksum.
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize)
-
- // UDP checksum
- sum = header.Checksum(nil, sum)
- u.SetChecksum(^u.CalculateChecksum(sum))
-
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: udp.ProtocolNumber,
- HopLimit: 255,
- SrcAddr: src,
- DstAddr: dst,
- })
-
- e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
-
- stat := s.Stats().UDP.PacketsReceived
-
- if got := stat.Value(); got != want {
- t.Fatalf("got UDPPacketsReceived = %d, want = %d", got, want)
- }
-}
-
-func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error {
- // sourcePacket does not have its IP Header populated. Let's copy the one
- // from the first fragment.
- source := header.IPv6(packets[0].NetworkHeader().View())
- sourceIPHeadersLen := len(source)
- vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views())
- source = append(source, vv.ToView()...)
-
- var reassembledPayload buffer.VectorisedView
- for i, fragment := range packets {
- // Confirm that the packet is valid.
- allBytes := buffer.NewVectorisedView(fragment.Size(), fragment.Views())
- fragmentIPHeaders := header.IPv6(allBytes.ToView())
- if !fragmentIPHeaders.IsValid(len(fragmentIPHeaders)) {
- return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeaders))
- }
-
- fragmentIPHeadersLength := fragment.NetworkHeader().View().Size()
- if fragmentIPHeadersLength != sourceIPHeadersLen {
- return fmt.Errorf("fragment #%d: got fragmentIPHeadersLength = %d, want = %d", i, fragmentIPHeadersLength, sourceIPHeadersLen)
- }
-
- if got := len(fragmentIPHeaders); got > int(mtu) {
- return fmt.Errorf("fragment #%d: got len(fragmentIPHeaders) = %d, want <= %d", i, got, mtu)
- }
-
- sourceIPHeader := source[:header.IPv6MinimumSize]
- fragmentIPHeader := fragmentIPHeaders[:header.IPv6MinimumSize]
-
- if got := fragmentIPHeaders.PayloadLength(); got != wantFragments[i].payloadSize {
- return fmt.Errorf("fragment #%d: got fragmentIPHeaders.PayloadLength() = %d, want = %d", i, got, wantFragments[i].payloadSize)
- }
-
- // We expect the IPv6 Header to be similar across each fragment, besides the
- // payload length.
- sourceIPHeader.SetPayloadLength(0)
- fragmentIPHeader.SetPayloadLength(0)
- if diff := cmp.Diff(fragmentIPHeader, sourceIPHeader); diff != "" {
- return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff)
- }
-
- if got := fragment.AvailableHeaderBytes(); got != extraHeaderReserve {
- return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve)
- }
- if fragment.NetworkProtocolNumber != sourcePacket.NetworkProtocolNumber {
- return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, fragment.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber)
- }
-
- if len(packets) > 1 {
- // If the source packet was big enough that it needed fragmentation, let's
- // inspect the fragment header. Because no other extension headers are
- // supported, it will always be the last extension header.
- fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[fragmentIPHeadersLength-header.IPv6FragmentHeaderSize : fragmentIPHeadersLength])
-
- if got := fragmentHeader.More(); got != wantFragments[i].more {
- return fmt.Errorf("fragment #%d: got fragmentHeader.More() = %t, want = %t", i, got, wantFragments[i].more)
- }
- if got := fragmentHeader.FragmentOffset(); got != wantFragments[i].offset {
- return fmt.Errorf("fragment #%d: got fragmentHeader.FragmentOffset() = %d, want = %d", i, got, wantFragments[i].offset)
- }
- if got := fragmentHeader.NextHeader(); got != uint8(proto) {
- return fmt.Errorf("fragment #%d: got fragmentHeader.NextHeader() = %d, want = %d", i, got, uint8(proto))
- }
- }
-
- // Store the reassembled payload as we parse each fragment. The payload
- // includes the Transport header and everything after.
- reassembledPayload.AppendView(fragment.TransportHeader().View())
- reassembledPayload.AppendView(fragment.Data().AsRange().ToOwnedView())
- }
-
- if diff := cmp.Diff(buffer.View(source[sourceIPHeadersLen:]), reassembledPayload.ToView()); diff != "" {
- return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
- }
-
- return nil
-}
-
-// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and
-// UDP packets destined to the IPv6 link-local all-nodes multicast address.
-func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
- tests := []struct {
- name string
- protocolFactory stack.TransportProtocolFactory
- rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
- }{
- {"ICMP", icmp.NewProtocol6, testReceiveICMP},
- {"UDP", udp.NewProtocol, testReceiveUDP},
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory},
- })
- e := channel.New(10, header.IPv6MinimumMTU, linkAddr1)
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
-
- // Should receive a packet destined to the all-nodes
- // multicast address.
- test.rxf(t, s, e, addr1, header.IPv6AllNodesMulticastAddress, 1)
- })
- }
-}
-
-// TestReceiveOnSolicitedNodeAddr tests that IPv6 endpoints receive ICMP and UDP
-// packets destined to the IPv6 solicited-node address of an assigned IPv6
-// address.
-func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
- tests := []struct {
- name string
- protocolFactory stack.TransportProtocolFactory
- rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
- }{
- {"ICMP", icmp.NewProtocol6, testReceiveICMP},
- {"UDP", udp.NewProtocol, testReceiveUDP},
- }
-
- snmc := header.SolicitedNodeAddr(addr2)
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory},
- })
- e := channel.New(1, header.IPv6MinimumMTU, linkAddr1)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv6EmptySubnet,
- NIC: nicID,
- },
- })
-
- // Should not receive a packet destined to the solicited node address of
- // addr2/addr3 yet as we haven't added those addresses.
- test.rxf(t, s, e, addr1, snmc, 0)
-
- protocolAddr2 := tcpip.ProtocolAddress{
- Protocol: ProtocolNumber,
- AddressWithPrefix: addr2.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr2, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr2, err)
- }
-
- // Should receive a packet destined to the solicited node address of
- // addr2/addr3 now that we have added added addr2.
- test.rxf(t, s, e, addr1, snmc, 1)
-
- protocolAddr3 := tcpip.ProtocolAddress{
- Protocol: ProtocolNumber,
- AddressWithPrefix: addr3.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr3, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr3, err)
- }
-
- // Should still receive a packet destined to the solicited node address of
- // addr2/addr3 now that we have added addr3.
- test.rxf(t, s, e, addr1, snmc, 2)
-
- if err := s.RemoveAddress(nicID, addr2); err != nil {
- t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr2, err)
- }
-
- // Should still receive a packet destined to the solicited node address of
- // addr2/addr3 now that we have removed addr2.
- test.rxf(t, s, e, addr1, snmc, 3)
-
- // Make sure addr3's endpoint does not get removed from the NIC by
- // incrementing its reference count with a route.
- r, err := s.FindRoute(nicID, addr3, addr4, ProtocolNumber, false)
- if err != nil {
- t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr3, addr4, ProtocolNumber, err)
- }
- defer r.Release()
-
- if err := s.RemoveAddress(nicID, addr3); err != nil {
- t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr3, err)
- }
-
- // Should not receive a packet destined to the solicited node address of
- // addr2/addr3 yet as both of them got removed, even though a route using
- // addr3 exists.
- test.rxf(t, s, e, addr1, snmc, 3)
- })
- }
-}
-
-// TestAddIpv6Address tests adding IPv6 addresses.
-func TestAddIpv6Address(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- addr tcpip.Address
- }{
- // This test is in response to b/140943433.
- {
- "Nil",
- tcpip.Address([]byte(nil)),
- },
- {
- "ValidUnicast",
- addr1,
- },
- {
- "ValidLinkLocalUnicast",
- lladdr0,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ProtocolNumber,
- AddressWithPrefix: test.addr.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
-
- if addr, err := s.GetMainNICAddress(nicID, ProtocolNumber); err != nil {
- t.Fatalf("stack.GetMainNICAddress(%d, %d): %s", nicID, ProtocolNumber, err)
- } else if addr.Address != test.addr {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, ProtocolNumber, addr.Address, test.addr)
- }
- })
- }
-}
-
-func TestReceiveIPv6ExtHdrs(t *testing.T) {
- tests := []struct {
- name string
- extHdr func(nextHdr uint8) ([]byte, uint8)
- shouldAccept bool
- countersToBeIncremented func(*tcpip.Stats) []*tcpip.StatCounter
- // Should we expect an ICMP response and if so, with what contents?
- expectICMP bool
- ICMPType header.ICMPv6Type
- ICMPCode header.ICMPv6Code
- pointer uint32
- multicast bool
- }{
- {
- name: "None",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return nil, nextHdr },
- shouldAccept: true,
- expectICMP: false,
- },
- {
- name: "hopbyhop with router alert option",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 0,
-
- // Router Alert option.
- 5, 2, 0, 0, 0, 0,
- }, hopByHopExtHdrID
- },
- shouldAccept: true,
- countersToBeIncremented: func(stats *tcpip.Stats) []*tcpip.StatCounter {
- return []*tcpip.StatCounter{stats.IP.OptionRouterAlertReceived}
- },
- },
- {
- name: "hopbyhop with two router alert options",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 1,
-
- // Router Alert option.
- 5, 2, 0, 0, 0, 0,
-
- // Router Alert option.
- 5, 2, 0, 0, 0, 0, 0, 0,
- }, hopByHopExtHdrID
- },
- shouldAccept: false,
- countersToBeIncremented: func(stats *tcpip.Stats) []*tcpip.StatCounter {
- return []*tcpip.StatCounter{
- stats.IP.OptionRouterAlertReceived,
- stats.IP.MalformedPacketsReceived,
- }
- },
- },
- {
- name: "hopbyhop with unknown option skippable action",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Skippable unknown.
- 62, 6, 1, 2, 3, 4, 5, 6,
- }, hopByHopExtHdrID
- },
- shouldAccept: true,
- },
- {
- name: "hopbyhop with unknown option discard action",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard unknown.
- 127, 6, 1, 2, 3, 4, 5, 6,
- }, hopByHopExtHdrID
- },
- shouldAccept: false,
- expectICMP: false,
- },
- {
- name: "hopbyhop with unknown option discard and send icmp action (unicast)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard & send ICMP if option is unknown.
- 191, 6, 1, 2, 3, 4, 5, 6,
- //^ Unknown option.
- }, hopByHopExtHdrID
- },
- shouldAccept: false,
- expectICMP: true,
- ICMPType: header.ICMPv6ParamProblem,
- ICMPCode: header.ICMPv6UnknownOption,
- pointer: header.IPv6FixedHeaderSize + 8,
- },
- {
- name: "hopbyhop with unknown option discard and send icmp action (multicast)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard & send ICMP if option is unknown.
- 191, 6, 1, 2, 3, 4, 5, 6,
- //^ Unknown option.
- }, hopByHopExtHdrID
- },
- multicast: true,
- shouldAccept: false,
- expectICMP: true,
- ICMPType: header.ICMPv6ParamProblem,
- ICMPCode: header.ICMPv6UnknownOption,
- pointer: header.IPv6FixedHeaderSize + 8,
- },
- {
- name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard & send ICMP unless packet is for multicast destination if
- // option is unknown.
- 255, 6, 1, 2, 3, 4, 5, 6,
- //^ Unknown option.
- }, hopByHopExtHdrID
- },
- expectICMP: true,
- ICMPType: header.ICMPv6ParamProblem,
- ICMPCode: header.ICMPv6UnknownOption,
- pointer: header.IPv6FixedHeaderSize + 8,
- },
- {
- name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard & send ICMP unless packet is for multicast destination if
- // option is unknown.
- 255, 6, 1, 2, 3, 4, 5, 6,
- //^ Unknown option.
- }, hopByHopExtHdrID
- },
- multicast: true,
- shouldAccept: false,
- expectICMP: false,
- },
- {
- name: "routing with zero segments left",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 0,
- 1, 0, 2, 3, 4, 5,
- }, routingExtHdrID
- },
- shouldAccept: true,
- },
- {
- name: "routing with non-zero segments left",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 0,
- 1, 1, 2, 3, 4, 5,
- }, routingExtHdrID
- },
- shouldAccept: false,
- expectICMP: true,
- ICMPType: header.ICMPv6ParamProblem,
- ICMPCode: header.ICMPv6ErroneousHeader,
- pointer: header.IPv6FixedHeaderSize + 2,
- },
- {
- name: "atomic fragment with zero ID",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 0,
- 0, 0, 0, 0, 0, 0,
- }, fragmentExtHdrID
- },
- shouldAccept: true,
- },
- {
- name: "atomic fragment with non-zero ID",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 0,
- 0, 0, 1, 2, 3, 4,
- }, fragmentExtHdrID
- },
- shouldAccept: true,
- expectICMP: false,
- },
- {
- name: "fragment",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 0,
- 1, 0, 1, 2, 3, 4,
- }, fragmentExtHdrID
- },
- shouldAccept: false,
- expectICMP: false,
- },
- {
- name: "No next header",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return nil, noNextHdrID
- },
- shouldAccept: false,
- expectICMP: false,
- },
- {
- name: "unknown next header (first)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 0, 63, 4, 1, 2, 3, 4,
- }, unknownHdrID
- },
- shouldAccept: false,
- expectICMP: true,
- ICMPType: header.ICMPv6ParamProblem,
- ICMPCode: header.ICMPv6UnknownHeader,
- pointer: header.IPv6NextHeaderOffset,
- },
- {
- name: "unknown next header (not first)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- unknownHdrID, 0,
- 63, 4, 1, 2, 3, 4,
- }, hopByHopExtHdrID
- },
- shouldAccept: false,
- expectICMP: true,
- ICMPType: header.ICMPv6ParamProblem,
- ICMPCode: header.ICMPv6UnknownHeader,
- pointer: header.IPv6FixedHeaderSize,
- },
- {
- name: "destination with unknown option skippable action",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Skippable unknown.
- 62, 6, 1, 2, 3, 4, 5, 6,
- }, destinationExtHdrID
- },
- shouldAccept: true,
- expectICMP: false,
- },
- {
- name: "destination with unknown option discard action",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard unknown.
- 127, 6, 1, 2, 3, 4, 5, 6,
- }, destinationExtHdrID
- },
- shouldAccept: false,
- expectICMP: false,
- },
- {
- name: "destination with unknown option discard and send icmp action (unicast)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard & send ICMP if option is unknown.
- 191, 6, 1, 2, 3, 4, 5, 6,
- //^ 191 is an unknown option.
- }, destinationExtHdrID
- },
- shouldAccept: false,
- expectICMP: true,
- ICMPType: header.ICMPv6ParamProblem,
- ICMPCode: header.ICMPv6UnknownOption,
- pointer: header.IPv6FixedHeaderSize + 8,
- },
- {
- name: "destination with unknown option discard and send icmp action (muilticast)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard & send ICMP if option is unknown.
- 191, 6, 1, 2, 3, 4, 5, 6,
- //^ 191 is an unknown option.
- }, destinationExtHdrID
- },
- multicast: true,
- shouldAccept: false,
- expectICMP: true,
- ICMPType: header.ICMPv6ParamProblem,
- ICMPCode: header.ICMPv6UnknownOption,
- pointer: header.IPv6FixedHeaderSize + 8,
- },
- {
- name: "destination with unknown option discard and send icmp action unless multicast dest (unicast)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard & send ICMP unless packet is for multicast destination if
- // option is unknown.
- 255, 6, 1, 2, 3, 4, 5, 6,
- //^ 255 is unknown.
- }, destinationExtHdrID
- },
- shouldAccept: false,
- expectICMP: true,
- ICMPType: header.ICMPv6ParamProblem,
- ICMPCode: header.ICMPv6UnknownOption,
- pointer: header.IPv6FixedHeaderSize + 8,
- },
- {
- name: "destination with unknown option discard and send icmp action unless multicast dest (multicast)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard & send ICMP unless packet is for multicast destination if
- // option is unknown.
- 255, 6, 1, 2, 3, 4, 5, 6,
- //^ 255 is unknown.
- }, destinationExtHdrID
- },
- shouldAccept: false,
- expectICMP: false,
- multicast: true,
- },
- {
- name: "atomic fragment - routing",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- // Fragment extension header.
- routingExtHdrID, 0, 0, 0, 1, 2, 3, 4,
-
- // Routing extension header.
- nextHdr, 0, 1, 0, 2, 3, 4, 5,
- }, fragmentExtHdrID
- },
- shouldAccept: true,
- },
- {
- name: "hop by hop (with skippable unknown) - routing",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- // Hop By Hop extension header with skippable unknown option.
- routingExtHdrID, 0, 62, 4, 1, 2, 3, 4,
-
- // Routing extension header.
- nextHdr, 0, 1, 0, 2, 3, 4, 5,
- }, hopByHopExtHdrID
- },
- shouldAccept: true,
- },
- {
- name: "routing - hop by hop (with skippable unknown)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- // Routing extension header.
- hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5,
- // ^^^ The HopByHop extension header may not appear after the first
- // extension header.
-
- // Hop By Hop extension header with skippable unknown option.
- nextHdr, 0, 62, 4, 1, 2, 3, 4,
- }, routingExtHdrID
- },
- shouldAccept: false,
- expectICMP: true,
- ICMPType: header.ICMPv6ParamProblem,
- ICMPCode: header.ICMPv6UnknownHeader,
- pointer: header.IPv6FixedHeaderSize,
- },
- {
- name: "routing - hop by hop (with send icmp unknown)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- // Routing extension header.
- hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5,
- // ^^^ The HopByHop extension header may not appear after the first
- // extension header.
-
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Skippable unknown.
- 191, 6, 1, 2, 3, 4, 5, 6,
- }, routingExtHdrID
- },
- shouldAccept: false,
- expectICMP: true,
- ICMPType: header.ICMPv6ParamProblem,
- ICMPCode: header.ICMPv6UnknownHeader,
- pointer: header.IPv6FixedHeaderSize,
- },
- {
- name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with skippable unknown)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- // Hop By Hop extension header with skippable unknown option.
- routingExtHdrID, 0, 62, 4, 1, 2, 3, 4,
-
- // Routing extension header.
- fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
-
- // Fragment extension header.
- destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4,
-
- // Destination extension header with skippable unknown option.
- nextHdr, 0, 63, 4, 1, 2, 3, 4,
- }, hopByHopExtHdrID
- },
- shouldAccept: true,
- },
- {
- name: "hopbyhop (with discard unknown) - routing - atomic fragment - destination (with skippable unknown)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- // Hop By Hop extension header with discard action for unknown option.
- routingExtHdrID, 0, 65, 4, 1, 2, 3, 4,
-
- // Routing extension header.
- fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
-
- // Fragment extension header.
- destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4,
-
- // Destination extension header with skippable unknown option.
- nextHdr, 0, 63, 4, 1, 2, 3, 4,
- }, hopByHopExtHdrID
- },
- shouldAccept: false,
- expectICMP: false,
- },
- {
- name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with discard unknown)",
- extHdr: func(nextHdr uint8) ([]byte, uint8) {
- return []byte{
- // Hop By Hop extension header with skippable unknown option.
- routingExtHdrID, 0, 62, 4, 1, 2, 3, 4,
-
- // Routing extension header.
- fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
-
- // Fragment extension header.
- destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4,
-
- // Destination extension header with discard action for unknown
- // option.
- nextHdr, 0, 65, 4, 1, 2, 3, 4,
- }, hopByHopExtHdrID
- },
- shouldAccept: false,
- expectICMP: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
- e := channel.New(1, header.IPv6MinimumMTU, linkAddr1)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ProtocolNumber,
- AddressWithPrefix: addr2.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
-
- // Add a default route so that a return packet knows where to go.
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv6EmptySubnet,
- NIC: nicID,
- },
- })
-
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.WritableEvents)
- defer wq.EventUnregister(&we)
- defer close(ch)
- ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ProtocolNumber, err)
- }
- defer ep.Close()
-
- bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80}
- if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("Bind(%+v): %s", bindAddr, err)
- }
-
- udpPayload := []byte{1, 2, 3, 4, 5, 6, 7, 8}
- udpLength := header.UDPMinimumSize + len(udpPayload)
- extHdrBytes, ipv6NextHdr := test.extHdr(uint8(header.UDPProtocolNumber))
- extHdrLen := len(extHdrBytes)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + extHdrLen + udpLength)
-
- // Serialize UDP message.
- u := header.UDP(hdr.Prepend(udpLength))
- u.Encode(&header.UDPFields{
- SrcPort: 5555,
- DstPort: 80,
- Length: uint16(udpLength),
- })
- copy(u.Payload(), udpPayload)
-
- dstAddr := tcpip.Address(addr2)
- if test.multicast {
- dstAddr = header.IPv6AllNodesMulticastAddress
- }
-
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, dstAddr, uint16(udpLength))
- sum = header.Checksum(udpPayload, sum)
- u.SetChecksum(^u.CalculateChecksum(sum))
-
- // Copy extension header bytes between the UDP message and the IPv6
- // fixed header.
- copy(hdr.Prepend(extHdrLen), extHdrBytes)
-
- // Serialize IPv6 fixed header.
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- // We're lying about transport protocol here to be able to generate
- // raw extension headers from the test definitions.
- TransportProtocol: tcpip.TransportProtocolNumber(ipv6NextHdr),
- HopLimit: 255,
- SrcAddr: addr1,
- DstAddr: dstAddr,
- })
-
- stats := s.Stats()
- var counters []*tcpip.StatCounter
- // Make sure that the counters we expect to be incremented are initially
- // set to zero.
- if fn := test.countersToBeIncremented; fn != nil {
- counters = fn(&stats)
- }
- for i := range counters {
- if got := counters[i].Value(); got != 0 {
- t.Errorf("before writing packet: got test.countersToBeIncremented(&stats)[%d].Value() = %d, want = 0", i, got)
- }
- }
-
- e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
-
- for i := range counters {
- if got := counters[i].Value(); got != 1 {
- t.Errorf("after writing packet: got test.countersToBeIncremented(&stats)[%d].Value() = %d, want = 1", i, got)
- }
- }
-
- udpReceiveStat := stats.UDP.PacketsReceived
- if !test.shouldAccept {
- if got := udpReceiveStat.Value(); got != 0 {
- t.Errorf("got UDP Rx Packets = %d, want = 0", got)
- }
-
- if !test.expectICMP {
- if p, ok := e.Read(); ok {
- t.Fatalf("unexpected packet received: %#v", p)
- }
- return
- }
-
- // ICMP required.
- p, ok := e.Read()
- if !ok {
- t.Fatalf("expected packet wasn't written out")
- }
-
- // Pack the output packet into a single buffer.View as the checkers
- // assume that.
- vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
- pkt := vv.ToView()
- if got, want := len(pkt), header.IPv6FixedHeaderSize+header.ICMPv6MinimumSize+hdr.UsedLength(); got != want {
- t.Fatalf("got an ICMP packet of size = %d, want = %d", got, want)
- }
-
- ipHdr := header.IPv6(pkt)
- checker.IPv6(t, ipHdr, checker.ICMPv6(
- checker.ICMPv6Type(test.ICMPType),
- checker.ICMPv6Code(test.ICMPCode)))
-
- // We know we are looking at no extension headers in the error ICMP
- // packets.
- icmpPkt := header.ICMPv6(ipHdr.Payload())
- // We know we sent small packets that won't be truncated when reflected
- // back to us.
- originalPacket := icmpPkt.Payload()
- if got, want := icmpPkt.TypeSpecific(), test.pointer; got != want {
- t.Errorf("unexpected ICMPv6 pointer, got = %d, want = %d\n", got, want)
- }
- if diff := cmp.Diff(hdr.View(), buffer.View(originalPacket)); diff != "" {
- t.Errorf("ICMPv6 payload mismatch (-want +got):\n%s", diff)
- }
- return
- }
-
- // Expect a UDP packet.
- if got := udpReceiveStat.Value(); got != 1 {
- t.Errorf("got UDP Rx Packets = %d, want = 1", got)
- }
- var buf bytes.Buffer
- result, err := ep.Read(&buf, tcpip.ReadOptions{})
- if err != nil {
- t.Fatalf("Read: %s", err)
- }
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: len(udpPayload),
- Total: len(udpPayload),
- }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
- t.Errorf("Read: unexpected result (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(udpPayload, buf.Bytes()); diff != "" {
- t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
- }
-
- // Should not have any more UDP packets.
- res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{})
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got Read = (%v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{})
- }
- })
- }
-}
-
-// fragmentData holds the IPv6 payload for a fragmented IPv6 packet.
-type fragmentData struct {
- srcAddr tcpip.Address
- dstAddr tcpip.Address
- nextHdr uint8
- data buffer.VectorisedView
-}
-
-func TestReceiveIPv6Fragments(t *testing.T) {
- const (
- udpPayload1Length = 256
- udpPayload2Length = 128
- // Used to test cases where the fragment blocks are not a multiple of
- // the fragment block size of 8 (RFC 8200 section 4.5).
- udpPayload3Length = 127
- udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize
- udpMaximumSizeMinus15 = header.UDPMaximumSize - 15
- fragmentExtHdrLen = 8
- // Note, not all routing extension headers will be 8 bytes but this test
- // uses 8 byte routing extension headers for most sub tests.
- routingExtHdrLen = 8
- )
-
- udpGen := func(payload []byte, multiplier uint8, src, dst tcpip.Address) buffer.View {
- payloadLen := len(payload)
- for i := 0; i < payloadLen; i++ {
- payload[i] = uint8(i) * multiplier
- }
-
- udpLength := header.UDPMinimumSize + payloadLen
-
- hdr := buffer.NewPrependable(udpLength)
- u := header.UDP(hdr.Prepend(udpLength))
- u.Encode(&header.UDPFields{
- SrcPort: 5555,
- DstPort: 80,
- Length: uint16(udpLength),
- })
- copy(u.Payload(), payload)
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(udpLength))
- sum = header.Checksum(payload, sum)
- u.SetChecksum(^u.CalculateChecksum(sum))
- return hdr.View()
- }
-
- var udpPayload1Addr1ToAddr2Buf [udpPayload1Length]byte
- udpPayload1Addr1ToAddr2 := udpPayload1Addr1ToAddr2Buf[:]
- ipv6Payload1Addr1ToAddr2 := udpGen(udpPayload1Addr1ToAddr2, 1, addr1, addr2)
-
- var udpPayload1Addr3ToAddr2Buf [udpPayload1Length]byte
- udpPayload1Addr3ToAddr2 := udpPayload1Addr3ToAddr2Buf[:]
- ipv6Payload1Addr3ToAddr2 := udpGen(udpPayload1Addr3ToAddr2, 4, addr3, addr2)
-
- var udpPayload2Addr1ToAddr2Buf [udpPayload2Length]byte
- udpPayload2Addr1ToAddr2 := udpPayload2Addr1ToAddr2Buf[:]
- ipv6Payload2Addr1ToAddr2 := udpGen(udpPayload2Addr1ToAddr2, 2, addr1, addr2)
-
- var udpPayload3Addr1ToAddr2Buf [udpPayload3Length]byte
- udpPayload3Addr1ToAddr2 := udpPayload3Addr1ToAddr2Buf[:]
- ipv6Payload3Addr1ToAddr2 := udpGen(udpPayload3Addr1ToAddr2, 3, addr1, addr2)
-
- var udpPayload4Addr1ToAddr2Buf [udpPayload4Length]byte
- udpPayload4Addr1ToAddr2 := udpPayload4Addr1ToAddr2Buf[:]
- ipv6Payload4Addr1ToAddr2 := udpGen(udpPayload4Addr1ToAddr2, 4, addr1, addr2)
-
- tests := []struct {
- name string
- expectedPayload []byte
- fragments []fragmentData
- expectedPayloads [][]byte
- }{
- {
- name: "No fragmentation",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: uint8(header.UDPProtocolNumber),
- data: ipv6Payload1Addr1ToAddr2.ToVectorisedView(),
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
- },
- {
- name: "Atomic fragment",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2),
- []buffer.View{
- // Fragment extension header.
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0},
-
- ipv6Payload1Addr1ToAddr2,
- },
- ),
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
- },
- {
- name: "Atomic fragment with size not a multiple of fragment block size",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2),
- []buffer.View{
- // Fragment extension header.
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0},
-
- ipv6Payload3Addr1ToAddr2,
- },
- ),
- },
- },
- expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2},
- },
- {
- name: "Two fragments",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
-
- ipv6Payload1Addr1ToAddr2[:64],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 8, More = false, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1},
-
- ipv6Payload1Addr1ToAddr2[64:],
- },
- ),
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
- },
- {
- name: "Two fragments out of order",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 8, More = false, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1},
-
- ipv6Payload1Addr1ToAddr2[64:],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
-
- ipv6Payload1Addr1ToAddr2[:64],
- },
- ),
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
- },
- {
- name: "Two fragments with different Next Header values",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
-
- ipv6Payload1Addr1ToAddr2[:64],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 8, More = false, ID = 1
- // NextHeader value is different than the one in the first fragment, so
- // this NextHeader should be ignored.
- []byte{uint8(header.IPv6NoNextHeaderIdentifier), 0, 0, 64, 0, 0, 0, 1},
-
- ipv6Payload1Addr1ToAddr2[64:],
- },
- ),
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
- },
- {
- name: "Two fragments with last fragment size not a multiple of fragment block size",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
-
- ipv6Payload3Addr1ToAddr2[:64],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2)-64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 8, More = false, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1},
-
- ipv6Payload3Addr1ToAddr2[64:],
- },
- ),
- },
- },
- expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2},
- },
- {
- name: "Two fragments with first fragment size not a multiple of fragment block size",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+63,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
-
- ipv6Payload3Addr1ToAddr2[:63],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2)-63,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 8, More = false, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1},
-
- ipv6Payload3Addr1ToAddr2[63:],
- },
- ),
- },
- },
- expectedPayloads: nil,
- },
- {
- name: "Two fragments with different IDs",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
-
- ipv6Payload1Addr1ToAddr2[:64],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 8, More = false, ID = 2
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2},
-
- ipv6Payload1Addr1ToAddr2[64:],
- },
- ),
- },
- },
- expectedPayloads: nil,
- },
- {
- name: "Two fragments reassembled into a maximum UDP packet",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+udpMaximumSizeMinus15,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
-
- ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = udpMaximumSizeMinus15/8, More = false, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0,
- udpMaximumSizeMinus15 >> 8,
- udpMaximumSizeMinus15 & 0xff,
- 0, 0, 0, 1},
-
- ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:],
- },
- ),
- },
- },
- expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2},
- },
- {
- name: "Two fragments with MF flag reassembled into a maximum UDP packet",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+udpMaximumSizeMinus15,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
-
- ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = udpMaximumSizeMinus15/8, More = true, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0,
- udpMaximumSizeMinus15 >> 8,
- (udpMaximumSizeMinus15 & 0xff) + 1,
- 0, 0, 0, 1},
-
- ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:],
- },
- ),
- },
- },
- expectedPayloads: nil,
- },
- {
- name: "Two fragments with per-fragment routing header with zero segments left",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: routingExtHdrID,
- data: buffer.NewVectorisedView(
- routingExtHdrLen+fragmentExtHdrLen+64,
- []buffer.View{
- // Routing extension header.
- //
- // Segments left = 0.
- []byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5},
-
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
-
- ipv6Payload1Addr1ToAddr2[:64],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: routingExtHdrID,
- data: buffer.NewVectorisedView(
- routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
- []buffer.View{
- // Routing extension header.
- //
- // Segments left = 0.
- []byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5},
-
- // Fragment extension header.
- //
- // Fragment offset = 8, More = false, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1},
-
- ipv6Payload1Addr1ToAddr2[64:],
- },
- ),
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
- },
- {
- name: "Two fragments with per-fragment routing header with non-zero segments left",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: routingExtHdrID,
- data: buffer.NewVectorisedView(
- routingExtHdrLen+fragmentExtHdrLen+64,
- []buffer.View{
- // Routing extension header.
- //
- // Segments left = 1.
- []byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5},
-
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
-
- ipv6Payload1Addr1ToAddr2[:64],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: routingExtHdrID,
- data: buffer.NewVectorisedView(
- routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
- []buffer.View{
- // Routing extension header.
- //
- // Segments left = 1.
- []byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5},
-
- // Fragment extension header.
- //
- // Fragment offset = 9, More = false, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1},
-
- ipv6Payload1Addr1ToAddr2[64:],
- },
- ),
- },
- },
- expectedPayloads: nil,
- },
- {
- name: "Two fragments with routing header with zero segments left",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- routingExtHdrLen+fragmentExtHdrLen+64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1},
-
- // Routing extension header.
- //
- // Segments left = 0.
- []byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5},
-
- ipv6Payload1Addr1ToAddr2[:64],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 9, More = false, ID = 1
- []byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1},
-
- ipv6Payload1Addr1ToAddr2[64:],
- },
- ),
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
- },
- {
- name: "Two fragments with routing header with non-zero segments left",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- routingExtHdrLen+fragmentExtHdrLen+64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1},
-
- // Routing extension header.
- //
- // Segments left = 1.
- []byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5},
-
- ipv6Payload1Addr1ToAddr2[:64],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 9, More = false, ID = 1
- []byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1},
-
- ipv6Payload1Addr1ToAddr2[64:],
- },
- ),
- },
- },
- expectedPayloads: nil,
- },
- {
- name: "Two fragments with routing header with zero segments left across fragments",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- // The length of this payload is fragmentExtHdrLen+8 because the
- // first 8 bytes of the 16 byte routing extension header is in
- // this fragment.
- fragmentExtHdrLen+8,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1},
-
- // Routing extension header (part 1)
- //
- // Segments left = 0.
- []byte{uint8(header.UDPProtocolNumber), 1, 1, 0, 2, 3, 4, 5},
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- // The length of this payload is
- // fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2) because the last 8 bytes of
- // the 16 byte routing extension header is in this fagment.
- fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2),
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 1, More = false, ID = 1
- []byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1},
-
- // Routing extension header (part 2)
- []byte{6, 7, 8, 9, 10, 11, 12, 13},
-
- ipv6Payload1Addr1ToAddr2,
- },
- ),
- },
- },
- expectedPayloads: nil,
- },
- {
- name: "Two fragments with routing header with non-zero segments left across fragments",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- // The length of this payload is fragmentExtHdrLen+8 because the
- // first 8 bytes of the 16 byte routing extension header is in
- // this fragment.
- fragmentExtHdrLen+8,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1},
-
- // Routing extension header (part 1)
- //
- // Segments left = 1.
- []byte{uint8(header.UDPProtocolNumber), 1, 1, 1, 2, 3, 4, 5},
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- // The length of this payload is
- // fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2) because the last 8 bytes of
- // the 16 byte routing extension header is in this fagment.
- fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2),
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 1, More = false, ID = 1
- []byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1},
-
- // Routing extension header (part 2)
- []byte{6, 7, 8, 9, 10, 11, 12, 13},
-
- ipv6Payload1Addr1ToAddr2,
- },
- ),
- },
- },
- expectedPayloads: nil,
- },
- // As per RFC 6946, IPv6 atomic fragments MUST NOT interfere with "normal"
- // fragmented traffic.
- {
- name: "Two fragments with atomic",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
-
- ipv6Payload1Addr1ToAddr2[:64],
- },
- ),
- },
- // This fragment has the same ID as the other fragments but is an atomic
- // fragment. It should not interfere with the other fragments.
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload2Addr1ToAddr2),
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = false, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1},
-
- ipv6Payload2Addr1ToAddr2,
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 8, More = false, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1},
-
- ipv6Payload1Addr1ToAddr2[64:],
- },
- ),
- },
- },
- expectedPayloads: [][]byte{udpPayload2Addr1ToAddr2, udpPayload1Addr1ToAddr2},
- },
- {
- name: "Two interleaved fragmented packets",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
-
- ipv6Payload1Addr1ToAddr2[:64],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+32,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 2
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2},
-
- ipv6Payload2Addr1ToAddr2[:32],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 8, More = false, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1},
-
- ipv6Payload1Addr1ToAddr2[64:],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload2Addr1ToAddr2)-32,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 4, More = false, ID = 2
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2},
-
- ipv6Payload2Addr1ToAddr2[32:],
- },
- ),
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload2Addr1ToAddr2},
- },
- {
- name: "Two interleaved fragmented packets from different sources but with same ID",
- fragments: []fragmentData{
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
-
- ipv6Payload1Addr1ToAddr2[:64],
- },
- ),
- },
- {
- srcAddr: addr3,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+32,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 0, More = true, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1},
-
- ipv6Payload1Addr3ToAddr2[:32],
- },
- ),
- },
- {
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 8, More = false, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1},
-
- ipv6Payload1Addr1ToAddr2[64:],
- },
- ),
- },
- {
- srcAddr: addr3,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-32,
- []buffer.View{
- // Fragment extension header.
- //
- // Fragment offset = 4, More = false, ID = 1
- []byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 1},
-
- ipv6Payload1Addr3ToAddr2[32:],
- },
- ),
- },
- },
- expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload1Addr3ToAddr2},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
- e := channel.New(0, header.IPv6MinimumMTU, linkAddr1)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ProtocolNumber,
- AddressWithPrefix: addr2.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
-
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
- defer close(ch)
- ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ProtocolNumber, err)
- }
- defer ep.Close()
-
- bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80}
- if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("Bind(%+v): %s", bindAddr, err)
- }
-
- for _, f := range test.fragments {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize)
-
- // Serialize IPv6 fixed header.
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(f.data.Size()),
- // We're lying about transport protocol here so that we can generate
- // raw extension headers for the tests.
- TransportProtocol: tcpip.TransportProtocolNumber(f.nextHdr),
- HopLimit: 255,
- SrcAddr: f.srcAddr,
- DstAddr: f.dstAddr,
- })
-
- vv := hdr.View().ToVectorisedView()
- vv.Append(f.data)
-
- e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- }))
- }
-
- if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want {
- t.Errorf("got UDP Rx Packets = %d, want = %d", got, want)
- }
-
- for i, p := range test.expectedPayloads {
- var buf bytes.Buffer
- _, err := ep.Read(&buf, tcpip.ReadOptions{})
- if err != nil {
- t.Fatalf("(i=%d) Read: %s", i, err)
- }
- if diff := cmp.Diff(p, buf.Bytes()); diff != "" {
- t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff)
- }
- }
-
- res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{})
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{})
- }
- })
- }
-}
-
-func TestInvalidIPv6Fragments(t *testing.T) {
- const (
- addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- nicID = 1
- hoplimit = 255
- ident = 1
- data = "TEST_INVALID_IPV6_FRAGMENTS"
- )
-
- type fragmentData struct {
- ipv6Fields header.IPv6Fields
- ipv6FragmentFields header.IPv6SerializableFragmentExtHdr
- payload []byte
- }
-
- tests := []struct {
- name string
- fragments []fragmentData
- wantMalformedIPPackets uint64
- wantMalformedFragments uint64
- expectICMP bool
- expectICMPType header.ICMPv6Type
- expectICMPCode header.ICMPv6Code
- expectICMPTypeSpecific uint32
- }{
- {
- name: "fragment size is not a multiple of 8 and the M flag is true",
- fragments: []fragmentData{
- {
- ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 9,
- TransportProtocol: header.UDPProtocolNumber,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
- FragmentOffset: 0 >> 3,
- M: true,
- Identification: ident,
- },
- payload: []byte(data)[:9],
- },
- },
- wantMalformedIPPackets: 1,
- wantMalformedFragments: 1,
- expectICMP: true,
- expectICMPType: header.ICMPv6ParamProblem,
- expectICMPCode: header.ICMPv6ErroneousHeader,
- expectICMPTypeSpecific: header.IPv6PayloadLenOffset,
- },
- {
- name: "fragments reassembled into a payload exceeding the max IPv6 payload size",
- fragments: []fragmentData{
- {
- ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- TransportProtocol: header.UDPProtocolNumber,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
- FragmentOffset: ((header.IPv6MaximumPayloadSize + 1) - 16) >> 3,
- M: false,
- Identification: ident,
- },
- payload: []byte(data)[:16],
- },
- },
- wantMalformedIPPackets: 1,
- wantMalformedFragments: 1,
- expectICMP: true,
- expectICMPType: header.ICMPv6ParamProblem,
- expectICMPCode: header.ICMPv6ErroneousHeader,
- expectICMPTypeSpecific: header.IPv6MinimumSize + 2, /* offset for 'Fragment Offset' in the fragment header */
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{
- NewProtocol,
- },
- })
- e := channel.New(1, 1500, linkAddr1)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ProtocolNumber,
- AddressWithPrefix: addr2.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv6EmptySubnet,
- NIC: nicID,
- }})
-
- var expectICMPPayload buffer.View
- for _, f := range test.fragments {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)
-
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize))
- encodeArgs := f.ipv6Fields
- encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields)
- ip.Encode(&encodeArgs)
-
- vv := hdr.View().ToVectorisedView()
- vv.AppendView(f.payload)
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- })
-
- if test.expectICMP {
- expectICMPPayload = stack.PayloadSince(pkt.NetworkHeader())
- }
-
- e.InjectInbound(ProtocolNumber, pkt)
- }
-
- if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want {
- t.Errorf("got Stats.IP.MalformedPacketsReceived = %d, want = %d", got, want)
- }
- if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want {
- t.Errorf("got Stats.IP.MalformedFragmentsReceived = %d, want = %d", got, want)
- }
-
- reply, ok := e.Read()
- if !test.expectICMP {
- if ok {
- t.Fatalf("unexpected ICMP error message received: %#v", reply)
- }
- return
- }
- if !ok {
- t.Fatal("expected ICMP error message missing")
- }
-
- checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
- checker.SrcAddr(addr2),
- checker.DstAddr(addr1),
- checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectICMPPayload.Size())),
- checker.ICMPv6(
- checker.ICMPv6Type(test.expectICMPType),
- checker.ICMPv6Code(test.expectICMPCode),
- checker.ICMPv6TypeSpecific(test.expectICMPTypeSpecific),
- checker.ICMPv6Payload(expectICMPPayload),
- ),
- )
- })
- }
-}
-
-func TestFragmentReassemblyTimeout(t *testing.T) {
- const (
- addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- nicID = 1
- hoplimit = 255
- ident = 1
- data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT"
- )
-
- type fragmentData struct {
- ipv6Fields header.IPv6Fields
- ipv6FragmentFields header.IPv6SerializableFragmentExtHdr
- payload []byte
- }
-
- tests := []struct {
- name string
- fragments []fragmentData
- expectICMP bool
- }{
- {
- name: "first fragment only",
- fragments: []fragmentData{
- {
- ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- TransportProtocol: header.UDPProtocolNumber,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
- FragmentOffset: 0,
- M: true,
- Identification: ident,
- },
- payload: []byte(data)[:16],
- },
- },
- expectICMP: true,
- },
- {
- name: "two first fragments",
- fragments: []fragmentData{
- {
- ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- TransportProtocol: header.UDPProtocolNumber,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
- FragmentOffset: 0,
- M: true,
- Identification: ident,
- },
- payload: []byte(data)[:16],
- },
- {
- ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- TransportProtocol: header.UDPProtocolNumber,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
- FragmentOffset: 0,
- M: true,
- Identification: ident,
- },
- payload: []byte(data)[:16],
- },
- },
- expectICMP: true,
- },
- {
- name: "second fragment only",
- fragments: []fragmentData{
- {
- ipv6Fields: header.IPv6Fields{
- PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
- TransportProtocol: header.UDPProtocolNumber,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
- FragmentOffset: 8,
- M: false,
- Identification: ident,
- },
- payload: []byte(data)[16:],
- },
- },
- expectICMP: false,
- },
- {
- name: "two fragments with a gap",
- fragments: []fragmentData{
- {
- ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- TransportProtocol: header.UDPProtocolNumber,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
- FragmentOffset: 0,
- M: true,
- Identification: ident,
- },
- payload: []byte(data)[:16],
- },
- {
- ipv6Fields: header.IPv6Fields{
- PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
- TransportProtocol: header.UDPProtocolNumber,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
- FragmentOffset: 8,
- M: false,
- Identification: ident,
- },
- payload: []byte(data)[16:],
- },
- },
- expectICMP: true,
- },
- {
- name: "two fragments with a gap in reverse order",
- fragments: []fragmentData{
- {
- ipv6Fields: header.IPv6Fields{
- PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
- TransportProtocol: header.UDPProtocolNumber,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
- FragmentOffset: 8,
- M: false,
- Identification: ident,
- },
- payload: []byte(data)[16:],
- },
- {
- ipv6Fields: header.IPv6Fields{
- PayloadLength: header.IPv6FragmentHeaderSize + 16,
- TransportProtocol: header.UDPProtocolNumber,
- HopLimit: hoplimit,
- SrcAddr: addr1,
- DstAddr: addr2,
- },
- ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{
- FragmentOffset: 0,
- M: true,
- Identification: ident,
- },
- payload: []byte(data)[:16],
- },
- },
- expectICMP: true,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{
- NewProtocol,
- },
- Clock: clock,
- })
-
- e := channel.New(1, 1500, linkAddr1)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ProtocolNumber,
- AddressWithPrefix: addr2.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv6EmptySubnet,
- NIC: nicID,
- }})
-
- var firstFragmentSent buffer.View
- for _, f := range test.fragments {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)
-
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize))
- encodeArgs := f.ipv6Fields
- encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields)
- ip.Encode(&encodeArgs)
-
- fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:])
-
- vv := hdr.View().ToVectorisedView()
- vv.AppendView(f.payload)
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- })
-
- if firstFragmentSent == nil && fragHDR.FragmentOffset() == 0 {
- firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader())
- }
-
- e.InjectInbound(ProtocolNumber, pkt)
- }
-
- clock.Advance(ReassembleTimeout)
-
- reply, ok := e.Read()
- if !test.expectICMP {
- if ok {
- t.Fatalf("unexpected ICMP error message received: %#v", reply)
- }
- return
- }
- if !ok {
- t.Fatal("expected ICMP error message missing")
- }
- if firstFragmentSent == nil {
- t.Fatalf("unexpected ICMP error message received: %#v", reply)
- }
-
- checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
- checker.SrcAddr(addr2),
- checker.DstAddr(addr1),
- checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+firstFragmentSent.Size())),
- checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6TimeExceeded),
- checker.ICMPv6Code(header.ICMPv6ReassemblyTimeout),
- checker.ICMPv6Payload(firstFragmentSent),
- ),
- )
- })
- }
-}
-
-func TestWriteStats(t *testing.T) {
- const nPackets = 3
- tests := []struct {
- name string
- setup func(*testing.T, *stack.Stack)
- allowPackets int
- expectSent int
- expectOutputDropped int
- expectPostroutingDropped int
- expectWritten int
- }{
- {
- name: "Accept all",
- // No setup needed, tables accept everything by default.
- setup: func(*testing.T, *stack.Stack) {},
- allowPackets: math.MaxInt32,
- expectSent: nPackets,
- expectOutputDropped: 0,
- expectPostroutingDropped: 0,
- expectWritten: nPackets,
- }, {
- name: "Accept all with error",
- // No setup needed, tables accept everything by default.
- setup: func(*testing.T, *stack.Stack) {},
- allowPackets: nPackets - 1,
- expectSent: nPackets - 1,
- expectOutputDropped: 0,
- expectPostroutingDropped: 0,
- expectWritten: nPackets - 1,
- }, {
- name: "Drop all with Output chain",
- setup: func(t *testing.T, stk *stack.Stack) {
- // Install Output DROP rule.
- ipt := stk.IPTables()
- filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Output]
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("failed to replace table: %v", err)
- }
- },
- allowPackets: math.MaxInt32,
- expectSent: 0,
- expectOutputDropped: nPackets,
- expectPostroutingDropped: 0,
- expectWritten: nPackets,
- }, {
- name: "Drop all with Postrouting chain",
- setup: func(t *testing.T, stk *stack.Stack) {
- // Install Output DROP rule.
- ipt := stk.IPTables()
- filter := ipt.GetTable(stack.NATID, true /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Postrouting]
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("failed to replace table: %v", err)
- }
- },
- allowPackets: math.MaxInt32,
- expectSent: 0,
- expectOutputDropped: 0,
- expectPostroutingDropped: nPackets,
- expectWritten: nPackets,
- }, {
- name: "Drop some with Output chain",
- setup: func(t *testing.T, stk *stack.Stack) {
- // Install Output DROP rule that matches only 1
- // of the 3 packets.
- ipt := stk.IPTables()
- filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
- // We'll match and DROP the last packet.
- ruleIdx := filter.BuiltinChains[stack.Output]
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
- // Make sure the next rule is ACCEPT.
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("failed to replace table: %v", err)
- }
- },
- allowPackets: math.MaxInt32,
- expectSent: nPackets - 1,
- expectOutputDropped: 1,
- expectPostroutingDropped: 0,
- expectWritten: nPackets,
- }, {
- name: "Drop some with Postrouting chain",
- setup: func(t *testing.T, stk *stack.Stack) {
- // Install Postrouting DROP rule that matches only 1
- // of the 3 packets.
- ipt := stk.IPTables()
- filter := ipt.GetTable(stack.NATID, true /* ipv6 */)
- // We'll match and DROP the last packet.
- ruleIdx := filter.BuiltinChains[stack.Postrouting]
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
- // Make sure the next rule is ACCEPT.
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("failed to replace table: %v", err)
- }
- },
- allowPackets: math.MaxInt32,
- expectSent: nPackets - 1,
- expectOutputDropped: 0,
- expectPostroutingDropped: 1,
- expectWritten: nPackets,
- },
- }
-
- writers := []struct {
- name string
- writePackets func(*stack.Route, stack.PacketBufferList) (int, tcpip.Error)
- }{
- {
- name: "WritePacket",
- writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) {
- nWritten := 0
- for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- if err := rt.WritePacket(stack.NetworkHeaderParams{}, pkt); err != nil {
- return nWritten, err
- }
- nWritten++
- }
- return nWritten, nil
- },
- }, {
- name: "WritePackets",
- writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) {
- return rt.WritePackets(pkts, stack.NetworkHeaderParams{})
- },
- },
- }
-
- for _, writer := range writers {
- t.Run(writer.name, func(t *testing.T) {
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- ep := iptestutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets)
- rt := buildRoute(t, ep)
- var pkts stack.PacketBufferList
- for i := 0; i < nPackets; i++ {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()),
- Data: buffer.NewView(0).ToVectorisedView(),
- })
- pkt.TransportHeader().Push(header.UDPMinimumSize)
- pkts.PushBack(pkt)
- }
-
- test.setup(t, rt.Stack())
-
- nWritten, _ := writer.writePackets(rt, pkts)
-
- if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
- t.Errorf("got rt.Stats().IP.PacketsSent.Value() = %d, want = %d", got, test.expectSent)
- }
- if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectOutputDropped {
- t.Errorf("got rt.Stats().IP.IPTablesOutputDropped.Value() = %d, want = %d", got, test.expectOutputDropped)
- }
- if got := int(rt.Stats().IP.IPTablesPostroutingDropped.Value()); got != test.expectPostroutingDropped {
- t.Errorf("got r.Stats().IP.IPTablesPostroutingDropped.Value() = %d, want = %d", got, test.expectPostroutingDropped)
- }
- if nWritten != test.expectWritten {
- t.Errorf("got nWritten = %d, want = %d", nWritten, test.expectWritten)
- }
- })
- }
- })
- }
-}
-
-func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatalf("CreateNIC(1, _) failed: %s", err)
- }
- const (
- src = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- dst = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- )
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ProtocolNumber,
- AddressWithPrefix: src.WithPrefix(),
- }
- if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
- }
- {
- mask := tcpip.AddressMask("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff")
- subnet, err := tcpip.NewSubnet(dst, mask)
- if err != nil {
- t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err)
- }
- s.SetRouteTable([]tcpip.Route{{
- Destination: subnet,
- NIC: 1,
- }})
- }
- rt, err := s.FindRoute(1, src, dst, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s, want = nil", src, dst, ProtocolNumber, err)
- }
- return rt
-}
-
-// limitedMatcher is an iptables matcher that matches after a certain number of
-// packets are checked against it.
-type limitedMatcher struct {
- limit int
-}
-
-// Name implements Matcher.Name.
-func (*limitedMatcher) Name() string {
- return "limitedMatcher"
-}
-
-// Match implements Matcher.Match.
-func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) (bool, bool) {
- if lm.limit == 0 {
- return true, false
- }
- lm.limit--
- return false, false
-}
-
-func knownNICIDs(proto *protocol) []tcpip.NICID {
- var nicIDs []tcpip.NICID
-
- for k := range proto.mu.eps {
- nicIDs = append(nicIDs, k)
- }
-
- return nicIDs
-}
-
-func TestClearEndpointFromProtocolOnClose(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
- var nic testInterface
- ep := proto.NewEndpoint(&nic, nil).(*endpoint)
- var nicIDs []tcpip.NICID
-
- proto.mu.Lock()
- foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()]
- nicIDs = knownNICIDs(proto)
- proto.mu.Unlock()
- if !hasEndpointBeforeClose {
- t.Fatalf("expected to find the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs)
- }
- if foundEP != ep {
- t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID())
- }
-
- ep.Close()
-
- proto.mu.Lock()
- _, hasEndpointAfterClose := proto.mu.eps[nic.ID()]
- nicIDs = knownNICIDs(proto)
- proto.mu.Unlock()
- if hasEndpointAfterClose {
- t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs)
- }
-}
-
-type fragmentInfo struct {
- offset uint16
- more bool
- payloadSize uint16
-}
-
-var fragmentationTests = []struct {
- description string
- mtu uint32
- transHdrLen int
- payloadSize int
- wantFragments []fragmentInfo
-}{
- {
- description: "No fragmentation",
- mtu: header.IPv6MinimumMTU,
- transHdrLen: 0,
- payloadSize: 1000,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 1000, more: false},
- },
- },
- {
- description: "Fragmented",
- mtu: header.IPv6MinimumMTU,
- transHdrLen: 0,
- payloadSize: 2000,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 1240, more: true},
- {offset: 154, payloadSize: 776, more: false},
- },
- },
- {
- description: "Fragmented with mtu not a multiple of 8",
- mtu: header.IPv6MinimumMTU + 1,
- transHdrLen: 0,
- payloadSize: 2000,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 1240, more: true},
- {offset: 154, payloadSize: 776, more: false},
- },
- },
- {
- description: "No fragmentation with big header",
- mtu: 2000,
- transHdrLen: 100,
- payloadSize: 1000,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 1100, more: false},
- },
- },
- {
- description: "Fragmented with big header",
- mtu: header.IPv6MinimumMTU,
- transHdrLen: 100,
- payloadSize: 1200,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 1240, more: true},
- {offset: 154, payloadSize: 76, more: false},
- },
- },
-}
-
-func TestFragmentationWritePacket(t *testing.T) {
- const ttl = 42
-
- for _, ft := range fragmentationTests {
- t.Run(ft.description, func(t *testing.T) {
- pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
- source := pkt.Clone()
- ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
- r := buildRoute(t, ep)
- err := r.WritePacket(stack.NetworkHeaderParams{
- Protocol: tcp.ProtocolNumber,
- TTL: ttl,
- TOS: stack.DefaultTOS,
- }, pkt)
- if err != nil {
- t.Fatalf("WritePacket(_, _, _): = %s", err)
- }
- if got := len(ep.WrittenPackets); got != len(ft.wantFragments) {
- t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments))
- }
- if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) {
- t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments))
- }
- if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
- t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
- }
- if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
- t.Error(err)
- }
- })
- }
-}
-
-func TestFragmentationWritePackets(t *testing.T) {
- const ttl = 42
- tests := []struct {
- description string
- insertBefore int
- insertAfter int
- }{
- {
- description: "Single packet",
- insertBefore: 0,
- insertAfter: 0,
- },
- {
- description: "With packet before",
- insertBefore: 1,
- insertAfter: 0,
- },
- {
- description: "With packet after",
- insertBefore: 0,
- insertAfter: 1,
- },
- {
- description: "With packet before and after",
- insertBefore: 1,
- insertAfter: 1,
- },
- }
- tinyPacket := iptestutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber)
-
- for _, test := range tests {
- t.Run(test.description, func(t *testing.T) {
- for _, ft := range fragmentationTests {
- t.Run(ft.description, func(t *testing.T) {
- var pkts stack.PacketBufferList
- for i := 0; i < test.insertBefore; i++ {
- pkts.PushBack(tinyPacket.Clone())
- }
- pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
- source := pkt
- pkts.PushBack(pkt.Clone())
- for i := 0; i < test.insertAfter; i++ {
- pkts.PushBack(tinyPacket.Clone())
- }
-
- ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
- r := buildRoute(t, ep)
-
- wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter
- n, err := r.WritePackets(pkts, stack.NetworkHeaderParams{
- Protocol: tcp.ProtocolNumber,
- TTL: ttl,
- TOS: stack.DefaultTOS,
- })
- if n != wantTotalPackets || err != nil {
- t.Errorf("got WritePackets(_, _, _) = (%d, %s), want = (%d, nil)", n, err, wantTotalPackets)
- }
- if got := len(ep.WrittenPackets); got != wantTotalPackets {
- t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets)
- }
- if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets {
- t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets)
- }
- if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
- t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
- }
-
- if wantTotalPackets == 0 {
- return
- }
-
- fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore]
- if err := compareFragments(fragments, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
- t.Error(err)
- }
- })
- }
- })
- }
-}
-
-// TestFragmentationErrors checks that errors are returned from WritePacket
-// correctly.
-func TestFragmentationErrors(t *testing.T) {
- const ttl = 42
-
- tests := []struct {
- description string
- mtu uint32
- transHdrLen int
- payloadSize int
- allowPackets int
- outgoingErrors int
- mockError tcpip.Error
- wantError tcpip.Error
- }{
- {
- description: "No frag",
- mtu: 2000,
- payloadSize: 1000,
- transHdrLen: 0,
- allowPackets: 0,
- outgoingErrors: 1,
- mockError: &tcpip.ErrAborted{},
- wantError: &tcpip.ErrAborted{},
- },
- {
- description: "Error on first frag",
- mtu: 1300,
- payloadSize: 3000,
- transHdrLen: 0,
- allowPackets: 0,
- outgoingErrors: 3,
- mockError: &tcpip.ErrAborted{},
- wantError: &tcpip.ErrAborted{},
- },
- {
- description: "Error on second frag",
- mtu: 1500,
- payloadSize: 4000,
- transHdrLen: 0,
- allowPackets: 1,
- outgoingErrors: 2,
- mockError: &tcpip.ErrAborted{},
- wantError: &tcpip.ErrAborted{},
- },
- {
- description: "Error when MTU is smaller than transport header",
- mtu: header.IPv6MinimumMTU,
- transHdrLen: 1500,
- payloadSize: 500,
- allowPackets: 0,
- outgoingErrors: 1,
- mockError: nil,
- wantError: &tcpip.ErrMessageTooLong{},
- },
- {
- description: "Error when MTU is smaller than IPv6 minimum MTU",
- mtu: header.IPv6MinimumMTU - 1,
- transHdrLen: 0,
- payloadSize: 500,
- allowPackets: 0,
- outgoingErrors: 1,
- mockError: nil,
- wantError: &tcpip.ErrInvalidEndpointState{},
- },
- }
-
- for _, ft := range tests {
- t.Run(ft.description, func(t *testing.T) {
- pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
- ep := iptestutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
- r := buildRoute(t, ep)
- err := r.WritePacket(stack.NetworkHeaderParams{
- Protocol: tcp.ProtocolNumber,
- TTL: ttl,
- TOS: stack.DefaultTOS,
- }, pkt)
- if diff := cmp.Diff(ft.wantError, err); diff != "" {
- t.Errorf("unexpected error from WritePacket(_, _, _), (-want, +got):\n%s", diff)
- }
- if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets {
- t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets)
- }
- if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors {
- t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors)
- }
- })
- }
-}
-
-func TestForwarding(t *testing.T) {
- const (
- incomingNICID = 1
- outgoingNICID = 2
- randomSequence = 123
- randomIdent = 42
- )
-
- incomingIPv6Addr := tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("10::1").To16()),
- PrefixLen: 64,
- }
- outgoingIPv6Addr := tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("11::1").To16()),
- PrefixLen: 64,
- }
- multicastIPv6Addr := tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("ff00::").To16()),
- PrefixLen: 64,
- }
-
- remoteIPv6Addr1 := tcpip.Address(net.ParseIP("10::2").To16())
- remoteIPv6Addr2 := tcpip.Address(net.ParseIP("11::2").To16())
- unreachableIPv6Addr := tcpip.Address(net.ParseIP("12::2").To16())
- linkLocalIPv6Addr := tcpip.Address(net.ParseIP("fe80::").To16())
-
- tests := []struct {
- name string
- extHdr func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker)
- TTL uint8
- expectErrorICMP bool
- expectPacketForwarded bool
- payloadLength int
- countUnrouteablePackets uint64
- sourceAddr tcpip.Address
- destAddr tcpip.Address
- icmpType header.ICMPv6Type
- icmpCode header.ICMPv6Code
- expectPacketUnrouteableError bool
- expectLinkLocalSourceError bool
- expectLinkLocalDestError bool
- expectExtensionHeaderError bool
- }{
- {
- name: "TTL of zero",
- TTL: 0,
- expectErrorICMP: true,
- sourceAddr: remoteIPv6Addr1,
- destAddr: remoteIPv6Addr2,
- icmpType: header.ICMPv6TimeExceeded,
- icmpCode: header.ICMPv6HopLimitExceeded,
- },
- {
- name: "TTL of one",
- TTL: 1,
- expectErrorICMP: true,
- sourceAddr: remoteIPv6Addr1,
- destAddr: remoteIPv6Addr2,
- icmpType: header.ICMPv6TimeExceeded,
- icmpCode: header.ICMPv6HopLimitExceeded,
- },
- {
- name: "TTL of two",
- TTL: 2,
- expectPacketForwarded: true,
- sourceAddr: remoteIPv6Addr1,
- destAddr: remoteIPv6Addr2,
- },
- {
- name: "TTL of three",
- TTL: 3,
- expectPacketForwarded: true,
- sourceAddr: remoteIPv6Addr1,
- destAddr: remoteIPv6Addr2,
- },
- {
- name: "Max TTL",
- TTL: math.MaxUint8,
- expectPacketForwarded: true,
- sourceAddr: remoteIPv6Addr1,
- destAddr: remoteIPv6Addr2,
- },
- {
- name: "Network unreachable",
- TTL: 2,
- expectErrorICMP: true,
- sourceAddr: remoteIPv6Addr1,
- destAddr: unreachableIPv6Addr,
- icmpType: header.ICMPv6DstUnreachable,
- icmpCode: header.ICMPv6NetworkUnreachable,
- expectPacketUnrouteableError: true,
- },
- {
- name: "Multicast destination",
- TTL: 2,
- countUnrouteablePackets: 1,
- sourceAddr: remoteIPv6Addr1,
- destAddr: multicastIPv6Addr.Address,
- expectPacketForwarded: true,
- },
- {
- name: "Link local destination",
- TTL: 2,
- sourceAddr: remoteIPv6Addr1,
- destAddr: linkLocalIPv6Addr,
- expectLinkLocalDestError: true,
- },
- {
- name: "Link local source",
- TTL: 2,
- sourceAddr: linkLocalIPv6Addr,
- destAddr: remoteIPv6Addr2,
- expectLinkLocalSourceError: true,
- },
- {
- name: "Hopbyhop with unknown option skippable action",
- TTL: 2,
- sourceAddr: remoteIPv6Addr1,
- destAddr: remoteIPv6Addr2,
- extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Skippable unknown.
- 62, 6, 1, 2, 3, 4, 5, 6,
- }, hopByHopExtHdrID, checker.IPv6ExtHdr(checker.IPv6HopByHopExtensionHeader(checker.IPv6UnknownOption(), checker.IPv6UnknownOption()))
- },
- expectPacketForwarded: true,
- },
- {
- name: "Hopbyhop with unknown option discard action",
- TTL: 2,
- sourceAddr: remoteIPv6Addr1,
- destAddr: remoteIPv6Addr2,
- extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard unknown.
- 127, 6, 1, 2, 3, 4, 5, 6,
- }, hopByHopExtHdrID, nil
- },
- expectExtensionHeaderError: true,
- },
- {
- name: "Hopbyhop with unknown option discard and send icmp action (unicast)",
- TTL: 2,
- sourceAddr: remoteIPv6Addr1,
- destAddr: remoteIPv6Addr2,
- extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard & send ICMP if option is unknown.
- 191, 6, 1, 2, 3, 4, 5, 6,
- }, hopByHopExtHdrID, nil
- },
- expectErrorICMP: true,
- icmpType: header.ICMPv6ParamProblem,
- icmpCode: header.ICMPv6UnknownOption,
- expectExtensionHeaderError: true,
- },
- {
- name: "Hopbyhop with unknown option discard and send icmp action (multicast)",
- TTL: 2,
- sourceAddr: remoteIPv6Addr1,
- destAddr: multicastIPv6Addr.Address,
- extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard & send ICMP if option is unknown.
- 191, 6, 1, 2, 3, 4, 5, 6,
- }, hopByHopExtHdrID, nil
- },
- expectErrorICMP: true,
- icmpType: header.ICMPv6ParamProblem,
- icmpCode: header.ICMPv6UnknownOption,
- expectExtensionHeaderError: true,
- },
- {
- name: "Hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)",
- TTL: 2,
- sourceAddr: remoteIPv6Addr1,
- destAddr: remoteIPv6Addr2,
- extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard & send ICMP unless packet is for multicast destination if
- // option is unknown.
- 255, 6, 1, 2, 3, 4, 5, 6,
- }, hopByHopExtHdrID, nil
- },
- expectErrorICMP: true,
- icmpType: header.ICMPv6ParamProblem,
- icmpCode: header.ICMPv6UnknownOption,
- expectExtensionHeaderError: true,
- },
- {
- name: "Hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)",
- TTL: 2,
- sourceAddr: remoteIPv6Addr1,
- destAddr: multicastIPv6Addr.Address,
- extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
- return []byte{
- nextHdr, 1,
-
- // Skippable unknown.
- 63, 4, 1, 2, 3, 4,
-
- // Discard & send ICMP unless packet is for multicast destination if
- // option is unknown.
- 255, 6, 1, 2, 3, 4, 5, 6,
- }, hopByHopExtHdrID, nil
- },
- expectExtensionHeaderError: true,
- },
- {
- name: "Hopbyhop with router alert option",
- TTL: 2,
- sourceAddr: remoteIPv6Addr1,
- destAddr: remoteIPv6Addr2,
- extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
- return []byte{
- nextHdr, 0,
-
- // Router Alert option.
- 5, 2, 0, 0, 0, 0,
- }, hopByHopExtHdrID, checker.IPv6ExtHdr(checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)))
- },
- expectPacketForwarded: true,
- },
- {
- name: "Hopbyhop with two router alert options",
- TTL: 2,
- sourceAddr: remoteIPv6Addr1,
- destAddr: remoteIPv6Addr2,
- extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) {
- return []byte{
- nextHdr, 1,
-
- // Router Alert option.
- 5, 2, 0, 0, 0, 0,
-
- // Router Alert option.
- 5, 2, 0, 0, 0, 0,
- }, hopByHopExtHdrID, nil
- },
- expectExtensionHeaderError: true,
- },
- {
- name: "Can't fragment",
- TTL: 2,
- payloadLength: header.IPv6MinimumMTU + 1,
- expectErrorICMP: true,
- sourceAddr: remoteIPv6Addr1,
- destAddr: remoteIPv6Addr2,
- icmpType: header.ICMPv6PacketTooBig,
- icmpCode: header.ICMPv6UnusedCode,
- },
- {
- name: "Can't fragment multicast",
- TTL: 2,
- payloadLength: header.IPv6MinimumMTU + 1,
- sourceAddr: remoteIPv6Addr1,
- destAddr: multicastIPv6Addr.Address,
- expectErrorICMP: true,
- icmpType: header.ICMPv6PacketTooBig,
- icmpCode: header.ICMPv6UnusedCode,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
- })
- // We expect at most a single packet in response to our ICMP Echo Request.
- incomingEndpoint := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err)
- }
- incomingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: incomingIPv6Addr}
- if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingIPv6ProtoAddr, err)
- }
-
- outgoingEndpoint := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err)
- }
- outgoingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: outgoingIPv6Addr}
- if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingIPv6ProtoAddr, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: incomingIPv6Addr.Subnet(),
- NIC: incomingNICID,
- },
- {
- Destination: outgoingIPv6Addr.Subnet(),
- NIC: outgoingNICID,
- },
- {
- Destination: multicastIPv6Addr.Subnet(),
- NIC: outgoingNICID,
- },
- })
-
- if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil {
- t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err)
- }
-
- transportProtocol := header.ICMPv6ProtocolNumber
- var extHdrBytes []byte
- extHdrChecker := checker.IPv6ExtHdr()
- if test.extHdr != nil {
- nextHdrID := hopByHopExtHdrID
- extHdrBytes, nextHdrID, extHdrChecker = test.extHdr(uint8(header.ICMPv6ProtocolNumber))
- transportProtocol = tcpip.TransportProtocolNumber(nextHdrID)
- }
- extHdrLen := len(extHdrBytes)
-
- ipHeaderLength := header.IPv6MinimumSize
- icmpHeaderLength := header.ICMPv6MinimumSize
- totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength + extHdrLen
- hdr := buffer.NewPrependable(totalLength)
- hdr.Prepend(test.payloadLength)
- icmpH := header.ICMPv6(hdr.Prepend(icmpHeaderLength))
-
- icmpH.SetIdent(randomIdent)
- icmpH.SetSequence(randomSequence)
- icmpH.SetType(header.ICMPv6EchoRequest)
- icmpH.SetCode(header.ICMPv6UnusedCode)
- icmpH.SetChecksum(0)
- icmpH.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmpH,
- Src: test.sourceAddr,
- Dst: test.destAddr,
- }))
- copy(hdr.Prepend(extHdrLen), extHdrBytes)
- ip := header.IPv6(hdr.Prepend(ipHeaderLength))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(header.ICMPv6MinimumSize + test.payloadLength),
- TransportProtocol: transportProtocol,
- HopLimit: test.TTL,
- SrcAddr: test.sourceAddr,
- DstAddr: test.destAddr,
- })
- requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- })
- incomingEndpoint.InjectInbound(ProtocolNumber, requestPkt)
-
- reply, ok := incomingEndpoint.Read()
-
- if test.expectErrorICMP {
- if !ok {
- t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType)
- }
-
- // As per RFC 4443, page 9:
- //
- // The returned ICMP packet will contain as much of invoking packet
- // as possible without the ICMPv6 packet exceeding the minimum IPv6
- // MTU.
- expectedICMPPayloadLength := func() int {
- maxICMPPayloadLength := header.IPv6MinimumMTU - ipHeaderLength - icmpHeaderLength
- if len(hdr.View()) > maxICMPPayloadLength {
- return maxICMPPayloadLength
- }
- return len(hdr.View())
- }
-
- checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
- checker.SrcAddr(incomingIPv6Addr.Address),
- checker.DstAddr(test.sourceAddr),
- checker.TTL(DefaultTTL),
- checker.ICMPv6(
- checker.ICMPv6Type(test.icmpType),
- checker.ICMPv6Code(test.icmpCode),
- checker.ICMPv6Payload(hdr.View()[:expectedICMPPayloadLength()]),
- ),
- )
-
- if n := outgoingEndpoint.Drain(); n != 0 {
- t.Fatalf("got e2.Drain() = %d, want = 0", n)
- }
- } else if ok {
- t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", reply)
- }
-
- reply, ok = outgoingEndpoint.Read()
- if test.expectPacketForwarded {
- if !ok {
- t.Fatal("expected ICMP Echo Request packet through outgoing NIC")
- }
-
- checker.IPv6WithExtHdr(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
- checker.SrcAddr(test.sourceAddr),
- checker.DstAddr(test.destAddr),
- checker.TTL(test.TTL-1),
- extHdrChecker,
- checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6EchoRequest),
- checker.ICMPv6Code(header.ICMPv6UnusedCode),
- checker.ICMPv6Payload(nil),
- ),
- )
-
- if n := incomingEndpoint.Drain(); n != 0 {
- t.Fatalf("got e1.Drain() = %d, want = 0", n)
- }
- } else if ok {
- t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply)
- }
-
- boolToInt := func(val bool) uint64 {
- if val {
- return 1
- }
- return 0
- }
-
- if got, want := s.Stats().IP.Forwarding.LinkLocalSource.Value(), boolToInt(test.expectLinkLocalSourceError); got != want {
- t.Errorf("got s.Stats().IP.Forwarding.LinkLocalSource.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.Stats().IP.Forwarding.LinkLocalDestination.Value(), boolToInt(test.expectLinkLocalDestError); got != want {
- t.Errorf("got s.Stats().IP.Forwarding.LinkLocalDestination.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.Stats().IP.Forwarding.ExhaustedTTL.Value(), boolToInt(test.TTL <= 1); got != want {
- t.Errorf("got rt.Stats().IP.Forwarding.ExhaustedTTL.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.Stats().IP.Forwarding.Unrouteable.Value(), boolToInt(test.expectPacketUnrouteableError); got != want {
- t.Errorf("got s.Stats().IP.Forwarding.Unrouteable.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.Stats().IP.Forwarding.Errors.Value(), boolToInt(!test.expectPacketForwarded); got != want {
- t.Errorf("got s.Stats().IP.Forwarding.Errors.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.Stats().IP.Forwarding.ExtensionHeaderProblem.Value(), boolToInt(test.expectExtensionHeaderError); got != want {
- t.Errorf("got s.Stats().IP.Forwarding.ExtensionHeaderProblem.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.Stats().IP.Forwarding.PacketTooBig.Value(), boolToInt(test.icmpType == header.ICMPv6PacketTooBig); got != want {
- t.Errorf("got s.Stats().IP.Forwarding.PacketTooBig.Value() = %d, want = %d", got, want)
- }
- })
- }
-}
-
-func TestMultiCounterStatsInitialization(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
- var nic testInterface
- ep := proto.NewEndpoint(&nic, nil).(*endpoint)
- // At this point, the Stack's stats and the NetworkEndpoint's stats are
- // supposed to be bound.
- refStack := s.Stats()
- refEP := ep.stats.localStats
- if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.ip).Elem(), []reflect.Value{reflect.ValueOf(&refStack.IP).Elem(), reflect.ValueOf(&refEP.IP).Elem()}); err != nil {
- t.Error(err)
- }
- if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.icmp).Elem(), []reflect.Value{reflect.ValueOf(&refStack.ICMP.V6).Elem(), reflect.ValueOf(&refEP.ICMP).Elem()}); err != nil {
- t.Error(err)
- }
-}
diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go
deleted file mode 100644
index 3e5c438d3..000000000
--- a/pkg/tcpip/network/ipv6/mld_test.go
+++ /dev/null
@@ -1,620 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ipv6_test
-
-import (
- "bytes"
- "math/rand"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
-)
-
-var (
- linkLocalAddr = testutil.MustParse6("fe80::1")
- globalAddr = testutil.MustParse6("a80::1")
- globalMulticastAddr = testutil.MustParse6("ff05:100::2")
-
- linkLocalAddrSNMC = header.SolicitedNodeAddr(linkLocalAddr)
- globalAddrSNMC = header.SolicitedNodeAddr(globalAddr)
-)
-
-func validateMLDPacket(t *testing.T, p buffer.View, localAddress, remoteAddress tcpip.Address, mldType header.ICMPv6Type, groupAddress tcpip.Address) {
- t.Helper()
-
- checker.IPv6WithExtHdr(t, p,
- checker.IPv6ExtHdr(
- checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)),
- ),
- checker.SrcAddr(localAddress),
- checker.DstAddr(remoteAddress),
- checker.TTL(header.MLDHopLimit),
- checker.MLD(mldType, header.MLDMinimumSize,
- checker.MLDMaxRespDelay(0),
- checker.MLDMulticastAddress(groupAddress),
- ),
- )
-}
-
-func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) {
- const nicID = 1
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- MLD: ipv6.MLDOptions{
- Enabled: true,
- },
- })},
- })
- e := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
-
- // The stack will join an address's solicited node multicast address when
- // an address is added. An MLD report message should be sent for the
- // solicited-node group.
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: linkLocalAddr.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a report message to be sent")
- } else {
- validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC)
- }
-
- // The stack will leave an address's solicited node multicast address when
- // an address is removed. An MLD done message should be sent for the
- // solicited-node group.
- if err := s.RemoveAddress(nicID, linkLocalAddr); err != nil {
- t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, linkLocalAddr, err)
- }
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a done message to be sent")
- } else {
- validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersLinkLocalMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC)
- }
-}
-
-func TestSendQueuedMLDReports(t *testing.T) {
- const (
- nicID = 1
- maxReports = 2
- )
-
- tests := []struct {
- name string
- dadTransmits uint8
- retransmitTimer time.Duration
- }{
- {
- name: "DAD Disabled",
- dadTransmits: 0,
- retransmitTimer: 0,
- },
- {
- name: "DAD Enabled",
- dadTransmits: 1,
- retransmitTimer: time.Second,
- },
- }
-
- nonce := [...]byte{
- 1, 2, 3, 4, 5, 6,
- }
-
- const maxNSMessages = 2
- secureRNGBytes := make([]byte, len(nonce)*maxNSMessages)
- for b := secureRNGBytes[:]; len(b) > 0; b = b[len(nonce):] {
- if n := copy(b, nonce[:]); n != len(nonce) {
- t.Fatalf("got copy(...) = %d, want = %d", n, len(nonce))
- }
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- dadResolutionTime := test.retransmitTimer * time.Duration(test.dadTransmits)
- clock := faketime.NewManualClock()
- var secureRNG bytes.Reader
- secureRNG.Reset(secureRNGBytes[:])
- s := stack.New(stack.Options{
- SecureRNG: &secureRNG,
- RandSource: rand.NewSource(time.Now().UnixNano()),
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- DADConfigs: stack.DADConfigurations{
- DupAddrDetectTransmits: test.dadTransmits,
- RetransmitTimer: test.retransmitTimer,
- },
- MLD: ipv6.MLDOptions{
- Enabled: true,
- },
- })},
- Clock: clock,
- })
-
- // Allow space for an extra packet so we can observe packets that were
- // unexpectedly sent.
- e := channel.New(maxReports+int(test.dadTransmits)+1 /* extra */, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
-
- resolveDAD := func(addr, snmc tcpip.Address) {
- clock.Advance(dadResolutionTime)
- if p, ok := e.Read(); !ok {
- t.Fatal("expected DAD packet")
- } else {
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(snmc),
- checker.TTL(header.NDPHopLimit),
- checker.NDPNS(
- checker.NDPNSTargetAddress(addr),
- checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonce[:])}),
- ))
- }
- }
-
- var reportCounter uint64
- reportStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
- if got := reportStat.Value(); got != reportCounter {
- t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
- }
- var doneCounter uint64
- doneStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
- if got := doneStat.Value(); got != doneCounter {
- t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter)
- }
-
- // Joining a group without an assigned address should send an MLD report
- // with the unspecified address.
- if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, globalMulticastAddr); err != nil {
- t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalMulticastAddr, err)
- }
- reportCounter++
- if got := reportStat.Value(); got != reportCounter {
- t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Errorf("expected MLD report for %s", globalMulticastAddr)
- } else {
- validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalMulticastAddr, header.ICMPv6MulticastListenerReport, globalMulticastAddr)
- }
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Errorf("got unexpected packet = %#v", p)
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Adding a global address should not send reports for the already joined
- // group since we should only send queued reports when a link-local
- // address is assigned.
- //
- // Note, we will still expect to send a report for the global address's
- // solicited node address from the unspecified address as per RFC 3590
- // section 4.
- properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
- globalProtocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: globalAddr.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, globalProtocolAddr, properties); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, globalProtocolAddr, properties, err)
- }
- reportCounter++
- if got := reportStat.Value(); got != reportCounter {
- t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Errorf("expected MLD report for %s", globalAddrSNMC)
- } else {
- validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalAddrSNMC, header.ICMPv6MulticastListenerReport, globalAddrSNMC)
- }
- if dadResolutionTime != 0 {
- // Reports should not be sent when the address resolves.
- resolveDAD(globalAddr, globalAddrSNMC)
- if got := reportStat.Value(); got != reportCounter {
- t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
- }
- }
- // Leave the group since we don't care about the global address's
- // solicited node multicast group membership.
- if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, globalAddrSNMC); err != nil {
- t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalAddrSNMC, err)
- }
- if got := doneStat.Value(); got != doneCounter {
- t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter)
- }
- if p, ok := e.Read(); ok {
- t.Errorf("got unexpected packet = %#v", p)
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Adding a link-local address should send a report for its solicited node
- // address and globalMulticastAddr.
- linkLocalProtocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: linkLocalAddr.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, linkLocalProtocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, linkLocalProtocolAddr, err)
- }
- if dadResolutionTime != 0 {
- reportCounter++
- if got := reportStat.Value(); got != reportCounter {
- t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Errorf("expected MLD report for %s", linkLocalAddrSNMC)
- } else {
- validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC)
- }
- resolveDAD(linkLocalAddr, linkLocalAddrSNMC)
- }
-
- // We expect two batches of reports to be sent (1 batch when the
- // link-local address is assigned, and another after the maximum
- // unsolicited report interval.
- for i := 0; i < 2; i++ {
- // We expect reports to be sent (one for globalMulticastAddr and another
- // for linkLocalAddrSNMC).
- reportCounter += maxReports
- if got := reportStat.Value(); got != reportCounter {
- t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter)
- }
-
- addrs := map[tcpip.Address]bool{
- globalMulticastAddr: false,
- linkLocalAddrSNMC: false,
- }
- for range addrs {
- p, ok := e.Read()
- if !ok {
- t.Fatalf("expected MLD report for %s and %s; addrs = %#v", globalMulticastAddr, linkLocalAddrSNMC, addrs)
- }
-
- addr := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())).DestinationAddress()
- if seen, ok := addrs[addr]; !ok {
- t.Fatalf("got unexpected packet destined to %s", addr)
- } else if seen {
- t.Fatalf("got another packet destined to %s", addr)
- }
-
- addrs[addr] = true
- validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, addr, header.ICMPv6MulticastListenerReport, addr)
-
- clock.Advance(ipv6.UnsolicitedReportIntervalMax)
- }
- }
-
- // Should not send any more reports.
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Errorf("got unexpected packet = %#v", p)
- }
- })
- }
-}
-
-// createAndInjectMLDPacket creates and injects an MLD packet with the
-// specified fields.
-func createAndInjectMLDPacket(e *channel.Endpoint, mldType header.ICMPv6Type, hopLimit uint8, srcAddress tcpip.Address, withRouterAlertOption bool, routerAlertValue header.IPv6RouterAlertValue) {
- var extensionHeaders header.IPv6ExtHdrSerializer
- if withRouterAlertOption {
- extensionHeaders = header.IPv6ExtHdrSerializer{
- header.IPv6SerializableHopByHopExtHdr{
- &header.IPv6RouterAlertOption{Value: routerAlertValue},
- },
- }
- }
-
- extensionHeadersLength := extensionHeaders.Length()
- payloadLength := extensionHeadersLength + header.ICMPv6HeaderSize + header.MLDMinimumSize
- buf := buffer.NewView(header.IPv6MinimumSize + payloadLength)
-
- ip := header.IPv6(buf)
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- HopLimit: hopLimit,
- TransportProtocol: header.ICMPv6ProtocolNumber,
- SrcAddr: srcAddress,
- DstAddr: header.IPv6AllNodesMulticastAddress,
- ExtensionHeaders: extensionHeaders,
- })
-
- icmp := header.ICMPv6(ip.Payload()[extensionHeadersLength:])
- icmp.SetType(mldType)
- mld := header.MLD(icmp.MessageBody())
- mld.SetMaximumResponseDelay(0)
- mld.SetMulticastAddress(header.IPv6Any)
- icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmp,
- Src: srcAddress,
- Dst: header.IPv6AllNodesMulticastAddress,
- }))
-
- e.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-}
-
-func TestMLDPacketValidation(t *testing.T) {
- const nicID = 1
- linkLocalAddr2 := testutil.MustParse6("fe80::2")
-
- tests := []struct {
- name string
- messageType header.ICMPv6Type
- srcAddr tcpip.Address
- includeRouterAlertOption bool
- routerAlertValue header.IPv6RouterAlertValue
- hopLimit uint8
- expectValidMLD bool
- getMessageTypeStatValue func(tcpip.Stats) uint64
- }{
- {
- name: "valid",
- messageType: header.ICMPv6MulticastListenerQuery,
- includeRouterAlertOption: true,
- routerAlertValue: header.IPv6RouterAlertMLD,
- srcAddr: linkLocalAddr2,
- hopLimit: header.MLDHopLimit,
- expectValidMLD: true,
- getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.ICMP.V6.PacketsReceived.MulticastListenerQuery.Value() },
- },
- {
- name: "bad hop limit",
- messageType: header.ICMPv6MulticastListenerReport,
- includeRouterAlertOption: true,
- routerAlertValue: header.IPv6RouterAlertMLD,
- srcAddr: linkLocalAddr2,
- hopLimit: header.MLDHopLimit + 1,
- expectValidMLD: false,
- getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.ICMP.V6.PacketsReceived.MulticastListenerReport.Value() },
- },
- {
- name: "src ip not link local",
- messageType: header.ICMPv6MulticastListenerReport,
- includeRouterAlertOption: true,
- routerAlertValue: header.IPv6RouterAlertMLD,
- srcAddr: globalAddr,
- hopLimit: header.MLDHopLimit,
- expectValidMLD: false,
- getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.ICMP.V6.PacketsReceived.MulticastListenerReport.Value() },
- },
- {
- name: "missing router alert ip option",
- messageType: header.ICMPv6MulticastListenerDone,
- includeRouterAlertOption: false,
- srcAddr: linkLocalAddr2,
- hopLimit: header.MLDHopLimit,
- expectValidMLD: false,
- getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.ICMP.V6.PacketsReceived.MulticastListenerDone.Value() },
- },
- {
- name: "incorrect router alert value",
- messageType: header.ICMPv6MulticastListenerDone,
- includeRouterAlertOption: true,
- routerAlertValue: header.IPv6RouterAlertRSVP,
- srcAddr: linkLocalAddr2,
- hopLimit: header.MLDHopLimit,
- expectValidMLD: false,
- getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.ICMP.V6.PacketsReceived.MulticastListenerDone.Value() },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- MLD: ipv6.MLDOptions{
- Enabled: true,
- },
- })},
- })
- e := channel.New(nicID, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
- stats := s.Stats()
- // Verify that every relevant stats is zero'd before we send a packet.
- if got := test.getMessageTypeStatValue(s.Stats()); got != 0 {
- t.Errorf("got test.getMessageTypeStatValue(s.Stats()) = %d, want = 0", got)
- }
- if got := stats.ICMP.V6.PacketsReceived.Invalid.Value(); got != 0 {
- t.Errorf("got stats.ICMP.V6.PacketsReceived.Invalid.Value() = %d, want = 0", got)
- }
- if got := stats.IP.PacketsDelivered.Value(); got != 0 {
- t.Fatalf("got stats.IP.PacketsDelivered.Value() = %d, want = 0", got)
- }
- createAndInjectMLDPacket(e, test.messageType, test.hopLimit, test.srcAddr, test.includeRouterAlertOption, test.routerAlertValue)
- // We always expect the packet to pass IP validation.
- if got := stats.IP.PacketsDelivered.Value(); got != 1 {
- t.Fatalf("got stats.IP.PacketsDelivered.Value() = %d, want = 1", got)
- }
- // Even when the MLD-specific validation checks fail, we expect the
- // corresponding MLD counter to be incremented.
- if got := test.getMessageTypeStatValue(s.Stats()); got != 1 {
- t.Errorf("got test.getMessageTypeStatValue(s.Stats()) = %d, want = 1", got)
- }
- var expectedInvalidCount uint64
- if !test.expectValidMLD {
- expectedInvalidCount = 1
- }
- if got := stats.ICMP.V6.PacketsReceived.Invalid.Value(); got != expectedInvalidCount {
- t.Errorf("got stats.ICMP.V6.PacketsReceived.Invalid.Value() = %d, want = %d", got, expectedInvalidCount)
- }
- })
- }
-}
-
-func TestMLDSkipProtocol(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- group tcpip.Address
- expectReport bool
- }{
- {
- name: "Reserverd0",
- group: "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: false,
- },
- {
- name: "Interface Local",
- group: "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: false,
- },
- {
- name: "Link Local",
- group: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- {
- name: "Realm Local",
- group: "\xff\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- {
- name: "Admin Local",
- group: "\xff\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- {
- name: "Site Local",
- group: "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- {
- name: "Unassigned(6)",
- group: "\xff\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- {
- name: "Unassigned(7)",
- group: "\xff\x07\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- {
- name: "Organization Local",
- group: "\xff\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- {
- name: "Unassigned(9)",
- group: "\xff\x09\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- {
- name: "Unassigned(A)",
- group: "\xff\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- {
- name: "Unassigned(B)",
- group: "\xff\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- {
- name: "Unassigned(C)",
- group: "\xff\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- {
- name: "Unassigned(D)",
- group: "\xff\x0d\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- {
- name: "Global",
- group: "\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- {
- name: "ReservedF",
- group: "\xff\x0f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
- expectReport: true,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- MLD: ipv6.MLDOptions{
- Enabled: true,
- },
- })},
- })
- e := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: linkLocalAddr.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a report message to be sent")
- } else {
- validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC)
- }
-
- if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, test.group); err != nil {
- t.Fatalf("s.JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, test.group, err)
- }
- if isInGroup, err := s.IsInGroup(nicID, test.group); err != nil {
- t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.group, err)
- } else if !isInGroup {
- t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.group)
- }
-
- if !test.expectReport {
- if p, ok := e.Read(); ok {
- t.Fatalf("got e.Read() = (%#v, true), want = (_, false)", p)
- }
-
- return
- }
-
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a report message to be sent")
- } else {
- validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, test.group, header.ICMPv6MulticastListenerReport, test.group)
- }
- })
- }
-}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
deleted file mode 100644
index 8297a7e10..000000000
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ /dev/null
@@ -1,1365 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ipv6
-
-import (
- "bytes"
- "math/rand"
- "strings"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
-)
-
-var _ NDPDispatcher = (*testNDPDispatcher)(nil)
-
-// testNDPDispatcher is an NDPDispatcher only allows default router discovery.
-type testNDPDispatcher struct {
- addr tcpip.Address
-}
-
-func (*testNDPDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) {
-}
-
-func (t *testNDPDispatcher) OnOffLinkRouteUpdated(_ tcpip.NICID, _ tcpip.Subnet, addr tcpip.Address, _ header.NDPRoutePreference) {
- t.addr = addr
-}
-
-func (t *testNDPDispatcher) OnOffLinkRouteInvalidated(_ tcpip.NICID, _ tcpip.Subnet, addr tcpip.Address) {
- t.addr = addr
-}
-
-func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) {
-}
-
-func (*testNDPDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {
-}
-
-func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) {
-}
-
-func (*testNDPDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {
-}
-
-func (*testNDPDispatcher) OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) {
-}
-
-func (*testNDPDispatcher) OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) {
-}
-
-func (*testNDPDispatcher) OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) {
-}
-
-func (*testNDPDispatcher) OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) {
-}
-
-func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) {
- var ndpDisp testNDPDispatcher
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{
- NDPDisp: &ndpDisp,
- })},
- })
-
- if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
-
- ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber)
- if err != nil {
- t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ProtocolNumber, err)
- }
-
- ipv6EP := ep.(*endpoint)
- ipv6EP.mu.Lock()
- ipv6EP.mu.ndp.handleOffLinkRouteDiscovery(offLinkRoute{dest: header.IPv6EmptySubnet, router: lladdr1}, time.Hour, header.MediumRoutePreference)
- ipv6EP.mu.Unlock()
-
- if ndpDisp.addr != lladdr1 {
- t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1)
- }
-
- ndpDisp.addr = ""
- ndpEP := ep.(stack.NDPEndpoint)
- ndpEP.InvalidateDefaultRouter(lladdr1)
- if ndpDisp.addr != lladdr1 {
- t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1)
- }
-}
-
-// TestNeighborSolicitationWithSourceLinkLayerOption tests that receiving a
-// valid NDP NS message with the Source Link Layer Address option results in a
-// new entry in the link address cache for the sender of the message.
-func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- optsBuf []byte
- expectedLinkAddr tcpip.LinkAddress
- }{
- {
- name: "Valid",
- optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7},
- expectedLinkAddr: "\x02\x03\x04\x05\x06\x07",
- },
- {
- name: "Too Small",
- optsBuf: []byte{1, 1, 2, 3, 4, 5, 6},
- },
- {
- name: "Invalid Length",
- optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- e := channel.New(0, 1280, linkAddr0)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ProtocolNumber,
- AddressWithPrefix: lladdr0.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
-
- ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
- pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
- pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.MessageBody())
- ns.SetTargetAddress(lladdr0)
- opts := ns.Options()
- copy(opts, test.optsBuf)
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: lladdr1,
- Dst: lladdr0,
- }))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
-
- invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
-
- // Invalid count should initially be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
-
- e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
-
- neighbors, err := s.Neighbors(nicID, ProtocolNumber)
- if err != nil {
- t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err)
- }
-
- neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry)
- for _, n := range neighbors {
- if existing, ok := neighborByAddr[n.Addr]; ok {
- if diff := cmp.Diff(existing, n); diff != "" {
- t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, ProtocolNumber, diff)
- }
- t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry: %#v", nicID, ProtocolNumber, existing)
- }
- neighborByAddr[n.Addr] = n
- }
-
- if neigh, ok := neighborByAddr[lladdr1]; len(test.expectedLinkAddr) != 0 {
- // Invalid count should not have increased.
- if got := invalid.Value(); got != 0 {
- t.Errorf("got invalid = %d, want = 0", got)
- }
-
- if !ok {
- t.Fatalf("expected a neighbor entry for %q", lladdr1)
- }
- if neigh.LinkAddr != test.expectedLinkAddr {
- t.Errorf("got link address = %s, want = %s", neigh.LinkAddr, test.expectedLinkAddr)
- }
- if neigh.State != stack.Stale {
- t.Errorf("got NUD state = %s, want = %s", neigh.State, stack.Stale)
- }
- } else {
- // Invalid count should have increased.
- if got := invalid.Value(); got != 1 {
- t.Errorf("got invalid = %d, want = 1", got)
- }
-
- if ok {
- t.Fatalf("unexpectedly got neighbor entry: %#v", neigh)
- }
- }
- })
- }
-}
-
-func TestNeighborSolicitationResponse(t *testing.T) {
- const nicID = 1
- nicAddr := lladdr0
- remoteAddr := lladdr1
- nicAddrSNMC := header.SolicitedNodeAddr(nicAddr)
- nicLinkAddr := linkAddr0
- remoteLinkAddr0 := linkAddr1
- remoteLinkAddr1 := linkAddr2
-
- tests := []struct {
- name string
- nsOpts header.NDPOptionsSerializer
- nsSrcLinkAddr tcpip.LinkAddress
- nsSrc tcpip.Address
- nsDst tcpip.Address
- nsInvalid bool
- naDstLinkAddr tcpip.LinkAddress
- naSolicited bool
- naSrc tcpip.Address
- naDst tcpip.Address
- performsLinkResolution bool
- }{
- {
- name: "Unspecified source to solicited-node multicast destination",
- nsOpts: nil,
- nsSrcLinkAddr: remoteLinkAddr0,
- nsSrc: header.IPv6Any,
- nsDst: nicAddrSNMC,
- nsInvalid: false,
- naDstLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress),
- naSolicited: false,
- naSrc: nicAddr,
- naDst: header.IPv6AllNodesMulticastAddress,
- },
- {
- name: "Unspecified source with source ll option to multicast destination",
- nsOpts: header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
- },
- nsSrcLinkAddr: remoteLinkAddr0,
- nsSrc: header.IPv6Any,
- nsDst: nicAddrSNMC,
- nsInvalid: true,
- },
- {
- name: "Unspecified source to unicast destination",
- nsOpts: nil,
- nsSrcLinkAddr: remoteLinkAddr0,
- nsSrc: header.IPv6Any,
- nsDst: nicAddr,
- nsInvalid: true,
- },
- {
- name: "Unspecified source with source ll option to unicast destination",
- nsOpts: header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
- },
- nsSrcLinkAddr: remoteLinkAddr0,
- nsSrc: header.IPv6Any,
- nsDst: nicAddr,
- nsInvalid: true,
- },
- {
- name: "Specified source with 1 source ll to multicast destination",
- nsOpts: header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
- },
- nsSrcLinkAddr: remoteLinkAddr0,
- nsSrc: remoteAddr,
- nsDst: nicAddrSNMC,
- nsInvalid: false,
- naDstLinkAddr: remoteLinkAddr0,
- naSolicited: true,
- naSrc: nicAddr,
- naDst: remoteAddr,
- },
- {
- name: "Specified source with 1 source ll different from route to multicast destination",
- nsOpts: header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
- },
- nsSrcLinkAddr: remoteLinkAddr0,
- nsSrc: remoteAddr,
- nsDst: nicAddrSNMC,
- nsInvalid: false,
- naDstLinkAddr: remoteLinkAddr1,
- naSolicited: true,
- naSrc: nicAddr,
- naDst: remoteAddr,
- },
- {
- name: "Specified source to multicast destination",
- nsOpts: nil,
- nsSrcLinkAddr: remoteLinkAddr0,
- nsSrc: remoteAddr,
- nsDst: nicAddrSNMC,
- nsInvalid: true,
- },
- {
- name: "Specified source with 2 source ll to multicast destination",
- nsOpts: header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
- header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
- },
- nsSrcLinkAddr: remoteLinkAddr0,
- nsSrc: remoteAddr,
- nsDst: nicAddrSNMC,
- nsInvalid: true,
- },
-
- {
- name: "Specified source to unicast destination",
- nsOpts: nil,
- nsSrcLinkAddr: remoteLinkAddr0,
- nsSrc: remoteAddr,
- nsDst: nicAddr,
- nsInvalid: false,
- naDstLinkAddr: remoteLinkAddr0,
- naSolicited: true,
- naSrc: nicAddr,
- naDst: remoteAddr,
- // Since we send a unicast solicitations to a node without an entry for
- // the remote, the node needs to perform neighbor discovery to get the
- // remote's link address to send the advertisement response.
- performsLinkResolution: true,
- },
- {
- name: "Specified source with 1 source ll to unicast destination",
- nsOpts: header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
- },
- nsSrcLinkAddr: remoteLinkAddr0,
- nsSrc: remoteAddr,
- nsDst: nicAddr,
- nsInvalid: false,
- naDstLinkAddr: remoteLinkAddr0,
- naSolicited: true,
- naSrc: nicAddr,
- naDst: remoteAddr,
- },
- {
- name: "Specified source with 1 source ll different from route to unicast destination",
- nsOpts: header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
- },
- nsSrcLinkAddr: remoteLinkAddr0,
- nsSrc: remoteAddr,
- nsDst: nicAddr,
- nsInvalid: false,
- naDstLinkAddr: remoteLinkAddr1,
- naSolicited: true,
- naSrc: nicAddr,
- naDst: remoteAddr,
- },
- {
- name: "Specified source with 2 source ll to unicast destination",
- nsOpts: header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
- header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
- },
- nsSrcLinkAddr: remoteLinkAddr0,
- nsSrc: remoteAddr,
- nsDst: nicAddr,
- nsInvalid: true,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- Clock: clock,
- })
- e := channel.New(1, 1280, nicLinkAddr)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ProtocolNumber,
- AddressWithPrefix: nicAddr.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv6EmptySubnet,
- NIC: 1,
- },
- })
-
- ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length()
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
- pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
- pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.MessageBody())
- ns.SetTargetAddress(nicAddr)
- opts := ns.Options()
- opts.Serialize(test.nsOpts)
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: test.nsSrc,
- Dst: test.nsDst,
- }))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: 255,
- SrcAddr: test.nsSrc,
- DstAddr: test.nsDst,
- })
-
- invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
-
- // Invalid count should initially be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
-
- e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
-
- if test.nsInvalid {
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
-
- if p, got := e.Read(); got {
- t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt)
- }
-
- // If we expected the NS to be invalid, we have nothing else to check.
- return
- }
-
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
-
- if test.performsLinkResolution {
- clock.RunImmediatelyScheduledJobs()
- p, got := e.Read()
- if !got {
- t.Fatal("expected an NDP NS response")
- }
-
- respNSDst := header.SolicitedNodeAddr(test.nsSrc)
- var want stack.RouteInfo
- want.NetProto = ProtocolNumber
- want.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(respNSDst)
- if diff := cmp.Diff(want, p.Route, cmp.AllowUnexported(want)); diff != "" {
- t.Errorf("route info mismatch (-want +got):\n%s", diff)
- }
-
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(nicAddr),
- checker.DstAddr(respNSDst),
- checker.TTL(header.NDPHopLimit),
- checker.NDPNS(
- checker.NDPNSTargetAddress(test.nsSrc),
- checker.NDPNSOptions([]header.NDPOption{
- header.NDPSourceLinkLayerAddressOption(nicLinkAddr),
- }),
- ))
-
- ser := header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(linkAddr1),
- }
- ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + ser.Length()
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
- pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
- pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.MessageBody())
- na.SetSolicitedFlag(true)
- na.SetOverrideFlag(true)
- na.SetTargetAddress(test.nsSrc)
- na.Options().Serialize(ser)
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: test.nsSrc,
- Dst: nicAddr,
- }))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: header.NDPHopLimit,
- SrcAddr: test.nsSrc,
- DstAddr: nicAddr,
- })
- e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- }
-
- clock.RunImmediatelyScheduledJobs()
- p, got := e.Read()
- if !got {
- t.Fatal("expected an NDP NA response")
- }
-
- if p.Route.LocalAddress != test.naSrc {
- t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, test.naSrc)
- }
- if p.Route.LocalLinkAddress != nicLinkAddr {
- t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr)
- }
- if p.Route.RemoteAddress != test.naDst {
- t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst)
- }
- if p.Route.RemoteLinkAddress != test.naDstLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr)
- }
-
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(test.naSrc),
- checker.DstAddr(test.naDst),
- checker.TTL(header.NDPHopLimit),
- checker.NDPNA(
- checker.NDPNASolicitedFlag(test.naSolicited),
- checker.NDPNATargetAddress(nicAddr),
- checker.NDPNAOptions([]header.NDPOption{
- header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]),
- }),
- ))
- })
- }
-}
-
-// TestNeighborAdvertisementWithTargetLinkLayerOption tests that receiving a
-// valid NDP NA message with the Target Link Layer Address option does not
-// result in a new entry in the neighbor cache for the target of the message.
-func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- optsBuf []byte
- isValid bool
- }{
- {
- name: "Valid",
- optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7},
- isValid: true,
- },
- {
- name: "Too Small",
- optsBuf: []byte{2, 1, 2, 3, 4, 5, 6},
- },
- {
- name: "Invalid Length",
- optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7},
- },
- {
- name: "Multiple",
- optsBuf: []byte{
- 2, 1, 2, 3, 4, 5, 6, 7,
- 2, 1, 2, 3, 4, 5, 6, 8,
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- e := channel.New(0, 1280, linkAddr0)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ProtocolNumber,
- AddressWithPrefix: lladdr0.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
-
- ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
- pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
- pkt.SetType(header.ICMPv6NeighborAdvert)
- ns := header.NDPNeighborAdvert(pkt.MessageBody())
- ns.SetTargetAddress(lladdr1)
- opts := ns.Options()
- copy(opts, test.optsBuf)
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: lladdr1,
- Dst: lladdr0,
- }))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
-
- invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
-
- // Invalid count should initially be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
-
- e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
-
- neighbors, err := s.Neighbors(nicID, ProtocolNumber)
- if err != nil {
- t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err)
- }
-
- neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry)
- for _, n := range neighbors {
- if existing, ok := neighborByAddr[n.Addr]; ok {
- if diff := cmp.Diff(existing, n); diff != "" {
- t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, ProtocolNumber, diff)
- }
- t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry: %#v", nicID, ProtocolNumber, existing)
- }
- neighborByAddr[n.Addr] = n
- }
-
- if neigh, ok := neighborByAddr[lladdr1]; ok {
- t.Fatalf("unexpectedly got neighbor entry: %#v", neigh)
- }
-
- if test.isValid {
- // Invalid count should not have increased.
- if got := invalid.Value(); got != 0 {
- t.Errorf("got invalid = %d, want = 0", got)
- }
- } else {
- // Invalid count should have increased.
- if got := invalid.Value(); got != 1 {
- t.Errorf("got invalid = %d, want = 1", got)
- }
- }
- })
- }
-}
-
-func TestNDPValidation(t *testing.T) {
- const nicID = 1
-
- handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) {
- var extHdrs header.IPv6ExtHdrSerializer
- if atomicFragment {
- extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{})
- }
- extHdrsLen := extHdrs.Length()
-
- ip := buffer.NewView(header.IPv6MinimumSize + extHdrsLen)
- header.IPv6(ip).Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(payload) + extHdrsLen),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: hopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- ExtensionHeaders: extHdrs,
- })
- vv := ip.ToVectorisedView()
- vv.AppendView(payload)
- ep.HandlePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: vv,
- }))
- }
-
- var tllData [header.NDPLinkLayerAddressSize]byte
- header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(linkAddr1),
- })
-
- var sllData [header.NDPLinkLayerAddressSize]byte
- header.NDPOptions(sllData[:]).Serialize(header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(linkAddr1),
- })
-
- types := []struct {
- name string
- typ header.ICMPv6Type
- size int
- extraData []byte
- statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
- routerOnly bool
- }{
- {
- name: "RouterSolicit",
- typ: header.ICMPv6RouterSolicit,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterSolicit
- },
- routerOnly: true,
- },
- {
- name: "RouterAdvert",
- typ: header.ICMPv6RouterAdvert,
- size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterAdvert
- },
- },
- {
- name: "NeighborSolicit",
- typ: header.ICMPv6NeighborSolicit,
- size: header.ICMPv6NeighborSolicitMinimumSize,
- extraData: sllData[:],
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborSolicit
- },
- },
- {
- name: "NeighborAdvert",
- typ: header.ICMPv6NeighborAdvert,
- size: header.ICMPv6NeighborAdvertMinimumSize,
- extraData: tllData[:],
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborAdvert
- },
- },
- {
- name: "RedirectMsg",
- typ: header.ICMPv6RedirectMsg,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RedirectMsg
- },
- },
- }
-
- subTests := []struct {
- name string
- atomicFragment bool
- hopLimit uint8
- code header.ICMPv6Code
- valid bool
- }{
- {
- name: "Valid",
- atomicFragment: false,
- hopLimit: header.NDPHopLimit,
- code: 0,
- valid: true,
- },
- {
- name: "Fragmented",
- atomicFragment: true,
- hopLimit: header.NDPHopLimit,
- code: 0,
- valid: false,
- },
- {
- name: "Invalid hop limit",
- atomicFragment: false,
- hopLimit: header.NDPHopLimit - 1,
- code: 0,
- valid: false,
- },
- {
- name: "Invalid ICMPv6 code",
- atomicFragment: false,
- hopLimit: header.NDPHopLimit,
- code: 1,
- valid: false,
- },
- }
-
- subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0))))
- if err != nil {
- t.Fatal(err)
- }
-
- for _, typ := range types {
- for _, isRouter := range []bool{false, true} {
- name := typ.name
- if isRouter {
- name += " (Router)"
- }
-
- t.Run(name, func(t *testing.T) {
- for _, test := range subTests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
- })
-
- if isRouter {
- if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil {
- t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err)
- }
- }
-
- if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ProtocolNumber,
- AddressWithPrefix: lladdr0.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
-
- ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber)
- if err != nil {
- t.Fatal("cannot find network endpoint instance for IPv6")
- }
-
- s.SetRouteTable([]tcpip.Route{{
- Destination: subnet,
- NIC: nicID,
- }})
-
- stats := s.Stats().ICMP.V6.PacketsReceived
- invalid := stats.Invalid
- routerOnly := stats.RouterOnlyPacketsDroppedByHost
- typStat := typ.statCounter(stats)
-
- icmpH := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
- copy(icmpH[typ.size:], typ.extraData)
- icmpH.SetType(typ.typ)
- icmpH.SetCode(test.code)
- icmpH.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmpH[:typ.size],
- Src: lladdr0,
- Dst: lladdr1,
- PayloadCsum: header.Checksum(typ.extraData /* initial */, 0),
- PayloadLen: len(typ.extraData),
- }))
-
- // Rx count of the NDP message should initially be 0.
- if got := typStat.Value(); got != 0 {
- t.Errorf("got %s = %d, want = 0", typ.name, got)
- }
-
- // Invalid count should initially be 0.
- if got := invalid.Value(); got != 0 {
- t.Errorf("got invalid.Value() = %d, want = 0", got)
- }
-
- // Should initially not have dropped any packets.
- if got := routerOnly.Value(); got != 0 {
- t.Errorf("got routerOnly.Value() = %d, want = 0", got)
- }
-
- if t.Failed() {
- t.FailNow()
- }
-
- handleIPv6Payload(buffer.View(icmpH), test.hopLimit, test.atomicFragment, ep)
-
- // Rx count of the NDP packet should have increased.
- if got := typStat.Value(); got != 1 {
- t.Errorf("got %s = %d, want = 1", typ.name, got)
- }
-
- want := uint64(0)
- if !test.valid {
- // Invalid count should have increased.
- want = 1
- }
- if got := invalid.Value(); got != want {
- t.Errorf("got invalid.Value() = %d, want = %d", got, want)
- }
-
- want = 0
- if test.valid && !isRouter && typ.routerOnly {
- // Router only packets are expected to be dropped when operating
- // as a host.
- want = 1
- }
- if got := routerOnly.Value(); got != want {
- t.Errorf("got routerOnly.Value() = %d, want = %d", got, want)
- }
- })
- }
- })
- }
- }
-}
-
-// TestNeighborAdvertisementValidation tests that the NIC validates received
-// Neighbor Advertisements.
-//
-// In particular, if the IP Destination Address is a multicast address, and the
-// Solicited flag is not zero, the Neighbor Advertisement is invalid and should
-// be discarded.
-func TestNeighborAdvertisementValidation(t *testing.T) {
- tests := []struct {
- name string
- ipDstAddr tcpip.Address
- solicitedFlag bool
- valid bool
- }{
- {
- name: "Multicast IP destination address with Solicited flag set",
- ipDstAddr: header.IPv6AllNodesMulticastAddress,
- solicitedFlag: true,
- valid: false,
- },
- {
- name: "Multicast IP destination address with Solicited flag unset",
- ipDstAddr: header.IPv6AllNodesMulticastAddress,
- solicitedFlag: false,
- valid: true,
- },
- {
- name: "Unicast IP destination address with Solicited flag set",
- ipDstAddr: lladdr0,
- solicitedFlag: true,
- valid: true,
- },
- {
- name: "Unicast IP destination address with Solicited flag unset",
- ipDstAddr: lladdr0,
- solicitedFlag: false,
- valid: true,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- e := channel.New(0, header.IPv6MinimumMTU, linkAddr0)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ProtocolNumber,
- AddressWithPrefix: lladdr0.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
-
- ndpNASize := header.ICMPv6NeighborAdvertMinimumSize
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
- pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
- pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.MessageBody())
- na.SetTargetAddress(lladdr1)
- na.SetSolicitedFlag(test.solicitedFlag)
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: lladdr1,
- Dst: test.ipDstAddr,
- }))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: test.ipDstAddr,
- })
-
- stats := s.Stats().ICMP.V6.PacketsReceived
- invalid := stats.Invalid
- rxNA := stats.NeighborAdvert
-
- if got := rxNA.Value(); got != 0 {
- t.Fatalf("got rxNA = %d, want = 0", got)
- }
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
-
- e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
-
- if got := rxNA.Value(); got != 1 {
- t.Fatalf("got rxNA = %d, want = 1", got)
- }
- var wantInvalid uint64 = 1
- if test.valid {
- wantInvalid = 0
- }
- if got := invalid.Value(); got != wantInvalid {
- t.Fatalf("got invalid = %d, want = %d", got, wantInvalid)
- }
- // As per RFC 4861 section 7.2.5:
- // When a valid Neighbor Advertisement is received ...
- // If no entry exists, the advertisement SHOULD be silently discarded.
- // There is no need to create an entry if none exists, since the
- // recipient has apparently not initiated any communication with the
- // target.
- if neighbors, err := s.Neighbors(nicID, ProtocolNumber); err != nil {
- t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err)
- } else if len(neighbors) != 0 {
- t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors)
- }
- })
- }
-}
-
-// TestRouterAdvertValidation tests that when the NIC is configured to handle
-// NDP Router Advertisement packets, it validates the Router Advertisement
-// properly before handling them.
-func TestRouterAdvertValidation(t *testing.T) {
- tests := []struct {
- name string
- src tcpip.Address
- hopLimit uint8
- code header.ICMPv6Code
- ndpPayload []byte
- expectedSuccess bool
- }{
- {
- "OK",
- lladdr0,
- 255,
- 0,
- []byte{
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- },
- true,
- },
- {
- "NonLinkLocalSourceAddr",
- addr1,
- 255,
- 0,
- []byte{
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- },
- false,
- },
- {
- "HopLimitNot255",
- lladdr0,
- 254,
- 0,
- []byte{
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- },
- false,
- },
- {
- "NonZeroCode",
- lladdr0,
- 255,
- 1,
- []byte{
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- },
- false,
- },
- {
- "NDPPayloadTooSmall",
- lladdr0,
- 255,
- 0,
- []byte{
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0,
- },
- false,
- },
- {
- "OKWithOptions",
- lladdr0,
- 255,
- 0,
- []byte{
- // RA payload
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
-
- // Option #1 (TargetLinkLayerAddress)
- 2, 1, 0, 0, 0, 0, 0, 0,
-
- // Option #2 (unrecognized)
- 255, 1, 0, 0, 0, 0, 0, 0,
-
- // Option #3 (PrefixInformation)
- 3, 4, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- },
- true,
- },
- {
- "OptionWithZeroLength",
- lladdr0,
- 255,
- 0,
- []byte{
- // RA payload
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
-
- // Option #1 (TargetLinkLayerAddress)
- // Invalid as it has 0 length.
- 2, 0, 0, 0, 0, 0, 0, 0,
-
- // Option #2 (unrecognized)
- 255, 1, 0, 0, 0, 0, 0, 0,
-
- // Option #3 (PrefixInformation)
- 3, 4, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- },
- false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e := channel.New(10, 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
-
- icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
- pkt := header.ICMPv6(hdr.Prepend(icmpSize))
- pkt.SetType(header.ICMPv6RouterAdvert)
- pkt.SetCode(test.code)
- copy(pkt.MessageBody(), test.ndpPayload)
- payloadLength := hdr.UsedLength()
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: test.src,
- Dst: header.IPv6AllNodesMulticastAddress,
- }))
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: icmp.ProtocolNumber6,
- HopLimit: test.hopLimit,
- SrcAddr: test.src,
- DstAddr: header.IPv6AllNodesMulticastAddress,
- })
-
- stats := s.Stats().ICMP.V6.PacketsReceived
- invalid := stats.Invalid
- rxRA := stats.RouterAdvert
-
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- if got := rxRA.Value(); got != 0 {
- t.Fatalf("got rxRA = %d, want = 0", got)
- }
-
- e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
-
- if got := rxRA.Value(); got != 1 {
- t.Fatalf("got rxRA = %d, want = 1", got)
- }
-
- if test.expectedSuccess {
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- } else {
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- }
- })
- }
-}
-
-// TestCheckDuplicateAddress checks that calls to CheckDuplicateAddress and DAD
-// performed when adding new addresses do not interfere with each other.
-func TestCheckDuplicateAddress(t *testing.T) {
- const nicID = 1
-
- clock := faketime.NewManualClock()
- dadConfigs := stack.DADConfigurations{
- DupAddrDetectTransmits: 1,
- RetransmitTimer: time.Second,
- }
-
- nonces := [...][]byte{
- {1, 2, 3, 4, 5, 6},
- {7, 8, 9, 10, 11, 12},
- }
-
- var secureRNGBytes []byte
- for _, n := range nonces {
- secureRNGBytes = append(secureRNGBytes, n...)
- }
- var secureRNG bytes.Reader
- secureRNG.Reset(secureRNGBytes[:])
- s := stack.New(stack.Options{
- Clock: clock,
- RandSource: rand.NewSource(time.Now().UnixNano()),
- SecureRNG: &secureRNG,
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{
- DADConfigs: dadConfigs,
- })},
- })
- // This test is expected to send at max 2 DAD messages. We allow an extra
- // packet to be stored to catch unexpected packets.
- e := channel.New(3, header.IPv6MinimumMTU, linkAddr0)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- dadPacketsSent := 0
- snmc := header.SolicitedNodeAddr(lladdr0)
- remoteLinkAddr := header.EthernetAddressFromMulticastIPv6Address(snmc)
- checkDADMsg := func() {
- clock.RunImmediatelyScheduledJobs()
- p, ok := e.Read()
- if !ok {
- t.Fatalf("expected %d-th DAD message", dadPacketsSent)
- }
-
- if p.Proto != header.IPv6ProtocolNumber {
- t.Errorf("(i=%d) got p.Proto = %d, want = %d", dadPacketsSent, p.Proto, header.IPv6ProtocolNumber)
- }
-
- if p.Route.RemoteLinkAddress != remoteLinkAddr {
- t.Errorf("(i=%d) got p.Route.RemoteLinkAddress = %s, want = %s", dadPacketsSent, p.Route.RemoteLinkAddress, remoteLinkAddr)
- }
-
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(snmc),
- checker.TTL(header.NDPHopLimit),
- checker.NDPNS(
- checker.NDPNSTargetAddress(lladdr0),
- checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonces[dadPacketsSent])}),
- ))
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ProtocolNumber,
- AddressWithPrefix: lladdr0.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- checkDADMsg()
-
- // Start DAD for the address we just added.
- //
- // Even though the stack will perform DAD before the added address transitions
- // from tentative to assigned, this DAD request should be independent of that.
- ch := make(chan stack.DADResult, 3)
- dadRequestsMade := 1
- dadPacketsSent++
- if res, err := s.CheckDuplicateAddress(nicID, ProtocolNumber, lladdr0, func(r stack.DADResult) {
- ch <- r
- }); err != nil {
- t.Fatalf("s.CheckDuplicateAddress(%d, %d, %s, _): %s", nicID, ProtocolNumber, lladdr0, err)
- } else if res != stack.DADStarting {
- t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, ProtocolNumber, lladdr0, res, stack.DADStarting)
- }
- checkDADMsg()
-
- // Remove the address and make sure our DAD request was not stopped.
- if err := s.RemoveAddress(nicID, lladdr0); err != nil {
- t.Fatalf("RemoveAddress(%d, %s): %s", nicID, lladdr0, err)
- }
- // Should not restart DAD since we already requested DAD above - the handler
- // should be called when the original request completes so we should not send
- // an extra DAD message here.
- dadRequestsMade++
- if res, err := s.CheckDuplicateAddress(nicID, ProtocolNumber, lladdr0, func(r stack.DADResult) {
- ch <- r
- }); err != nil {
- t.Fatalf("s.CheckDuplicateAddress(%d, %d, %s, _): %s", nicID, ProtocolNumber, lladdr0, err)
- } else if res != stack.DADAlreadyRunning {
- t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, ProtocolNumber, lladdr0, res, stack.DADAlreadyRunning)
- }
-
- // Wait for DAD to complete.
- clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer)
- for i := 0; i < dadRequestsMade; i++ {
- if diff := cmp.Diff(&stack.DADSucceeded{}, <-ch); diff != "" {
- t.Errorf("(i=%d) DAD result mismatch (-want +got):\n%s", i, diff)
- }
- }
- // Should have no more results.
- select {
- case r := <-ch:
- t.Errorf("unexpectedly got an extra DAD result; r = %#v", r)
- default:
- }
-
- // Should have no more packets.
- if p, ok := e.Read(); ok {
- t.Errorf("got unexpected packet = %#v", p)
- }
-}
diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go
deleted file mode 100644
index 26640b7ee..000000000
--- a/pkg/tcpip/network/multicast_group_test.go
+++ /dev/null
@@ -1,1285 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ip_test
-
-import (
- "fmt"
- "strings"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
-)
-
-const (
- linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
-
- defaultIPv4PrefixLength = 24
-
- igmpMembershipQuery = uint8(header.IGMPMembershipQuery)
- igmpv1MembershipReport = uint8(header.IGMPv1MembershipReport)
- igmpv2MembershipReport = uint8(header.IGMPv2MembershipReport)
- igmpLeaveGroup = uint8(header.IGMPLeaveGroup)
- mldQuery = uint8(header.ICMPv6MulticastListenerQuery)
- mldReport = uint8(header.ICMPv6MulticastListenerReport)
- mldDone = uint8(header.ICMPv6MulticastListenerDone)
-
- maxUnsolicitedReports = 2
-)
-
-var (
- stackIPv4Addr = testutil.MustParse4("10.0.0.1")
- linkLocalIPv6Addr1 = testutil.MustParse6("fe80::1")
- linkLocalIPv6Addr2 = testutil.MustParse6("fe80::2")
-
- ipv4MulticastAddr1 = testutil.MustParse4("224.0.0.3")
- ipv4MulticastAddr2 = testutil.MustParse4("224.0.0.4")
- ipv4MulticastAddr3 = testutil.MustParse4("224.0.0.5")
- ipv6MulticastAddr1 = testutil.MustParse6("ff02::3")
- ipv6MulticastAddr2 = testutil.MustParse6("ff02::4")
- ipv6MulticastAddr3 = testutil.MustParse6("ff02::5")
-)
-
-var (
- // unsolicitedIGMPReportIntervalMaxTenthSec is the maximum amount of time the
- // NIC will wait before sending an unsolicited report after joining a
- // multicast group, in deciseconds.
- unsolicitedIGMPReportIntervalMaxTenthSec = func() uint8 {
- const decisecond = time.Second / 10
- if ipv4.UnsolicitedReportIntervalMax%decisecond != 0 {
- panic(fmt.Sprintf("UnsolicitedReportIntervalMax of %d is a lossy conversion to deciseconds", ipv4.UnsolicitedReportIntervalMax))
- }
- return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond)
- }()
-
- ipv6AddrSNMC = header.SolicitedNodeAddr(linkLocalIPv6Addr1)
-)
-
-// validateMLDPacket checks that a passed PacketInfo is an IPv6 MLD packet
-// sent to the provided address with the passed fields set.
-func validateMLDPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, mldType uint8, maxRespTime byte, groupAddress tcpip.Address) {
- t.Helper()
-
- payload := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader()))
- checker.IPv6WithExtHdr(t, payload,
- checker.IPv6ExtHdr(
- checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)),
- ),
- checker.SrcAddr(linkLocalIPv6Addr1),
- checker.DstAddr(remoteAddress),
- // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3.
- checker.TTL(1),
- checker.MLD(header.ICMPv6Type(mldType), header.MLDMinimumSize,
- checker.MLDMaxRespDelay(time.Duration(maxRespTime)*time.Millisecond),
- checker.MLDMulticastAddress(groupAddress),
- ),
- )
-}
-
-// validateIGMPPacket checks that a passed PacketInfo is an IPv4 IGMP packet
-// sent to the provided address with the passed fields set.
-func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType uint8, maxRespTime byte, groupAddress tcpip.Address) {
- t.Helper()
-
- payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
- checker.IPv4(t, payload,
- checker.SrcAddr(stackIPv4Addr),
- checker.DstAddr(remoteAddress),
- // TTL for an IGMP message must be 1 as per RFC 2236 section 2.
- checker.TTL(1),
- checker.IPv4RouterAlert(),
- checker.IGMP(
- checker.IGMPType(header.IGMPType(igmpType)),
- checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)),
- checker.IGMPGroupAddress(groupAddress),
- ),
- )
-}
-
-func createStack(t *testing.T, v4, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) {
- t.Helper()
-
- e := channel.New(maxUnsolicitedReports, header.IPv6MinimumMTU, linkAddr)
- s, clock := createStackWithLinkEndpoint(t, v4, mgpEnabled, e)
- return e, s, clock
-}
-
-func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.LinkEndpoint) (*stack.Stack, *faketime.ManualClock) {
- t.Helper()
-
- igmpEnabled := v4 && mgpEnabled
- mldEnabled := !v4 && mgpEnabled
-
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{
- ipv4.NewProtocolWithOptions(ipv4.Options{
- IGMP: ipv4.IGMPOptions{
- Enabled: igmpEnabled,
- },
- }),
- ipv6.NewProtocolWithOptions(ipv6.Options{
- MLD: ipv6.MLDOptions{
- Enabled: mldEnabled,
- },
- }),
- },
- Clock: clock,
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- addr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: stackIPv4Addr,
- PrefixLen: defaultIPv4PrefixLength,
- },
- }
- if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err)
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: linkLocalIPv6Addr1.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
-
- return s, clock
-}
-
-// checkInitialIPv6Groups checks the initial IPv6 groups that a NIC will join
-// when it is created with an IPv6 address.
-//
-// To not interfere with tests, checkInitialIPv6Groups will leave the added
-// address's solicited node multicast group so that the tests can all assume
-// the NIC has not joined any IPv6 groups.
-func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, clock *faketime.ManualClock) (reportCounter uint64, leaveCounter uint64) {
- t.Helper()
-
- stats := s.Stats().ICMP.V6.PacketsSent
-
- reportCounter++
- if got := stats.MulticastListenerReport.Value(); got != reportCounter {
- t.Errorf("got stats.MulticastListenerReport.Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a report message to be sent")
- } else {
- validateMLDPacket(t, p, ipv6AddrSNMC, mldReport, 0, ipv6AddrSNMC)
- }
-
- // Leave the group to not affect the tests. This is fine since we are not
- // testing DAD or the solicited node address specifically.
- if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, ipv6AddrSNMC); err != nil {
- t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, ipv6AddrSNMC, err)
- }
- leaveCounter++
- if got := stats.MulticastListenerDone.Value(); got != leaveCounter {
- t.Errorf("got stats.MulticastListenerDone.Value() = %d, want = %d", got, leaveCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a report message to be sent")
- } else {
- validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6AddrSNMC)
- }
-
- // Should not send any more packets.
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Fatalf("sent unexpected packet = %#v", p)
- }
-
- return reportCounter, leaveCounter
-}
-
-// createAndInjectIGMPPacket creates and injects an IGMP packet with the
-// specified fields.
-func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType byte, maxRespTime byte, groupAddress tcpip.Address) {
- options := header.IPv4OptionsSerializer{
- &header.IPv4SerializableRouterAlertOption{},
- }
- buf := buffer.NewView(header.IPv4MinimumSize + int(options.Length()) + header.IGMPQueryMinimumSize)
- ip := header.IPv4(buf)
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(len(buf)),
- TTL: header.IGMPTTL,
- Protocol: uint8(header.IGMPProtocolNumber),
- SrcAddr: remoteIPv4Addr,
- DstAddr: header.IPv4AllSystems,
- Options: options,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- igmp := header.IGMP(ip.Payload())
- igmp.SetType(header.IGMPType(igmpType))
- igmp.SetMaxRespTime(maxRespTime)
- igmp.SetGroupAddress(groupAddress)
- igmp.SetChecksum(header.IGMPCalculateChecksum(igmp))
-
- e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-}
-
-// createAndInjectMLDPacket creates and injects an MLD packet with the
-// specified fields.
-func createAndInjectMLDPacket(e *channel.Endpoint, mldType uint8, maxRespDelay byte, groupAddress tcpip.Address) {
- extensionHeaders := header.IPv6ExtHdrSerializer{
- header.IPv6SerializableHopByHopExtHdr{
- &header.IPv6RouterAlertOption{Value: header.IPv6RouterAlertMLD},
- },
- }
-
- extensionHeadersLength := extensionHeaders.Length()
- payloadLength := extensionHeadersLength + header.ICMPv6HeaderSize + header.MLDMinimumSize
- buf := buffer.NewView(header.IPv6MinimumSize + payloadLength)
-
- ip := header.IPv6(buf)
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- HopLimit: header.MLDHopLimit,
- TransportProtocol: header.ICMPv6ProtocolNumber,
- SrcAddr: linkLocalIPv6Addr2,
- DstAddr: header.IPv6AllNodesMulticastAddress,
- ExtensionHeaders: extensionHeaders,
- })
-
- icmp := header.ICMPv6(ip.Payload()[extensionHeadersLength:])
- icmp.SetType(header.ICMPv6Type(mldType))
- mld := header.MLD(icmp.MessageBody())
- mld.SetMaximumResponseDelay(uint16(maxRespDelay))
- mld.SetMulticastAddress(groupAddress)
- icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmp,
- Src: linkLocalIPv6Addr2,
- Dst: header.IPv6AllNodesMulticastAddress,
- }))
-
- e.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-}
-
-// TestMGPDisabled tests that the multicast group protocol is not enabled by
-// default.
-func TestMGPDisabled(t *testing.T) {
- tests := []struct {
- name string
- protoNum tcpip.NetworkProtocolNumber
- multicastAddr tcpip.Address
- sentReportStat func(*stack.Stack) *tcpip.StatCounter
- receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
- rxQuery func(*channel.Endpoint)
- }{
- {
- name: "IGMP",
- protoNum: ipv4.ProtocolNumber,
- multicastAddr: ipv4MulticastAddr1,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsSent.V2MembershipReport
- },
- receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsReceived.MembershipQuery
- },
- rxQuery: func(e *channel.Endpoint) {
- createAndInjectIGMPPacket(e, igmpMembershipQuery, unsolicitedIGMPReportIntervalMaxTenthSec, header.IPv4Any)
- },
- },
- {
- name: "MLD",
- protoNum: ipv6.ProtocolNumber,
- multicastAddr: ipv6MulticastAddr1,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
- },
- receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
- },
- rxQuery: func(e *channel.Endpoint) {
- createAndInjectMLDPacket(e, mldQuery, 0, header.IPv6Any)
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, false /* mgpEnabled */)
-
- // This NIC may join multicast groups when it is enabled but since MGP is
- // disabled, no reports should be sent.
- sentReportStat := test.sentReportStat(s)
- if got := sentReportStat.Value(); got != 0 {
- t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
- }
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Fatalf("sent unexpected packet, stack with disabled MGP sent packet = %#v", p.Pkt)
- }
-
- // Test joining a specific group explicitly and verify that no reports are
- // sent.
- if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
- t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
- }
- if got := sentReportStat.Value(); got != 0 {
- t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
- }
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %#v", p.Pkt)
- }
-
- // Inject a general query message. This should only trigger a report to be
- // sent if the MGP was enabled.
- test.rxQuery(e)
- if got := test.receivedQueryStat(s).Value(); got != 1 {
- t.Fatalf("got receivedQueryStat(_).Value() = %d, want = 1", got)
- }
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt)
- }
- })
- }
-}
-
-func TestMGPReceiveCounters(t *testing.T) {
- tests := []struct {
- name string
- headerType uint8
- maxRespTime byte
- groupAddress tcpip.Address
- statCounter func(*stack.Stack) *tcpip.StatCounter
- rxMGPkt func(*channel.Endpoint, byte, byte, tcpip.Address)
- }{
- {
- name: "IGMP Membership Query",
- headerType: igmpMembershipQuery,
- maxRespTime: unsolicitedIGMPReportIntervalMaxTenthSec,
- groupAddress: header.IPv4Any,
- statCounter: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsReceived.MembershipQuery
- },
- rxMGPkt: createAndInjectIGMPPacket,
- },
- {
- name: "IGMPv1 Membership Report",
- headerType: igmpv1MembershipReport,
- maxRespTime: 0,
- groupAddress: header.IPv4AllSystems,
- statCounter: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsReceived.V1MembershipReport
- },
- rxMGPkt: createAndInjectIGMPPacket,
- },
- {
- name: "IGMPv2 Membership Report",
- headerType: igmpv2MembershipReport,
- maxRespTime: 0,
- groupAddress: header.IPv4AllSystems,
- statCounter: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsReceived.V2MembershipReport
- },
- rxMGPkt: createAndInjectIGMPPacket,
- },
- {
- name: "IGMP Leave Group",
- headerType: igmpLeaveGroup,
- maxRespTime: 0,
- groupAddress: header.IPv4AllRoutersGroup,
- statCounter: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsReceived.LeaveGroup
- },
- rxMGPkt: createAndInjectIGMPPacket,
- },
- {
- name: "MLD Query",
- headerType: mldQuery,
- maxRespTime: 0,
- groupAddress: header.IPv6Any,
- statCounter: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
- },
- rxMGPkt: createAndInjectMLDPacket,
- },
- {
- name: "MLD Report",
- headerType: mldReport,
- maxRespTime: 0,
- groupAddress: header.IPv6Any,
- statCounter: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerReport
- },
- rxMGPkt: createAndInjectMLDPacket,
- },
- {
- name: "MLD Done",
- headerType: mldDone,
- maxRespTime: 0,
- groupAddress: header.IPv6Any,
- statCounter: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerDone
- },
- rxMGPkt: createAndInjectMLDPacket,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e, s, _ := createStack(t, len(test.groupAddress) == header.IPv4AddressSize /* v4 */, true /* mgpEnabled */)
-
- test.rxMGPkt(e, test.headerType, test.maxRespTime, test.groupAddress)
- if got := test.statCounter(s).Value(); got != 1 {
- t.Fatalf("got %s received = %d, want = 1", test.name, got)
- }
- })
- }
-}
-
-// TestMGPJoinGroup tests that when explicitly joining a multicast group, the
-// stack schedules and sends correct Membership Reports.
-func TestMGPJoinGroup(t *testing.T) {
- tests := []struct {
- name string
- protoNum tcpip.NetworkProtocolNumber
- multicastAddr tcpip.Address
- maxUnsolicitedResponseDelay time.Duration
- sentReportStat func(*stack.Stack) *tcpip.StatCounter
- receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
- validateReport func(*testing.T, channel.PacketInfo)
- checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
- }{
- {
- name: "IGMP",
- protoNum: ipv4.ProtocolNumber,
- multicastAddr: ipv4MulticastAddr1,
- maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsSent.V2MembershipReport
- },
- receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsReceived.MembershipQuery
- },
- validateReport: func(t *testing.T, p channel.PacketInfo) {
- t.Helper()
-
- validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
- },
- },
- {
- name: "MLD",
- protoNum: ipv6.ProtocolNumber,
- multicastAddr: ipv6MulticastAddr1,
- maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
- },
- receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
- },
- validateReport: func(t *testing.T, p channel.PacketInfo) {
- t.Helper()
-
- validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
- },
- checkInitialGroups: checkInitialIPv6Groups,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
-
- var reportCounter uint64
- if test.checkInitialGroups != nil {
- reportCounter, _ = test.checkInitialGroups(t, e, s, clock)
- }
-
- // Test joining a specific address explicitly and verify a Report is sent
- // immediately.
- if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
- t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
- }
- reportCounter++
- sentReportStat := test.sentReportStat(s)
- if got := sentReportStat.Value(); got != reportCounter {
- t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a report message to be sent")
- } else {
- test.validateReport(t, p)
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Verify the second report is sent by the maximum unsolicited response
- // interval.
- p, ok := e.Read()
- if ok {
- t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p.Pkt)
- }
- clock.Advance(test.maxUnsolicitedResponseDelay)
- reportCounter++
- if got := sentReportStat.Value(); got != reportCounter {
- t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a report message to be sent")
- } else {
- test.validateReport(t, p)
- }
-
- // Should not send any more packets.
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Fatalf("sent unexpected packet = %#v", p)
- }
- })
- }
-}
-
-// TestMGPLeaveGroup tests that when leaving a previously joined multicast
-// group the stack sends a leave/done message.
-func TestMGPLeaveGroup(t *testing.T) {
- tests := []struct {
- name string
- protoNum tcpip.NetworkProtocolNumber
- multicastAddr tcpip.Address
- sentReportStat func(*stack.Stack) *tcpip.StatCounter
- sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
- validateReport func(*testing.T, channel.PacketInfo)
- validateLeave func(*testing.T, channel.PacketInfo)
- checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
- }{
- {
- name: "IGMP",
- protoNum: ipv4.ProtocolNumber,
- multicastAddr: ipv4MulticastAddr1,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsSent.V2MembershipReport
- },
- sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsSent.LeaveGroup
- },
- validateReport: func(t *testing.T, p channel.PacketInfo) {
- t.Helper()
-
- validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
- },
- validateLeave: func(t *testing.T, p channel.PacketInfo) {
- t.Helper()
-
- validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, ipv4MulticastAddr1)
- },
- },
- {
- name: "MLD",
- protoNum: ipv6.ProtocolNumber,
- multicastAddr: ipv6MulticastAddr1,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
- },
- sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
- },
- validateReport: func(t *testing.T, p channel.PacketInfo) {
- t.Helper()
-
- validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
- },
- validateLeave: func(t *testing.T, p channel.PacketInfo) {
- t.Helper()
-
- validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6MulticastAddr1)
- },
- checkInitialGroups: checkInitialIPv6Groups,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
-
- var reportCounter uint64
- var leaveCounter uint64
- if test.checkInitialGroups != nil {
- reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
- }
-
- if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
- t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
- }
- reportCounter++
- if got := test.sentReportStat(s).Value(); got != reportCounter {
- t.Errorf("got sentReportStat(_).Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a report message to be sent")
- } else {
- test.validateReport(t, p)
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Leaving the group should trigger an leave/done message to be sent.
- if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
- t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err)
- }
- leaveCounter++
- if got := test.sentLeaveStat(s).Value(); got != leaveCounter {
- t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a leave message to be sent")
- } else {
- test.validateLeave(t, p)
- }
-
- // Should not send any more packets.
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Fatalf("sent unexpected packet = %#v", p)
- }
- })
- }
-}
-
-// TestMGPQueryMessages tests that a report is sent in response to query
-// messages.
-func TestMGPQueryMessages(t *testing.T) {
- tests := []struct {
- name string
- protoNum tcpip.NetworkProtocolNumber
- multicastAddr tcpip.Address
- maxUnsolicitedResponseDelay time.Duration
- sentReportStat func(*stack.Stack) *tcpip.StatCounter
- receivedQueryStat func(*stack.Stack) *tcpip.StatCounter
- rxQuery func(*channel.Endpoint, uint8, tcpip.Address)
- validateReport func(*testing.T, channel.PacketInfo)
- maxRespTimeToDuration func(uint8) time.Duration
- checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
- }{
- {
- name: "IGMP",
- protoNum: ipv4.ProtocolNumber,
- multicastAddr: ipv4MulticastAddr1,
- maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsSent.V2MembershipReport
- },
- receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsReceived.MembershipQuery
- },
- rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) {
- createAndInjectIGMPPacket(e, igmpMembershipQuery, maxRespTime, groupAddress)
- },
- validateReport: func(t *testing.T, p channel.PacketInfo) {
- t.Helper()
-
- validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
- },
- maxRespTimeToDuration: header.DecisecondToDuration,
- },
- {
- name: "MLD",
- protoNum: ipv6.ProtocolNumber,
- multicastAddr: ipv6MulticastAddr1,
- maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
- },
- receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery
- },
- rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) {
- createAndInjectMLDPacket(e, mldQuery, maxRespTime, groupAddress)
- },
- validateReport: func(t *testing.T, p channel.PacketInfo) {
- t.Helper()
-
- validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
- },
- maxRespTimeToDuration: func(d uint8) time.Duration {
- return time.Duration(d) * time.Millisecond
- },
- checkInitialGroups: checkInitialIPv6Groups,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- subTests := []struct {
- name string
- multicastAddr tcpip.Address
- expectReport bool
- }{
- {
- name: "Unspecified",
- multicastAddr: tcpip.Address(strings.Repeat("\x00", len(test.multicastAddr))),
- expectReport: true,
- },
- {
- name: "Specified",
- multicastAddr: test.multicastAddr,
- expectReport: true,
- },
- {
- name: "Specified other address",
- multicastAddr: func() tcpip.Address {
- addrBytes := []byte(test.multicastAddr)
- addrBytes[len(addrBytes)-1]++
- return tcpip.Address(addrBytes)
- }(),
- expectReport: false,
- },
- }
-
- for _, subTest := range subTests {
- t.Run(subTest.name, func(t *testing.T) {
- e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
-
- var reportCounter uint64
- if test.checkInitialGroups != nil {
- reportCounter, _ = test.checkInitialGroups(t, e, s, clock)
- }
-
- if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
- t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
- }
- sentReportStat := test.sentReportStat(s)
- for i := 0; i < maxUnsolicitedReports; i++ {
- sentReportStat := test.sentReportStat(s)
- reportCounter++
- if got := sentReportStat.Value(); got != reportCounter {
- t.Errorf("(i=%d) got sentReportStat.Value() = %d, want = %d", i, got, reportCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Fatalf("expected %d-th report message to be sent", i)
- } else {
- test.validateReport(t, p)
- }
- clock.Advance(test.maxUnsolicitedResponseDelay)
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Should not send any more packets until a query.
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Fatalf("sent unexpected packet = %#v", p)
- }
-
- // Receive a query message which should trigger a report to be sent at
- // some time before the maximum response time if the report is
- // targeted at the host.
- const maxRespTime = 100
- test.rxQuery(e, maxRespTime, subTest.multicastAddr)
- if p, ok := e.Read(); ok {
- t.Fatalf("sent unexpected packet = %#v", p.Pkt)
- }
-
- if subTest.expectReport {
- clock.Advance(test.maxRespTimeToDuration(maxRespTime))
- reportCounter++
- if got := sentReportStat.Value(); got != reportCounter {
- t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a report message to be sent")
- } else {
- test.validateReport(t, p)
- }
- }
-
- // Should not send any more packets.
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Fatalf("sent unexpected packet = %#v", p)
- }
- })
- }
- })
- }
-}
-
-// TestMGPQueryMessages tests that no further reports or leave/done messages
-// are sent after receiving a report.
-func TestMGPReportMessages(t *testing.T) {
- tests := []struct {
- name string
- protoNum tcpip.NetworkProtocolNumber
- multicastAddr tcpip.Address
- sentReportStat func(*stack.Stack) *tcpip.StatCounter
- sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
- rxReport func(*channel.Endpoint)
- validateReport func(*testing.T, channel.PacketInfo)
- maxRespTimeToDuration func(uint8) time.Duration
- checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
- }{
- {
- name: "IGMP",
- protoNum: ipv4.ProtocolNumber,
- multicastAddr: ipv4MulticastAddr1,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsSent.V2MembershipReport
- },
- sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsSent.LeaveGroup
- },
- rxReport: func(e *channel.Endpoint) {
- createAndInjectIGMPPacket(e, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
- },
- validateReport: func(t *testing.T, p channel.PacketInfo) {
- t.Helper()
-
- validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1)
- },
- maxRespTimeToDuration: header.DecisecondToDuration,
- },
- {
- name: "MLD",
- protoNum: ipv6.ProtocolNumber,
- multicastAddr: ipv6MulticastAddr1,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
- },
- sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
- },
- rxReport: func(e *channel.Endpoint) {
- createAndInjectMLDPacket(e, mldReport, 0, ipv6MulticastAddr1)
- },
- validateReport: func(t *testing.T, p channel.PacketInfo) {
- t.Helper()
-
- validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1)
- },
- maxRespTimeToDuration: func(d uint8) time.Duration {
- return time.Duration(d) * time.Millisecond
- },
- checkInitialGroups: checkInitialIPv6Groups,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
-
- var reportCounter uint64
- var leaveCounter uint64
- if test.checkInitialGroups != nil {
- reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
- }
-
- if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
- t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
- }
- sentReportStat := test.sentReportStat(s)
- reportCounter++
- if got := sentReportStat.Value(); got != reportCounter {
- t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a report message to be sent")
- } else {
- test.validateReport(t, p)
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Receiving a report for a group we joined should cancel any further
- // reports.
- test.rxReport(e)
- clock.Advance(time.Hour)
- if got := sentReportStat.Value(); got != reportCounter {
- t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); ok {
- t.Errorf("sent unexpected packet = %#v", p)
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Leaving a group after getting a report should not send a leave/done
- // message.
- if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
- t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err)
- }
- clock.Advance(time.Hour)
- if got := test.sentLeaveStat(s).Value(); got != leaveCounter {
- t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter)
- }
-
- // Should not send any more packets.
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Fatalf("sent unexpected packet = %#v", p)
- }
- })
- }
-}
-
-func TestMGPWithNICLifecycle(t *testing.T) {
- tests := []struct {
- name string
- protoNum tcpip.NetworkProtocolNumber
- multicastAddrs []tcpip.Address
- finalMulticastAddr tcpip.Address
- maxUnsolicitedResponseDelay time.Duration
- sentReportStat func(*stack.Stack) *tcpip.StatCounter
- sentLeaveStat func(*stack.Stack) *tcpip.StatCounter
- validateReport func(*testing.T, channel.PacketInfo, tcpip.Address)
- validateLeave func(*testing.T, channel.PacketInfo, tcpip.Address)
- getAndCheckGroupAddress func(*testing.T, map[tcpip.Address]bool, channel.PacketInfo) tcpip.Address
- checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64)
- }{
- {
- name: "IGMP",
- protoNum: ipv4.ProtocolNumber,
- multicastAddrs: []tcpip.Address{ipv4MulticastAddr1, ipv4MulticastAddr2},
- finalMulticastAddr: ipv4MulticastAddr3,
- maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsSent.V2MembershipReport
- },
- sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsSent.LeaveGroup
- },
- validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
- t.Helper()
-
- validateIGMPPacket(t, p, addr, igmpv2MembershipReport, 0, addr)
- },
- validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
- t.Helper()
-
- validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, addr)
- },
- getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address {
- t.Helper()
-
- ipv4 := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
- if got := tcpip.TransportProtocolNumber(ipv4.Protocol()); got != header.IGMPProtocolNumber {
- t.Fatalf("got ipv4.Protocol() = %d, want = %d", got, header.IGMPProtocolNumber)
- }
- addr := header.IGMP(ipv4.Payload()).GroupAddress()
- s, ok := seen[addr]
- if !ok {
- t.Fatalf("unexpectedly got a packet for group %s", addr)
- }
- if s {
- t.Fatalf("already saw packet for group %s", addr)
- }
- seen[addr] = true
- return addr
- },
- },
- {
- name: "MLD",
- protoNum: ipv6.ProtocolNumber,
- multicastAddrs: []tcpip.Address{ipv6MulticastAddr1, ipv6MulticastAddr2},
- finalMulticastAddr: ipv6MulticastAddr3,
- maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
- },
- sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone
- },
- validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
- t.Helper()
-
- validateMLDPacket(t, p, addr, mldReport, 0, addr)
- },
- validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
- t.Helper()
-
- validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, addr)
- },
- getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address {
- t.Helper()
-
- ipv6 := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader()))
-
- ipv6HeaderIter := header.MakeIPv6PayloadIterator(
- header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()),
- buffer.View(ipv6.Payload()).ToVectorisedView(),
- )
-
- var transport header.IPv6RawPayloadHeader
- for {
- h, done, err := ipv6HeaderIter.Next()
- if err != nil {
- t.Fatalf("ipv6HeaderIter.Next(): %s", err)
- }
- if done {
- t.Fatalf("ipv6HeaderIter.Next() = (%T, %t, _), want = (_, false, _)", h, done)
- }
- if t, ok := h.(header.IPv6RawPayloadHeader); ok {
- transport = t
- break
- }
- }
-
- if got := tcpip.TransportProtocolNumber(transport.Identifier); got != header.ICMPv6ProtocolNumber {
- t.Fatalf("got ipv6.NextHeader() = %d, want = %d", got, header.ICMPv6ProtocolNumber)
- }
- icmpv6 := header.ICMPv6(transport.Buf.ToView())
- if got := icmpv6.Type(); got != header.ICMPv6MulticastListenerReport && got != header.ICMPv6MulticastListenerDone {
- t.Fatalf("got icmpv6.Type() = %d, want = %d or %d", got, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone)
- }
- addr := header.MLD(icmpv6.MessageBody()).MulticastAddress()
- s, ok := seen[addr]
- if !ok {
- t.Fatalf("unexpectedly got a packet for group %s", addr)
- }
- if s {
- t.Fatalf("already saw packet for group %s", addr)
- }
- seen[addr] = true
- return addr
- },
- checkInitialGroups: checkInitialIPv6Groups,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */)
-
- var reportCounter uint64
- var leaveCounter uint64
- if test.checkInitialGroups != nil {
- reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock)
- }
-
- sentReportStat := test.sentReportStat(s)
- for _, a := range test.multicastAddrs {
- if err := s.JoinGroup(test.protoNum, nicID, a); err != nil {
- t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, a, err)
- }
- reportCounter++
- if got := sentReportStat.Value(); got != reportCounter {
- t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Fatalf("expected a report message to be sent for %s", a)
- } else {
- test.validateReport(t, p, a)
- }
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Leave messages should be sent for the joined groups when the NIC is
- // disabled.
- if err := s.DisableNIC(nicID); err != nil {
- t.Fatalf("DisableNIC(%d): %s", nicID, err)
- }
- sentLeaveStat := test.sentLeaveStat(s)
- leaveCounter += uint64(len(test.multicastAddrs))
- if got := sentLeaveStat.Value(); got != leaveCounter {
- t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
- }
- {
- seen := make(map[tcpip.Address]bool)
- for _, a := range test.multicastAddrs {
- seen[a] = false
- }
-
- for i := range test.multicastAddrs {
- p, ok := e.Read()
- if !ok {
- t.Fatalf("expected (%d-th) leave message to be sent", i)
- }
-
- test.validateLeave(t, p, test.getAndCheckGroupAddress(t, seen, p))
- }
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Reports should be sent for the joined groups when the NIC is enabled.
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("EnableNIC(%d): %s", nicID, err)
- }
- reportCounter += uint64(len(test.multicastAddrs))
- if got := sentReportStat.Value(); got != reportCounter {
- t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
- }
- {
- seen := make(map[tcpip.Address]bool)
- for _, a := range test.multicastAddrs {
- seen[a] = false
- }
-
- for i := range test.multicastAddrs {
- p, ok := e.Read()
- if !ok {
- t.Fatalf("expected (%d-th) report message to be sent", i)
- }
-
- test.validateReport(t, p, test.getAndCheckGroupAddress(t, seen, p))
- }
- }
- if t.Failed() {
- t.FailNow()
- }
-
- // Joining/leaving a group while disabled should not send any messages.
- if err := s.DisableNIC(nicID); err != nil {
- t.Fatalf("DisableNIC(%d): %s", nicID, err)
- }
- leaveCounter += uint64(len(test.multicastAddrs))
- if got := sentLeaveStat.Value(); got != leaveCounter {
- t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
- }
- for i := range test.multicastAddrs {
- if _, ok := e.Read(); !ok {
- t.Fatalf("expected (%d-th) leave message to be sent", i)
- }
- }
- for _, a := range test.multicastAddrs {
- if err := s.LeaveGroup(test.protoNum, nicID, a); err != nil {
- t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, a, err)
- }
- if got := sentLeaveStat.Value(); got != leaveCounter {
- t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter)
- }
- if p, ok := e.Read(); ok {
- t.Fatalf("leaving group %s on disabled NIC sent unexpected packet = %#v", a, p.Pkt)
- }
- }
- if err := s.JoinGroup(test.protoNum, nicID, test.finalMulticastAddr); err != nil {
- t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.finalMulticastAddr, err)
- }
- if got := sentReportStat.Value(); got != reportCounter {
- t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); ok {
- t.Fatalf("joining group %s on disabled NIC sent unexpected packet = %#v", test.finalMulticastAddr, p.Pkt)
- }
-
- // A report should only be sent for the group we last joined after
- // enabling the NIC since the original groups were all left.
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("EnableNIC(%d): %s", nicID, err)
- }
- reportCounter++
- if got := sentReportStat.Value(); got != reportCounter {
- t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a report message to be sent")
- } else {
- test.validateReport(t, p, test.finalMulticastAddr)
- }
-
- clock.Advance(test.maxUnsolicitedResponseDelay)
- reportCounter++
- if got := sentReportStat.Value(); got != reportCounter {
- t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter)
- }
- if p, ok := e.Read(); !ok {
- t.Fatal("expected a report message to be sent")
- } else {
- test.validateReport(t, p, test.finalMulticastAddr)
- }
-
- // Should not send any more packets.
- clock.Advance(time.Hour)
- if p, ok := e.Read(); ok {
- t.Fatalf("sent unexpected packet = %#v", p)
- }
- })
- }
-}
-
-// TestMGPDisabledOnLoopback tests that the multicast group protocol is not
-// performed on loopback interfaces since they have no neighbours.
-func TestMGPDisabledOnLoopback(t *testing.T) {
- tests := []struct {
- name string
- protoNum tcpip.NetworkProtocolNumber
- multicastAddr tcpip.Address
- sentReportStat func(*stack.Stack) *tcpip.StatCounter
- }{
- {
- name: "IGMP",
- protoNum: ipv4.ProtocolNumber,
- multicastAddr: ipv4MulticastAddr1,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().IGMP.PacketsSent.V2MembershipReport
- },
- },
- {
- name: "MLD",
- protoNum: ipv6.ProtocolNumber,
- multicastAddr: ipv6MulticastAddr1,
- sentReportStat: func(s *stack.Stack) *tcpip.StatCounter {
- return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s, clock := createStackWithLinkEndpoint(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */, loopback.New())
-
- sentReportStat := test.sentReportStat(s)
- if got := sentReportStat.Value(); got != 0 {
- t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
- }
- clock.Advance(time.Hour)
- if got := sentReportStat.Value(); got != 0 {
- t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
- }
-
- // Test joining a specific group explicitly and verify that no reports are
- // sent.
- if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil {
- t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err)
- }
- if got := sentReportStat.Value(); got != 0 {
- t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
- }
- clock.Advance(time.Hour)
- if got := sentReportStat.Value(); got != 0 {
- t.Fatalf("got sentReportStat.Value() = %d, want = 0", got)
- }
- })
- }
-}
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
deleted file mode 100644
index fe98a52af..000000000
--- a/pkg/tcpip/ports/BUILD
+++ /dev/null
@@ -1,28 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "ports",
- srcs = [
- "flags.go",
- "ports.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/header",
- ],
-)
-
-go_test(
- name = "ports_test",
- srcs = ["ports_test.go"],
- library = ":ports",
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/testutil",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/ports/ports_state_autogen.go b/pkg/tcpip/ports/ports_state_autogen.go
new file mode 100644
index 000000000..2719f6c41
--- /dev/null
+++ b/pkg/tcpip/ports/ports_state_autogen.go
@@ -0,0 +1,42 @@
+// automatically generated by stateify.
+
+package ports
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (f *Flags) StateTypeName() string {
+ return "pkg/tcpip/ports.Flags"
+}
+
+func (f *Flags) StateFields() []string {
+ return []string{
+ "MostRecent",
+ "LoadBalanced",
+ "TupleOnly",
+ }
+}
+
+func (f *Flags) beforeSave() {}
+
+// +checklocksignore
+func (f *Flags) StateSave(stateSinkObject state.Sink) {
+ f.beforeSave()
+ stateSinkObject.Save(0, &f.MostRecent)
+ stateSinkObject.Save(1, &f.LoadBalanced)
+ stateSinkObject.Save(2, &f.TupleOnly)
+}
+
+func (f *Flags) afterLoad() {}
+
+// +checklocksignore
+func (f *Flags) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &f.MostRecent)
+ stateSourceObject.Load(1, &f.LoadBalanced)
+ stateSourceObject.Load(2, &f.TupleOnly)
+}
+
+func init() {
+ state.Register((*Flags)(nil))
+}
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
deleted file mode 100644
index a91b130df..000000000
--- a/pkg/tcpip/ports/ports_test.go
+++ /dev/null
@@ -1,525 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ports
-
-import (
- "math"
- "math/rand"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
-)
-
-const (
- fakeTransNumber tcpip.TransportProtocolNumber = 1
- fakeNetworkNumber tcpip.NetworkProtocolNumber = 2
-)
-
-var (
- fakeIPAddress = testutil.MustParse4("8.8.8.8")
- fakeIPAddress1 = testutil.MustParse4("8.8.8.9")
-)
-
-type portReserveTestAction struct {
- port uint16
- ip tcpip.Address
- want tcpip.Error
- flags Flags
- release bool
- device tcpip.NICID
- dest tcpip.FullAddress
-}
-
-func TestPortReservation(t *testing.T) {
- for _, test := range []struct {
- tname string
- actions []portReserveTestAction
- }{
- {
- tname: "bind to ip",
- actions: []portReserveTestAction{
- {port: 80, ip: fakeIPAddress, want: nil},
- {port: 80, ip: fakeIPAddress1, want: nil},
- /* N.B. Order of tests matters! */
- {port: 80, ip: anyIPAddress, want: &tcpip.ErrPortInUse{}},
- {port: 80, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}, flags: Flags{LoadBalanced: true}},
- },
- },
- {
- tname: "bind to inaddr any",
- actions: []portReserveTestAction{
- {port: 22, ip: anyIPAddress, want: nil},
- {port: 22, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}},
- /* release fakeIPAddress, but anyIPAddress is still inuse */
- {port: 22, ip: fakeIPAddress, release: true},
- {port: 22, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}},
- {port: 22, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}, flags: Flags{LoadBalanced: true}},
- /* Release port 22 from any IP address, then try to reserve fake IP address on 22 */
- {port: 22, ip: anyIPAddress, want: nil, release: true},
- {port: 22, ip: fakeIPAddress, want: nil},
- },
- }, {
- tname: "bind to zero port",
- actions: []portReserveTestAction{
- {port: 00, ip: fakeIPAddress, want: nil},
- {port: 00, ip: fakeIPAddress, want: nil},
- {port: 00, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- },
- }, {
- tname: "bind to ip with reuseport",
- actions: []portReserveTestAction{
- {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
-
- {port: 25, ip: fakeIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}},
- {port: 25, ip: anyIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}},
-
- {port: 25, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- },
- }, {
- tname: "bind to inaddr any with reuseport",
- actions: []portReserveTestAction{
- {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
-
- {port: 24, ip: anyIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}},
- {port: 24, ip: fakeIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}},
-
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, release: true, want: nil},
-
- {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true},
- {port: 24, ip: anyIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}},
-
- {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true},
- {port: 24, ip: anyIPAddress, flags: Flags{}, want: nil},
- },
- }, {
- tname: "bind twice with device fails",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 3, want: nil},
- {port: 24, ip: fakeIPAddress, device: 3, want: &tcpip.ErrPortInUse{}},
- },
- }, {
- tname: "bind to device",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 1, want: nil},
- {port: 24, ip: fakeIPAddress, device: 2, want: nil},
- },
- }, {
- tname: "bind to device and then without device",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}},
- },
- }, {
- tname: "bind without device",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, want: nil},
- {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}},
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
- {port: 24, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}},
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
- },
- }, {
- tname: "bind with device",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, want: nil},
- {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}},
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
- {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}},
- {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
- {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 789, want: nil},
- {port: 24, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}},
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
- },
- }, {
- tname: "bind with reuseport",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}},
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}},
- {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil},
- },
- }, {
- tname: "binding with reuseport and device",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}},
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}},
- {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 999, want: &tcpip.ErrPortInUse{}},
- },
- }, {
- tname: "mixing reuseport and not reuseport by binding to device",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 456, want: nil},
- {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 999, want: nil},
- },
- }, {
- tname: "can't bind to 0 after mixing reuseport and not reuseport",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 456, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
- },
- }, {
- tname: "bind and release",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: &tcpip.ErrPortInUse{}},
- {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil},
-
- // Release the bind to device 0 and try again.
- {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil, release: true},
- {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil},
- },
- }, {
- tname: "bind twice with reuseport once",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
- },
- }, {
- tname: "release an unreserved device",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil},
- // The below don't exist.
- {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil, release: true},
- {port: 9999, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true},
- // Release all.
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true},
- {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil, release: true},
- },
- }, {
- tname: "bind with reuseaddr",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}},
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{MostRecent: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}},
- {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: nil},
- },
- }, {
- tname: "bind twice with reuseaddr once",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: &tcpip.ErrPortInUse{}},
- },
- }, {
- tname: "bind with reuseaddr and reuseport",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
- },
- }, {
- tname: "bind with reuseaddr and reuseport, and then reuseaddr",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
- },
- }, {
- tname: "bind with reuseaddr and reuseport, and then reuseport",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: &tcpip.ErrPortInUse{}},
- },
- }, {
- tname: "bind with reuseaddr and reuseport twice, and then reuseaddr",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
- },
- }, {
- tname: "bind with reuseaddr and reuseport twice, and then reuseport",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- },
- }, {
- tname: "bind with reuseaddr, and then reuseaddr and reuseport",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
- },
- }, {
- tname: "bind with reuseport, and then reuseaddr and reuseport",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: &tcpip.ErrPortInUse{}},
- },
- }, {
- tname: "bind tuple with reuseaddr, and then wildcard with reuseaddr",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{}, want: nil},
- },
- }, {
- tname: "bind tuple with reuseaddr, and then wildcard",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil},
- {port: 24, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}},
- },
- }, {
- tname: "bind wildcard with reuseaddr, and then tuple with reuseaddr",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil},
- },
- }, {
- tname: "bind tuple with reuseaddr, and then wildcard",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: &tcpip.ErrPortInUse{}},
- },
- }, {
- tname: "bind two tuples with reuseaddr",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 25}, want: nil},
- },
- }, {
- tname: "bind two tuples",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil},
- {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 25}, want: nil},
- },
- }, {
- tname: "bind wildcard, and then tuple with reuseaddr",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: &tcpip.ErrPortInUse{}},
- },
- }, {
- tname: "bind wildcard twice with reuseaddr",
- actions: []portReserveTestAction{
- {port: 24, ip: anyIPAddress, flags: Flags{TupleOnly: true}, want: nil},
- {port: 24, ip: anyIPAddress, flags: Flags{TupleOnly: true}, want: nil},
- },
- },
- } {
- t.Run(test.tname, func(t *testing.T) {
- pm := NewPortManager()
- net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber}
- rng := rand.New(rand.NewSource(time.Now().UnixNano()))
-
- for _, test := range test.actions {
- first, _ := pm.PortRange()
- if test.release {
- portRes := Reservation{
- Networks: net,
- Transport: fakeTransNumber,
- Addr: test.ip,
- Port: test.port,
- Flags: test.flags,
- BindToDevice: test.device,
- Dest: test.dest,
- }
- pm.ReleasePort(portRes)
- continue
- }
- portRes := Reservation{
- Networks: net,
- Transport: fakeTransNumber,
- Addr: test.ip,
- Port: test.port,
- Flags: test.flags,
- BindToDevice: test.device,
- Dest: test.dest,
- }
- gotPort, err := pm.ReservePort(rng, portRes, nil /* testPort */)
- if diff := cmp.Diff(test.want, err); diff != "" {
- t.Fatalf("unexpected error from ReservePort(%+v, _), (-want, +got):\n%s", portRes, diff)
- }
- if test.port == 0 && (gotPort == 0 || gotPort < first) {
- t.Fatalf("ReservePort(%+v, _) = %d, want port number >= %d to be picked", portRes, gotPort, first)
- }
- }
- })
- }
-}
-
-func TestPickEphemeralPort(t *testing.T) {
- const (
- firstEphemeral = 32000
- numEphemeralPorts = 1000
- )
-
- for _, test := range []struct {
- name string
- f func(port uint16) (bool, tcpip.Error)
- wantErr tcpip.Error
- wantPort uint16
- }{
- {
- name: "no-port-available",
- f: func(port uint16) (bool, tcpip.Error) {
- return false, nil
- },
- wantErr: &tcpip.ErrNoPortAvailable{},
- },
- {
- name: "port-tester-error",
- f: func(port uint16) (bool, tcpip.Error) {
- return false, &tcpip.ErrBadBuffer{}
- },
- wantErr: &tcpip.ErrBadBuffer{},
- },
- {
- name: "only-port-16042-available",
- f: func(port uint16) (bool, tcpip.Error) {
- if port == firstEphemeral+42 {
- return true, nil
- }
- return false, nil
- },
- wantPort: firstEphemeral + 42,
- },
- {
- name: "only-port-under-16000-available",
- f: func(port uint16) (bool, tcpip.Error) {
- if port < firstEphemeral {
- return true, nil
- }
- return false, nil
- },
- wantErr: &tcpip.ErrNoPortAvailable{},
- },
- } {
- t.Run(test.name, func(t *testing.T) {
- pm := NewPortManager()
- rng := rand.New(rand.NewSource(time.Now().UnixNano()))
- if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil {
- t.Fatalf("failed to set ephemeral port range: %s", err)
- }
- port, err := pm.PickEphemeralPort(rng, test.f)
- if diff := cmp.Diff(test.wantErr, err); diff != "" {
- t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff)
- }
- if port != test.wantPort {
- t.Errorf("got PickEphemeralPort(..) = (%d, nil); want (%d, nil)", port, test.wantPort)
- }
- })
- }
-}
-
-func TestPickEphemeralPortStable(t *testing.T) {
- const (
- firstEphemeral = 32000
- numEphemeralPorts = 1000
- )
-
- for _, test := range []struct {
- name string
- f func(port uint16) (bool, tcpip.Error)
- wantErr tcpip.Error
- wantPort uint16
- }{
- {
- name: "no-port-available",
- f: func(port uint16) (bool, tcpip.Error) {
- return false, nil
- },
- wantErr: &tcpip.ErrNoPortAvailable{},
- },
- {
- name: "port-tester-error",
- f: func(port uint16) (bool, tcpip.Error) {
- return false, &tcpip.ErrBadBuffer{}
- },
- wantErr: &tcpip.ErrBadBuffer{},
- },
- {
- name: "only-port-16042-available",
- f: func(port uint16) (bool, tcpip.Error) {
- if port == firstEphemeral+42 {
- return true, nil
- }
- return false, nil
- },
- wantPort: firstEphemeral + 42,
- },
- {
- name: "only-port-under-16000-available",
- f: func(port uint16) (bool, tcpip.Error) {
- if port < firstEphemeral {
- return true, nil
- }
- return false, nil
- },
- wantErr: &tcpip.ErrNoPortAvailable{},
- },
- } {
- t.Run(test.name, func(t *testing.T) {
- pm := NewPortManager()
- if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil {
- t.Fatalf("failed to set ephemeral port range: %s", err)
- }
- portOffset := uint32(rand.Int31n(int32(numEphemeralPorts)))
- port, err := pm.PickEphemeralPortStable(portOffset, test.f)
- if diff := cmp.Diff(test.wantErr, err); diff != "" {
- t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff)
- }
- if port != test.wantPort {
- t.Errorf("got PickEphemeralPort(..) = (%d, nil); want (%d, nil)", port, test.wantPort)
- }
- })
- }
-}
-
-// TestOverflow addresses b/183593432, wherein an overflowing uint16 causes a
-// port allocation failure.
-func TestOverflow(t *testing.T) {
- // Use a small range and start at offsets that will cause an overflow.
- count := uint16(50)
- for offset := uint32(math.MaxUint16 - count); offset < math.MaxUint16; offset++ {
- reservedPorts := make(map[uint16]struct{})
- // Ensure we can reserve everything in the allowed range.
- for i := uint16(0); i < count; i++ {
- port, err := pickEphemeralPort(offset, firstEphemeral, count, func(port uint16) (bool, tcpip.Error) {
- if _, ok := reservedPorts[port]; !ok {
- reservedPorts[port] = struct{}{}
- return true, nil
- }
- return false, nil
- })
- if err != nil {
- t.Fatalf("port picking failed at iteration %d, for offset %d, len(reserved): %+v", i, offset, len(reservedPorts))
- }
- if port < firstEphemeral || port > firstEphemeral+count {
- t.Fatalf("reserved port %d, which is not in range [%d, %d]", port, firstEphemeral, firstEphemeral+count-1)
- }
- }
- }
-}
diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD
deleted file mode 100644
index db9b91815..000000000
--- a/pkg/tcpip/sample/tun_tcp_connect/BUILD
+++ /dev/null
@@ -1,21 +0,0 @@
-load("//tools:defs.bzl", "go_binary")
-
-package(licenses = ["notice"])
-
-go_binary(
- name = "tun_tcp_connect",
- srcs = ["main.go"],
- visibility = ["//:sandbox"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/fdbased",
- "//pkg/tcpip/link/rawfile",
- "//pkg/tcpip/link/sniffer",
- "//pkg/tcpip/link/tun",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/tcp",
- "//pkg/waiter",
- ],
-)
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
deleted file mode 100644
index 05b879543..000000000
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ /dev/null
@@ -1,224 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//go:build linux
-// +build linux
-
-// This sample creates a stack with TCP and IPv4 protocols on top of a TUN
-// device, and connects to a peer. Similar to "nc <address> <port>". While the
-// sample is running, attempts to connect to its IPv4 address will result in
-// a RST segment.
-//
-// As an example of how to run it, a TUN device can be created and enabled on
-// a linux host as follows (this only needs to be done once per boot):
-//
-// [sudo] ip tuntap add user <username> mode tun <device-name>
-// [sudo] ip link set <device-name> up
-// [sudo] ip addr add <ipv4-address>/<mask-length> dev <device-name>
-//
-// A concrete example:
-//
-// $ sudo ip tuntap add user wedsonaf mode tun tun0
-// $ sudo ip link set tun0 up
-// $ sudo ip addr add 192.168.1.1/24 dev tun0
-//
-// Then one can run tun_tcp_connect as such:
-//
-// $ ./tun/tun_tcp_connect tun0 192.168.1.2 0 192.168.1.1 1234
-//
-// This will attempt to connect to the linux host's stack. One can run nc in
-// listen mode to accept a connect from tun_tcp_connect and exchange data.
-package main
-
-import (
- "bytes"
- "fmt"
- "log"
- "math/rand"
- "net"
- "os"
- "strconv"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
- "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
- "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.dev/gvisor/pkg/tcpip/link/tun"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-// writer reads from standard input and writes to the endpoint until standard
-// input is closed. It signals that it's done by closing the provided channel.
-func writer(ch chan struct{}, ep tcpip.Endpoint) {
- defer func() {
- ep.Shutdown(tcpip.ShutdownWrite)
- close(ch)
- }()
-
- var b bytes.Buffer
- if err := func() error {
- for {
- if _, err := b.ReadFrom(os.Stdin); err != nil {
- return fmt.Errorf("b.ReadFrom failed: %w", err)
- }
-
- for b.Len() != 0 {
- if _, err := ep.Write(&b, tcpip.WriteOptions{Atomic: true}); err != nil {
- return fmt.Errorf("ep.Write failed: %s", err)
- }
- }
- }
- }(); err != nil {
- fmt.Println(err)
- }
-}
-
-func main() {
- if len(os.Args) != 6 {
- log.Fatal("Usage: ", os.Args[0], " <tun-device> <local-ipv4-address> <local-port> <remote-ipv4-address> <remote-port>")
- }
-
- tunName := os.Args[1]
- addrName := os.Args[2]
- portName := os.Args[3]
- remoteAddrName := os.Args[4]
- remotePortName := os.Args[5]
-
- rand.Seed(time.Now().UnixNano())
-
- addr := tcpip.Address(net.ParseIP(addrName).To4())
- remote := tcpip.FullAddress{
- NIC: 1,
- Addr: tcpip.Address(net.ParseIP(remoteAddrName).To4()),
- }
-
- var localPort uint16
- if v, err := strconv.Atoi(portName); err != nil {
- log.Fatalf("Unable to convert port %v: %v", portName, err)
- } else {
- localPort = uint16(v)
- }
-
- if v, err := strconv.Atoi(remotePortName); err != nil {
- log.Fatalf("Unable to convert port %v: %v", remotePortName, err)
- } else {
- remote.Port = uint16(v)
- }
-
- // Create the stack with ipv4 and tcp protocols, then add a tun-based
- // NIC and ipv4 address.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
- })
-
- mtu, err := rawfile.GetMTU(tunName)
- if err != nil {
- log.Fatal(err)
- }
-
- fd, err := tun.Open(tunName)
- if err != nil {
- log.Fatal(err)
- }
-
- linkEP, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu})
- if err != nil {
- log.Fatal(err)
- }
- if err := s.CreateNIC(1, sniffer.New(linkEP)); err != nil {
- log.Fatal(err)
- }
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: addr.WithPrefix(),
- }
- if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
- log.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
- }
-
- // Add default route.
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: 1,
- },
- })
-
- // Create TCP endpoint.
- var wq waiter.Queue
- ep, e := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if e != nil {
- log.Fatal(e)
- }
-
- // Bind if a port is specified.
- if localPort != 0 {
- if err := ep.Bind(tcpip.FullAddress{0, "", localPort}); err != nil {
- log.Fatal("Bind failed: ", err)
- }
- }
-
- // Issue connect request and wait for it to complete.
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- wq.EventRegister(&waitEntry, waiter.WritableEvents)
- terr := ep.Connect(remote)
- if _, ok := terr.(*tcpip.ErrConnectStarted); ok {
- fmt.Println("Connect is pending...")
- <-notifyCh
- terr = ep.LastError()
- }
- wq.EventUnregister(&waitEntry)
-
- if terr != nil {
- log.Fatal("Unable to connect: ", terr)
- }
-
- fmt.Println("Connected")
-
- // Start the writer in its own goroutine.
- writerCompletedCh := make(chan struct{})
- go writer(writerCompletedCh, ep) // S/R-SAFE: sample code.
-
- // Read data and write to standard output until the peer closes the
- // connection from its side.
- wq.EventRegister(&waitEntry, waiter.ReadableEvents)
- for {
- _, err := ep.Read(os.Stdout, tcpip.ReadOptions{})
- if err != nil {
- if _, ok := err.(*tcpip.ErrClosedForReceive); ok {
- break
- }
-
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
- <-notifyCh
- continue
- }
-
- log.Fatal("Read() failed:", err)
- }
- }
- wq.EventUnregister(&waitEntry)
-
- // The reader has completed. Now wait for the writer as well.
- <-writerCompletedCh
-
- ep.Close()
-}
diff --git a/pkg/tcpip/sample/tun_tcp_echo/BUILD b/pkg/tcpip/sample/tun_tcp_echo/BUILD
deleted file mode 100644
index 43264b76d..000000000
--- a/pkg/tcpip/sample/tun_tcp_echo/BUILD
+++ /dev/null
@@ -1,21 +0,0 @@
-load("//tools:defs.bzl", "go_binary")
-
-package(licenses = ["notice"])
-
-go_binary(
- name = "tun_tcp_echo",
- srcs = ["main.go"],
- visibility = ["//:sandbox"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/link/fdbased",
- "//pkg/tcpip/link/rawfile",
- "//pkg/tcpip/link/tun",
- "//pkg/tcpip/network/arp",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/tcp",
- "//pkg/waiter",
- ],
-)
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
deleted file mode 100644
index a72afadda..000000000
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ /dev/null
@@ -1,235 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//go:build linux
-// +build linux
-
-// This sample creates a stack with TCP and IPv4 protocols on top of a TUN
-// device, and listens on a port. Data received by the server in the accepted
-// connections is echoed back to the clients.
-package main
-
-import (
- "bytes"
- "flag"
- "io"
- "log"
- "math/rand"
- "net"
- "os"
- "strconv"
- "strings"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
- "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
- "gvisor.dev/gvisor/pkg/tcpip/link/tun"
- "gvisor.dev/gvisor/pkg/tcpip/network/arp"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-var tap = flag.Bool("tap", false, "use tap istead of tun")
-var mac = flag.String("mac", "aa:00:01:01:01:01", "mac address to use in tap device")
-
-type endpointWriter struct {
- ep tcpip.Endpoint
-}
-
-type tcpipError struct {
- inner tcpip.Error
-}
-
-func (e *tcpipError) Error() string {
- return e.inner.String()
-}
-
-func (e *endpointWriter) Write(p []byte) (int, error) {
- var r bytes.Reader
- r.Reset(p)
- n, err := e.ep.Write(&r, tcpip.WriteOptions{})
- if err != nil {
- return int(n), &tcpipError{
- inner: err,
- }
- }
- if n != int64(len(p)) {
- return int(n), io.ErrShortWrite
- }
- return int(n), nil
-}
-
-func echo(wq *waiter.Queue, ep tcpip.Endpoint) {
- defer ep.Close()
-
- // Create wait queue entry that notifies a channel.
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
-
- wq.EventRegister(&waitEntry, waiter.ReadableEvents)
- defer wq.EventUnregister(&waitEntry)
-
- w := endpointWriter{
- ep: ep,
- }
-
- for {
- _, err := ep.Read(&w, tcpip.ReadOptions{})
- if err != nil {
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
- <-notifyCh
- continue
- }
-
- return
- }
- }
-}
-
-func main() {
- flag.Parse()
- if len(flag.Args()) != 3 {
- log.Fatal("Usage: ", os.Args[0], " <tun-device> <local-address> <local-port>")
- }
-
- tunName := flag.Arg(0)
- addrName := flag.Arg(1)
- portName := flag.Arg(2)
-
- rand.Seed(time.Now().UnixNano())
-
- // Parse the mac address.
- maddr, err := net.ParseMAC(*mac)
- if err != nil {
- log.Fatalf("Bad MAC address: %v", *mac)
- }
-
- // Parse the IP address. Support both ipv4 and ipv6.
- parsedAddr := net.ParseIP(addrName)
- if parsedAddr == nil {
- log.Fatalf("Bad IP address: %v", addrName)
- }
-
- var addrWithPrefix tcpip.AddressWithPrefix
- var proto tcpip.NetworkProtocolNumber
- if parsedAddr.To4() != nil {
- addrWithPrefix = tcpip.Address(parsedAddr.To4()).WithPrefix()
- proto = ipv4.ProtocolNumber
- } else if parsedAddr.To16() != nil {
- addrWithPrefix = tcpip.Address(parsedAddr.To16()).WithPrefix()
- proto = ipv6.ProtocolNumber
- } else {
- log.Fatalf("Unknown IP type: %v", addrName)
- }
-
- localPort, err := strconv.Atoi(portName)
- if err != nil {
- log.Fatalf("Unable to convert port %v: %v", portName, err)
- }
-
- // Create the stack with ip and tcp protocols, then add a tun-based
- // NIC and address.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol, arp.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
- })
-
- mtu, err := rawfile.GetMTU(tunName)
- if err != nil {
- log.Fatal(err)
- }
-
- var fd int
- if *tap {
- fd, err = tun.OpenTAP(tunName)
- } else {
- fd, err = tun.Open(tunName)
- }
- if err != nil {
- log.Fatal(err)
- }
-
- linkEP, err := fdbased.New(&fdbased.Options{
- FDs: []int{fd},
- MTU: mtu,
- EthernetHeader: *tap,
- Address: tcpip.LinkAddress(maddr),
- })
- if err != nil {
- log.Fatal(err)
- }
- if err := s.CreateNIC(1, linkEP); err != nil {
- log.Fatal(err)
- }
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: proto,
- AddressWithPrefix: addrWithPrefix,
- }
- if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
- log.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
- }
-
- subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addrWithPrefix.Address))), tcpip.AddressMask(strings.Repeat("\x00", len(addrWithPrefix.Address))))
- if err != nil {
- log.Fatal(err)
- }
-
- // Add default route.
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: subnet,
- NIC: 1,
- },
- })
-
- // Create TCP endpoint, bind it, then start listening.
- var wq waiter.Queue
- ep, e := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq)
- if e != nil {
- log.Fatal(e)
- }
-
- defer ep.Close()
-
- if err := ep.Bind(tcpip.FullAddress{0, "", uint16(localPort)}); err != nil {
- log.Fatal("Bind failed: ", err)
- }
-
- if err := ep.Listen(10); err != nil {
- log.Fatal("Listen failed: ", err)
- }
-
- // Wait for connections to appear.
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- wq.EventRegister(&waitEntry, waiter.ReadableEvents)
- defer wq.EventUnregister(&waitEntry)
-
- for {
- n, wq, err := ep.Accept(nil)
- if err != nil {
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
- <-notifyCh
- continue
- }
-
- log.Fatal("Accept() failed:", err)
- }
-
- go echo(wq, n) // S/R-SAFE: sample code.
- }
-}
diff --git a/pkg/tcpip/seqnum/BUILD b/pkg/tcpip/seqnum/BUILD
deleted file mode 100644
index 45f503845..000000000
--- a/pkg/tcpip/seqnum/BUILD
+++ /dev/null
@@ -1,9 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "seqnum",
- srcs = ["seqnum.go"],
- visibility = ["//visibility:public"],
-)
diff --git a/pkg/tcpip/seqnum/seqnum_state_autogen.go b/pkg/tcpip/seqnum/seqnum_state_autogen.go
new file mode 100644
index 000000000..23e79811d
--- /dev/null
+++ b/pkg/tcpip/seqnum/seqnum_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package seqnum
diff --git a/pkg/tcpip/sock_err_list.go b/pkg/tcpip/sock_err_list.go
new file mode 100644
index 000000000..0be1993af
--- /dev/null
+++ b/pkg/tcpip/sock_err_list.go
@@ -0,0 +1,221 @@
+package tcpip
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type sockErrorElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (sockErrorElementMapper) linkerFor(elem *SockError) *SockError { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type sockErrorList struct {
+ head *SockError
+ tail *SockError
+}
+
+// Reset resets list l to the empty state.
+func (l *sockErrorList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+//
+//go:nosplit
+func (l *sockErrorList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+//
+//go:nosplit
+func (l *sockErrorList) Front() *SockError {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+//
+//go:nosplit
+func (l *sockErrorList) Back() *SockError {
+ return l.tail
+}
+
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+//
+//go:nosplit
+func (l *sockErrorList) Len() (count int) {
+ for e := l.Front(); e != nil; e = (sockErrorElementMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
+// PushFront inserts the element e at the front of list l.
+//
+//go:nosplit
+func (l *sockErrorList) PushFront(e *SockError) {
+ linker := sockErrorElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
+ if l.head != nil {
+ sockErrorElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+//
+//go:nosplit
+func (l *sockErrorList) PushBack(e *SockError) {
+ linker := sockErrorElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
+ if l.tail != nil {
+ sockErrorElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+//
+//go:nosplit
+func (l *sockErrorList) PushBackList(m *sockErrorList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ sockErrorElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ sockErrorElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+//
+//go:nosplit
+func (l *sockErrorList) InsertAfter(b, e *SockError) {
+ bLinker := sockErrorElementMapper{}.linkerFor(b)
+ eLinker := sockErrorElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
+
+ if a != nil {
+ sockErrorElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+//
+//go:nosplit
+func (l *sockErrorList) InsertBefore(a, e *SockError) {
+ aLinker := sockErrorElementMapper{}.linkerFor(a)
+ eLinker := sockErrorElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
+
+ if b != nil {
+ sockErrorElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+//
+//go:nosplit
+func (l *sockErrorList) Remove(e *SockError) {
+ linker := sockErrorElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
+
+ if prev != nil {
+ sockErrorElementMapper{}.linkerFor(prev).SetNext(next)
+ } else if l.head == e {
+ l.head = next
+ }
+
+ if next != nil {
+ sockErrorElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else if l.tail == e {
+ l.tail = prev
+ }
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type sockErrorEntry struct {
+ next *SockError
+ prev *SockError
+}
+
+// Next returns the entry that follows e in the list.
+//
+//go:nosplit
+func (e *sockErrorEntry) Next() *SockError {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *sockErrorEntry) Prev() *SockError {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+//
+//go:nosplit
+func (e *sockErrorEntry) SetNext(elem *SockError) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *sockErrorEntry) SetPrev(elem *SockError) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
deleted file mode 100644
index 6c42ab29b..000000000
--- a/pkg/tcpip/stack/BUILD
+++ /dev/null
@@ -1,153 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test", "most_shards")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-package(licenses = ["notice"])
-
-go_template_instance(
- name = "neighbor_entry_list",
- out = "neighbor_entry_list.go",
- package = "stack",
- prefix = "neighborEntry",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*neighborEntry",
- "Linker": "*neighborEntry",
- },
-)
-
-go_template_instance(
- name = "packet_buffer_list",
- out = "packet_buffer_list.go",
- package = "stack",
- prefix = "PacketBuffer",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*PacketBuffer",
- "Linker": "*PacketBuffer",
- },
-)
-
-go_template_instance(
- name = "tuple_list",
- out = "tuple_list.go",
- package = "stack",
- prefix = "tuple",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*tuple",
- "Linker": "*tuple",
- },
-)
-
-go_library(
- name = "stack",
- srcs = [
- "addressable_endpoint_state.go",
- "conntrack.go",
- "headertype_string.go",
- "hook_string.go",
- "icmp_rate_limit.go",
- "iptables.go",
- "iptables_state.go",
- "iptables_targets.go",
- "iptables_types.go",
- "neighbor_cache.go",
- "neighbor_entry.go",
- "neighbor_entry_list.go",
- "neighborstate_string.go",
- "nic.go",
- "nic_stats.go",
- "nud.go",
- "packet_buffer.go",
- "packet_buffer_list.go",
- "packet_buffer_unsafe.go",
- "pending_packets.go",
- "rand.go",
- "registration.go",
- "route.go",
- "stack.go",
- "stack_global_state.go",
- "stack_options.go",
- "tcp.go",
- "transport_demuxer.go",
- "tuple_list.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/atomicbitops",
- "//pkg/buffer",
- "//pkg/ilist",
- "//pkg/log",
- "//pkg/rand",
- "//pkg/sleep",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/hash/jenkins",
- "//pkg/tcpip/header",
- "//pkg/tcpip/internal/tcp",
- "//pkg/tcpip/ports",
- "//pkg/tcpip/seqnum",
- "//pkg/tcpip/transport/tcpconntrack",
- "//pkg/waiter",
- "@org_golang_x_time//rate:go_default_library",
- ],
-)
-
-go_test(
- name = "stack_x_test",
- size = "small",
- srcs = [
- "addressable_endpoint_state_test.go",
- "ndp_test.go",
- "nud_test.go",
- "stack_test.go",
- "transport_demuxer_test.go",
- "transport_test.go",
- ],
- shard_count = most_shards,
- deps = [
- ":stack",
- "//pkg/rand",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/faketime",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/loopback",
- "//pkg/tcpip/network/arp",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/ports",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/icmp",
- "//pkg/tcpip/transport/udp",
- "//pkg/waiter",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
-
-go_test(
- name = "stack_test",
- size = "small",
- srcs = [
- "forwarding_test.go",
- "neighbor_cache_test.go",
- "neighbor_entry_test.go",
- "nic_test.go",
- "packet_buffer_test.go",
- ],
- library = ":stack",
- deps = [
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/faketime",
- "//pkg/tcpip/header",
- "//pkg/tcpip/testutil",
- "@com_github_google_go_cmp//cmp:go_default_library",
- "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go
deleted file mode 100644
index c55f85743..000000000
--- a/pkg/tcpip/stack/addressable_endpoint_state_test.go
+++ /dev/null
@@ -1,61 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package stack_test
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-// TestAddressableEndpointStateCleanup tests that cleaning up an addressable
-// endpoint state removes permanent addresses and leaves groups.
-func TestAddressableEndpointStateCleanup(t *testing.T) {
- var ep fakeNetworkEndpoint
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
-
- var s stack.AddressableEndpointState
- s.Init(&ep)
-
- addr := tcpip.AddressWithPrefix{
- Address: "\x01",
- PrefixLen: 8,
- }
-
- {
- ep, err := s.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{PEB: stack.NeverPrimaryEndpoint})
- if err != nil {
- t.Fatalf("s.AddAndAcquirePermanentAddress(%s, AddressProperties{PEB: NeverPrimaryEndpoint}): %s", addr, err)
- }
- // We don't need the address endpoint.
- ep.DecRef()
- }
- {
- ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint)
- if ep == nil {
- t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = nil, want = non-nil", addr.Address)
- }
- ep.DecRef()
- }
-
- s.Cleanup()
- if ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint); ep != nil {
- ep.DecRef()
- t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix())
- }
-}
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
deleted file mode 100644
index c2f1f4798..000000000
--- a/pkg/tcpip/stack/forwarding_test.go
+++ /dev/null
@@ -1,804 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package stack
-
-import (
- "encoding/binary"
- "math"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/header"
-)
-
-const (
- fwdTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
- fwdTestNetHeaderLen = 12
- fwdTestNetDefaultPrefixLen = 8
-
- // fwdTestNetDefaultMTU is the MTU, in bytes, used throughout the tests,
- // except where another value is explicitly used. It is chosen to match
- // the MTU of loopback interfaces on linux systems.
- fwdTestNetDefaultMTU = 65536
-
- dstAddrOffset = 0
- srcAddrOffset = 1
- protocolNumberOffset = 2
-)
-
-var _ LinkAddressResolver = (*fwdTestNetworkEndpoint)(nil)
-var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil)
-
-// fwdTestNetworkEndpoint is a network-layer protocol endpoint.
-// Headers of this protocol are fwdTestNetHeaderLen bytes, but we currently only
-// use the first three: destination address, source address, and transport
-// protocol. They're all one byte fields to simplify parsing.
-type fwdTestNetworkEndpoint struct {
- AddressableEndpointState
-
- nic NetworkInterface
- proto *fwdTestNetworkProtocol
- dispatcher TransportDispatcher
-
- mu struct {
- sync.RWMutex
- forwarding bool
- }
-}
-
-func (*fwdTestNetworkEndpoint) Enable() tcpip.Error {
- return nil
-}
-
-func (*fwdTestNetworkEndpoint) Enabled() bool {
- return true
-}
-
-func (*fwdTestNetworkEndpoint) Disable() {}
-
-func (f *fwdTestNetworkEndpoint) MTU() uint32 {
- return f.nic.MTU() - uint32(f.MaxHeaderLength())
-}
-
-func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
- return 123
-}
-
-func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) {
- if _, _, ok := f.proto.Parse(pkt); !ok {
- return
- }
-
- netHdr := pkt.NetworkHeader().View()
- _, dst := f.proto.ParseAddresses(netHdr)
-
- addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), CanBePrimaryEndpoint)
- if addressEndpoint != nil {
- addressEndpoint.DecRef()
- // Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(netHdr[protocolNumberOffset]), pkt)
- return
- }
-
- r, err := f.proto.stack.FindRoute(0, "", dst, fwdTestNetNumber, false /* multicastLoop */)
- if err != nil {
- return
- }
- defer r.Release()
-
- vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
- pkt = NewPacketBuffer(PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: vv.ToView().ToVectorisedView(),
- })
- // TODO(gvisor.dev/issue/1085) Decrease the TTL field in forwarded packets.
- _ = r.WriteHeaderIncludedPacket(pkt)
-}
-
-func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
- return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen
-}
-
-func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
- return f.proto.Number()
-}
-
-func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error {
- // Add the protocol's header to the packet and send it to the link
- // endpoint.
- b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen)
- b[dstAddrOffset] = r.RemoteAddress()[0]
- b[srcAddrOffset] = r.LocalAddress()[0]
- b[protocolNumberOffset] = byte(params.Protocol)
-
- return f.nic.WritePacket(r, fwdTestNetNumber, pkt)
-}
-
-// WritePackets implements LinkEndpoint.WritePackets.
-func (*fwdTestNetworkEndpoint) WritePackets(*Route, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) {
- panic("not implemented")
-}
-
-func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) tcpip.Error {
- // The network header should not already be populated.
- if _, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen); !ok {
- return &tcpip.ErrMalformedHeader{}
- }
-
- return f.nic.WritePacket(r, fwdTestNetNumber, pkt)
-}
-
-func (f *fwdTestNetworkEndpoint) Close() {
- f.AddressableEndpointState.Cleanup()
-}
-
-// Stats implements stack.NetworkEndpoint.
-func (*fwdTestNetworkEndpoint) Stats() NetworkEndpointStats {
- return &fwdTestNetworkEndpointStats{}
-}
-
-var _ NetworkEndpointStats = (*fwdTestNetworkEndpointStats)(nil)
-
-type fwdTestNetworkEndpointStats struct{}
-
-// IsNetworkEndpointStats implements stack.NetworkEndpointStats.
-func (*fwdTestNetworkEndpointStats) IsNetworkEndpointStats() {}
-
-var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil)
-
-// fwdTestNetworkProtocol is a network-layer protocol that implements Address
-// resolution.
-type fwdTestNetworkProtocol struct {
- stack *Stack
-
- neigh *neighborCache
- addrResolveDelay time.Duration
- onLinkAddressResolved func(*neighborCache, tcpip.Address, tcpip.LinkAddress)
- onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool)
-}
-
-func (*fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
- return fwdTestNetNumber
-}
-
-func (*fwdTestNetworkProtocol) MinimumPacketSize() int {
- return fwdTestNetHeaderLen
-}
-
-func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
- return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
-}
-
-func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
- netHeader, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen)
- if !ok {
- return 0, false, false
- }
- return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true
-}
-
-func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, dispatcher TransportDispatcher) NetworkEndpoint {
- e := &fwdTestNetworkEndpoint{
- nic: nic,
- proto: f,
- dispatcher: dispatcher,
- }
- e.AddressableEndpointState.Init(e)
- return e
-}
-
-func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error {
- return &tcpip.ErrUnknownProtocolOption{}
-}
-
-func (*fwdTestNetworkProtocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error {
- return &tcpip.ErrUnknownProtocolOption{}
-}
-
-func (*fwdTestNetworkProtocol) Close() {}
-
-func (*fwdTestNetworkProtocol) Wait() {}
-
-func (f *fwdTestNetworkEndpoint) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error {
- if fn := f.proto.onLinkAddressResolved; fn != nil {
- f.proto.stack.clock.AfterFunc(f.proto.addrResolveDelay, func() {
- fn(f.proto.neigh, addr, remoteLinkAddr)
- })
- }
- return nil
-}
-
-func (f *fwdTestNetworkEndpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
- if fn := f.proto.onResolveStaticAddress; fn != nil {
- return fn(addr)
- }
- return "", false
-}
-
-func (*fwdTestNetworkEndpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
- return fwdTestNetNumber
-}
-
-// Forwarding implements stack.ForwardingNetworkEndpoint.
-func (f *fwdTestNetworkEndpoint) Forwarding() bool {
- f.mu.RLock()
- defer f.mu.RUnlock()
- return f.mu.forwarding
-
-}
-
-// SetForwarding implements stack.ForwardingNetworkEndpoint.
-func (f *fwdTestNetworkEndpoint) SetForwarding(v bool) {
- f.mu.Lock()
- defer f.mu.Unlock()
- f.mu.forwarding = v
-}
-
-// fwdTestPacketInfo holds all the information about an outbound packet.
-type fwdTestPacketInfo struct {
- RemoteLinkAddress tcpip.LinkAddress
- LocalLinkAddress tcpip.LinkAddress
- Pkt *PacketBuffer
-}
-
-var _ LinkEndpoint = (*fwdTestLinkEndpoint)(nil)
-
-type fwdTestLinkEndpoint struct {
- dispatcher NetworkDispatcher
- mtu uint32
- linkAddr tcpip.LinkAddress
-
- // C is where outbound packets are queued.
- C chan fwdTestPacketInfo
-}
-
-// InjectInbound injects an inbound packet.
-func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
- e.InjectLinkAddr(protocol, "", pkt)
-}
-
-// InjectLinkAddr injects an inbound packet with a remote link address.
-func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *PacketBuffer) {
- e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt)
-}
-
-// Attach saves the stack network-layer dispatcher for use later when packets
-// are injected.
-func (e *fwdTestLinkEndpoint) Attach(dispatcher NetworkDispatcher) {
- e.dispatcher = dispatcher
-}
-
-// IsAttached implements stack.LinkEndpoint.IsAttached.
-func (e *fwdTestLinkEndpoint) IsAttached() bool {
- return e.dispatcher != nil
-}
-
-// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
-// during construction.
-func (e *fwdTestLinkEndpoint) MTU() uint32 {
- return e.mtu
-}
-
-// Capabilities implements stack.LinkEndpoint.Capabilities.
-func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities {
- caps := LinkEndpointCapabilities(0)
- return caps | CapabilityResolutionRequired
-}
-
-// MaxHeaderLength returns the maximum size of the link layer header. Given it
-// doesn't have a header, it just returns 0.
-func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 {
- return 0
-}
-
-// LinkAddress returns the link address of this endpoint.
-func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress {
- return e.linkAddr
-}
-
-func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
- p := fwdTestPacketInfo{
- RemoteLinkAddress: r.RemoteLinkAddress,
- LocalLinkAddress: r.LocalLinkAddress,
- Pkt: pkt,
- }
-
- select {
- case e.C <- p:
- default:
- }
-
- return nil
-}
-
-// WritePackets stores outbound packets into the channel.
-func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- n := 0
- for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.WritePacket(r, protocol, pkt)
- n++
- }
-
- return n, nil
-}
-
-func (*fwdTestLinkEndpoint) WriteRawPacket(*PacketBuffer) tcpip.Error {
- return &tcpip.ErrNotSupported{}
-}
-
-// Wait implements stack.LinkEndpoint.Wait.
-func (*fwdTestLinkEndpoint) Wait() {}
-
-// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
-func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
- panic("not implemented")
-}
-
-// AddHeader implements stack.LinkEndpoint.AddHeader.
-func (e *fwdTestLinkEndpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *PacketBuffer) {
- panic("not implemented")
-}
-
-func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.ManualClock, *fwdTestLinkEndpoint, *fwdTestLinkEndpoint) {
- clock := faketime.NewManualClock()
- // Create a stack with the network protocol and two NICs.
- s := New(Options{
- NetworkProtocols: []NetworkProtocolFactory{func(s *Stack) NetworkProtocol {
- proto.stack = s
- return proto
- }},
- Clock: clock,
- })
-
- protoNum := proto.Number()
- if err := s.SetForwardingDefaultAndAllNICs(protoNum, true); err != nil {
- t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", protoNum, err)
- }
-
- // NIC 1 has the link address "a", and added the network address 1.
- ep1 := &fwdTestLinkEndpoint{
- C: make(chan fwdTestPacketInfo, 300),
- mtu: fwdTestNetDefaultMTU,
- linkAddr: "a",
- }
- if err := s.CreateNIC(1, ep1); err != nil {
- t.Fatal("CreateNIC #1 failed:", err)
- }
- protocolAddr1 := tcpip.ProtocolAddress{
- Protocol: fwdTestNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: "\x01",
- PrefixLen: fwdTestNetDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(1, protocolAddr1, AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
- }
-
- // NIC 2 has the link address "b", and added the network address 2.
- ep2 := &fwdTestLinkEndpoint{
- C: make(chan fwdTestPacketInfo, 300),
- mtu: fwdTestNetDefaultMTU,
- linkAddr: "b",
- }
- if err := s.CreateNIC(2, ep2); err != nil {
- t.Fatal("CreateNIC #2 failed:", err)
- }
- protocolAddr2 := tcpip.ProtocolAddress{
- Protocol: fwdTestNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: "\x02",
- PrefixLen: fwdTestNetDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(2, protocolAddr2, AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err)
- }
-
- nic, ok := s.nics[2]
- if !ok {
- t.Fatal("NIC 2 does not exist")
- }
-
- if l, ok := nic.linkAddrResolvers[fwdTestNetNumber]; ok {
- proto.neigh = &l.neigh
- }
-
- // Route all packets to NIC 2.
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: 2}})
- }
-
- return clock, ep1, ep2
-}
-
-func TestForwardingWithStaticResolver(t *testing.T) {
- // Create a network protocol with a static resolver.
- proto := &fwdTestNetworkProtocol{
- onResolveStaticAddress:
- // The network address 3 is resolved to the link address "c".
- func(addr tcpip.Address) (tcpip.LinkAddress, bool) {
- if addr == "\x03" {
- return "c", true
- }
- return "", false
- },
- }
-
- clock, ep1, ep2 := fwdTestNetFactory(t, proto)
-
- // Inject an inbound packet to address 3 on NIC 1, and see if it is
- // forwarded to NIC 2.
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 3
- ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- var p fwdTestPacketInfo
-
- clock.Advance(proto.addrResolveDelay)
- select {
- case p = <-ep2.C:
- default:
- t.Fatal("packet not forwarded")
- }
-
- // Test that the static address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
- }
-}
-
-func TestForwardingWithFakeResolver(t *testing.T) {
- proto := fwdTestNetworkProtocol{
- addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
- t.Helper()
- if len(linkAddr) != 0 {
- t.Fatalf("got linkAddr=%q, want unspecified", linkAddr)
- }
- // Any address will be resolved to the link address "c".
- neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- },
- }
- clock, ep1, ep2 := fwdTestNetFactory(t, &proto)
-
- // Inject an inbound packet to address 3 on NIC 1, and see if it is
- // forwarded to NIC 2.
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 3
- ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- var p fwdTestPacketInfo
-
- clock.Advance(proto.addrResolveDelay)
- select {
- case p = <-ep2.C:
- default:
- t.Fatal("packet not forwarded")
- }
-
- // Test that the address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
- }
-}
-
-func TestForwardingWithNoResolver(t *testing.T) {
- // Create a network protocol without a resolver.
- proto := &fwdTestNetworkProtocol{}
-
- // Whether or not we use the neighbor cache here does not matter since
- // neither linkAddrCache nor neighborCache will be used.
- clock, ep1, ep2 := fwdTestNetFactory(t, proto)
-
- // inject an inbound packet to address 3 on NIC 1, and see if it is
- // forwarded to NIC 2.
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 3
- ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- clock.Advance(proto.addrResolveDelay)
- select {
- case <-ep2.C:
- t.Fatal("Packet should not be forwarded")
- default:
- }
-}
-
-func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) {
- proto := &fwdTestNetworkProtocol{
- addrResolveDelay: 50 * time.Millisecond,
- onLinkAddressResolved: func(*neighborCache, tcpip.Address, tcpip.LinkAddress) {
- // Don't resolve the link address.
- },
- }
-
- clock, ep1, ep2 := fwdTestNetFactory(t, proto)
-
- const numPackets int = 5
- // These packets will all be enqueued in the packet queue to wait for link
- // address resolution.
- for i := 0; i < numPackets; i++ {
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 3
- ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- }
-
- // All packets should fail resolution.
- for i := 0; i < numPackets; i++ {
- clock.Advance(proto.addrResolveDelay)
- select {
- case got := <-ep2.C:
- t.Fatalf("got %#v; packets should have failed resolution and not been forwarded", got)
- default:
- }
- }
-}
-
-func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
- proto := fwdTestNetworkProtocol{
- addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
- t.Helper()
- if len(linkAddr) != 0 {
- t.Fatalf("got linkAddr=%q, want unspecified", linkAddr)
- }
- // Only packets to address 3 will be resolved to the
- // link address "c".
- if addr == "\x03" {
- neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- }
- },
- }
- clock, ep1, ep2 := fwdTestNetFactory(t, &proto)
-
- // Inject an inbound packet to address 4 on NIC 1. This packet should
- // not be forwarded.
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 4
- ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- // Inject an inbound packet to address 3 on NIC 1, and see if it is
- // forwarded to NIC 2.
- buf = buffer.NewView(30)
- buf[dstAddrOffset] = 3
- ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- var p fwdTestPacketInfo
-
- clock.Advance(proto.addrResolveDelay)
- select {
- case p = <-ep2.C:
- default:
- t.Fatal("packet not forwarded")
- }
-
- if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 {
- t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset])
- }
-
- // Test that the address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
- }
-}
-
-func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
- proto := fwdTestNetworkProtocol{
- addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
- t.Helper()
- if len(linkAddr) != 0 {
- t.Fatalf("got linkAddr=%q, want unspecified", linkAddr)
- }
- // Any packets will be resolved to the link address "c".
- neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- },
- }
- clock, ep1, ep2 := fwdTestNetFactory(t, &proto)
-
- // Inject two inbound packets to address 3 on NIC 1.
- for i := 0; i < 2; i++ {
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 3
- ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- }
-
- for i := 0; i < 2; i++ {
- var p fwdTestPacketInfo
-
- clock.Advance(proto.addrResolveDelay)
- select {
- case p = <-ep2.C:
- default:
- t.Fatal("packet not forwarded")
- }
-
- if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 {
- t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset])
- }
-
- // Test that the address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
- }
- }
-}
-
-func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
- proto := fwdTestNetworkProtocol{
- addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
- t.Helper()
- if len(linkAddr) != 0 {
- t.Fatalf("got linkAddr=%q, want unspecified", linkAddr)
- }
- // Any packets will be resolved to the link address "c".
- neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- },
- }
- clock, ep1, ep2 := fwdTestNetFactory(t, &proto)
-
- for i := 0; i < maxPendingPacketsPerResolution+5; i++ {
- // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1.
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 3
- // Set the packet sequence number.
- binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i))
- ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- }
-
- for i := 0; i < maxPendingPacketsPerResolution; i++ {
- var p fwdTestPacketInfo
-
- clock.Advance(proto.addrResolveDelay)
- select {
- case p = <-ep2.C:
- default:
- t.Fatal("packet not forwarded")
- }
-
- b := PayloadSince(p.Pkt.NetworkHeader())
- if b[dstAddrOffset] != 3 {
- t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset])
- }
- if len(b) < fwdTestNetHeaderLen+2 {
- t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b)
- }
- seqNumBuf := b[fwdTestNetHeaderLen:]
-
- // The first 5 packets should not be forwarded so the sequence number should
- // start with 5.
- want := uint16(i + 5)
- if n := binary.BigEndian.Uint16(seqNumBuf); n != want {
- t.Fatalf("got the packet #%d, want = #%d", n, want)
- }
-
- // Test that the address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
- }
- }
-}
-
-func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
- proto := fwdTestNetworkProtocol{
- addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
- t.Helper()
- if len(linkAddr) != 0 {
- t.Fatalf("got linkAddr=%q, want unspecified", linkAddr)
- }
- // Any packets will be resolved to the link address "c".
- neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- },
- }
- clock, ep1, ep2 := fwdTestNetFactory(t, &proto)
-
- for i := 0; i < maxPendingResolutions+5; i++ {
- // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1.
- // Each packet has a different destination address (3 to
- // maxPendingResolutions + 7).
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = byte(3 + i)
- ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- }
-
- for i := 0; i < maxPendingResolutions; i++ {
- var p fwdTestPacketInfo
-
- clock.Advance(proto.addrResolveDelay)
- select {
- case p = <-ep2.C:
- default:
- t.Fatal("packet not forwarded")
- }
-
- // The first 5 packets (address 3 to 7) should not be forwarded
- // because their address resolutions are interrupted.
- if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 {
- t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset])
- }
-
- // Test that the address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
- }
- }
-}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
deleted file mode 100644
index 40b33b6b5..000000000
--- a/pkg/tcpip/stack/ndp_test.go
+++ /dev/null
@@ -1,5614 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package stack_test
-
-import (
- "bytes"
- "encoding/binary"
- "fmt"
- "math/rand"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- cryptorand "gvisor.dev/gvisor/pkg/rand"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-var (
- addr1 = testutil.MustParse6("a00::1")
- addr2 = testutil.MustParse6("a00::2")
- addr3 = testutil.MustParse6("a00::3")
-)
-
-const (
- linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07")
- linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
- linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09")
-
- defaultPrefixLen = 128
-)
-
-var (
- llAddr1 = header.LinkLocalAddr(linkAddr1)
- llAddr2 = header.LinkLocalAddr(linkAddr2)
- llAddr3 = header.LinkLocalAddr(linkAddr3)
- llAddr4 = header.LinkLocalAddr(linkAddr4)
- dstAddr = tcpip.FullAddress{
- Addr: "\x0a\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- Port: 25,
- }
-)
-
-func addrForSubnet(subnet tcpip.Subnet, linkAddr tcpip.LinkAddress) tcpip.AddressWithPrefix {
- if !header.IsValidUnicastEthernetAddress(linkAddr) {
- return tcpip.AddressWithPrefix{}
- }
-
- addrBytes := []byte(subnet.ID())
- header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:])
- return tcpip.AddressWithPrefix{
- Address: tcpip.Address(addrBytes),
- PrefixLen: 64,
- }
-}
-
-// prefixSubnetAddr returns a prefix (Address + Length), the prefix's equivalent
-// tcpip.Subnet, and an address where the lower half of the address is composed
-// of the EUI-64 of linkAddr if it is a valid unicast ethernet address.
-func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWithPrefix, tcpip.Subnet, tcpip.AddressWithPrefix) {
- prefixBytes := []byte{1, 2, 3, 4, 5, 6, 7, 8 + offset, 0, 0, 0, 0, 0, 0, 0, 0}
- prefix := tcpip.AddressWithPrefix{
- Address: tcpip.Address(prefixBytes),
- PrefixLen: 64,
- }
-
- subnet := prefix.Subnet()
-
- return prefix, subnet, addrForSubnet(subnet, linkAddr)
-}
-
-// ndpDADEvent is a set of parameters that was passed to
-// ndpDispatcher.OnDuplicateAddressDetectionResult.
-type ndpDADEvent struct {
- nicID tcpip.NICID
- addr tcpip.Address
- res stack.DADResult
-}
-
-type ndpOffLinkRouteEvent struct {
- nicID tcpip.NICID
- subnet tcpip.Subnet
- router tcpip.Address
- prf header.NDPRoutePreference
- // true if route was updated, false if invalidated.
- updated bool
-}
-
-type ndpPrefixEvent struct {
- nicID tcpip.NICID
- prefix tcpip.Subnet
- // true if prefix was discovered, false if invalidated.
- discovered bool
-}
-
-type ndpAutoGenAddrEventType int
-
-const (
- newAddr ndpAutoGenAddrEventType = iota
- deprecatedAddr
- invalidatedAddr
-)
-
-type ndpAutoGenAddrEvent struct {
- nicID tcpip.NICID
- addr tcpip.AddressWithPrefix
- eventType ndpAutoGenAddrEventType
-}
-
-func (e ndpAutoGenAddrEvent) String() string {
- return fmt.Sprintf("%T{nicID=%d addr=%s eventType=%d}", e, e.nicID, e.addr, e.eventType)
-}
-
-type ndpRDNSS struct {
- addrs []tcpip.Address
- lifetime time.Duration
-}
-
-type ndpRDNSSEvent struct {
- nicID tcpip.NICID
- rdnss ndpRDNSS
-}
-
-type ndpDNSSLEvent struct {
- nicID tcpip.NICID
- domainNames []string
- lifetime time.Duration
-}
-
-type ndpDHCPv6Event struct {
- nicID tcpip.NICID
- configuration ipv6.DHCPv6ConfigurationFromNDPRA
-}
-
-var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil)
-
-// ndpDispatcher implements NDPDispatcher so tests can know when various NDP
-// related events happen for test purposes.
-type ndpDispatcher struct {
- dadC chan ndpDADEvent
- offLinkRouteC chan ndpOffLinkRouteEvent
- prefixC chan ndpPrefixEvent
- autoGenAddrC chan ndpAutoGenAddrEvent
- rdnssC chan ndpRDNSSEvent
- dnsslC chan ndpDNSSLEvent
- routeTable []tcpip.Route
- dhcpv6ConfigurationC chan ndpDHCPv6Event
-}
-
-// Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionResult.
-func (n *ndpDispatcher) OnDuplicateAddressDetectionResult(nicID tcpip.NICID, addr tcpip.Address, res stack.DADResult) {
- if n.dadC != nil {
- n.dadC <- ndpDADEvent{
- nicID,
- addr,
- res,
- }
- }
-}
-
-// Implements ipv6.NDPDispatcher.OnOffLinkRouteUpdated.
-func (n *ndpDispatcher) OnOffLinkRouteUpdated(nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address, prf header.NDPRoutePreference) {
- if c := n.offLinkRouteC; c != nil {
- c <- ndpOffLinkRouteEvent{
- nicID,
- subnet,
- router,
- prf,
- true,
- }
- }
-}
-
-// Implements ipv6.NDPDispatcher.OnOffLinkRouteInvalidated.
-func (n *ndpDispatcher) OnOffLinkRouteInvalidated(nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address) {
- if c := n.offLinkRouteC; c != nil {
- var prf header.NDPRoutePreference
- c <- ndpOffLinkRouteEvent{
- nicID,
- subnet,
- router,
- prf,
- false,
- }
- }
-}
-
-// Implements ipv6.NDPDispatcher.OnOnLinkPrefixDiscovered.
-func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) {
- if c := n.prefixC; c != nil {
- c <- ndpPrefixEvent{
- nicID,
- prefix,
- true,
- }
- }
-}
-
-// Implements ipv6.NDPDispatcher.OnOnLinkPrefixInvalidated.
-func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) {
- if c := n.prefixC; c != nil {
- c <- ndpPrefixEvent{
- nicID,
- prefix,
- false,
- }
- }
-}
-
-func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) {
- if c := n.autoGenAddrC; c != nil {
- c <- ndpAutoGenAddrEvent{
- nicID,
- addr,
- newAddr,
- }
- }
-}
-
-func (n *ndpDispatcher) OnAutoGenAddressDeprecated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) {
- if c := n.autoGenAddrC; c != nil {
- c <- ndpAutoGenAddrEvent{
- nicID,
- addr,
- deprecatedAddr,
- }
- }
-}
-
-func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) {
- if c := n.autoGenAddrC; c != nil {
- c <- ndpAutoGenAddrEvent{
- nicID,
- addr,
- invalidatedAddr,
- }
- }
-}
-
-// Implements ipv6.NDPDispatcher.OnRecursiveDNSServerOption.
-func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) {
- if c := n.rdnssC; c != nil {
- c <- ndpRDNSSEvent{
- nicID,
- ndpRDNSS{
- addrs,
- lifetime,
- },
- }
- }
-}
-
-// Implements ipv6.NDPDispatcher.OnDNSSearchListOption.
-func (n *ndpDispatcher) OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) {
- if n.dnsslC != nil {
- n.dnsslC <- ndpDNSSLEvent{
- nicID,
- domainNames,
- lifetime,
- }
- }
-}
-
-// Implements ipv6.NDPDispatcher.OnDHCPv6Configuration.
-func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration ipv6.DHCPv6ConfigurationFromNDPRA) {
- if c := n.dhcpv6ConfigurationC; c != nil {
- c <- ndpDHCPv6Event{
- nicID,
- configuration,
- }
- }
-}
-
-// channelLinkWithHeaderLength is a channel.Endpoint with a configurable
-// header length.
-type channelLinkWithHeaderLength struct {
- *channel.Endpoint
- headerLength uint16
-}
-
-func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 {
- return l.headerLength
-}
-
-// Check e to make sure that the event is for addr on nic with ID 1, and the
-// resolved flag set to resolved with the specified err.
-func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, res stack.DADResult) string {
- return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, res: res}, e, cmp.AllowUnexported(e))
-}
-
-// TestDADDisabled tests that an address successfully resolves immediately
-// when DAD is not enabled (the default for an empty stack.Options).
-func TestDADDisabled(t *testing.T) {
- const nicID = 1
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPDisp: &ndpDisp,
- })},
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- addrWithPrefix := tcpip.AddressWithPrefix{
- Address: addr1,
- PrefixLen: defaultPrefixLen,
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: addrWithPrefix,
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err)
- }
-
- // Should get the address immediately since we should not have performed
- // DAD on it.
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr1, &stack.DADSucceeded{}); diff != "" {
- t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected DAD event")
- }
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
- t.Fatal(err)
- }
-
- // We should not have sent any NDP NS messages.
- if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != 0 {
- t.Fatalf("got NeighborSolicit = %d, want = 0", got)
- }
-}
-
-func TestDADResolveLoopback(t *testing.T) {
- const nicID = 1
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 1),
- }
-
- dadConfigs := stack.DADConfigurations{
- RetransmitTimer: time.Second,
- DupAddrDetectTransmits: 1,
- }
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- Clock: clock,
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPDisp: &ndpDisp,
- DADConfigs: dadConfigs,
- })},
- })
- if err := s.CreateNIC(nicID, loopback.New()); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addr1,
- PrefixLen: defaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err)
- }
-
- // Address should not be considered bound to the NIC yet (DAD ongoing).
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
- t.Fatal(err)
- }
-
- // DAD should not resolve after the normal resolution time since our DAD
- // message was looped back - we should extend our DAD process.
- dadResolutionTime := time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer
- clock.Advance(dadResolutionTime)
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
- t.Error(err)
- }
-
- // Make sure the address does not resolve before the extended resolution time
- // has passed.
- const delta = time.Nanosecond
- // DAD will send extra NS probes if an NS message is looped back.
- const extraTransmits = 3
- clock.Advance(dadResolutionTime*extraTransmits - delta)
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
- t.Error(err)
- }
-
- // DAD should now resolve.
- clock.Advance(delta)
- if diff := checkDADEvent(<-ndpDisp.dadC, nicID, addr1, &stack.DADSucceeded{}); diff != "" {
- t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
- }
-}
-
-// TestDADResolve tests that an address successfully resolves after performing
-// DAD for various values of DupAddrDetectTransmits and RetransmitTimer.
-// Included in the subtests is a test to make sure that an invalid
-// RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s.
-// This tests also validates the NDP NS packet that is transmitted.
-func TestDADResolve(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- linkHeaderLen uint16
- dupAddrDetectTransmits uint8
- retransTimer time.Duration
- expectedRetransmitTimer time.Duration
- }{
- {
- name: "1:1s:1s",
- dupAddrDetectTransmits: 1,
- retransTimer: time.Second,
- expectedRetransmitTimer: time.Second,
- },
- {
- name: "2:1s:1s",
- linkHeaderLen: 1,
- dupAddrDetectTransmits: 2,
- retransTimer: time.Second,
- expectedRetransmitTimer: time.Second,
- },
- {
- name: "1:2s:2s",
- linkHeaderLen: 2,
- dupAddrDetectTransmits: 1,
- retransTimer: 2 * time.Second,
- expectedRetransmitTimer: 2 * time.Second,
- },
- // 0s is an invalid RetransmitTimer timer and will be fixed to
- // the default RetransmitTimer value of 1s.
- {
- name: "1:0s:1s",
- linkHeaderLen: 3,
- dupAddrDetectTransmits: 1,
- retransTimer: 0,
- expectedRetransmitTimer: time.Second,
- },
- }
-
- nonces := [][]byte{
- {1, 2, 3, 4, 5, 6},
- {7, 8, 9, 10, 11, 12},
- }
-
- var secureRNGBytes []byte
- for _, n := range nonces {
- secureRNGBytes = append(secureRNGBytes, n...)
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 1),
- }
-
- e := channelLinkWithHeaderLength{
- Endpoint: channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1),
- headerLength: test.linkHeaderLen,
- }
- e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
-
- var secureRNG bytes.Reader
- secureRNG.Reset(secureRNGBytes)
-
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- Clock: clock,
- RandSource: rand.NewSource(time.Now().UnixNano()),
- SecureRNG: &secureRNG,
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPDisp: &ndpDisp,
- DADConfigs: stack.DADConfigurations{
- RetransmitTimer: test.retransTimer,
- DupAddrDetectTransmits: test.dupAddrDetectTransmits,
- },
- })},
- })
- if err := s.CreateNIC(nicID, &e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- // We add a default route so the call to FindRoute below will succeed
- // once we have an assigned address.
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv6EmptySubnet,
- Gateway: addr3,
- NIC: nicID,
- }})
-
- addrWithPrefix := tcpip.AddressWithPrefix{
- Address: addr1,
- PrefixLen: defaultPrefixLen,
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: addrWithPrefix,
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err)
- }
-
- // Make sure the address does not resolve before the resolution time has
- // passed.
- const delta = time.Nanosecond
- clock.Advance(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta)
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
- t.Error(err)
- }
- // Should not get a route even if we specify the local address as the
- // tentative address.
- {
- r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false)
- if _, ok := err.(*tcpip.ErrNoRoute); !ok {
- t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, &tcpip.ErrNoRoute{})
- }
- if r != nil {
- r.Release()
- }
- }
- {
- r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false)
- if _, ok := err.(*tcpip.ErrNoRoute); !ok {
- t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, &tcpip.ErrNoRoute{})
- }
- if r != nil {
- r.Release()
- }
- }
-
- if t.Failed() {
- t.FailNow()
- }
-
- // Wait for DAD to resolve.
- clock.Advance(delta)
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr1, &stack.DADSucceeded{}); diff != "" {
- t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatalf("expected DAD event for %s on NIC(%d)", addr1, nicID)
- }
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
- t.Error(err)
- }
- // Should get a route using the address now that it is resolved.
- {
- r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false)
- if err != nil {
- t.Errorf("got FindRoute(%d, '', %s, %d, false): %s", nicID, addr2, header.IPv6ProtocolNumber, err)
- } else if r.LocalAddress() != addr1 {
- t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), addr1)
- }
- r.Release()
- }
- {
- r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false)
- if err != nil {
- t.Errorf("got FindRoute(%d, %s, %s, %d, false): %s", nicID, addr1, addr2, header.IPv6ProtocolNumber, err)
- } else if r.LocalAddress() != addr1 {
- t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), addr1)
- }
- if r != nil {
- r.Release()
- }
- }
-
- if t.Failed() {
- t.FailNow()
- }
-
- // Should not have sent any more NS messages.
- if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) {
- t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits)
- }
-
- // Validate the sent Neighbor Solicitation messages.
- for i := uint8(0); i < test.dupAddrDetectTransmits; i++ {
- p, ok := e.Read()
- if !ok {
- t.Fatal("packet didn't arrive")
- }
-
- // Make sure its an IPv6 packet.
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
-
- // Make sure the right remote link address is used.
- snmc := header.SolicitedNodeAddr(addr1)
- if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want {
- t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
- }
-
- // Check NDP NS packet.
- //
- // As per RFC 4861 section 4.3, a possible option is the Source Link
- // Layer option, but this option MUST NOT be included when the source
- // address of the packet is the unspecified address.
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(snmc),
- checker.TTL(header.NDPHopLimit),
- checker.NDPNS(
- checker.NDPNSTargetAddress(addr1),
- checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonces[i])}),
- ))
-
- if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
- t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
- }
- }
- })
- }
-}
-
-func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
- pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.MessageBody())
- ns.SetTargetAddress(tgt)
- snmc := header.SolicitedNodeAddr(tgt)
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: header.IPv6Any,
- Dst: snmc,
- }))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: icmp.ProtocolNumber6,
- HopLimit: 255,
- SrcAddr: header.IPv6Any,
- DstAddr: snmc,
- })
- e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()}))
-}
-
-// TestDADFail tests to make sure that the DAD process fails if another node is
-// detected to be performing DAD on the same address (receive an NS message from
-// a node doing DAD for the same address), or if another node is detected to own
-// the address already (receive an NA message for the tentative address).
-func TestDADFail(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- rxPkt func(e *channel.Endpoint, tgt tcpip.Address)
- getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
- expectedHolderLinkAddress tcpip.LinkAddress
- }{
- {
- name: "RxSolicit",
- rxPkt: rxNDPSolicit,
- getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return s.NeighborSolicit
- },
- expectedHolderLinkAddress: "",
- },
- {
- name: "RxAdvert",
- rxPkt: func(e *channel.Endpoint, tgt tcpip.Address) {
- naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
- pkt := header.ICMPv6(hdr.Prepend(naSize))
- pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.MessageBody())
- na.SetSolicitedFlag(true)
- na.SetOverrideFlag(true)
- na.SetTargetAddress(tgt)
- na.Options().Serialize(header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(linkAddr1),
- })
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: tgt,
- Dst: header.IPv6AllNodesMulticastAddress,
- }))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: icmp.ProtocolNumber6,
- HopLimit: 255,
- SrcAddr: tgt,
- DstAddr: header.IPv6AllNodesMulticastAddress,
- })
- e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()}))
- },
- getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return s.NeighborAdvert
- },
- expectedHolderLinkAddress: linkAddr1,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 1),
- }
- dadConfigs := stack.DefaultDADConfigurations()
- dadConfigs.RetransmitTimer = time.Second * 2
-
- e := channel.New(0, 1280, linkAddr1)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPDisp: &ndpDisp,
- DADConfigs: dadConfigs,
- })},
- Clock: clock,
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: addr1.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
-
- // Address should not be considered bound to the NIC yet
- // (DAD ongoing).
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
- t.Fatal(err)
- }
-
- // Receive a packet to simulate an address conflict.
- test.rxPkt(e, addr1)
-
- stat := test.getStat(s.Stats().ICMP.V6.PacketsReceived)
- if got := stat.Value(); got != 1 {
- t.Fatalf("got stat = %d, want = 1", got)
- }
-
- // Wait for DAD to fail and make sure the address did
- // not get resolved.
- clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer)
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr1, &stack.DADDupAddrDetected{HolderLinkAddress: test.expectedHolderLinkAddress}); diff != "" {
- t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
- }
- default:
- // If we don't get a failure event after the
- // expected resolution time + extra 1s buffer,
- // something is wrong.
- t.Fatal("timed out waiting for DAD failure")
- }
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
- t.Fatal(err)
- }
-
- // Attempting to add the address again should not fail if the address's
- // state was cleaned up when DAD failed.
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- })
- }
-}
-
-func TestDADStop(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- stopFn func(t *testing.T, s *stack.Stack)
- skipFinalAddrCheck bool
- }{
- // Tests to make sure that DAD stops when an address is removed.
- {
- name: "Remove address",
- stopFn: func(t *testing.T, s *stack.Stack) {
- if err := s.RemoveAddress(nicID, addr1); err != nil {
- t.Fatalf("RemoveAddress(%d, %s): %s", nicID, addr1, err)
- }
- },
- },
-
- // Tests to make sure that DAD stops when the NIC is disabled.
- {
- name: "Disable NIC",
- stopFn: func(t *testing.T, s *stack.Stack) {
- if err := s.DisableNIC(nicID); err != nil {
- t.Fatalf("DisableNIC(%d): %s", nicID, err)
- }
- },
- },
-
- // Tests to make sure that DAD stops when the NIC is removed.
- {
- name: "Remove NIC",
- stopFn: func(t *testing.T, s *stack.Stack) {
- if err := s.RemoveNIC(nicID); err != nil {
- t.Fatalf("RemoveNIC(%d): %s", nicID, err)
- }
- },
- // The NIC is removed so we can't check its addresses after calling
- // stopFn.
- skipFinalAddrCheck: true,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 1),
- }
-
- dadConfigs := stack.DADConfigurations{
- RetransmitTimer: time.Second,
- DupAddrDetectTransmits: 2,
- }
-
- e := channel.New(0, 1280, linkAddr1)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPDisp: &ndpDisp,
- DADConfigs: dadConfigs,
- })},
- Clock: clock,
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: addr1.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
-
- // Address should not be considered bound to the NIC yet (DAD ongoing).
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
- t.Fatal(err)
- }
-
- test.stopFn(t, s)
-
- // Wait for DAD to fail (since the address was removed during DAD).
- clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer)
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr1, &stack.DADAborted{}); diff != "" {
- t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
- }
- default:
- // If we don't get a failure event after the expected resolution
- // time + extra 1s buffer, something is wrong.
- t.Fatal("timed out waiting for DAD failure")
- }
-
- if !test.skipFinalAddrCheck {
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
- t.Fatal(err)
- }
- }
-
- // Should not have sent more than 1 NS message.
- if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got > 1 {
- t.Errorf("got NeighborSolicit = %d, want <= 1", got)
- }
- })
- }
-}
-
-// TestSetNDPConfigurations tests that we can update and use per-interface NDP
-// configurations without affecting the default NDP configurations or other
-// interfaces' configurations.
-func TestSetNDPConfigurations(t *testing.T) {
- const nicID1 = 1
- const nicID2 = 2
- const nicID3 = 3
-
- tests := []struct {
- name string
- dupAddrDetectTransmits uint8
- retransmitTimer time.Duration
- expectedRetransmitTimer time.Duration
- }{
- {
- "OK",
- 1,
- time.Second,
- time.Second,
- },
- {
- "Invalid Retransmit Timer",
- 1,
- 0,
- time.Second,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPDisp: &ndpDisp,
- })},
- Clock: clock,
- })
-
- expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) {
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" {
- t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatalf("expected DAD event for %s", addr)
- }
- }
-
- // This NIC(1)'s NDP configurations will be updated to
- // be different from the default.
- if err := s.CreateNIC(nicID1, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err)
- }
-
- // Created before updating NIC(1)'s NDP configurations
- // but updating NIC(1)'s NDP configurations should not
- // affect other existing NICs.
- if err := s.CreateNIC(nicID2, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err)
- }
-
- // Update the configurations on NIC(1) to use DAD.
- if ipv6Ep, err := s.GetNetworkEndpoint(nicID1, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, header.IPv6ProtocolNumber, err)
- } else {
- dad := ipv6Ep.(stack.DuplicateAddressDetector)
- dad.SetDADConfigurations(stack.DADConfigurations{
- DupAddrDetectTransmits: test.dupAddrDetectTransmits,
- RetransmitTimer: test.retransmitTimer,
- })
- }
-
- // Created after updating NIC(1)'s NDP configurations
- // but the stack's default NDP configurations should not
- // have been updated.
- if err := s.CreateNIC(nicID3, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID3, err)
- }
-
- // Add addresses for each NIC.
- addrWithPrefix1 := tcpip.AddressWithPrefix{Address: addr1, PrefixLen: defaultPrefixLen}
- protocolAddr1 := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: addrWithPrefix1,
- }
- if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID1, protocolAddr1, err)
- }
- addrWithPrefix2 := tcpip.AddressWithPrefix{Address: addr2, PrefixLen: defaultPrefixLen}
- protocolAddr2 := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: addrWithPrefix2,
- }
- if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID2, protocolAddr2, err)
- }
- expectDADEvent(nicID2, addr2)
- addrWithPrefix3 := tcpip.AddressWithPrefix{Address: addr3, PrefixLen: defaultPrefixLen}
- protocolAddr3 := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: addrWithPrefix3,
- }
- if err := s.AddProtocolAddress(nicID3, protocolAddr3, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID3, protocolAddr3, err)
- }
- expectDADEvent(nicID3, addr3)
-
- // Address should not be considered bound to NIC(1) yet
- // (DAD ongoing).
- if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
- t.Fatal(err)
- }
-
- // Should get the address on NIC(2) and NIC(3)
- // immediately since we should not have performed DAD on
- // it as the stack was configured to not do DAD by
- // default and we only updated the NDP configurations on
- // NIC(1).
- if err := checkGetMainNICAddress(s, nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil {
- t.Fatal(err)
- }
- if err := checkGetMainNICAddress(s, nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil {
- t.Fatal(err)
- }
-
- // Sleep until right before resolution to make sure the address didn't
- // resolve on NIC(1) yet.
- const delta = 1
- clock.Advance(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta)
- if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
- t.Fatal(err)
- }
-
- // Wait for DAD to resolve.
- clock.Advance(delta)
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID1, addr1, &stack.DADSucceeded{}); diff != "" {
- t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("timed out waiting for DAD resolution")
- }
- if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil {
- t.Fatal(err)
- }
- })
- }
-}
-
-// raBuf returns a valid NDP Router Advertisement with options, router
-// preference and DHCPv6 configurations specified.
-func raBuf(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, prf header.NDPRoutePreference, optSer header.NDPOptionsSerializer) *stack.PacketBuffer {
- const flagsByte = 1
- const routerLifetimeOffset = 2
-
- icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + optSer.Length()
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
- pkt := header.ICMPv6(hdr.Prepend(icmpSize))
- pkt.SetType(header.ICMPv6RouterAdvert)
- pkt.SetCode(0)
- raPayload := pkt.MessageBody()
- ra := header.NDPRouterAdvert(raPayload)
- // Populate the Router Lifetime.
- binary.BigEndian.PutUint16(raPayload[routerLifetimeOffset:], rl)
- // Populate the Managed Address flag field.
- if managedAddress {
- // The Managed Addresses flag field is the 7th bit of the flags byte.
- raPayload[flagsByte] |= 1 << 7
- }
- // Populate the Other Configurations flag field.
- if otherConfigurations {
- // The Other Configurations flag field is the 6th bit of the flags byte.
- raPayload[flagsByte] |= 1 << 6
- }
- // The Prf field is held in the flags byte.
- raPayload[flagsByte] |= byte(prf) << 3
- opts := ra.Options()
- opts.Serialize(optSer)
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: ip,
- Dst: header.IPv6AllNodesMulticastAddress,
- }))
- payloadLength := hdr.UsedLength()
- iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- iph.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: icmp.ProtocolNumber6,
- HopLimit: header.NDPHopLimit,
- SrcAddr: ip,
- DstAddr: header.IPv6AllNodesMulticastAddress,
- })
-
- return stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- })
-}
-
-// raBufWithOpts returns a valid NDP Router Advertisement with options.
-//
-// Note, raBufWithOpts does not populate any of the RA fields other than the
-// Router Lifetime.
-func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) *stack.PacketBuffer {
- return raBuf(ip, rl, false /* managedAddress */, false /* otherConfigurations */, 0 /* prf */, optSer)
-}
-
-// raBufWithDHCPv6 returns a valid NDP Router Advertisement with DHCPv6 related
-// fields set.
-//
-// Note, raBufWithDHCPv6 does not populate any of the RA fields other than the
-// DHCPv6 related ones.
-func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfigurations bool) *stack.PacketBuffer {
- return raBuf(ip, 0, managedAddresses, otherConfigurations, 0 /* prf */, header.NDPOptionsSerializer{})
-}
-
-// raBuf returns a valid NDP Router Advertisement.
-//
-// Note, raBuf does not populate any of the RA fields other than the
-// Router Lifetime.
-func raBufSimple(ip tcpip.Address, rl uint16) *stack.PacketBuffer {
- return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{})
-}
-
-// raBufWithPrf returns a valid NDP Router Advertisement with a preference.
-//
-// Note, raBufWithPrf does not populate any of the RA fields other than the
-// Router Lifetime and Default Router Preference fields.
-func raBufWithPrf(ip tcpip.Address, rl uint16, prf header.NDPRoutePreference) *stack.PacketBuffer {
- return raBuf(ip, rl, false /* managedAddress */, false /* otherConfigurations */, prf, header.NDPOptionsSerializer{})
-}
-
-// raBufWithPI returns a valid NDP Router Advertisement with a single Prefix
-// Information option.
-//
-// Note, raBufWithPI does not populate any of the RA fields other than the
-// Router Lifetime.
-func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) *stack.PacketBuffer {
- flags := uint8(0)
- if onLink {
- // The OnLink flag is the 7th bit in the flags byte.
- flags |= 1 << 7
- }
- if auto {
- // The Address Auto-Configuration flag is the 6th bit in the
- // flags byte.
- flags |= 1 << 6
- }
-
- // A valid header.NDPPrefixInformation must be 30 bytes.
- buf := [30]byte{}
- // The first byte in a header.NDPPrefixInformation is the Prefix Length
- // field.
- buf[0] = uint8(prefix.PrefixLen)
- // The 2nd byte within a header.NDPPrefixInformation is the Flags field.
- buf[1] = flags
- // The Valid Lifetime field starts after the 2nd byte within a
- // header.NDPPrefixInformation.
- binary.BigEndian.PutUint32(buf[2:], vl)
- // The Preferred Lifetime field starts after the 6th byte within a
- // header.NDPPrefixInformation.
- binary.BigEndian.PutUint32(buf[6:], pl)
- // The Prefix Address field starts after the 14th byte within a
- // header.NDPPrefixInformation.
- copy(buf[14:], prefix.Address)
- return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{
- header.NDPPrefixInformation(buf[:]),
- })
-}
-
-// raBufWithRIO returns a valid NDP Router Advertisement with a single Route
-// Information option.
-//
-// All fields in the RA will be zero except the RIO option.
-func raBufWithRIO(t *testing.T, ip tcpip.Address, prefix tcpip.AddressWithPrefix, lifetimeSeconds uint32, prf header.NDPRoutePreference) *stack.PacketBuffer {
- // buf will hold the route information option after the Type and Length
- // fields.
- //
- // 2.3. Route Information Option
- //
- // 0 1 2 3
- // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
- // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
- // | Type | Length | Prefix Length |Resvd|Prf|Resvd|
- // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
- // | Route Lifetime |
- // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
- // | Prefix (Variable Length) |
- // . .
- // . .
- // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
- var buf [22]byte
- buf[0] = uint8(prefix.PrefixLen)
- buf[1] = byte(prf) << 3
- binary.BigEndian.PutUint32(buf[2:], lifetimeSeconds)
- if n := copy(buf[6:], prefix.Address); n != len(prefix.Address) {
- t.Fatalf("got copy(...) = %d, want = %d", n, len(prefix.Address))
- }
- return raBufWithOpts(ip, 0 /* router lifetime */, header.NDPOptionsSerializer{
- header.NDPRouteInformation(buf[:]),
- })
-}
-
-func TestDynamicConfigurationsDisabled(t *testing.T) {
- const (
- nicID = 1
- maxRtrSolicitDelay = time.Second
- )
-
- prefix := tcpip.AddressWithPrefix{
- Address: testutil.MustParse6("102:304:506:708::"),
- PrefixLen: 64,
- }
-
- tests := []struct {
- name string
- config func(bool) ipv6.NDPConfigurations
- ra *stack.PacketBuffer
- }{
- {
- name: "No Router Discovery",
- config: func(enable bool) ipv6.NDPConfigurations {
- return ipv6.NDPConfigurations{DiscoverDefaultRouters: enable}
- },
- ra: raBufSimple(llAddr2, 1000),
- },
- {
- name: "No Prefix Discovery",
- config: func(enable bool) ipv6.NDPConfigurations {
- return ipv6.NDPConfigurations{DiscoverOnLinkPrefixes: enable}
- },
- ra: raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0),
- },
- {
- name: "No Autogenerate Addresses",
- config: func(enable bool) ipv6.NDPConfigurations {
- return ipv6.NDPConfigurations{AutoGenGlobalAddresses: enable}
- },
- ra: raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0),
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- // Being configured to discover routers/prefixes or auto-generate
- // addresses means RAs must be handled, and router/prefix discovery or
- // SLAAC must be enabled.
- //
- // This tests all possible combinations of the configurations where
- // router/prefix discovery or SLAAC are disabled.
- for i := 0; i < 7; i++ {
- handle := ipv6.HandlingRAsDisabled
- if i&1 != 0 {
- handle = ipv6.HandlingRAsEnabledWhenForwardingDisabled
- }
- enable := i&2 != 0
- forwarding := i&4 == 0
-
- t.Run(fmt.Sprintf("HandleRAs(%s), Forwarding(%t), Enabled(%t)", handle, forwarding, enable), func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1),
- prefixC: make(chan ndpPrefixEvent, 1),
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- ndpConfigs := test.config(enable)
- ndpConfigs.HandleRAs = handle
- ndpConfigs.MaxRtrSolicitations = 1
- ndpConfigs.RtrSolicitationInterval = maxRtrSolicitDelay
- ndpConfigs.MaxRtrSolicitationDelay = maxRtrSolicitDelay
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- Clock: clock,
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
- })},
- })
- if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil {
- t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
- }
-
- e := channel.New(1, 1280, linkAddr1)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
-
- handleRAsDisabled := handle == ipv6.HandlingRAsDisabled || forwarding
- ep, err := s.GetNetworkEndpoint(nicID, ipv6.ProtocolNumber)
- if err != nil {
- t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ipv6.ProtocolNumber, err)
- }
- stats := ep.Stats()
- v6Stats, ok := stats.(*ipv6.Stats)
- if !ok {
- t.Fatalf("got v6Stats = %T, expected = %T", stats, v6Stats)
- }
-
- // Make sure that when handling RAs are enabled, we solicit routers.
- clock.Advance(maxRtrSolicitDelay)
- if got, want := v6Stats.ICMP.PacketsSent.RouterSolicit.Value(), boolToUint64(!handleRAsDisabled); got != want {
- t.Errorf("got v6Stats.ICMP.PacketsSent.RouterSolicit.Value() = %d, want = %d", got, want)
- }
- if handleRAsDisabled {
- if p, ok := e.Read(); ok {
- t.Errorf("unexpectedly got a packet = %#v", p)
- }
- } else if p, ok := e.Read(); !ok {
- t.Error("expected router solicitation packet")
- } else if p.Proto != header.IPv6ProtocolNumber {
- t.Errorf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- } else {
- if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want {
- t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
- }
-
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS(checker.NDPRSOptions(nil)),
- )
- }
-
- // Make sure we do not discover any routers or prefixes, or perform
- // SLAAC on reception of an RA.
- e.InjectInbound(header.IPv6ProtocolNumber, test.ra.Clone())
- // Make sure that the unhandled RA stat is only incremented when
- // handling RAs is disabled.
- if got, want := v6Stats.UnhandledRouterAdvertisements.Value(), boolToUint64(handleRAsDisabled); got != want {
- t.Errorf("got v6Stats.UnhandledRouterAdvertisements.Value() = %d, want = %d", got, want)
- }
- select {
- case e := <-ndpDisp.offLinkRouteC:
- t.Errorf("unexpectedly updated an off-link route when configured not to: %#v", e)
- default:
- }
- select {
- case e := <-ndpDisp.prefixC:
- t.Errorf("unexpectedly discovered a prefix when configured not to: %#v", e)
- default:
- }
- select {
- case e := <-ndpDisp.autoGenAddrC:
- t.Errorf("unexpectedly auto-generated an address when configured not to: %#v", e)
- default:
- }
- })
- }
- })
- }
-}
-
-func boolToUint64(v bool) uint64 {
- if v {
- return 1
- }
- return 0
-}
-
-func checkOffLinkRouteEvent(e ndpOffLinkRouteEvent, nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address, prf header.NDPRoutePreference, updated bool) string {
- return cmp.Diff(ndpOffLinkRouteEvent{nicID: nicID, subnet: subnet, router: router, prf: prf, updated: updated}, e, cmp.AllowUnexported(e))
-}
-
-func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, bool)) {
- tests := [...]struct {
- name string
- handleRAs ipv6.HandleRAsConfiguration
- forwarding bool
- }{
- {
- name: "Handle RAs when forwarding disabled",
- handleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- forwarding: false,
- },
- {
- name: "Always Handle RAs with forwarding disabled",
- handleRAs: ipv6.HandlingRAsAlwaysEnabled,
- forwarding: false,
- },
- {
- name: "Always Handle RAs with forwarding enabled",
- handleRAs: ipv6.HandlingRAsAlwaysEnabled,
- forwarding: true,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- f(t, test.handleRAs, test.forwarding)
- })
- }
-}
-
-func TestOffLinkRouteDiscovery(t *testing.T) {
- const nicID = 1
-
- moreSpecificPrefix := tcpip.AddressWithPrefix{Address: testutil.MustParse6("a00::"), PrefixLen: 16}
- tests := []struct {
- name string
-
- discoverDefaultRouters bool
- discoverMoreSpecificRoutes bool
-
- dest tcpip.Subnet
- ra func(*testing.T, tcpip.Address, uint16, header.NDPRoutePreference) *stack.PacketBuffer
- }{
- {
- name: "Default router discovery",
- discoverDefaultRouters: true,
- discoverMoreSpecificRoutes: false,
- dest: header.IPv6EmptySubnet,
- ra: func(_ *testing.T, router tcpip.Address, lifetimeSeconds uint16, prf header.NDPRoutePreference) *stack.PacketBuffer {
- return raBufWithPrf(router, lifetimeSeconds, prf)
- },
- },
- {
- name: "More-specific route discovery",
- discoverDefaultRouters: false,
- discoverMoreSpecificRoutes: true,
- dest: moreSpecificPrefix.Subnet(),
- ra: func(t *testing.T, router tcpip.Address, lifetimeSeconds uint16, prf header.NDPRoutePreference) *stack.PacketBuffer {
- return raBufWithRIO(t, router, moreSpecificPrefix, uint32(lifetimeSeconds), prf)
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) {
- ndpDisp := ndpDispatcher{
- offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: handleRAs,
- DiscoverDefaultRouters: test.discoverDefaultRouters,
- DiscoverMoreSpecificRoutes: test.discoverMoreSpecificRoutes,
- },
- NDPDisp: &ndpDisp,
- })},
- Clock: clock,
- })
-
- expectOffLinkRouteEvent := func(addr tcpip.Address, prf header.NDPRoutePreference, updated bool) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.offLinkRouteC:
- if diff := checkOffLinkRouteEvent(e, nicID, test.dest, addr, prf, updated); diff != "" {
- t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected router discovery event")
- }
- }
-
- expectAsyncOffLinkRouteInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) {
- t.Helper()
-
- clock.Advance(timeout)
- select {
- case e := <-ndpDisp.offLinkRouteC:
- var prf header.NDPRoutePreference
- if diff := checkOffLinkRouteEvent(e, nicID, test.dest, addr, prf, false); diff != "" {
- t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("timed out waiting for router discovery event")
- }
- }
-
- if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil {
- t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
- }
-
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
-
- // Rx an RA from lladdr2 with zero lifetime. It should not be
- // remembered.
- e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 0, header.MediumRoutePreference))
- select {
- case <-ndpDisp.offLinkRouteC:
- t.Fatal("unexpectedly updated an off-link route with 0 lifetime")
- default:
- }
-
- // Discover an off-link route through llAddr2.
- e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 1000, header.ReservedRoutePreference))
- if test.discoverMoreSpecificRoutes {
- // The reserved value is considered invalid with more-specific route
- // discovery so we inject the same packet but with the default
- // (medium) preference value.
- select {
- case <-ndpDisp.offLinkRouteC:
- t.Fatal("unexpectedly updated an off-link route with a reserved preference value")
- default:
- }
- e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 1000, header.MediumRoutePreference))
- }
- expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true)
-
- // Rx an RA from another router (lladdr3) with non-zero lifetime and
- // non-default preference value.
- const l3LifetimeSeconds = 6
- e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr3, l3LifetimeSeconds, header.HighRoutePreference))
- expectOffLinkRouteEvent(llAddr3, header.HighRoutePreference, true)
-
- // Rx an RA from lladdr2 with lesser lifetime and default (medium)
- // preference value.
- const l2LifetimeSeconds = 2
- e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, l2LifetimeSeconds, header.MediumRoutePreference))
- select {
- case <-ndpDisp.offLinkRouteC:
- t.Fatal("should not receive a off-link route event when updating lifetimes for known routers")
- default:
- }
-
- // Rx an RA from lladdr2 with a different preference.
- e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, l2LifetimeSeconds, header.LowRoutePreference))
- expectOffLinkRouteEvent(llAddr2, header.LowRoutePreference, true)
-
- // Wait for lladdr2's router invalidation job to execute. The lifetime
- // of the router should have been updated to the most recent (smaller)
- // lifetime.
- //
- // Wait for the normal lifetime plus an extra bit for the
- // router to get invalidated. If we don't get an invalidation
- // event after this time, then something is wrong.
- expectAsyncOffLinkRouteInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second)
-
- // Rx an RA from lladdr2 with huge lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 1000, header.MediumRoutePreference))
- expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true)
-
- // Rx an RA from lladdr2 with zero lifetime. It should be invalidated.
- e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 0, header.MediumRoutePreference))
- expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, false)
-
- // Wait for lladdr3's router invalidation job to execute. The lifetime
- // of the router should have been updated to the most recent (smaller)
- // lifetime.
- //
- // Wait for the normal lifetime plus an extra bit for the
- // router to get invalidated. If we don't get an invalidation
- // event after this time, then something is wrong.
- expectAsyncOffLinkRouteInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second)
- })
- })
- }
-}
-
-// TestRouterDiscoveryMaxRouters tests that only
-// ipv6.MaxDiscoveredOffLinkRoutes discovered routers are remembered.
-func TestRouterDiscoveryMaxRouters(t *testing.T) {
- const nicID = 1
-
- ndpDisp := ndpDispatcher{
- offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- DiscoverDefaultRouters: true,
- },
- NDPDisp: &ndpDisp,
- })},
- })
-
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
-
- // Receive an RA from 2 more than the max number of discovered routers.
- for i := 1; i <= ipv6.MaxDiscoveredOffLinkRoutes+2; i++ {
- linkAddr := []byte{2, 2, 3, 4, 5, 0}
- linkAddr[5] = byte(i)
- llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr))
-
- e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr, 5))
-
- if i <= ipv6.MaxDiscoveredOffLinkRoutes {
- select {
- case e := <-ndpDisp.offLinkRouteC:
- if diff := checkOffLinkRouteEvent(e, nicID, header.IPv6EmptySubnet, llAddr, header.MediumRoutePreference, true); diff != "" {
- t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected router discovery event")
- }
-
- } else {
- select {
- case <-ndpDisp.offLinkRouteC:
- t.Fatal("should not have discovered a new router after we already discovered the max number of routers")
- default:
- }
- }
- }
-}
-
-// Check e to make sure that the event is for prefix on nic with ID 1, and the
-// discovered flag set to discovered.
-func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) string {
- return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e))
-}
-
-func TestPrefixDiscovery(t *testing.T) {
- prefix1, subnet1, _ := prefixSubnetAddr(0, "")
- prefix2, subnet2, _ := prefixSubnetAddr(1, "")
- prefix3, subnet3, _ := prefixSubnetAddr(2, "")
-
- testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) {
- ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: handleRAs,
- DiscoverOnLinkPrefixes: true,
- },
- NDPDisp: &ndpDisp,
- })},
- Clock: clock,
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.prefixC:
- if diff := checkPrefixEvent(e, prefix, discovered); diff != "" {
- t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected prefix discovery event")
- }
- }
-
- if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil {
- t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
- }
-
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with zero valid lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0))
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("unexpectedly discovered a prefix with 0 lifetime")
- default:
- }
-
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with non-zero lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0))
- expectPrefixEvent(subnet1, true)
-
- // Receive an RA with prefix2 in a PI.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0))
- expectPrefixEvent(subnet2, true)
-
- // Receive an RA with prefix3 in a PI.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0))
- expectPrefixEvent(subnet3, true)
-
- // Receive an RA with prefix1 in a PI with lifetime = 0.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0))
- expectPrefixEvent(subnet1, false)
-
- // Receive an RA with prefix2 in a PI with lesser lifetime.
- lifetime := uint32(2)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0))
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("unexpectedly received prefix event when updating lifetime")
- default:
- }
-
- // Wait for prefix2's most recent invalidation job plus some buffer to
- // expire.
- clock.Advance(time.Duration(lifetime) * time.Second)
- select {
- case e := <-ndpDisp.prefixC:
- if diff := checkPrefixEvent(e, subnet2, false); diff != "" {
- t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("timed out waiting for prefix discovery event")
- }
-
- // Receive RA to invalidate prefix3.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0))
- expectPrefixEvent(subnet3, false)
- })
-}
-
-func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
- prefix := tcpip.AddressWithPrefix{
- Address: testutil.MustParse6("102:304:506:708::"),
- PrefixLen: 64,
- }
- subnet := prefix.Subnet()
-
- ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- DiscoverOnLinkPrefixes: true,
- },
- NDPDisp: &ndpDisp,
- })},
- Clock: clock,
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.prefixC:
- if diff := checkPrefixEvent(e, prefix, discovered); diff != "" {
- t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected prefix discovery event")
- }
- }
-
- // Receive an RA with prefix in an NDP Prefix Information option (PI)
- // with infinite valid lifetime which should not get invalidated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds, 0))
- expectPrefixEvent(subnet, true)
- clock.Advance(header.NDPInfiniteLifetime)
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
- default:
- }
-
- // Receive an RA with finite lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds-1, 0))
- clock.Advance(header.NDPInfiniteLifetime - time.Second)
- select {
- case e := <-ndpDisp.prefixC:
- if diff := checkPrefixEvent(e, subnet, false); diff != "" {
- t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("timed out waiting for prefix discovery event")
- }
-
- // Receive an RA with finite lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds-1, 0))
- expectPrefixEvent(subnet, true)
-
- // Receive an RA with prefix with an infinite lifetime.
- // The prefix should not be invalidated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds, 0))
- clock.Advance(header.NDPInfiniteLifetime)
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
- default:
- }
-
- // Receive an RA with 0 lifetime.
- // The prefix should get invalidated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 0, 0))
- expectPrefixEvent(subnet, false)
-}
-
-// TestPrefixDiscoveryMaxRouters tests that only
-// ipv6.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered.
-func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
- ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- DiscoverDefaultRouters: false,
- DiscoverOnLinkPrefixes: true,
- },
- NDPDisp: &ndpDisp,
- })},
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- optSer := make(header.NDPOptionsSerializer, ipv6.MaxDiscoveredOnLinkPrefixes+2)
- prefixes := [ipv6.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{}
-
- // Receive an RA with 2 more than the max number of discovered on-link
- // prefixes.
- for i := 0; i < ipv6.MaxDiscoveredOnLinkPrefixes+2; i++ {
- prefixAddr := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0}
- prefixAddr[7] = byte(i)
- prefix := tcpip.AddressWithPrefix{
- Address: tcpip.Address(prefixAddr[:]),
- PrefixLen: 64,
- }
- prefixes[i] = prefix.Subnet()
- buf := [30]byte{}
- buf[0] = uint8(prefix.PrefixLen)
- buf[1] = 128
- binary.BigEndian.PutUint32(buf[2:], 10)
- copy(buf[14:], prefix.Address)
-
- optSer[i] = header.NDPPrefixInformation(buf[:])
- }
-
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer))
- for i := 0; i < ipv6.MaxDiscoveredOnLinkPrefixes+2; i++ {
- if i < ipv6.MaxDiscoveredOnLinkPrefixes {
- select {
- case e := <-ndpDisp.prefixC:
- if diff := checkPrefixEvent(e, prefixes[i], true); diff != "" {
- t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected prefix discovery event")
- }
- } else {
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("should not have discovered a new prefix after we already discovered the max number of prefixes")
- default:
- }
- }
- }
-}
-
-// Checks to see if list contains an IPv6 address, item.
-func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool {
- protocolAddress := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: item,
- }
-
- return containsAddr(list, protocolAddress)
-}
-
-// Check e to make sure that the event is for addr on nic with ID 1, and the
-// event type is set to eventType.
-func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) string {
- return cmp.Diff(ndpAutoGenAddrEvent{nicID: 1, addr: addr, eventType: eventType}, e, cmp.AllowUnexported(e))
-}
-
-const minVLSeconds = uint32(ipv6.MinPrefixInformationValidLifetimeForUpdate / time.Second)
-const infiniteLifetimeSeconds = uint32(header.NDPInfiniteLifetime / time.Second)
-
-// TestAutoGenAddr tests that an address is properly generated and invalidated
-// when configured to do so.
-func TestAutoGenAddr(t *testing.T) {
- prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
- prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
-
- testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) {
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: handleRAs,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- })},
- Clock: clock,
- })
-
- if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil {
- t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
- }
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
-
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with zero valid lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly auto-generated an address with 0 lifetime")
- default:
- }
-
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with non-zero lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
- expectAutoGenAddrEvent(addr1, newAddr)
- if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
- t.Fatalf("Should have %s in the list of addresses", addr1)
- }
-
- // Receive an RA with prefix2 in an NDP Prefix Information option (PI)
- // with preferred lifetime > valid lifetime
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime")
- default:
- }
-
- // Receive an RA with prefix2 in a PI with a valid lifetime that exceeds
- // the minimum.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, minVLSeconds+1, 0))
- expectAutoGenAddrEvent(addr2, newAddr)
- if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
- t.Fatalf("Should have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) {
- t.Fatalf("Should have %s in the list of addresses", addr2)
- }
-
- // Refresh valid lifetime for addr of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix")
- default:
- }
-
- // Wait for addr of prefix1 to be invalidated.
- clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate)
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("timed out waiting for addr auto gen event")
- }
- if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
- t.Fatalf("Should not have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) {
- t.Fatalf("Should have %s in the list of addresses", addr2)
- }
- })
-}
-
-func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []tcpip.AddressWithPrefix) string {
- ret := ""
- for _, c := range containList {
- if !containsV6Addr(addrs, c) {
- ret += fmt.Sprintf("should have %s in the list of addresses\n", c)
- }
- }
- for _, c := range notContainList {
- if containsV6Addr(addrs, c) {
- ret += fmt.Sprintf("should not have %s in the list of addresses\n", c)
- }
- }
- return ret
-}
-
-// TestAutoGenTempAddr tests that temporary SLAAC addresses are generated when
-// configured to do so as part of IPv6 Privacy Extensions.
-func TestAutoGenTempAddr(t *testing.T) {
- const nicID = 1
-
- prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
- prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
-
- tests := []struct {
- name string
- dupAddrTransmits uint8
- retransmitTimer time.Duration
- }{
- {
- name: "DAD disabled",
- },
- {
- name: "DAD enabled",
- dupAddrTransmits: 1,
- retransmitTimer: time.Second,
- },
- }
-
- for i, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- seed := []byte{uint8(i)}
- var tempIIDHistory [header.IIDSize]byte
- header.InitialTempIID(tempIIDHistory[:], seed, nicID)
- newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix {
- return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr)
- }
-
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 2),
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
- }
- e := channel.New(0, 1280, linkAddr1)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- DADConfigs: stack.DADConfigurations{
- DupAddrDetectTransmits: test.dupAddrTransmits,
- RetransmitTimer: test.retransmitTimer,
- },
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- AutoGenTempGlobalAddresses: true,
- MaxTempAddrValidLifetime: 2 * ipv6.MinPrefixInformationValidLifetimeForUpdate,
- MaxTempAddrPreferredLifetime: 2 * ipv6.MinPrefixInformationValidLifetimeForUpdate,
- },
- NDPDisp: &ndpDisp,
- TempIIDSeed: seed,
- })},
- Clock: clock,
- })
-
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
-
- expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- clock.RunImmediatelyScheduledJobs()
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("timed out waiting for addr auto gen event")
- }
- }
-
- expectDADEventAsync := func(addr tcpip.Address) {
- t.Helper()
-
- clock.Advance(time.Duration(test.dupAddrTransmits) * test.retransmitTimer)
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" {
- t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("timed out waiting for DAD event")
- }
- }
-
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with zero valid lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0))
- select {
- case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e)
- default:
- }
-
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with non-zero valid lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
- expectAutoGenAddrEvent(addr1, newAddr)
- expectDADEventAsync(addr1.Address)
- select {
- case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpectedly got an auto gen addr event = %+v", e)
- default:
- }
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
-
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with non-zero valid & preferred lifetimes.
- tempAddr1 := newTempAddr(addr1.Address)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
- expectAutoGenAddrEvent(tempAddr1, newAddr)
- expectDADEventAsync(tempAddr1.Address)
- if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
-
- // Receive an RA with prefix2 in an NDP Prefix Information option (PI)
- // with preferred lifetime > valid lifetime
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6))
- select {
- case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e)
- default:
- }
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
-
- // Receive an RA with prefix2 in a PI with a valid lifetime that exceeds
- // the minimum and won't be reached in this test.
- tempAddr2 := newTempAddr(addr2.Address)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 2*minVLSeconds, 2*minVLSeconds))
- expectAutoGenAddrEvent(addr2, newAddr)
- expectDADEventAsync(addr2.Address)
- expectAutoGenAddrEventAsync(tempAddr2, newAddr)
- expectDADEventAsync(tempAddr2.Address)
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
-
- // Deprecate prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
- expectAutoGenAddrEvent(addr1, deprecatedAddr)
- expectAutoGenAddrEvent(tempAddr1, deprecatedAddr)
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
-
- // Refresh lifetimes for prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
-
- // Reduce valid lifetime and deprecate addresses of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, 0))
- expectAutoGenAddrEvent(addr1, deprecatedAddr)
- expectAutoGenAddrEvent(tempAddr1, deprecatedAddr)
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
-
- // Wait for addrs of prefix1 to be invalidated. They should be
- // invalidated at the same time.
- clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate)
- select {
- case e := <-ndpDisp.autoGenAddrC:
- var nextAddr tcpip.AddressWithPrefix
- if e.addr == addr1 {
- if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- nextAddr = tempAddr1
- } else {
- if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- nextAddr = addr1
- }
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("timed out waiting for addr auto gen event")
- }
- default:
- t.Fatal("timed out waiting for addr auto gen event")
- }
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" {
- t.Fatal(mismatch)
- }
-
- // Receive an RA with prefix2 in a PI w/ 0 lifetimes.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0))
- expectAutoGenAddrEvent(addr2, deprecatedAddr)
- expectAutoGenAddrEvent(tempAddr2, deprecatedAddr)
- select {
- case e := <-ndpDisp.autoGenAddrC:
- t.Errorf("got unexpected auto gen addr event = %+v", e)
- default:
- }
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" {
- t.Fatal(mismatch)
- }
- })
- }
-}
-
-// TestNoAutoGenTempAddrForLinkLocal test that temporary SLAAC addresses are not
-// generated for auto generated link-local addresses.
-func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- dupAddrTransmits uint8
- retransmitTimer time.Duration
- }{
- {
- name: "DAD disabled",
- },
- {
- name: "DAD enabled",
- dupAddrTransmits: 1,
- retransmitTimer: time.Second,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 1),
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- AutoGenTempGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- AutoGenLinkLocal: true,
- })},
- Clock: clock,
- })
-
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- // The stable link-local address should auto-generate and resolve DAD.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- clock.Advance(time.Duration(test.dupAddrTransmits) * test.retransmitTimer)
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, llAddr1, &stack.DADSucceeded{}); diff != "" {
- t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("timed out waiting for DAD event")
- }
-
- // No new addresses should be generated.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- t.Errorf("got unxpected auto gen addr event = %+v", e)
- default:
- }
- })
- }
-}
-
-// TestNoAutoGenTempAddrWithoutStableAddr tests that a temporary SLAAC address
-// will not be generated until after DAD completes, even if a new Router
-// Advertisement is received to refresh lifetimes.
-func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
- const (
- nicID = 1
- dadTransmits = 1
- retransmitTimer = 2 * time.Second
- )
-
- prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
- var tempIIDHistory [header.IIDSize]byte
- header.InitialTempIID(tempIIDHistory[:], nil, nicID)
- tempAddr := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
-
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 1),
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- DADConfigs: stack.DADConfigurations{
- DupAddrDetectTransmits: dadTransmits,
- RetransmitTimer: retransmitTimer,
- },
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- AutoGenTempGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- })},
- Clock: clock,
- })
-
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- // Receive an RA to trigger SLAAC for prefix.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
-
- // DAD on the stable address for prefix has not yet completed. Receiving a new
- // RA that would refresh lifetimes should not generate a temporary SLAAC
- // address for the prefix.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
- select {
- case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpected auto gen addr event = %+v", e)
- default:
- }
-
- // Wait for DAD to complete for the stable address then expect the temporary
- // address to be generated.
- clock.Advance(dadTransmits * retransmitTimer)
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" {
- t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("timed out waiting for DAD event")
- }
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, tempAddr, newAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("timed out waiting for addr auto gen event")
- }
-}
-
-// TestAutoGenTempAddrRegen tests that temporary SLAAC addresses are
-// regenerated.
-func TestAutoGenTempAddrRegen(t *testing.T) {
- const (
- nicID = 1
- regenAdv = 2 * time.Second
-
- numTempAddrs = 3
- maxTempAddrValidLifetime = numTempAddrs * ipv6.MinPrefixInformationValidLifetimeForUpdate
- )
-
- prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
- var tempIIDHistory [header.IIDSize]byte
- header.InitialTempIID(tempIIDHistory[:], nil, nicID)
- var tempAddrs [numTempAddrs]tcpip.AddressWithPrefix
- for i := 0; i < len(tempAddrs); i++ {
- tempAddrs[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
- }
-
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
- }
- e := channel.New(0, 1280, linkAddr1)
- ndpConfigs := ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- AutoGenTempGlobalAddresses: true,
- RegenAdvanceDuration: regenAdv,
- MaxTempAddrValidLifetime: maxTempAddrValidLifetime,
- MaxTempAddrPreferredLifetime: ipv6.MinPrefixInformationValidLifetimeForUpdate,
- }
- clock := faketime.NewManualClock()
- randSource := savingRandSource{
- s: rand.NewSource(time.Now().UnixNano()),
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
- })},
- Clock: clock,
- RandSource: &randSource,
- })
-
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
-
- expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
- t.Helper()
-
- clock.Advance(timeout)
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("timed out waiting for addr auto gen event")
- }
- }
-
- tempDesyncFactor := time.Duration(randSource.lastInt63) % ipv6.MaxDesyncFactor
- effectiveMaxTempAddrPL := ipv6.MinPrefixInformationValidLifetimeForUpdate - tempDesyncFactor
- // The time since the last regeneration before a new temporary address is
- // generated.
- tempAddrRegenenerationTime := effectiveMaxTempAddrPL - regenAdv
-
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with non-zero valid & preferred lifetimes.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds))
- expectAutoGenAddrEvent(addr, newAddr)
- expectAutoGenAddrEvent(tempAddrs[0], newAddr)
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0]}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
-
- // Wait for regeneration
- expectAutoGenAddrEventAsync(tempAddrs[1], newAddr, tempAddrRegenenerationTime)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds))
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0], tempAddrs[1]}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
- expectAutoGenAddrEventAsync(tempAddrs[0], deprecatedAddr, regenAdv)
-
- // Wait for regeneration
- expectAutoGenAddrEventAsync(tempAddrs[2], newAddr, tempAddrRegenenerationTime-regenAdv)
- expectAutoGenAddrEventAsync(tempAddrs[1], deprecatedAddr, regenAdv)
-
- // Stop generating temporary addresses
- ndpConfigs.AutoGenTempGlobalAddresses = false
- if ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else {
- ndpEP := ipv6Ep.(ipv6.NDPEndpoint)
- ndpEP.SetNDPConfigurations(ndpConfigs)
- }
-
- // Refresh lifetimes and wait for the last temporary address to be deprecated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds))
- expectAutoGenAddrEventAsync(tempAddrs[2], deprecatedAddr, effectiveMaxTempAddrPL-regenAdv)
-
- // Refresh lifetimes such that the prefix is valid and preferred forever.
- //
- // This should not affect the lifetimes of temporary addresses because they
- // are capped by the maximum valid and preferred lifetimes for temporary
- // addresses.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, infiniteLifetimeSeconds, infiniteLifetimeSeconds))
-
- // Wait for all the temporary addresses to get invalidated.
- invalidateAfter := maxTempAddrValidLifetime - clock.NowMonotonic().Sub(tcpip.MonotonicTime{})
- for _, addr := range tempAddrs {
- expectAutoGenAddrEventAsync(addr, invalidatedAddr, invalidateAfter)
- invalidateAfter = tempAddrRegenenerationTime
- }
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs[:]); mismatch != "" {
- t.Fatal(mismatch)
- }
-}
-
-// TestAutoGenTempAddrRegenJobUpdates tests that a temporary address's
-// regeneration job gets updated when refreshing the address's lifetimes.
-func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
- const (
- nicID = 1
- regenAdv = 2 * time.Second
-
- numTempAddrs = 3
- maxTempAddrPreferredLifetime = ipv6.MinPrefixInformationValidLifetimeForUpdate
- maxTempAddrPreferredLifetimeSeconds = uint32(maxTempAddrPreferredLifetime / time.Second)
- )
-
- prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
- var tempIIDHistory [header.IIDSize]byte
- header.InitialTempIID(tempIIDHistory[:], nil, nicID)
- var tempAddrs [numTempAddrs]tcpip.AddressWithPrefix
- for i := 0; i < len(tempAddrs); i++ {
- tempAddrs[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
- }
-
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
- }
- e := channel.New(0, 1280, linkAddr1)
- ndpConfigs := ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- AutoGenTempGlobalAddresses: true,
- RegenAdvanceDuration: regenAdv,
- MaxTempAddrPreferredLifetime: maxTempAddrPreferredLifetime,
- MaxTempAddrValidLifetime: maxTempAddrPreferredLifetime * 2,
- }
- clock := faketime.NewManualClock()
- initialTime := clock.NowMonotonic()
- randSource := savingRandSource{
- s: rand.NewSource(time.Now().UnixNano()),
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
- })},
- Clock: clock,
- RandSource: &randSource,
- })
-
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- tempDesyncFactor := time.Duration(randSource.lastInt63) % ipv6.MaxDesyncFactor
-
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
-
- expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
- t.Helper()
-
- clock.Advance(timeout)
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("timed out waiting for addr auto gen event")
- }
- }
-
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with non-zero valid & preferred lifetimes.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, maxTempAddrPreferredLifetimeSeconds))
- expectAutoGenAddrEvent(addr, newAddr)
- expectAutoGenAddrEvent(tempAddrs[0], newAddr)
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0]}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
-
- // Deprecate the prefix.
- //
- // A new temporary address should be generated after the regeneration
- // time has passed since the prefix is deprecated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, 0))
- expectAutoGenAddrEvent(addr, deprecatedAddr)
- expectAutoGenAddrEvent(tempAddrs[0], deprecatedAddr)
- select {
- case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpected auto gen addr event = %#v", e)
- default:
- }
-
- effectiveMaxTempAddrPL := maxTempAddrPreferredLifetime - tempDesyncFactor
- // The time since the last regeneration before a new temporary address is
- // generated.
- tempAddrRegenenerationTime := effectiveMaxTempAddrPL - regenAdv
-
- // Advance the clock by the regeneration time but don't expect a new temporary
- // address as the prefix is deprecated.
- clock.Advance(tempAddrRegenenerationTime)
- select {
- case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpected auto gen addr event = %#v", e)
- default:
- }
-
- // Prefer the prefix again.
- //
- // A new temporary address should immediately be generated since the
- // regeneration time has already passed since the last address was generated
- // - this regeneration does not depend on a job.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, maxTempAddrPreferredLifetimeSeconds))
- expectAutoGenAddrEvent(tempAddrs[1], newAddr)
- // Wait for the first temporary address to be deprecated.
- expectAutoGenAddrEventAsync(tempAddrs[0], deprecatedAddr, regenAdv)
- select {
- case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpected auto gen addr event = %s", e)
- default:
- }
-
- // Increase the maximum lifetimes for temporary addresses to large values
- // then refresh the lifetimes of the prefix.
- //
- // A new address should not be generated after the regeneration time that was
- // expected for the previous check. This is because the preferred lifetime for
- // the temporary addresses has increased, so it will take more time to
- // regenerate a new temporary address. Note, new addresses are only
- // regenerated after the preferred lifetime - the regenerate advance duration
- // as paased.
- const largeLifetimeSeconds = minVLSeconds * 2
- const largeLifetime = time.Duration(largeLifetimeSeconds) * time.Second
- ndpConfigs.MaxTempAddrValidLifetime = 2 * largeLifetime
- ndpConfigs.MaxTempAddrPreferredLifetime = largeLifetime
- ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- }
- ndpEP := ipv6Ep.(ipv6.NDPEndpoint)
- ndpEP.SetNDPConfigurations(ndpConfigs)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
- timeSinceInitialTime := clock.NowMonotonic().Sub(initialTime)
- clock.Advance(largeLifetime - timeSinceInitialTime)
- expectAutoGenAddrEvent(tempAddrs[0], deprecatedAddr)
- // to offset the advement of time to test the first temporary address's
- // deprecation after the second was generated
- advLess := regenAdv
- expectAutoGenAddrEventAsync(tempAddrs[2], newAddr, timeSinceInitialTime-advLess-(tempDesyncFactor+regenAdv))
- expectAutoGenAddrEventAsync(tempAddrs[1], deprecatedAddr, regenAdv)
- select {
- case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpected auto gen addr event = %+v", e)
- default:
- }
-}
-
-// TestMixedSLAACAddrConflictRegen tests SLAAC address regeneration in response
-// to a mix of DAD conflicts and NIC-local conflicts.
-func TestMixedSLAACAddrConflictRegen(t *testing.T) {
- const (
- nicID = 1
- nicName = "nic"
- lifetimeSeconds = 9999
- // From stack.maxSLAACAddrLocalRegenAttempts
- maxSLAACAddrLocalRegenAttempts = 10
- // We use 2 more addreses than the maximum local regeneration attempts
- // because we want to also trigger regeneration in response to a DAD
- // conflicts for this test.
- maxAddrs = maxSLAACAddrLocalRegenAttempts + 2
- dupAddrTransmits = 1
- retransmitTimer = time.Second
- )
-
- var tempIIDHistoryWithModifiedEUI64 [header.IIDSize]byte
- header.InitialTempIID(tempIIDHistoryWithModifiedEUI64[:], nil, nicID)
-
- var tempIIDHistoryWithOpaqueIID [header.IIDSize]byte
- header.InitialTempIID(tempIIDHistoryWithOpaqueIID[:], nil, nicID)
-
- prefix, subnet, stableAddrWithModifiedEUI64 := prefixSubnetAddr(0, linkAddr1)
- var stableAddrsWithOpaqueIID [maxAddrs]tcpip.AddressWithPrefix
- var tempAddrsWithOpaqueIID [maxAddrs]tcpip.AddressWithPrefix
- var tempAddrsWithModifiedEUI64 [maxAddrs]tcpip.AddressWithPrefix
- addrBytes := []byte(subnet.ID())
- for i := 0; i < maxAddrs; i++ {
- stableAddrsWithOpaqueIID[i] = tcpip.AddressWithPrefix{
- Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, uint8(i), nil)),
- PrefixLen: header.IIDOffsetInIPv6Address * 8,
- }
- // When generating temporary addresses, the resolved stable address for the
- // SLAAC prefix will be the first address stable address generated for the
- // prefix as we will not simulate address conflicts for the stable addresses
- // in tests involving temporary addresses. Address conflicts for stable
- // addresses will be done in their own tests.
- tempAddrsWithOpaqueIID[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistoryWithOpaqueIID[:], stableAddrsWithOpaqueIID[0].Address)
- tempAddrsWithModifiedEUI64[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistoryWithModifiedEUI64[:], stableAddrWithModifiedEUI64.Address)
- }
-
- tests := []struct {
- name string
- addrs []tcpip.AddressWithPrefix
- tempAddrs bool
- initialExpect tcpip.AddressWithPrefix
- maxAddrs int
- nicNameFromID func(tcpip.NICID, string) string
- }{
- {
- name: "Stable addresses with opaque IIDs",
- addrs: stableAddrsWithOpaqueIID[:],
- maxAddrs: 1,
- nicNameFromID: func(tcpip.NICID, string) string {
- return nicName
- },
- },
- {
- name: "Temporary addresses with opaque IIDs",
- addrs: tempAddrsWithOpaqueIID[:],
- tempAddrs: true,
- initialExpect: stableAddrsWithOpaqueIID[0],
- maxAddrs: 1 /* initial (stable) address */ + maxSLAACAddrLocalRegenAttempts,
- nicNameFromID: func(tcpip.NICID, string) string {
- return nicName
- },
- },
- {
- name: "Temporary addresses with modified EUI64",
- addrs: tempAddrsWithModifiedEUI64[:],
- tempAddrs: true,
- maxAddrs: 1 /* initial (stable) address */ + maxSLAACAddrLocalRegenAttempts,
- initialExpect: stableAddrWithModifiedEUI64,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- // We may receive a deprecated and invalidated event for each SLAAC
- // address that is assigned.
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAddrs*2),
- }
- e := channel.New(0, 1280, linkAddr1)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- Clock: clock,
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- AutoGenTempGlobalAddresses: test.tempAddrs,
- AutoGenAddressConflictRetries: 1,
- },
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: test.nicNameFromID,
- },
- })},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
-
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv6EmptySubnet,
- Gateway: llAddr2,
- NIC: nicID,
- }})
-
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- manuallyAssignedAddresses := make(map[tcpip.Address]struct{})
- for j := 0; j < len(test.addrs)-1; j++ {
- // The NIC will not attempt to generate an address in response to a
- // NIC-local conflict after some maximum number of attempts. We skip
- // creating a conflict for the address that would be generated as part
- // of the last attempt so we can simulate a DAD conflict for this
- // address and restart the NIC-local generation process.
- if j == maxSLAACAddrLocalRegenAttempts-1 {
- continue
- }
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: test.addrs[j].Address.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
-
- manuallyAssignedAddresses[test.addrs[j].Address] = struct{}{}
- }
-
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
-
- expectAutoGenAddrAsyncEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- if diff := checkAutoGenAddrEvent(<-ndpDisp.autoGenAddrC, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- }
-
- expectDADEventAsync := func(addr tcpip.Address) {
- t.Helper()
-
- clock.Advance(dupAddrTransmits * retransmitTimer)
- if diff := checkDADEvent(<-ndpDisp.dadC, nicID, addr, &stack.DADSucceeded{}); diff != "" {
- t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
- }
- }
-
- // Enable DAD.
- ndpDisp.dadC = make(chan ndpDADEvent, 2)
- if ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else {
- ndpEP := ipv6Ep.(stack.DuplicateAddressDetector)
- ndpEP.SetDADConfigurations(stack.DADConfigurations{
- DupAddrDetectTransmits: dupAddrTransmits,
- RetransmitTimer: retransmitTimer,
- })
- }
-
- // Do SLAAC for prefix.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds))
- if test.initialExpect != (tcpip.AddressWithPrefix{}) {
- expectAutoGenAddrEvent(test.initialExpect, newAddr)
- expectDADEventAsync(test.initialExpect.Address)
- }
-
- // The last local generation attempt should succeed, but we introduce a
- // DAD failure to restart the local generation process.
- addr := test.addrs[maxSLAACAddrLocalRegenAttempts-1]
- expectAutoGenAddrAsyncEvent(addr, newAddr)
- rxNDPSolicit(e, addr.Address)
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADDupAddrDetected{}); diff != "" {
- t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected DAD event")
- }
- expectAutoGenAddrEvent(addr, invalidatedAddr)
-
- // The last address generated should resolve DAD.
- addr = test.addrs[len(test.addrs)-1]
- expectAutoGenAddrAsyncEvent(addr, newAddr)
- expectDADEventAsync(addr.Address)
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpected auto gen addr event = %+v", e)
- default:
- }
-
- // Wait for all the SLAAC addresses to be invalidated.
- clock.Advance(lifetimeSeconds * time.Second)
- gotAddresses := make(map[tcpip.Address]struct{})
- for _, a := range s.NICInfo()[nicID].ProtocolAddresses {
- gotAddresses[a.AddressWithPrefix.Address] = struct{}{}
- }
- if diff := cmp.Diff(manuallyAssignedAddresses, gotAddresses); diff != "" {
- t.Fatalf("assigned addresses mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
-
-// stackAndNdpDispatcherWithDefaultRoute returns an ndpDispatcher,
-// channel.Endpoint and stack.Stack.
-//
-// stack.Stack will have a default route through the router (llAddr3) installed
-// and a static link-address (linkAddr3) added to the link address cache for the
-// router.
-func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) {
- t.Helper()
- ndpDisp := &ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: ndpDisp,
- })},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- Clock: clock,
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv6EmptySubnet,
- Gateway: llAddr3,
- NIC: nicID,
- }})
-
- if err := s.AddStaticNeighbor(nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3); err != nil {
- t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3, err)
- }
- return ndpDisp, e, s, clock
-}
-
-// addrForNewConnectionTo returns the local address used when creating a new
-// connection to addr.
-func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address {
- t.Helper()
-
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
- defer close(ch)
- ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
- }
- defer ep.Close()
- ep.SocketOptions().SetV6Only(true)
- if err := ep.Connect(addr); err != nil {
- t.Fatalf("ep.Connect(%+v): %s", addr, err)
- }
- got, err := ep.GetLocalAddress()
- if err != nil {
- t.Fatalf("ep.GetLocalAddress(): %s", err)
- }
- return got.Addr
-}
-
-// addrForNewConnection returns the local address used when creating a new
-// connection.
-func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address {
- t.Helper()
-
- return addrForNewConnectionTo(t, s, dstAddr)
-}
-
-// addrForNewConnectionWithAddr returns the local address used when creating a
-// new connection with a specific local address.
-func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address {
- t.Helper()
-
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
- defer close(ch)
- ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
- }
- defer ep.Close()
- ep.SocketOptions().SetV6Only(true)
- if err := ep.Bind(addr); err != nil {
- t.Fatalf("ep.Bind(%+v): %s", addr, err)
- }
- if err := ep.Connect(dstAddr); err != nil {
- t.Fatalf("ep.Connect(%+v): %s", dstAddr, err)
- }
- got, err := ep.GetLocalAddress()
- if err != nil {
- t.Fatalf("ep.GetLocalAddress(): %s", err)
- }
- return got.Addr
-}
-
-// TestAutoGenAddrDeprecateFromPI tests deprecating a SLAAC address when
-// receiving a PI with 0 preferred lifetime.
-func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
- const nicID = 1
-
- prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
- prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
-
- ndpDisp, e, s, _ := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
-
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
-
- expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
- t.Helper()
-
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil {
- t.Fatal(err)
- }
-
- if got := addrForNewConnection(t, s); got != addr.Address {
- t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
- }
- }
-
- // Receive PI for prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
- expectAutoGenAddrEvent(addr1, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should have %s in the list of addresses", addr1)
- }
- expectPrimaryAddr(addr1)
-
- // Deprecate addr for prefix1 immedaitely.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
- expectAutoGenAddrEvent(addr1, deprecatedAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should have %s in the list of addresses", addr1)
- }
- // addr should still be the primary endpoint as there are no other addresses.
- expectPrimaryAddr(addr1)
-
- // Refresh lifetimes of addr generated from prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr1)
-
- // Receive PI for prefix2.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
- expectAutoGenAddrEvent(addr2, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- expectPrimaryAddr(addr2)
-
- // Deprecate addr for prefix2 immedaitely.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
- expectAutoGenAddrEvent(addr2, deprecatedAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- // addr1 should be the primary endpoint now since addr2 is deprecated but
- // addr1 is not.
- expectPrimaryAddr(addr1)
- // addr2 is deprecated but if explicitly requested, it should be used.
- fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID}
- if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address)
- }
-
- // Another PI w/ 0 preferred lifetime should not result in a deprecation
- // event.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr1)
- if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address)
- }
-
- // Refresh lifetimes of addr generated from prefix2.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr2)
-}
-
-// TestAutoGenAddrJobDeprecation tests that an address is properly deprecated
-// when its preferred lifetime expires.
-func TestAutoGenAddrJobDeprecation(t *testing.T) {
- const nicID = 1
-
- prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
- prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
-
- ndpDisp, e, s, clock := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
-
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
-
- expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
- t.Helper()
-
- clock.Advance(timeout)
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("timed out waiting for addr auto gen event")
- }
- }
-
- expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
- t.Helper()
-
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil {
- t.Fatal(err)
- }
-
- if got := addrForNewConnection(t, s); got != addr.Address {
- t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
- }
- }
-
- // Receive PI for prefix2.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, infiniteLifetimeSeconds, infiniteLifetimeSeconds))
- expectAutoGenAddrEvent(addr2, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- expectPrimaryAddr(addr2)
-
- // Receive a PI for prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90))
- expectAutoGenAddrEvent(addr1, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- expectPrimaryAddr(addr1)
-
- // Refresh lifetime for addr of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, minVLSeconds-1))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr1)
-
- // Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, ipv6.MinPrefixInformationValidLifetimeForUpdate-time.Second)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should not have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- // addr2 should be the primary endpoint now since addr1 is deprecated but
- // addr2 is not.
- expectPrimaryAddr(addr2)
-
- // addr1 is deprecated but if explicitly requested, it should be used.
- fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID}
- if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
- }
-
- // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make
- // sure we do not get a deprecation event again.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, 0))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr2)
- if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
- }
-
- // Refresh lifetimes for addr of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, minVLSeconds-1))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- // addr1 is the primary endpoint again since it is non-deprecated now.
- expectPrimaryAddr(addr1)
-
- // Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, ipv6.MinPrefixInformationValidLifetimeForUpdate-time.Second)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should not have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- // addr2 should be the primary endpoint now since it is not deprecated.
- expectPrimaryAddr(addr2)
- if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
- }
-
- // Wait for addr of prefix1 to be invalidated.
- expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second)
- if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should not have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- expectPrimaryAddr(addr2)
-
- // Refresh both lifetimes for addr of prefix2 to the same value.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, minVLSeconds, minVLSeconds))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
-
- // Wait for a deprecation then invalidation events, or just an invalidation
- // event. We need to cover both cases but cannot deterministically hit both
- // cases because the deprecation and invalidation handlers could be handled in
- // either deprecation then invalidation, or invalidation then deprecation
- // (which should be cancelled by the invalidation handler).
- //
- // Since we're about to cause both events to fire, we need the dispatcher
- // channel to be able to hold both.
- if got, want := len(ndpDisp.autoGenAddrC), 0; got != want {
- t.Fatalf("got len(ndpDisp.autoGenAddrC) = %d, want %d", got, want)
- }
- if got, want := cap(ndpDisp.autoGenAddrC), 1; got != want {
- t.Fatalf("got cap(ndpDisp.autoGenAddrC) = %d, want %d", got, want)
- }
- ndpDisp.autoGenAddrC = make(chan ndpAutoGenAddrEvent, 2)
- clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate)
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" {
- // If we get a deprecation event first, we should get an invalidation
- // event almost immediately after.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("timed out waiting for addr auto gen event")
- }
- } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" {
- // If we get an invalidation event first, we should not get a deprecation
- // event after.
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- } else {
- t.Fatalf("got unexpected auto-generated event")
- }
- default:
- t.Fatal("timed out waiting for addr auto gen event")
- }
- if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should not have %s in the list of addresses", addr1)
- }
- if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should not have %s in the list of addresses", addr2)
- }
- // Should not have any primary endpoints.
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
- t.Fatal(err)
- }
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
- defer close(ch)
- ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
- }
- defer ep.Close()
- ep.SocketOptions().SetV6Only(true)
-
- {
- err := ep.Connect(dstAddr)
- if _, ok := err.(*tcpip.ErrNoRoute); !ok {
- t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, &tcpip.ErrNoRoute{})
- }
- }
-}
-
-// Tests transitioning a SLAAC address's valid lifetime between finite and
-// infinite values.
-func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
- const infiniteVLSeconds = 2
-
- prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
-
- tests := []struct {
- name string
- infiniteVL uint32
- }{
- {
- name: "EqualToInfiniteVL",
- infiniteVL: infiniteVLSeconds,
- },
- // Our implementation supports changing header.NDPInfiniteLifetime for tests
- // such that a packet can be received where the lifetime field has a value
- // greater than header.NDPInfiniteLifetime. Because of this, we test to make
- // sure that receiving a value greater than header.NDPInfiniteLifetime is
- // handled the same as when receiving a value equal to
- // header.NDPInfiniteLifetime.
- {
- name: "MoreThanInfiniteVL",
- infiniteVL: infiniteVLSeconds + 1,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- })},
- Clock: clock,
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Receive an RA with finite prefix.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0))
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
-
- default:
- t.Fatal("expected addr auto gen event")
- }
-
- // Receive an new RA with prefix with infinite VL.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0))
-
- // Receive a new RA with prefix with finite VL.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0))
-
- clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate)
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
-
- default:
- t.Fatal("timeout waiting for addr auto gen event")
- }
- })
- }
-}
-
-// TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an
-// auto-generated address only gets updated when required to, as specified in
-// RFC 4862 section 5.5.3.e.
-func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
- const infiniteVL = 4294967295
-
- prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
-
- tests := []struct {
- name string
- ovl uint32
- nvl uint32
- evl uint32
- }{
- // Should update the VL to the minimum VL for updating if the
- // new VL is less than minVLSeconds but was originally greater than
- // it.
- {
- "LargeVLToVLLessThanMinVLForUpdate",
- 9999,
- 1,
- minVLSeconds,
- },
- {
- "LargeVLTo0",
- 9999,
- 0,
- minVLSeconds,
- },
- {
- "InfiniteVLToVLLessThanMinVLForUpdate",
- infiniteVL,
- 1,
- minVLSeconds,
- },
- {
- "InfiniteVLTo0",
- infiniteVL,
- 0,
- minVLSeconds,
- },
-
- // Should not update VL if original VL was less than minVLSeconds
- // and the new VL is also less than minVLSeconds.
- {
- "ShouldNotUpdateWhenBothOldAndNewAreLessThanMinVLForUpdate",
- minVLSeconds - 1,
- minVLSeconds - 3,
- minVLSeconds - 1,
- },
-
- // Should take the new VL if the new VL is greater than the
- // remaining time or is greater than minVLSeconds.
- {
- "MorethanMinVLToLesserButStillMoreThanMinVLForUpdate",
- minVLSeconds + 5,
- minVLSeconds + 3,
- minVLSeconds + 3,
- },
- {
- "SmallVLToGreaterVLButStillLessThanMinVLForUpdate",
- minVLSeconds - 3,
- minVLSeconds - 1,
- minVLSeconds - 1,
- },
- {
- "SmallVLToGreaterVLThatIsMoreThaMinVLForUpdate",
- minVLSeconds - 3,
- minVLSeconds + 1,
- minVLSeconds + 1,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10),
- }
- e := channel.New(10, 1280, linkAddr1)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- })},
- Clock: clock,
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Receive an RA with prefix with initial VL,
- // test.ovl.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0))
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
-
- // Receive an new RA with prefix with new VL,
- // test.nvl.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0))
-
- //
- // Validate that the VL for the address got set
- // to test.evl.
- //
-
- // The address should not be invalidated until the effective valid
- // lifetime has passed.
- const delta = 1
- clock.Advance(time.Duration(test.evl)*time.Second - delta)
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly received an auto gen addr event")
- default:
- }
-
- // Wait for the invalidation event.
- clock.Advance(delta)
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("timeout waiting for addr auto gen event")
- }
- })
- }
-}
-
-// TestAutoGenAddrRemoval tests that when auto-generated addresses are removed
-// by the user, its resources will be cleaned up and an invalidation event will
-// be sent to the integrator.
-func TestAutoGenAddrRemoval(t *testing.T) {
- prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
-
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- })},
- Clock: clock,
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
-
- // Receive a PI to auto-generate an address.
- const lifetimeSeconds = 1
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0))
- expectAutoGenAddrEvent(addr, newAddr)
-
- // Removing the address should result in an invalidation event
- // immediately.
- if err := s.RemoveAddress(1, addr.Address); err != nil {
- t.Fatalf("RemoveAddress(_, %s) = %s", addr.Address, err)
- }
- expectAutoGenAddrEvent(addr, invalidatedAddr)
-
- // Wait for the original valid lifetime to make sure the original job got
- // cancelled/cleaned up.
- clock.Advance(lifetimeSeconds * time.Second)
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly received an auto gen addr event")
- default:
- }
-}
-
-// TestAutoGenAddrAfterRemoval tests adding a SLAAC address that was previously
-// assigned to the NIC but is in the permanentExpired state.
-func TestAutoGenAddrAfterRemoval(t *testing.T) {
- const nicID = 1
-
- prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
- prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
- ndpDisp, e, s, _ := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
-
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
-
- expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
- t.Helper()
-
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil {
- t.Fatal(err)
- }
-
- if got := addrForNewConnection(t, s); got != addr.Address {
- t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
- }
- }
-
- // Receive a PI to auto-generate addr1 with a large valid and preferred
- // lifetime.
- const largeLifetimeSeconds = 999
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
- expectAutoGenAddrEvent(addr1, newAddr)
- expectPrimaryAddr(addr1)
-
- // Add addr2 as a static address.
- protoAddr2 := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: addr2,
- }
- properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
- if err := s.AddProtocolAddress(nicID, protoAddr2, properties); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, %+v) = %s", nicID, protoAddr2, properties, err)
- }
- // addr2 should be more preferred now since it is at the front of the primary
- // list.
- expectPrimaryAddr(addr2)
-
- // Get a route using addr2 to increment its reference count then remove it
- // to leave it in the permanentExpired state.
- r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false)
- if err != nil {
- t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err)
- }
- defer r.Release()
- if err := s.RemoveAddress(nicID, addr2.Address); err != nil {
- t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err)
- }
- // addr1 should be preferred again since addr2 is in the expired state.
- expectPrimaryAddr(addr1)
-
- // Receive a PI to auto-generate addr2 as valid and preferred.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
- expectAutoGenAddrEvent(addr2, newAddr)
- // addr2 should be more preferred now that it is closer to the front of the
- // primary list and not deprecated.
- expectPrimaryAddr(addr2)
-
- // Removing the address should result in an invalidation event immediately.
- // It should still be in the permanentExpired state because r is still held.
- //
- // We remove addr2 here to make sure addr2 was marked as a SLAAC address
- // (it was previously marked as a static address).
- if err := s.RemoveAddress(1, addr2.Address); err != nil {
- t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
- }
- expectAutoGenAddrEvent(addr2, invalidatedAddr)
- // addr1 should be more preferred since addr2 is in the expired state.
- expectPrimaryAddr(addr1)
-
- // Receive a PI to auto-generate addr2 as valid and deprecated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0))
- expectAutoGenAddrEvent(addr2, newAddr)
- // addr1 should still be more preferred since addr2 is deprecated, even though
- // it is closer to the front of the primary list.
- expectPrimaryAddr(addr1)
-
- // Receive a PI to refresh addr2's preferred lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto gen addr event")
- default:
- }
- // addr2 should be more preferred now that it is not deprecated.
- expectPrimaryAddr(addr2)
-
- if err := s.RemoveAddress(1, addr2.Address); err != nil {
- t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
- }
- expectAutoGenAddrEvent(addr2, invalidatedAddr)
- expectPrimaryAddr(addr1)
-}
-
-// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that
-// is already assigned to the NIC, the static address remains.
-func TestAutoGenAddrStaticConflict(t *testing.T) {
- prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
-
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- })},
- Clock: clock,
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Add the address as a static address before SLAAC tries to add it.
- protocolAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}
- if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(1, %+v, {}) = %s", protocolAddr, err)
- }
- if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
- t.Fatalf("Should have %s in the list of addresses", addr1)
- }
-
- // Receive a PI where the generated address will be the same as the one
- // that we already have assigned statically.
- const lifetimeSeconds = 1
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly received an auto gen addr event for an address we already have statically")
- default:
- }
- if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
- t.Fatalf("Should have %s in the list of addresses", addr1)
- }
-
- // Should not get an invalidation event after the PI's invalidation
- // time.
- clock.Advance(lifetimeSeconds * time.Second)
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly received an auto gen addr event")
- default:
- }
- if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
- t.Fatalf("Should have %s in the list of addresses", addr1)
- }
-}
-
-func makeSecretKey(t *testing.T) []byte {
- secretKey := make([]byte, header.OpaqueIIDSecretKeyMinBytes)
- n, err := cryptorand.Read(secretKey)
- if err != nil {
- t.Fatalf("cryptorand.Read(_): %s", err)
- }
- if l := len(secretKey); n != l {
- t.Fatalf("got cryptorand.Read(_) = (%d, nil), want = (%d, nil)", n, l)
- }
- return secretKey
-}
-
-// TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use
-// opaque interface identifiers when configured to do so.
-func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
- const nicID = 1
- const nicName = "nic1"
-
- secretKey := makeSecretKey(t)
-
- prefix1, subnet1, _ := prefixSubnetAddr(0, linkAddr1)
- prefix2, subnet2, _ := prefixSubnetAddr(1, linkAddr1)
- // addr1 and addr2 are the addresses that are expected to be generated when
- // stack.Stack is configured to generate opaque interface identifiers as
- // defined by RFC 7217.
- addrBytes := []byte(subnet1.ID())
- addr1 := tcpip.AddressWithPrefix{
- Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet1, nicName, 0, secretKey)),
- PrefixLen: 64,
- }
- addrBytes = []byte(subnet2.ID())
- addr2 := tcpip.AddressWithPrefix{
- Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet2, nicName, 0, secretKey)),
- PrefixLen: 64,
- }
-
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: func(_ tcpip.NICID, nicName string) string {
- return nicName
- },
- SecretKey: secretKey,
- },
- })},
- Clock: clock,
- })
- opts := stack.NICOptions{Name: nicName}
- if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
- t.Fatalf("CreateNICWithOptions(%d, _, %+v, _) = %s", nicID, opts, err)
- }
-
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
-
- // Receive an RA with prefix1 in a PI.
- const validLifetimeSecondPrefix1 = 1
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, validLifetimeSecondPrefix1, 0))
- expectAutoGenAddrEvent(addr1, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should have %s in the list of addresses", addr1)
- }
-
- // Receive an RA with prefix2 in a PI with a large valid lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
- expectAutoGenAddrEvent(addr2, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
-
- // Wait for addr of prefix1 to be invalidated.
- clock.Advance(validLifetimeSecondPrefix1 * time.Second)
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("timed out waiting for addr auto gen event")
- }
- if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should not have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
-}
-
-func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
- const nicID = 1
- const nicName = "nic"
- const dadTransmits = 1
- const retransmitTimer = time.Second
- const maxMaxRetries = 3
- const lifetimeSeconds = 10
-
- secretKey := makeSecretKey(t)
-
- prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1)
-
- addrForSubnet := func(subnet tcpip.Subnet, dadCounter uint8) tcpip.AddressWithPrefix {
- addrBytes := []byte(subnet.ID())
- return tcpip.AddressWithPrefix{
- Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, dadCounter, secretKey)),
- PrefixLen: 64,
- }
- }
-
- expectAutoGenAddrEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
-
- expectAutoGenAddrEventAsync := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- clock.RunImmediatelyScheduledJobs()
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("timed out waiting for addr auto gen event")
- }
- }
-
- expectDADEvent := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) {
- t.Helper()
-
- clock.RunImmediatelyScheduledJobs()
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr, res); diff != "" {
- t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected DAD event")
- }
- }
-
- expectDADEventAsync := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) {
- t.Helper()
-
- clock.Advance(dadTransmits * retransmitTimer)
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr, res); diff != "" {
- t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("timed out waiting for DAD event")
- }
- }
-
- stableAddrForTempAddrTest := addrForSubnet(subnet, 0)
-
- addrTypes := []struct {
- name string
- ndpConfigs ipv6.NDPConfigurations
- autoGenLinkLocal bool
- prepareFn func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix
- addrGenFn func(dadCounter uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix
- }{
- {
- name: "Global address",
- ndpConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- },
- prepareFn: func(_ *testing.T, _ *faketime.ManualClock, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix {
- // Receive an RA with prefix1 in a PI.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds))
- return nil
-
- },
- addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix {
- return addrForSubnet(subnet, dadCounter)
- },
- },
- {
- name: "LinkLocal address",
- ndpConfigs: ipv6.NDPConfigurations{},
- autoGenLinkLocal: true,
- prepareFn: func(*testing.T, *faketime.ManualClock, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix {
- return nil
- },
- addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix {
- return addrForSubnet(header.IPv6LinkLocalPrefix.Subnet(), dadCounter)
- },
- },
- {
- name: "Temporary address",
- ndpConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- AutoGenTempGlobalAddresses: true,
- },
- prepareFn: func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix {
- header.InitialTempIID(tempIIDHistory, nil, nicID)
-
- // Generate a stable SLAAC address so temporary addresses will be
- // generated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
- expectAutoGenAddrEvent(t, ndpDisp, stableAddrForTempAddrTest, newAddr)
- expectDADEventAsync(t, clock, ndpDisp, stableAddrForTempAddrTest.Address, &stack.DADSucceeded{})
-
- // The stable address will be assigned throughout the test.
- return []tcpip.AddressWithPrefix{stableAddrForTempAddrTest}
- },
- addrGenFn: func(_ uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix {
- return header.GenerateTempIPv6SLAACAddr(tempIIDHistory, stableAddrForTempAddrTest.Address)
- },
- },
- }
-
- for _, addrType := range addrTypes {
- t.Run(addrType.name, func(t *testing.T) {
- for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ {
- for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ {
- maxRetries := maxRetries
- numFailures := numFailures
- addrType := addrType
-
- t.Run(fmt.Sprintf("%d max retries and %d failures", maxRetries, numFailures), func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 1),
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
- }
- e := channel.New(0, 1280, linkAddr1)
- ndpConfigs := addrType.ndpConfigs
- ndpConfigs.AutoGenAddressConflictRetries = maxRetries
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenLinkLocal: addrType.autoGenLinkLocal,
- DADConfigs: stack.DADConfigurations{
- DupAddrDetectTransmits: dadTransmits,
- RetransmitTimer: retransmitTimer,
- },
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: func(_ tcpip.NICID, nicName string) string {
- return nicName
- },
- SecretKey: secretKey,
- },
- })},
- Clock: clock,
- })
- opts := stack.NICOptions{Name: nicName}
- if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
- t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err)
- }
-
- var tempIIDHistory [header.IIDSize]byte
- stableAddrs := addrType.prepareFn(t, clock, &ndpDisp, e, tempIIDHistory[:])
-
- // Simulate DAD conflicts so the address is regenerated.
- for i := uint8(0); i < numFailures; i++ {
- addr := addrType.addrGenFn(i, tempIIDHistory[:])
- expectAutoGenAddrEventAsync(t, clock, &ndpDisp, addr, newAddr)
-
- // Should not have any new addresses assigned to the NIC.
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
-
- // Simulate a DAD conflict.
- rxNDPSolicit(e, addr.Address)
- expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr)
- expectDADEvent(t, clock, &ndpDisp, addr.Address, &stack.DADDupAddrDetected{})
-
- // Attempting to add the address manually should not fail if the
- // address's state was cleaned up when DAD failed.
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: addr,
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- if err := s.RemoveAddress(nicID, addr.Address); err != nil {
- t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err)
- }
- expectDADEvent(t, clock, &ndpDisp, addr.Address, &stack.DADAborted{})
- }
-
- // Should not have any new addresses assigned to the NIC.
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
-
- // If we had less failures than generation attempts, we should have
- // an address after DAD resolves.
- if maxRetries+1 > numFailures {
- addr := addrType.addrGenFn(numFailures, tempIIDHistory[:])
- expectAutoGenAddrEventAsync(t, clock, &ndpDisp, addr, newAddr)
- expectDADEventAsync(t, clock, &ndpDisp, addr.Address, &stack.DADSucceeded{})
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, append(stableAddrs, addr), nil); mismatch != "" {
- t.Fatal(mismatch)
- }
- }
-
- // Should not attempt address generation again.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpectedly got an auto-generated address event = %+v", e)
- default:
- }
- })
- }
- }
- })
- }
-}
-
-// TestAutoGenAddrWithEUI64IIDNoDADRetries tests that a regeneration attempt is
-// not made for SLAAC addresses generated with an IID based on the NIC's link
-// address.
-func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
- const nicID = 1
- const dadTransmits = 1
- const retransmitTimer = time.Second
- const maxRetries = 3
- const lifetimeSeconds = 10
-
- prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1)
-
- addrTypes := []struct {
- name string
- ndpConfigs ipv6.NDPConfigurations
- autoGenLinkLocal bool
- subnet tcpip.Subnet
- triggerSLAACFn func(e *channel.Endpoint)
- }{
- {
- name: "Global address",
- ndpConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- AutoGenAddressConflictRetries: maxRetries,
- },
- subnet: subnet,
- triggerSLAACFn: func(e *channel.Endpoint) {
- // Receive an RA with prefix1 in a PI.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds))
-
- },
- },
- {
- name: "LinkLocal address",
- ndpConfigs: ipv6.NDPConfigurations{
- AutoGenAddressConflictRetries: maxRetries,
- },
- autoGenLinkLocal: true,
- subnet: header.IPv6LinkLocalPrefix.Subnet(),
- triggerSLAACFn: func(e *channel.Endpoint) {},
- },
- }
-
- for _, addrType := range addrTypes {
- addrType := addrType
-
- t.Run(addrType.name, func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 1),
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
- }
- e := channel.New(0, 1280, linkAddr1)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenLinkLocal: addrType.autoGenLinkLocal,
- NDPConfigs: addrType.ndpConfigs,
- NDPDisp: &ndpDisp,
- DADConfigs: stack.DADConfigurations{
- DupAddrDetectTransmits: dadTransmits,
- RetransmitTimer: retransmitTimer,
- },
- })},
- Clock: clock,
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
-
- addrType.triggerSLAACFn(e)
-
- addrBytes := []byte(addrType.subnet.ID())
- header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr1, addrBytes[header.IIDOffsetInIPv6Address:])
- addr := tcpip.AddressWithPrefix{
- Address: tcpip.Address(addrBytes),
- PrefixLen: 64,
- }
- expectAutoGenAddrEvent(addr, newAddr)
-
- // Simulate a DAD conflict.
- rxNDPSolicit(e, addr.Address)
- expectAutoGenAddrEvent(addr, invalidatedAddr)
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADDupAddrDetected{}); diff != "" {
- t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected DAD event")
- }
-
- // Should not attempt address regeneration.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpectedly got an auto-generated address event = %+v", e)
- default:
- }
- })
- }
-}
-
-// TestAutoGenAddrContinuesLifetimesAfterRetry tests that retrying address
-// generation in response to DAD conflicts does not refresh the lifetimes.
-func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
- const nicID = 1
- const nicName = "nic"
- const dadTransmits = 1
- const retransmitTimer = 2 * time.Second
- const failureTimer = time.Second
- const maxRetries = 1
- const lifetimeSeconds = 5
-
- secretKey := makeSecretKey(t)
-
- prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1)
-
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 1),
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
- }
- e := channel.New(0, 1280, linkAddr1)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- DADConfigs: stack.DADConfigurations{
- DupAddrDetectTransmits: dadTransmits,
- RetransmitTimer: retransmitTimer,
- },
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- AutoGenAddressConflictRetries: maxRetries,
- },
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: func(_ tcpip.NICID, nicName string) string {
- return nicName
- },
- SecretKey: secretKey,
- },
- })},
- Clock: clock,
- })
- opts := stack.NICOptions{Name: nicName}
- if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
- t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err)
- }
-
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
-
- // Receive an RA with prefix in a PI.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds))
-
- addrBytes := []byte(subnet.ID())
- addr := tcpip.AddressWithPrefix{
- Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 0, secretKey)),
- PrefixLen: 64,
- }
- expectAutoGenAddrEvent(addr, newAddr)
-
- // Simulate a DAD conflict after some time has passed.
- clock.Advance(failureTimer)
- rxNDPSolicit(e, addr.Address)
- expectAutoGenAddrEvent(addr, invalidatedAddr)
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADDupAddrDetected{}); diff != "" {
- t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected DAD event")
- }
-
- // Let the next address resolve.
- addr.Address = tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 1, secretKey))
- expectAutoGenAddrEvent(addr, newAddr)
- clock.Advance(dadTransmits * retransmitTimer)
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" {
- t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("timed out waiting for DAD event")
- }
-
- // Address should be deprecated/invalidated after the lifetime expires.
- //
- // Note, the remaining lifetime is calculated from when the PI was first
- // processed. Since we wait for some time before simulating a DAD conflict
- // and more time for the new address to resolve, the new address is only
- // expected to be valid for the remaining time. The DAD conflict should
- // not have reset the lifetimes.
- //
- // We expect either just the invalidation event or the deprecation event
- // followed by the invalidation event.
- clock.Advance(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer)
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if e.eventType == deprecatedAddr {
- if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("timed out waiting for invalidated auto gen addr event after deprecation")
- }
- } else {
- if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- }
- default:
- t.Fatal("timed out waiting for auto gen addr event")
- }
-}
-
-// TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event
-// to the integrator when an RA is received with the NDP Recursive DNS Server
-// option with at least one valid address.
-func TestNDPRecursiveDNSServerDispatch(t *testing.T) {
- tests := []struct {
- name string
- opt header.NDPRecursiveDNSServer
- expected *ndpRDNSS
- }{
- {
- "Unspecified",
- header.NDPRecursiveDNSServer([]byte{
- 0, 0,
- 0, 0, 0, 2,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- }),
- nil,
- },
- {
- "Multicast",
- header.NDPRecursiveDNSServer([]byte{
- 0, 0,
- 0, 0, 0, 2,
- 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
- }),
- nil,
- },
- {
- "OptionTooSmall",
- header.NDPRecursiveDNSServer([]byte{
- 0, 0,
- 0, 0, 0, 2,
- 1, 2, 3, 4, 5, 6, 7, 8,
- }),
- nil,
- },
- {
- "0Addresses",
- header.NDPRecursiveDNSServer([]byte{
- 0, 0,
- 0, 0, 0, 2,
- }),
- nil,
- },
- {
- "Valid1Address",
- header.NDPRecursiveDNSServer([]byte{
- 0, 0,
- 0, 0, 0, 2,
- 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1,
- }),
- &ndpRDNSS{
- []tcpip.Address{
- "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01",
- },
- 2 * time.Second,
- },
- },
- {
- "Valid2Addresses",
- header.NDPRecursiveDNSServer([]byte{
- 0, 0,
- 0, 0, 0, 1,
- 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1,
- 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2,
- }),
- &ndpRDNSS{
- []tcpip.Address{
- "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01",
- "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02",
- },
- time.Second,
- },
- },
- {
- "Valid3Addresses",
- header.NDPRecursiveDNSServer([]byte{
- 0, 0,
- 0, 0, 0, 0,
- 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1,
- 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2,
- 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 3,
- }),
- &ndpRDNSS{
- []tcpip.Address{
- "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01",
- "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02",
- "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x03",
- },
- 0,
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- // We do not expect more than a single RDNSS
- // event at any time for this test.
- rdnssC: make(chan ndpRDNSSEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- },
- NDPDisp: &ndpDisp,
- })},
- })
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, header.NDPOptionsSerializer{test.opt}))
-
- if test.expected != nil {
- select {
- case e := <-ndpDisp.rdnssC:
- if e.nicID != 1 {
- t.Errorf("got rdnss nicID = %d, want = 1", e.nicID)
- }
- if diff := cmp.Diff(e.rdnss.addrs, test.expected.addrs); diff != "" {
- t.Errorf("rdnss addrs mismatch (-want +got):\n%s", diff)
- }
- if e.rdnss.lifetime != test.expected.lifetime {
- t.Errorf("got rdnss lifetime = %s, want = %s", e.rdnss.lifetime, test.expected.lifetime)
- }
- default:
- t.Fatal("expected an RDNSS option event")
- }
- }
-
- // Should have no more RDNSS options.
- select {
- case e := <-ndpDisp.rdnssC:
- t.Fatalf("unexpectedly got a new RDNSS option event: %+v", e)
- default:
- }
- })
- }
-}
-
-// TestNDPDNSSearchListDispatch tests that the integrator is informed when an
-// NDP DNS Search List option is received with at least one domain name in the
-// search list.
-func TestNDPDNSSearchListDispatch(t *testing.T) {
- const nicID = 1
-
- ndpDisp := ndpDispatcher{
- dnsslC: make(chan ndpDNSSLEvent, 3),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- },
- NDPDisp: &ndpDisp,
- })},
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- optSer := header.NDPOptionsSerializer{
- header.NDPDNSSearchList([]byte{
- 0, 0,
- 0, 0, 0, 0,
- 2, 'h', 'i',
- 0,
- }),
- header.NDPDNSSearchList([]byte{
- 0, 0,
- 0, 0, 0, 1,
- 1, 'i',
- 0,
- 2, 'a', 'm',
- 2, 'm', 'e',
- 0,
- }),
- header.NDPDNSSearchList([]byte{
- 0, 0,
- 0, 0, 1, 0,
- 3, 'x', 'y', 'z',
- 0,
- 5, 'h', 'e', 'l', 'l', 'o',
- 5, 'w', 'o', 'r', 'l', 'd',
- 0,
- 4, 't', 'h', 'i', 's',
- 2, 'i', 's',
- 1, 'a',
- 4, 't', 'e', 's', 't',
- 0,
- }),
- }
- expected := []struct {
- domainNames []string
- lifetime time.Duration
- }{
- {
- domainNames: []string{
- "hi",
- },
- lifetime: 0,
- },
- {
- domainNames: []string{
- "i",
- "am.me",
- },
- lifetime: time.Second,
- },
- {
- domainNames: []string{
- "xyz",
- "hello.world",
- "this.is.a.test",
- },
- lifetime: 256 * time.Second,
- },
- }
-
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer))
-
- for i, expected := range expected {
- select {
- case dnssl := <-ndpDisp.dnsslC:
- if dnssl.nicID != nicID {
- t.Errorf("got %d-th dnssl nicID = %d, want = %d", i, dnssl.nicID, nicID)
- }
- if diff := cmp.Diff(dnssl.domainNames, expected.domainNames); diff != "" {
- t.Errorf("%d-th dnssl domain names mismatch (-want +got):\n%s", i, diff)
- }
- if dnssl.lifetime != expected.lifetime {
- t.Errorf("got %d-th dnssl lifetime = %s, want = %s", i, dnssl.lifetime, expected.lifetime)
- }
- default:
- t.Fatal("expected a DNSSL event")
- }
- }
-
- // Should have no more DNSSL options.
- select {
- case <-ndpDisp.dnsslC:
- t.Fatal("unexpectedly got a DNSSL event")
- default:
- }
-}
-
-func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) {
- const (
- lifetimeSeconds = 999
- nicID = 1
- )
-
- ndpDisp := ndpDispatcher{
- offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1),
- prefixC: make(chan ndpPrefixEvent, 1),
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenLinkLocal: true,
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- DiscoverDefaultRouters: true,
- DiscoverOnLinkPrefixes: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- })},
- })
-
- e1 := channel.New(0, header.IPv6MinimumMTU, linkAddr1)
- if err := s.CreateNIC(nicID, e1); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- llAddr := tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, llAddr, newAddr); diff != "" {
- t.Errorf("auto-gen addr mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", llAddr, nicID)
- }
-
- prefix, subnet, addr := prefixSubnetAddr(0, linkAddr1)
- e1.InjectInbound(
- header.IPv6ProtocolNumber,
- raBufWithPI(
- llAddr3,
- lifetimeSeconds,
- prefix,
- true, /* onLink */
- true, /* auto */
- lifetimeSeconds,
- lifetimeSeconds,
- ),
- )
- select {
- case e := <-ndpDisp.offLinkRouteC:
- if diff := checkOffLinkRouteEvent(e, nicID, header.IPv6EmptySubnet, llAddr3, header.MediumRoutePreference, true /* discovered */); diff != "" {
- t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID)
- }
- select {
- case e := <-ndpDisp.prefixC:
- if diff := checkPrefixEvent(e, subnet, true /* discovered */); diff != "" {
- t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix, nicID)
- }
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
- t.Errorf("auto-gen addr mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", addr, nicID)
- }
-
- // Enabling or disabling forwarding should not invalidate discovered prefixes
- // or routers, or auto-generated address.
- for _, forwarding := range [...]bool{true, false} {
- t.Run(fmt.Sprintf("Transition forwarding to %t", forwarding), func(t *testing.T) {
- if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil {
- t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
- }
- select {
- case e := <-ndpDisp.offLinkRouteC:
- t.Errorf("unexpected off-link route event = %#v", e)
- default:
- }
- select {
- case e := <-ndpDisp.prefixC:
- t.Errorf("unexpected prefix event = %#v", e)
- default:
- }
- select {
- case e := <-ndpDisp.autoGenAddrC:
- t.Errorf("unexpected auto-gen addr event = %#v", e)
- default:
- }
- })
- }
-}
-
-func TestCleanupNDPState(t *testing.T) {
- const (
- lifetimeSeconds = 5
- maxRouterAndPrefixEvents = 4
- nicID1 = 1
- nicID2 = 2
- )
-
- prefix1, subnet1, e1Addr1 := prefixSubnetAddr(0, linkAddr1)
- prefix2, subnet2, e1Addr2 := prefixSubnetAddr(1, linkAddr1)
- e2Addr1 := addrForSubnet(subnet1, linkAddr2)
- e2Addr2 := addrForSubnet(subnet2, linkAddr2)
- llAddrWithPrefix1 := tcpip.AddressWithPrefix{
- Address: llAddr1,
- PrefixLen: 64,
- }
- llAddrWithPrefix2 := tcpip.AddressWithPrefix{
- Address: llAddr2,
- PrefixLen: 64,
- }
-
- tests := []struct {
- name string
- cleanupFn func(t *testing.T, s *stack.Stack)
- keepAutoGenLinkLocal bool
- maxAutoGenAddrEvents int
- skipFinalAddrCheck bool
- }{
- // A NIC should cleanup all NDP state when it is disabled.
- {
- name: "Disable NIC",
- cleanupFn: func(t *testing.T, s *stack.Stack) {
- t.Helper()
-
- if err := s.DisableNIC(nicID1); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
- }
- if err := s.DisableNIC(nicID2); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
- }
- },
- keepAutoGenLinkLocal: false,
- maxAutoGenAddrEvents: 6,
- },
-
- // A NIC should cleanup all NDP state when it is removed.
- {
- name: "Remove NIC",
- cleanupFn: func(t *testing.T, s *stack.Stack) {
- t.Helper()
-
- if err := s.RemoveNIC(nicID1); err != nil {
- t.Fatalf("s.RemoveNIC(%d): %s", nicID1, err)
- }
- if err := s.RemoveNIC(nicID2); err != nil {
- t.Fatalf("s.RemoveNIC(%d): %s", nicID2, err)
- }
- },
- keepAutoGenLinkLocal: false,
- maxAutoGenAddrEvents: 6,
- // The NICs are removed so we can't check their addresses after calling
- // stopFn.
- skipFinalAddrCheck: true,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- offLinkRouteC: make(chan ndpOffLinkRouteEvent, maxRouterAndPrefixEvents),
- prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents),
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents),
- }
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenLinkLocal: true,
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- DiscoverDefaultRouters: true,
- DiscoverOnLinkPrefixes: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- })},
- Clock: clock,
- })
-
- expectOffLinkRouteEvent := func() (bool, ndpOffLinkRouteEvent) {
- select {
- case e := <-ndpDisp.offLinkRouteC:
- return true, e
- default:
- }
-
- return false, ndpOffLinkRouteEvent{}
- }
-
- expectPrefixEvent := func() (bool, ndpPrefixEvent) {
- select {
- case e := <-ndpDisp.prefixC:
- return true, e
- default:
- }
-
- return false, ndpPrefixEvent{}
- }
-
- expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) {
- select {
- case e := <-ndpDisp.autoGenAddrC:
- return true, e
- default:
- }
-
- return false, ndpAutoGenAddrEvent{}
- }
-
- e1 := channel.New(0, 1280, linkAddr1)
- if err := s.CreateNIC(nicID1, e1); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err)
- }
- // We have other tests that make sure we receive the *correct* events
- // on normal discovery of routers/prefixes, and auto-generated
- // addresses. Here we just make sure we get an event and let other tests
- // handle the correctness check.
- expectAutoGenAddrEvent()
-
- e2 := channel.New(0, 1280, linkAddr2)
- if err := s.CreateNIC(nicID2, e2); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err)
- }
- expectAutoGenAddrEvent()
-
- // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and
- // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from
- // llAddr4) to discover multiple routers and prefixes, and auto-gen
- // multiple addresses.
-
- e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectOffLinkRouteEvent(); !ok {
- t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID1)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1)
- }
-
- e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectOffLinkRouteEvent(); !ok {
- t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr4, nicID1)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1)
- }
-
- e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectOffLinkRouteEvent(); !ok {
- t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID2)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2)
- }
-
- e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectOffLinkRouteEvent(); !ok {
- t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr4, nicID2)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2)
- }
-
- // We should have the auto-generated addresses added.
- nicinfo := s.NICInfo()
- nic1Addrs := nicinfo[nicID1].ProtocolAddresses
- nic2Addrs := nicinfo[nicID2].ProtocolAddresses
- if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
- }
- if !containsV6Addr(nic1Addrs, e1Addr1) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
- }
- if !containsV6Addr(nic1Addrs, e1Addr2) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
- }
- if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
- }
- if !containsV6Addr(nic2Addrs, e2Addr1) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
- }
- if !containsV6Addr(nic2Addrs, e2Addr2) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
- }
-
- // We can't proceed any further if we already failed the test (missing
- // some discovery/auto-generated address events or addresses).
- if t.Failed() {
- t.FailNow()
- }
-
- test.cleanupFn(t, s)
-
- // Collect invalidation events after having NDP state cleaned up.
- gotOffLinkRouteEvents := make(map[ndpOffLinkRouteEvent]int)
- for i := 0; i < maxRouterAndPrefixEvents; i++ {
- ok, e := expectOffLinkRouteEvent()
- if !ok {
- t.Errorf("expected %d off-link route events after becoming a router; got = %d", maxRouterAndPrefixEvents, i)
- break
- }
- gotOffLinkRouteEvents[e]++
- }
- gotPrefixEvents := make(map[ndpPrefixEvent]int)
- for i := 0; i < maxRouterAndPrefixEvents; i++ {
- ok, e := expectPrefixEvent()
- if !ok {
- t.Errorf("expected %d prefix events after becoming a router; got = %d", maxRouterAndPrefixEvents, i)
- break
- }
- gotPrefixEvents[e]++
- }
- gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int)
- for i := 0; i < test.maxAutoGenAddrEvents; i++ {
- ok, e := expectAutoGenAddrEvent()
- if !ok {
- t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", test.maxAutoGenAddrEvents, i)
- break
- }
- gotAutoGenAddrEvents[e]++
- }
-
- // No need to proceed any further if we already failed the test (missing
- // some invalidation events).
- if t.Failed() {
- t.FailNow()
- }
-
- expectedOffLinkRouteEvents := map[ndpOffLinkRouteEvent]int{
- {nicID: nicID1, subnet: header.IPv6EmptySubnet, router: llAddr3, updated: false}: 1,
- {nicID: nicID1, subnet: header.IPv6EmptySubnet, router: llAddr4, updated: false}: 1,
- {nicID: nicID2, subnet: header.IPv6EmptySubnet, router: llAddr3, updated: false}: 1,
- {nicID: nicID2, subnet: header.IPv6EmptySubnet, router: llAddr4, updated: false}: 1,
- }
- if diff := cmp.Diff(expectedOffLinkRouteEvents, gotOffLinkRouteEvents); diff != "" {
- t.Errorf("off-link route events mismatch (-want +got):\n%s", diff)
- }
- expectedPrefixEvents := map[ndpPrefixEvent]int{
- {nicID: nicID1, prefix: subnet1, discovered: false}: 1,
- {nicID: nicID1, prefix: subnet2, discovered: false}: 1,
- {nicID: nicID2, prefix: subnet1, discovered: false}: 1,
- {nicID: nicID2, prefix: subnet2, discovered: false}: 1,
- }
- if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" {
- t.Errorf("prefix events mismatch (-want +got):\n%s", diff)
- }
- expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{
- {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1,
- {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1,
- {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1,
- {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1,
- }
-
- if !test.keepAutoGenLinkLocal {
- expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID1, addr: llAddrWithPrefix1, eventType: invalidatedAddr}] = 1
- expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID2, addr: llAddrWithPrefix2, eventType: invalidatedAddr}] = 1
- }
-
- if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" {
- t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff)
- }
-
- if !test.skipFinalAddrCheck {
- // Make sure the auto-generated addresses got removed.
- nicinfo = s.NICInfo()
- nic1Addrs = nicinfo[nicID1].ProtocolAddresses
- nic2Addrs = nicinfo[nicID2].ProtocolAddresses
- if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal {
- if test.keepAutoGenLinkLocal {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
- } else {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
- }
- }
- if containsV6Addr(nic1Addrs, e1Addr1) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
- }
- if containsV6Addr(nic1Addrs, e1Addr2) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
- }
- if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal {
- if test.keepAutoGenLinkLocal {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
- } else {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
- }
- }
- if containsV6Addr(nic2Addrs, e2Addr1) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
- }
- if containsV6Addr(nic2Addrs, e2Addr2) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
- }
- }
-
- // Should not get any more events (invalidation timers should have been
- // cancelled when the NDP state was cleaned up).
- clock.Advance(lifetimeSeconds * time.Second)
- select {
- case <-ndpDisp.offLinkRouteC:
- t.Error("unexpected off-link route event")
- default:
- }
- select {
- case <-ndpDisp.prefixC:
- t.Error("unexpected prefix event")
- default:
- }
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Error("unexpected auto-generated address event")
- default:
- }
- })
- }
-}
-
-// TestDHCPv6ConfigurationFromNDPDA tests that the NDPDispatcher is properly
-// informed when new information about what configurations are available via
-// DHCPv6 is learned.
-func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
- const nicID = 1
-
- ndpDisp := ndpDispatcher{
- dhcpv6ConfigurationC: make(chan ndpDHCPv6Event, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- },
- NDPDisp: &ndpDisp,
- })},
- })
-
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- expectDHCPv6Event := func(configuration ipv6.DHCPv6ConfigurationFromNDPRA) {
- t.Helper()
- select {
- case e := <-ndpDisp.dhcpv6ConfigurationC:
- if diff := cmp.Diff(ndpDHCPv6Event{nicID: nicID, configuration: configuration}, e, cmp.AllowUnexported(e)); diff != "" {
- t.Errorf("dhcpv6 event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected DHCPv6 configuration event")
- }
- }
-
- expectNoDHCPv6Event := func() {
- t.Helper()
- select {
- case <-ndpDisp.dhcpv6ConfigurationC:
- t.Fatal("unexpected DHCPv6 configuration event")
- default:
- }
- }
-
- // Even if the first RA reports no DHCPv6 configurations are available, the
- // dispatcher should get an event.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
- expectDHCPv6Event(ipv6.DHCPv6NoConfiguration)
- // Receiving the same update again should not result in an event to the
- // dispatcher.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
- expectNoDHCPv6Event()
-
- // Receive an RA that updates the DHCPv6 configuration to Other
- // Configurations.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
- expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
- expectNoDHCPv6Event()
-
- // Receive an RA that updates the DHCPv6 configuration to Managed Address.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
- expectDHCPv6Event(ipv6.DHCPv6ManagedAddress)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
- expectNoDHCPv6Event()
-
- // Receive an RA that updates the DHCPv6 configuration to none.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
- expectDHCPv6Event(ipv6.DHCPv6NoConfiguration)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
- expectNoDHCPv6Event()
-
- // Receive an RA that updates the DHCPv6 configuration to Managed Address.
- //
- // Note, when the M flag is set, the O flag is redundant.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
- expectDHCPv6Event(ipv6.DHCPv6ManagedAddress)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
- expectNoDHCPv6Event()
- // Even though the DHCPv6 flags are different, the effective configuration is
- // the same so we should not receive a new event.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
- expectNoDHCPv6Event()
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
- expectNoDHCPv6Event()
-
- // Receive an RA that updates the DHCPv6 configuration to Other
- // Configurations.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
- expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
- expectNoDHCPv6Event()
-
- // Cycling the NIC should cause the last DHCPv6 configuration to be cleared.
- if err := s.DisableNIC(nicID); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
- }
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
- }
-
- // Receive an RA that updates the DHCPv6 configuration to Other
- // Configurations.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
- expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
- expectNoDHCPv6Event()
-}
-
-var _ rand.Source = (*savingRandSource)(nil)
-
-type savingRandSource struct {
- s rand.Source
-
- lastInt63 int64
-}
-
-func (d *savingRandSource) Int63() int64 {
- i := d.s.Int63()
- d.lastInt63 = i
- return i
-}
-func (d *savingRandSource) Seed(seed int64) {
- d.s.Seed(seed)
-}
-
-// TestRouterSolicitation tests the initial Router Solicitations that are sent
-// when a NIC newly becomes enabled.
-func TestRouterSolicitation(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- linkHeaderLen uint16
- linkAddr tcpip.LinkAddress
- nicAddr tcpip.Address
- expectedSrcAddr tcpip.Address
- expectedNDPOpts []header.NDPOption
- maxRtrSolicit uint8
- rtrSolicitInt time.Duration
- effectiveRtrSolicitInt time.Duration
- maxRtrSolicitDelay time.Duration
- effectiveMaxRtrSolicitDelay time.Duration
- }{
- {
- name: "Single RS with 2s delay and interval",
- expectedSrcAddr: header.IPv6Any,
- maxRtrSolicit: 1,
- rtrSolicitInt: 2 * time.Second,
- effectiveRtrSolicitInt: 2 * time.Second,
- maxRtrSolicitDelay: 2 * time.Second,
- effectiveMaxRtrSolicitDelay: 2 * time.Second,
- },
- {
- name: "Single RS with 4s delay and interval",
- expectedSrcAddr: header.IPv6Any,
- maxRtrSolicit: 1,
- rtrSolicitInt: 4 * time.Second,
- effectiveRtrSolicitInt: 4 * time.Second,
- maxRtrSolicitDelay: 4 * time.Second,
- effectiveMaxRtrSolicitDelay: 4 * time.Second,
- },
- {
- name: "Two RS with delay",
- linkHeaderLen: 1,
- nicAddr: llAddr1,
- expectedSrcAddr: llAddr1,
- maxRtrSolicit: 2,
- rtrSolicitInt: 2 * time.Second,
- effectiveRtrSolicitInt: 2 * time.Second,
- maxRtrSolicitDelay: 500 * time.Millisecond,
- effectiveMaxRtrSolicitDelay: 500 * time.Millisecond,
- },
- {
- name: "Single RS without delay",
- linkHeaderLen: 2,
- linkAddr: linkAddr1,
- nicAddr: llAddr1,
- expectedSrcAddr: llAddr1,
- expectedNDPOpts: []header.NDPOption{
- header.NDPSourceLinkLayerAddressOption(linkAddr1),
- },
- maxRtrSolicit: 1,
- rtrSolicitInt: 2 * time.Second,
- effectiveRtrSolicitInt: 2 * time.Second,
- maxRtrSolicitDelay: 0,
- effectiveMaxRtrSolicitDelay: 0,
- },
- {
- name: "Two RS without delay and invalid zero interval",
- linkHeaderLen: 3,
- linkAddr: linkAddr1,
- expectedSrcAddr: header.IPv6Any,
- maxRtrSolicit: 2,
- rtrSolicitInt: 0,
- effectiveRtrSolicitInt: 4 * time.Second,
- maxRtrSolicitDelay: 0,
- effectiveMaxRtrSolicitDelay: 0,
- },
- {
- name: "Three RS without delay",
- linkAddr: linkAddr1,
- expectedSrcAddr: header.IPv6Any,
- maxRtrSolicit: 3,
- rtrSolicitInt: 500 * time.Millisecond,
- effectiveRtrSolicitInt: 500 * time.Millisecond,
- maxRtrSolicitDelay: 0,
- effectiveMaxRtrSolicitDelay: 0,
- },
- {
- name: "Two RS with invalid negative delay",
- linkAddr: linkAddr1,
- expectedSrcAddr: header.IPv6Any,
- maxRtrSolicit: 2,
- rtrSolicitInt: time.Second,
- effectiveRtrSolicitInt: time.Second,
- maxRtrSolicitDelay: -3 * time.Second,
- effectiveMaxRtrSolicitDelay: time.Second,
- },
- }
-
- subTests := []struct {
- name string
- handleRAs ipv6.HandleRAsConfiguration
- afterFirstRS func(*testing.T, *stack.Stack)
- }{
- {
- name: "Handle RAs when forwarding disabled",
- handleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- afterFirstRS: func(*testing.T, *stack.Stack) {},
- },
-
- // Enabling forwarding when RAs are always configured to be handled
- // should not stop router solicitations.
- {
- name: "Handle RAs always",
- handleRAs: ipv6.HandlingRAsAlwaysEnabled,
- afterFirstRS: func(t *testing.T, s *stack.Stack) {
- if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
- t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
- }
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- for _, subTest := range subTests {
- t.Run(subTest.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- e := channelLinkWithHeaderLength{
- Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
- headerLength: test.linkHeaderLen,
- }
- e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- waitForPkt := func(timeout time.Duration) {
- t.Helper()
-
- clock.Advance(timeout)
- p, ok := e.Read()
- if !ok {
- t.Fatal("expected router solicitation packet")
- }
-
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
-
- // Make sure the right remote link address is used.
- if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want {
- t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
- }
-
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(test.expectedSrcAddr),
- checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
- )
-
- if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
- t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
- }
- }
- waitForNothing := func(timeout time.Duration) {
- t.Helper()
-
- clock.Advance(timeout)
- if p, ok := e.Read(); ok {
- t.Fatalf("unexpectedly got a packet = %#v", p)
- }
- }
- randSource := savingRandSource{
- s: rand.NewSource(time.Now().UnixNano()),
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: subTest.handleRAs,
- MaxRtrSolicitations: test.maxRtrSolicit,
- RtrSolicitationInterval: test.rtrSolicitInt,
- MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
- },
- })},
- Clock: clock,
- RandSource: &randSource,
- })
-
- opts := stack.NICOptions{Disabled: true}
- if err := s.CreateNICWithOptions(nicID, &e, opts); err != nil {
- t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, opts, err)
- }
-
- if addr := test.nicAddr; addr != "" {
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: addr.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- }
-
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("EnableNIC(%d): %s", nicID, err)
- }
-
- // Make sure each RS is sent at the right time.
- remaining := test.maxRtrSolicit
- if remaining != 0 {
- maxRtrSolicitDelay := test.maxRtrSolicitDelay
- if maxRtrSolicitDelay < 0 {
- maxRtrSolicitDelay = ipv6.DefaultNDPConfigurations().MaxRtrSolicitationDelay
- }
- var actualRtrSolicitDelay time.Duration
- if maxRtrSolicitDelay != 0 {
- actualRtrSolicitDelay = time.Duration(randSource.lastInt63) % maxRtrSolicitDelay
- }
- waitForPkt(actualRtrSolicitDelay)
- remaining--
- }
-
- subTest.afterFirstRS(t, s)
-
- for ; remaining != 0; remaining-- {
- if test.effectiveRtrSolicitInt != 0 {
- waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond)
- waitForPkt(time.Nanosecond)
- } else {
- waitForPkt(0)
- }
- }
-
- // Make sure no more RS.
- if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
- waitForNothing(test.effectiveRtrSolicitInt)
- } else {
- waitForNothing(test.effectiveMaxRtrSolicitDelay)
- }
-
- if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
- t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
- }
- })
- }
- })
- }
-}
-
-func TestStopStartSolicitingRouters(t *testing.T) {
- const nicID = 1
- const delay = 0
- const interval = 500 * time.Millisecond
- const maxRtrSolicitations = 3
-
- tests := []struct {
- name string
- startFn func(t *testing.T, s *stack.Stack)
- // first is used to tell stopFn that it is being called for the first time
- // after router solicitations were last enabled.
- stopFn func(t *testing.T, s *stack.Stack, first bool)
- }{
- // Tests that when forwarding is enabled or disabled, router solicitations
- // are stopped or started, respectively.
- {
- name: "Enable and disable forwarding",
- startFn: func(t *testing.T, s *stack.Stack) {
- t.Helper()
-
- if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, false); err != nil {
- t.Fatalf("SetForwardingDefaultAndAllNICs(%d, false): %s", ipv6.ProtocolNumber, err)
- }
- },
- stopFn: func(t *testing.T, s *stack.Stack, _ bool) {
- t.Helper()
-
- if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
- t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
- }
- },
- },
-
- // Tests that when a NIC is enabled or disabled, router solicitations
- // are started or stopped, respectively.
- {
- name: "Enable and disable NIC",
- startFn: func(t *testing.T, s *stack.Stack) {
- t.Helper()
-
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
- }
- },
- stopFn: func(t *testing.T, s *stack.Stack, _ bool) {
- t.Helper()
-
- if err := s.DisableNIC(nicID); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
- }
- },
- },
-
- // Tests that when a NIC is removed, router solicitations are stopped. We
- // cannot start router solications on a removed NIC.
- {
- name: "Remove NIC",
- stopFn: func(t *testing.T, s *stack.Stack, first bool) {
- t.Helper()
-
- // Only try to remove the NIC the first time stopFn is called since it's
- // impossible to remove an already removed NIC.
- if !first {
- return
- }
-
- if err := s.RemoveNIC(nicID); err != nil {
- t.Fatalf("s.RemoveNIC(%d): %s", nicID, err)
- }
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e := channel.New(maxRtrSolicitations, 1280, linkAddr1)
- waitForPkt := func(clock *faketime.ManualClock, timeout time.Duration) {
- t.Helper()
-
- clock.Advance(timeout)
- p, ok := e.Read()
- if !ok {
- t.Fatal("timed out waiting for packet")
- }
-
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS())
- }
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- MaxRtrSolicitations: maxRtrSolicitations,
- RtrSolicitationInterval: interval,
- MaxRtrSolicitationDelay: delay,
- },
- })},
- Clock: clock,
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- // Stop soliciting routers.
- test.stopFn(t, s, true /* first */)
- clock.Advance(delay)
- if _, ok := e.Read(); ok {
- // A single RS may have been sent before solicitations were stopped.
- clock.Advance(interval)
- if _, ok = e.Read(); ok {
- t.Fatal("should not have sent more than one RS message")
- }
- }
-
- // Stopping router solicitations after it has already been stopped should
- // do nothing.
- test.stopFn(t, s, false /* first */)
- clock.Advance(delay)
- if _, ok := e.Read(); ok {
- t.Fatal("unexpectedly got a packet after router solicitation has been stopepd")
- }
-
- // If test.startFn is nil, there is no way to restart router solications.
- if test.startFn == nil {
- return
- }
-
- // Start soliciting routers.
- test.startFn(t, s)
- waitForPkt(clock, delay)
- waitForPkt(clock, interval)
- waitForPkt(clock, interval)
- clock.Advance(interval)
- if _, ok := e.Read(); ok {
- t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
- }
-
- // Starting router solicitations after it has already completed should do
- // nothing.
- test.startFn(t, s)
- clock.Advance(interval)
- if _, ok := e.Read(); ok {
- t.Fatal("unexpectedly got a packet after finishing router solicitations")
- }
- })
- }
-}
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
deleted file mode 100644
index 7de25fe37..000000000
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ /dev/null
@@ -1,1584 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package stack
-
-import (
- "fmt"
- "math"
- "math/rand"
- "strings"
- "sync"
- "sync/atomic"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "github.com/google/go-cmp/cmp/cmpopts"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
-)
-
-const (
- // entryStoreSize is the default number of entries that will be generated and
- // added to the entry store. This number needs to be larger than the size of
- // the neighbor cache to give ample opportunity for verifying behavior during
- // cache overflows. Four times the size of the neighbor cache allows for
- // three complete cache overflows.
- entryStoreSize = 4 * neighborCacheSize
-
- // typicalLatency is the typical latency for an ARP or NDP packet to travel
- // to a router and back.
- typicalLatency = time.Millisecond
-
- // testEntryBroadcastAddr is a special address that indicates a packet should
- // be sent to all nodes.
- testEntryBroadcastAddr = tcpip.Address("broadcast")
-
- // testEntryBroadcastLinkAddr is a special link address sent back to
- // multicast neighbor probes.
- testEntryBroadcastLinkAddr = tcpip.LinkAddress("mac_broadcast")
-
- // infiniteDuration indicates that a task will not occur in our lifetime.
- infiniteDuration = time.Duration(math.MaxInt64)
-)
-
-// unorderedEventsDiffOpts returns options passed to cmp.Diff to sort slices of
-// events for cases where ordering must be ignored.
-func unorderedEventsDiffOpts() []cmp.Option {
- return []cmp.Option{
- cmpopts.SortSlices(func(a, b testEntryEventInfo) bool {
- return strings.Compare(string(a.Entry.Addr), string(b.Entry.Addr)) < 0
- }),
- }
-}
-
-// unorderedEntriesDiffOpts returns options passed to cmp.Diff to sort slices of
-// entries for cases where ordering must be ignored.
-func unorderedEntriesDiffOpts() []cmp.Option {
- return []cmp.Option{
- cmpopts.SortSlices(func(a, b NeighborEntry) bool {
- return strings.Compare(string(a.Addr), string(b.Addr)) < 0
- }),
- }
-}
-
-func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *testNeighborResolver {
- config.resetInvalidFields()
- rng := rand.New(rand.NewSource(time.Now().UnixNano()))
- linkRes := &testNeighborResolver{
- clock: clock,
- entries: newTestEntryStore(),
- delay: typicalLatency,
- }
- linkRes.neigh.init(&nic{
- stack: &Stack{
- clock: clock,
- nudDisp: nudDisp,
- nudConfigs: config,
- randomGenerator: rng,
- },
- id: 1,
- stats: makeNICStats(tcpip.NICStats{}.FillIn()),
- }, linkRes)
- return linkRes
-}
-
-// testEntryStore contains a set of IP to NeighborEntry mappings.
-type testEntryStore struct {
- mu sync.RWMutex
- entriesMap map[tcpip.Address]NeighborEntry
-}
-
-func toAddress(i uint16) tcpip.Address {
- return tcpip.Address([]byte{
- 1,
- 0,
- byte(i >> 8),
- byte(i),
- })
-}
-
-func toLinkAddress(i uint16) tcpip.LinkAddress {
- return tcpip.LinkAddress([]byte{
- 1,
- 0,
- 0,
- 0,
- byte(i >> 8),
- byte(i),
- })
-}
-
-// newTestEntryStore returns a testEntryStore pre-populated with entries.
-func newTestEntryStore() *testEntryStore {
- store := &testEntryStore{
- entriesMap: make(map[tcpip.Address]NeighborEntry),
- }
- for i := uint16(0); i < entryStoreSize; i++ {
- addr := toAddress(i)
- linkAddr := toLinkAddress(i)
-
- store.entriesMap[addr] = NeighborEntry{
- Addr: addr,
- LinkAddr: linkAddr,
- }
- }
- return store
-}
-
-// size returns the number of entries in the store.
-func (s *testEntryStore) size() uint16 {
- s.mu.RLock()
- defer s.mu.RUnlock()
- return uint16(len(s.entriesMap))
-}
-
-// entry returns the entry at index i. Returns an empty entry and false if i is
-// out of bounds.
-func (s *testEntryStore) entry(i uint16) (NeighborEntry, bool) {
- return s.entryByAddr(toAddress(i))
-}
-
-// entryByAddr returns the entry matching addr for situations when the index is
-// not available. Returns an empty entry and false if no entries match addr.
-func (s *testEntryStore) entryByAddr(addr tcpip.Address) (NeighborEntry, bool) {
- s.mu.RLock()
- defer s.mu.RUnlock()
- entry, ok := s.entriesMap[addr]
- return entry, ok
-}
-
-// entries returns all entries in the store.
-func (s *testEntryStore) entries() []NeighborEntry {
- entries := make([]NeighborEntry, 0, len(s.entriesMap))
- s.mu.RLock()
- defer s.mu.RUnlock()
- for i := uint16(0); i < entryStoreSize; i++ {
- addr := toAddress(i)
- if entry, ok := s.entriesMap[addr]; ok {
- entries = append(entries, entry)
- }
- }
- return entries
-}
-
-// set modifies the link addresses of an entry.
-func (s *testEntryStore) set(i uint16, linkAddr tcpip.LinkAddress) {
- addr := toAddress(i)
- s.mu.Lock()
- defer s.mu.Unlock()
- if entry, ok := s.entriesMap[addr]; ok {
- entry.LinkAddr = linkAddr
- s.entriesMap[addr] = entry
- }
-}
-
-// testNeighborResolver implements LinkAddressResolver to emulate sending a
-// neighbor probe.
-type testNeighborResolver struct {
- clock tcpip.Clock
- neigh neighborCache
- entries *testEntryStore
- delay time.Duration
- onLinkAddressRequest func()
- dropReplies bool
-}
-
-var _ LinkAddressResolver = (*testNeighborResolver)(nil)
-
-func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress) tcpip.Error {
- if !r.dropReplies {
- // Delay handling the request to emulate network latency.
- r.clock.AfterFunc(r.delay, func() {
- r.fakeRequest(targetAddr)
- })
- }
-
- // Execute post address resolution action, if available.
- if f := r.onLinkAddressRequest; f != nil {
- f()
- }
- return nil
-}
-
-// fakeRequest emulates handling a response for a link address request.
-func (r *testNeighborResolver) fakeRequest(addr tcpip.Address) {
- if entry, ok := r.entries.entryByAddr(addr); ok {
- r.neigh.handleConfirmation(addr, entry.LinkAddr, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- }
-}
-
-func (*testNeighborResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
- if addr == testEntryBroadcastAddr {
- return testEntryBroadcastLinkAddr, true
- }
- return "", false
-}
-
-func (*testNeighborResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
- return 0
-}
-
-func TestNeighborCacheGetConfig(t *testing.T) {
- nudDisp := testNUDDispatcher{}
- c := DefaultNUDConfigurations()
- clock := faketime.NewManualClock()
- linkRes := newTestNeighborResolver(&nudDisp, c, clock)
-
- if got, want := linkRes.neigh.config(), c; got != want {
- t.Errorf("got linkRes.neigh.config() = %+v, want = %+v", got, want)
- }
-
- // No events should have been dispatched.
- nudDisp.mu.Lock()
- defer nudDisp.mu.Unlock()
- if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
-}
-
-func TestNeighborCacheSetConfig(t *testing.T) {
- nudDisp := testNUDDispatcher{}
- c := DefaultNUDConfigurations()
- clock := faketime.NewManualClock()
- linkRes := newTestNeighborResolver(&nudDisp, c, clock)
-
- c.MinRandomFactor = 1
- c.MaxRandomFactor = 1
- linkRes.neigh.setConfig(c)
-
- if got, want := linkRes.neigh.config(), c; got != want {
- t.Errorf("got linkRes.neigh.config() = %+v, want = %+v", got, want)
- }
-
- // No events should have been dispatched.
- nudDisp.mu.Lock()
- defer nudDisp.mu.Unlock()
- if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
-}
-
-func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.ManualClock, linkRes *testNeighborResolver, entry NeighborEntry, removed []NeighborEntry) error {
- var gotLinkResolutionResult LinkResolutionResult
-
- _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) {
- gotLinkResolutionResult = r
- })
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- return fmt.Errorf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
- }
-
- {
- var wantEvents []testEntryEventInfo
-
- for _, removedEntry := range removed {
- wantEvents = append(wantEvents, testEntryEventInfo{
- EventType: entryTestRemoved,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: removedEntry.Addr,
- LinkAddr: removedEntry.LinkAddr,
- State: Reachable,
- UpdatedAt: clock.Now(),
- },
- })
- }
-
- wantEvents = append(wantEvents, testEntryEventInfo{
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: "",
- State: Incomplete,
- UpdatedAt: clock.Now(),
- },
- })
-
- nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, nudDisp.mu.events)
- nudDisp.mu.events = nil
- nudDisp.mu.Unlock()
- if diff != "" {
- return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- }
-
- clock.Advance(typicalLatency)
-
- select {
- case <-ch:
- default:
- return fmt.Errorf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr)
- }
- wantLinkResolutionResult := LinkResolutionResult{LinkAddress: entry.LinkAddr, Err: nil}
- if diff := cmp.Diff(wantLinkResolutionResult, gotLinkResolutionResult); diff != "" {
- return fmt.Errorf("got link resolution result mismatch (-want +got):\n%s", diff)
- }
-
- {
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, nudDisp.mu.events)
- nudDisp.mu.events = nil
- nudDisp.mu.Unlock()
- if diff != "" {
- return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- }
-
- return nil
-}
-
-func addReachableEntry(nudDisp *testNUDDispatcher, clock *faketime.ManualClock, linkRes *testNeighborResolver, entry NeighborEntry) error {
- return addReachableEntryWithRemoved(nudDisp, clock, linkRes, entry, nil /* removed */)
-}
-
-func TestNeighborCacheEntry(t *testing.T) {
- c := DefaultNUDConfigurations()
- nudDisp := testNUDDispatcher{}
- clock := faketime.NewManualClock()
- linkRes := newTestNeighborResolver(&nudDisp, c, clock)
-
- entry, ok := linkRes.entries.entry(0)
- if !ok {
- t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ")
- }
- if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil {
- t.Fatalf("addReachableEntry(...) = %s", err)
- }
-
- if _, _, err := linkRes.neigh.entry(entry.Addr, "", nil); err != nil {
- t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err)
- }
-
- // No more events should have been dispatched.
- nudDisp.mu.Lock()
- defer nudDisp.mu.Unlock()
- if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
-}
-
-func TestNeighborCacheRemoveEntry(t *testing.T) {
- config := DefaultNUDConfigurations()
-
- nudDisp := testNUDDispatcher{}
- clock := faketime.NewManualClock()
- linkRes := newTestNeighborResolver(&nudDisp, config, clock)
-
- entry, ok := linkRes.entries.entry(0)
- if !ok {
- t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ")
- }
- if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil {
- t.Fatalf("addReachableEntry(...) = %s", err)
- }
-
- linkRes.neigh.removeEntry(entry.Addr)
-
- {
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestRemoved,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, nudDisp.mu.events)
- nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- }
-
- {
- _, _, err := linkRes.neigh.entry(entry.Addr, "", nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
- }
- }
-}
-
-type testContext struct {
- clock *faketime.ManualClock
- linkRes *testNeighborResolver
- nudDisp *testNUDDispatcher
-}
-
-func newTestContext(c NUDConfigurations) testContext {
- nudDisp := &testNUDDispatcher{}
- clock := faketime.NewManualClock()
- linkRes := newTestNeighborResolver(nudDisp, c, clock)
-
- return testContext{
- clock: clock,
- linkRes: linkRes,
- nudDisp: nudDisp,
- }
-}
-
-type overflowOptions struct {
- startAtEntryIndex uint16
- wantStaticEntries []NeighborEntry
-}
-
-func (c *testContext) overflowCache(opts overflowOptions) error {
- // Fill the neighbor cache to capacity to verify the LRU eviction strategy is
- // working properly after the entry removal.
- for i := opts.startAtEntryIndex; i < c.linkRes.entries.size(); i++ {
- var removedEntries []NeighborEntry
-
- // When beyond the full capacity, the cache will evict an entry as per the
- // LRU eviction strategy. Note that the number of static entries should not
- // affect the total number of dynamic entries that can be added.
- if i >= neighborCacheSize+opts.startAtEntryIndex {
- removedEntry, ok := c.linkRes.entries.entry(i - neighborCacheSize)
- if !ok {
- return fmt.Errorf("got linkRes.entries.entry(%d) = _, false, want = true", i-neighborCacheSize)
- }
- removedEntries = append(removedEntries, removedEntry)
- }
-
- entry, ok := c.linkRes.entries.entry(i)
- if !ok {
- return fmt.Errorf("got c.linkRes.entries.entry(%d) = _, false, want = true", i)
- }
- if err := addReachableEntryWithRemoved(c.nudDisp, c.clock, c.linkRes, entry, removedEntries); err != nil {
- return fmt.Errorf("addReachableEntryWithRemoved(...) = %s", err)
- }
- }
-
- // Expect to find only the most recent entries. The order of entries reported
- // by entries() is nondeterministic, so entries have to be sorted before
- // comparison.
- wantUnorderedEntries := opts.wantStaticEntries
- for i := c.linkRes.entries.size() - neighborCacheSize; i < c.linkRes.entries.size(); i++ {
- entry, ok := c.linkRes.entries.entry(i)
- if !ok {
- return fmt.Errorf("got c.linkRes.entries.entry(%d) = _, false, want = true", i)
- }
- durationReachableNanos := time.Duration(c.linkRes.entries.size()-i-1) * typicalLatency
- wantEntry := NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAt: c.clock.Now().Add(-durationReachableNanos),
- }
- wantUnorderedEntries = append(wantUnorderedEntries, wantEntry)
- }
-
- if diff := cmp.Diff(wantUnorderedEntries, c.linkRes.neigh.entries(), unorderedEntriesDiffOpts()...); diff != "" {
- return fmt.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff)
- }
-
- // No more events should have been dispatched.
- c.nudDisp.mu.Lock()
- defer c.nudDisp.mu.Unlock()
- if diff := cmp.Diff([]testEntryEventInfo(nil), c.nudDisp.mu.events); diff != "" {
- return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
-
- return nil
-}
-
-// TestNeighborCacheOverflow verifies that the LRU cache eviction strategy
-// respects the dynamic entry count.
-func TestNeighborCacheOverflow(t *testing.T) {
- config := DefaultNUDConfigurations()
- // Stay in Reachable so the cache can overflow
- config.BaseReachableTime = infiniteDuration
- config.MinRandomFactor = 1
- config.MaxRandomFactor = 1
-
- c := newTestContext(config)
- opts := overflowOptions{
- startAtEntryIndex: 0,
- }
- if err := c.overflowCache(opts); err != nil {
- t.Errorf("c.overflowCache(%+v): %s", opts, err)
- }
-}
-
-// TestNeighborCacheRemoveEntryThenOverflow verifies that the LRU cache
-// eviction strategy respects the dynamic entry count when an entry is removed.
-func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
- config := DefaultNUDConfigurations()
- // Stay in Reachable so the cache can overflow
- config.BaseReachableTime = infiniteDuration
- config.MinRandomFactor = 1
- config.MaxRandomFactor = 1
-
- c := newTestContext(config)
-
- // Add a dynamic entry
- entry, ok := c.linkRes.entries.entry(0)
- if !ok {
- t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ")
- }
- if err := addReachableEntry(c.nudDisp, c.clock, c.linkRes, entry); err != nil {
- t.Fatalf("addReachableEntry(...) = %s", err)
- }
-
- // Remove the entry
- c.linkRes.neigh.removeEntry(entry.Addr)
-
- {
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestRemoved,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAt: c.clock.Now(),
- },
- },
- }
- c.nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, c.nudDisp.mu.events)
- c.nudDisp.mu.events = nil
- c.nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- }
-
- opts := overflowOptions{
- startAtEntryIndex: 0,
- }
- if err := c.overflowCache(opts); err != nil {
- t.Errorf("c.overflowCache(%+v): %s", opts, err)
- }
-}
-
-// TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress verifies that
-// adding a duplicate static entry with the same link address does not dispatch
-// any events.
-func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) {
- config := DefaultNUDConfigurations()
- c := newTestContext(config)
-
- // Add a static entry
- entry, ok := c.linkRes.entries.entry(0)
- if !ok {
- t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ")
- }
- staticLinkAddr := entry.LinkAddr + "static"
- c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
-
- {
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
- UpdatedAt: c.clock.Now(),
- },
- },
- }
- c.nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, c.nudDisp.mu.events)
- c.nudDisp.mu.events = nil
- c.nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- }
-
- // Add a duplicate static entry with the same link address.
- c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
-
- c.nudDisp.mu.Lock()
- defer c.nudDisp.mu.Unlock()
- if diff := cmp.Diff([]testEntryEventInfo(nil), c.nudDisp.mu.events); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
-}
-
-// TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress verifies that
-// adding a duplicate static entry with a different link address dispatches a
-// change event.
-func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) {
- config := DefaultNUDConfigurations()
- c := newTestContext(config)
-
- // Add a static entry
- entry, ok := c.linkRes.entries.entry(0)
- if !ok {
- t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ")
- }
- staticLinkAddr := entry.LinkAddr + "static"
- c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
-
- {
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
- UpdatedAt: c.clock.Now(),
- },
- },
- }
- c.nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, c.nudDisp.mu.events)
- c.nudDisp.mu.events = nil
- c.nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- }
-
- // Add a duplicate entry with a different link address
- staticLinkAddr += "duplicate"
- c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
-
- {
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
- UpdatedAt: c.clock.Now(),
- },
- },
- }
- c.nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, c.nudDisp.mu.events)
- c.nudDisp.mu.events = nil
- c.nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- }
-}
-
-// TestNeighborCacheRemoveStaticEntryThenOverflow verifies that the LRU cache
-// eviction strategy respects the dynamic entry count when a static entry is
-// added then removed. In this case, the dynamic entry count shouldn't have
-// been touched.
-func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
- config := DefaultNUDConfigurations()
- // Stay in Reachable so the cache can overflow
- config.BaseReachableTime = infiniteDuration
- config.MinRandomFactor = 1
- config.MaxRandomFactor = 1
-
- c := newTestContext(config)
-
- // Add a static entry
- entry, ok := c.linkRes.entries.entry(0)
- if !ok {
- t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ")
- }
- staticLinkAddr := entry.LinkAddr + "static"
- c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
-
- {
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
- UpdatedAt: c.clock.Now(),
- },
- },
- }
- c.nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, c.nudDisp.mu.events)
- c.nudDisp.mu.events = nil
- c.nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- }
-
- // Remove the static entry that was just added
- c.linkRes.neigh.removeEntry(entry.Addr)
-
- {
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestRemoved,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
- UpdatedAt: c.clock.Now(),
- },
- },
- }
- c.nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, c.nudDisp.mu.events)
- c.nudDisp.mu.events = nil
- c.nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- }
-
- opts := overflowOptions{
- startAtEntryIndex: 0,
- }
- if err := c.overflowCache(opts); err != nil {
- t.Errorf("c.overflowCache(%+v): %s", opts, err)
- }
-}
-
-// TestNeighborCacheOverwriteWithStaticEntryThenOverflow verifies that the LRU
-// cache eviction strategy keeps count of the dynamic entry count when an entry
-// is overwritten by a static entry. Static entries should not count towards
-// the size of the LRU cache.
-func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
- config := DefaultNUDConfigurations()
- // Stay in Reachable so the cache can overflow
- config.BaseReachableTime = infiniteDuration
- config.MinRandomFactor = 1
- config.MaxRandomFactor = 1
-
- c := newTestContext(config)
-
- // Add a dynamic entry
- entry, ok := c.linkRes.entries.entry(0)
- if !ok {
- t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ")
- }
- if err := addReachableEntry(c.nudDisp, c.clock, c.linkRes, entry); err != nil {
- t.Fatalf("addReachableEntry(...) = %s", err)
- }
-
- // Override the entry with a static one using the same address
- staticLinkAddr := entry.LinkAddr + "static"
- c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
-
- {
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestRemoved,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAt: c.clock.Now(),
- },
- },
- {
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
- UpdatedAt: c.clock.Now(),
- },
- },
- }
- c.nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, c.nudDisp.mu.events)
- c.nudDisp.mu.events = nil
- c.nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- }
-
- opts := overflowOptions{
- startAtEntryIndex: 1,
- wantStaticEntries: []NeighborEntry{
- {
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
- UpdatedAt: c.clock.Now(),
- },
- },
- }
- if err := c.overflowCache(opts); err != nil {
- t.Errorf("c.overflowCache(%+v): %s", opts, err)
- }
-}
-
-func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
- config := DefaultNUDConfigurations()
- // Stay in Reachable so the cache can overflow
- config.BaseReachableTime = infiniteDuration
- config.MinRandomFactor = 1
- config.MaxRandomFactor = 1
-
- c := newTestContext(config)
-
- entry, ok := c.linkRes.entries.entry(0)
- if !ok {
- t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ")
- }
- c.linkRes.neigh.addStaticEntry(entry.Addr, entry.LinkAddr)
- e, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil)
- if err != nil {
- t.Errorf("unexpected error from c.linkRes.neigh.entry(%s, \"\", nil): %s", entry.Addr, err)
- }
- want := NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Static,
- UpdatedAt: c.clock.Now(),
- }
- if diff := cmp.Diff(want, e); diff != "" {
- t.Errorf("c.linkRes.neigh.entry(%s, \"\", nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
- }
-
- {
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Static,
- UpdatedAt: c.clock.Now(),
- },
- },
- }
- c.nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, c.nudDisp.mu.events)
- c.nudDisp.mu.events = nil
- c.nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- }
-
- opts := overflowOptions{
- startAtEntryIndex: 1,
- wantStaticEntries: []NeighborEntry{
- {
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Static,
- UpdatedAt: c.clock.Now(),
- },
- },
- }
- if err := c.overflowCache(opts); err != nil {
- t.Errorf("c.overflowCache(%+v): %s", opts, err)
- }
-}
-
-func TestNeighborCacheClear(t *testing.T) {
- config := DefaultNUDConfigurations()
-
- nudDisp := testNUDDispatcher{}
- clock := faketime.NewManualClock()
- linkRes := newTestNeighborResolver(&nudDisp, config, clock)
-
- // Add a dynamic entry.
- entry, ok := linkRes.entries.entry(0)
- if !ok {
- t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ")
- }
- if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil {
- t.Fatalf("addReachableEntry(...) = %s", err)
- }
-
- // Add a static entry.
- linkRes.neigh.addStaticEntry(entryTestAddr1, entryTestLinkAddr1)
-
- {
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Static,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, nudDisp.mu.events)
- nudDisp.mu.events = nil
- nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- }
-
- // Clear should remove both dynamic and static entries.
- linkRes.neigh.clear()
-
- // Remove events dispatched from clear() have no deterministic order so they
- // need to be sorted before comparison.
- wantUnorderedEvents := []testEntryEventInfo{
- {
- EventType: entryTestRemoved,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAt: clock.Now(),
- },
- },
- {
- EventType: entryTestRemoved,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Static,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- defer nudDisp.mu.Unlock()
- if diff := cmp.Diff(wantUnorderedEvents, nudDisp.mu.events, unorderedEventsDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
-}
-
-// TestNeighborCacheClearThenOverflow verifies that the LRU cache eviction
-// strategy keeps count of the dynamic entry count when all entries are
-// cleared.
-func TestNeighborCacheClearThenOverflow(t *testing.T) {
- config := DefaultNUDConfigurations()
- // Stay in Reachable so the cache can overflow
- config.BaseReachableTime = infiniteDuration
- config.MinRandomFactor = 1
- config.MaxRandomFactor = 1
-
- c := newTestContext(config)
-
- // Add a dynamic entry
- entry, ok := c.linkRes.entries.entry(0)
- if !ok {
- t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ")
- }
- if err := addReachableEntry(c.nudDisp, c.clock, c.linkRes, entry); err != nil {
- t.Fatalf("addReachableEntry(...) = %s", err)
- }
-
- // Clear the cache.
- c.linkRes.neigh.clear()
-
- {
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestRemoved,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAt: c.clock.Now(),
- },
- },
- }
- c.nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, c.nudDisp.mu.events)
- c.nudDisp.mu.events = nil
- c.nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- }
-
- opts := overflowOptions{
- startAtEntryIndex: 0,
- }
- if err := c.overflowCache(opts); err != nil {
- t.Errorf("c.overflowCache(%+v): %s", opts, err)
- }
-}
-
-func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
- config := DefaultNUDConfigurations()
- // Stay in Reachable so the cache can overflow
- config.BaseReachableTime = infiniteDuration
- config.MinRandomFactor = 1
- config.MaxRandomFactor = 1
-
- nudDisp := testNUDDispatcher{}
- clock := faketime.NewManualClock()
- linkRes := newTestNeighborResolver(&nudDisp, config, clock)
-
- startedAt := clock.Now()
-
- // The following logic is very similar to overflowCache, but
- // periodically refreshes the frequently used entry.
-
- // Fill the neighbor cache to capacity
- for i := uint16(0); i < neighborCacheSize; i++ {
- entry, ok := linkRes.entries.entry(i)
- if !ok {
- t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i)
- }
- if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil {
- t.Fatalf("addReachableEntry(...) = %s", err)
- }
- }
-
- frequentlyUsedEntry, ok := linkRes.entries.entry(0)
- if !ok {
- t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ")
- }
-
- // Keep adding more entries
- for i := uint16(neighborCacheSize); i < linkRes.entries.size(); i++ {
- // Periodically refresh the frequently used entry
- if i%(neighborCacheSize/2) == 0 {
- if _, _, err := linkRes.neigh.entry(frequentlyUsedEntry.Addr, "", nil); err != nil {
- t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", frequentlyUsedEntry.Addr, err)
- }
- }
-
- entry, ok := linkRes.entries.entry(i)
- if !ok {
- t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i)
- }
-
- // An entry should have been removed, as per the LRU eviction strategy
- removedEntry, ok := linkRes.entries.entry(i - neighborCacheSize + 1)
- if !ok {
- t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i-neighborCacheSize+1)
- }
-
- if err := addReachableEntryWithRemoved(&nudDisp, clock, linkRes, entry, []NeighborEntry{removedEntry}); err != nil {
- t.Fatalf("addReachableEntryWithRemoved(...) = %s", err)
- }
- }
-
- // Expect to find only the frequently used entry and the most recent entries.
- // The order of entries reported by entries() is nondeterministic, so entries
- // have to be sorted before comparison.
- wantUnsortedEntries := []NeighborEntry{
- {
- Addr: frequentlyUsedEntry.Addr,
- LinkAddr: frequentlyUsedEntry.LinkAddr,
- State: Reachable,
- // Can be inferred since the frequently used entry is the first to
- // be created and transitioned to Reachable.
- UpdatedAt: startedAt.Add(typicalLatency),
- },
- }
-
- for i := linkRes.entries.size() - neighborCacheSize + 1; i < linkRes.entries.size(); i++ {
- entry, ok := linkRes.entries.entry(i)
- if !ok {
- t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i)
- }
- durationReachableNanos := time.Duration(linkRes.entries.size()-i-1) * typicalLatency
- wantUnsortedEntries = append(wantUnsortedEntries, NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAt: clock.Now().Add(-durationReachableNanos),
- })
- }
-
- if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), unorderedEntriesDiffOpts()...); diff != "" {
- t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff)
- }
-
- // No more events should have been dispatched.
- nudDisp.mu.Lock()
- defer nudDisp.mu.Unlock()
- if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
-}
-
-func TestNeighborCacheConcurrent(t *testing.T) {
- const concurrentProcesses = 16
-
- config := DefaultNUDConfigurations()
-
- nudDisp := testNUDDispatcher{}
- clock := faketime.NewManualClock()
- linkRes := newTestNeighborResolver(&nudDisp, config, clock)
-
- storeEntries := linkRes.entries.entries()
- for _, entry := range storeEntries {
- var wg sync.WaitGroup
- for r := 0; r < concurrentProcesses; r++ {
- wg.Add(1)
- go func(entry NeighborEntry) {
- defer wg.Done()
- switch e, _, err := linkRes.neigh.entry(entry.Addr, "", nil); err.(type) {
- case nil, *tcpip.ErrWouldBlock:
- default:
- t.Errorf("got linkRes.neigh.entry(%s, '', nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, &tcpip.ErrWouldBlock{})
- }
- }(entry)
- }
-
- // Wait for all goroutines to send a request
- wg.Wait()
-
- // Process all the requests for a single entry concurrently
- clock.Advance(typicalLatency)
- }
-
- // All goroutines add in the same order and add more values than can fit in
- // the cache. Our eviction strategy requires that the last entries are
- // present, up to the size of the neighbor cache, and the rest are missing.
- // The order of entries reported by entries() is nondeterministic, so entries
- // have to be sorted before comparison.
- var wantUnsortedEntries []NeighborEntry
- for i := linkRes.entries.size() - neighborCacheSize; i < linkRes.entries.size(); i++ {
- entry, ok := linkRes.entries.entry(i)
- if !ok {
- t.Errorf("got linkRes.entries.entry(%d) = _, false, want = true", i)
- }
- durationReachableNanos := time.Duration(linkRes.entries.size()-i-1) * typicalLatency
- wantUnsortedEntries = append(wantUnsortedEntries, NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAt: clock.Now().Add(-durationReachableNanos),
- })
- }
-
- if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), unorderedEntriesDiffOpts()...); diff != "" {
- t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff)
- }
-}
-
-func TestNeighborCacheReplace(t *testing.T) {
- config := DefaultNUDConfigurations()
-
- nudDisp := testNUDDispatcher{}
- clock := faketime.NewManualClock()
- linkRes := newTestNeighborResolver(&nudDisp, config, clock)
-
- entry, ok := linkRes.entries.entry(0)
- if !ok {
- t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ")
- }
- if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil {
- t.Fatalf("addReachableEntry(...) = %s", err)
- }
-
- // Notify of a link address change
- var updatedLinkAddr tcpip.LinkAddress
- {
- entry, ok := linkRes.entries.entry(1)
- if !ok {
- t.Fatal("got linkRes.entries.entry(1) = _, false, want = true")
- }
- updatedLinkAddr = entry.LinkAddr
- }
- linkRes.entries.set(0, updatedLinkAddr)
- linkRes.neigh.handleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
-
- // Requesting the entry again should start neighbor reachability confirmation.
- //
- // Verify the entry's new link address and the new state.
- {
- e, _, err := linkRes.neigh.entry(entry.Addr, "", nil)
- if err != nil {
- t.Fatalf("linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err)
- }
- want := NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: updatedLinkAddr,
- State: Delay,
- UpdatedAt: clock.Now(),
- }
- if diff := cmp.Diff(want, e); diff != "" {
- t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
- }
- }
-
- clock.Advance(config.DelayFirstProbeTime + typicalLatency)
-
- // Verify that the neighbor is now reachable.
- {
- e, _, err := linkRes.neigh.entry(entry.Addr, "", nil)
- if err != nil {
- t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err)
- }
- want := NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: updatedLinkAddr,
- State: Reachable,
- UpdatedAt: clock.Now(),
- }
- if diff := cmp.Diff(want, e); diff != "" {
- t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
- }
- }
-}
-
-func TestNeighborCacheResolutionFailed(t *testing.T) {
- config := DefaultNUDConfigurations()
-
- nudDisp := testNUDDispatcher{}
- clock := faketime.NewManualClock()
- linkRes := newTestNeighborResolver(&nudDisp, config, clock)
-
- var requestCount uint32
- linkRes.onLinkAddressRequest = func() {
- atomic.AddUint32(&requestCount, 1)
- }
-
- entry, ok := linkRes.entries.entry(0)
- if !ok {
- t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ")
- }
-
- // First, sanity check that resolution is working
- if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil {
- t.Fatalf("addReachableEntry(...) = %s", err)
- }
-
- got, _, err := linkRes.neigh.entry(entry.Addr, "", nil)
- if err != nil {
- t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err)
- }
- want := NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAt: clock.Now(),
- }
- if diff := cmp.Diff(want, got); diff != "" {
- t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
- }
-
- // Verify address resolution fails for an unknown address.
- before := atomic.LoadUint32(&requestCount)
-
- entry.Addr += "2"
- {
- _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) {
- if diff := cmp.Diff(LinkResolutionResult{Err: &tcpip.ErrTimeout{}}, r); diff != "" {
- t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
- }
- })
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
- }
- waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes)
- clock.Advance(waitFor)
- select {
- case <-ch:
- default:
- t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr)
- }
- }
-
- maxAttempts := linkRes.neigh.config().MaxUnicastProbes
- if got, want := atomic.LoadUint32(&requestCount)-before, maxAttempts; got != want {
- t.Errorf("got link address request count = %d, want = %d", got, want)
- }
-}
-
-// TestNeighborCacheResolutionTimeout simulates sending MaxMulticastProbes
-// probes and not retrieving a confirmation before the duration defined by
-// MaxMulticastProbes * RetransmitTimer.
-func TestNeighborCacheResolutionTimeout(t *testing.T) {
- config := DefaultNUDConfigurations()
- config.RetransmitTimer = time.Millisecond // small enough to cause timeout
-
- clock := faketime.NewManualClock()
- linkRes := newTestNeighborResolver(nil, config, clock)
- // large enough to cause timeout
- linkRes.delay = time.Minute
-
- entry, ok := linkRes.entries.entry(0)
- if !ok {
- t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ")
- }
-
- _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) {
- if diff := cmp.Diff(LinkResolutionResult{Err: &tcpip.ErrTimeout{}}, r); diff != "" {
- t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
- }
- })
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
- }
- waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
- clock.Advance(waitFor)
-
- select {
- case <-ch:
- default:
- t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr)
- }
-}
-
-// TestNeighborCacheRetryResolution simulates retrying communication after
-// failing to perform address resolution.
-func TestNeighborCacheRetryResolution(t *testing.T) {
- config := DefaultNUDConfigurations()
- nudDisp := testNUDDispatcher{}
- clock := faketime.NewManualClock()
- linkRes := newTestNeighborResolver(&nudDisp, config, clock)
- // Simulate a faulty link.
- linkRes.dropReplies = true
-
- entry, ok := linkRes.entries.entry(0)
- if !ok {
- t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ")
- }
-
- // Perform address resolution with a faulty link, which will fail.
- {
- _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) {
- if diff := cmp.Diff(LinkResolutionResult{Err: &tcpip.ErrTimeout{}}, r); diff != "" {
- t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
- }
- })
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
- }
-
- {
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: "",
- State: Incomplete,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, nudDisp.mu.events)
- nudDisp.mu.events = nil
- nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- }
-
- waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
- clock.Advance(waitFor)
-
- select {
- case <-ch:
- default:
- t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr)
- }
-
- {
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: "",
- State: Unreachable,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, nudDisp.mu.events)
- nudDisp.mu.events = nil
- nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- }
-
- {
- wantEntries := []NeighborEntry{
- {
- Addr: entry.Addr,
- LinkAddr: "",
- State: Unreachable,
- UpdatedAt: clock.Now(),
- },
- }
- if diff := cmp.Diff(linkRes.neigh.entries(), wantEntries, unorderedEntriesDiffOpts()...); diff != "" {
- t.Fatalf("neighbor entries mismatch (-got, +want):\n%s", diff)
- }
- }
- }
-
- // Retry address resolution with a working link.
- linkRes.dropReplies = false
- {
- incompleteEntry, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) {
- if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Err: nil}, r); diff != "" {
- t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
- }
- })
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
- }
- if incompleteEntry.State != Incomplete {
- t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete)
- }
-
- {
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: "",
- State: Incomplete,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, nudDisp.mu.events)
- nudDisp.mu.events = nil
- nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- }
-
- clock.Advance(typicalLatency)
-
- select {
- case <-ch:
- default:
- t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr)
- }
-
- {
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: 1,
- Entry: NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, nudDisp.mu.events)
- nudDisp.mu.events = nil
- nudDisp.mu.Unlock()
- if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- }
-
- {
- gotEntry, _, err := linkRes.neigh.entry(entry.Addr, "", nil)
- if err != nil {
- t.Fatalf("linkRes.neigh.entry(%s, '', _): %s", entry.Addr, err)
- }
-
- wantEntry := NeighborEntry{
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- UpdatedAt: clock.Now(),
- }
- if diff := cmp.Diff(gotEntry, wantEntry); diff != "" {
- t.Fatalf("neighbor entry mismatch (-got, +want):\n%s", diff)
- }
- }
- }
-}
-
-func BenchmarkCacheClear(b *testing.B) {
- b.StopTimer()
- config := DefaultNUDConfigurations()
- clock := tcpip.NewStdClock()
- linkRes := newTestNeighborResolver(nil, config, clock)
- linkRes.delay = 0
-
- // Clear for every possible size of the cache
- for cacheSize := uint16(0); cacheSize < neighborCacheSize; cacheSize++ {
- // Fill the neighbor cache to capacity.
- for i := uint16(0); i < cacheSize; i++ {
- entry, ok := linkRes.entries.entry(i)
- if !ok {
- b.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i)
- }
-
- _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) {
- if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Err: nil}, r); diff != "" {
- b.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
- }
- })
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- b.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
- }
-
- select {
- case <-ch:
- default:
- b.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr)
- }
- }
-
- b.StartTimer()
- linkRes.neigh.clear()
- b.StopTimer()
- }
-}
diff --git a/pkg/tcpip/stack/neighbor_entry_list.go b/pkg/tcpip/stack/neighbor_entry_list.go
new file mode 100644
index 000000000..d78430080
--- /dev/null
+++ b/pkg/tcpip/stack/neighbor_entry_list.go
@@ -0,0 +1,221 @@
+package stack
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type neighborEntryElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (neighborEntryElementMapper) linkerFor(elem *neighborEntry) *neighborEntry { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type neighborEntryList struct {
+ head *neighborEntry
+ tail *neighborEntry
+}
+
+// Reset resets list l to the empty state.
+func (l *neighborEntryList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+//
+//go:nosplit
+func (l *neighborEntryList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+//
+//go:nosplit
+func (l *neighborEntryList) Front() *neighborEntry {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+//
+//go:nosplit
+func (l *neighborEntryList) Back() *neighborEntry {
+ return l.tail
+}
+
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+//
+//go:nosplit
+func (l *neighborEntryList) Len() (count int) {
+ for e := l.Front(); e != nil; e = (neighborEntryElementMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
+// PushFront inserts the element e at the front of list l.
+//
+//go:nosplit
+func (l *neighborEntryList) PushFront(e *neighborEntry) {
+ linker := neighborEntryElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
+ if l.head != nil {
+ neighborEntryElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+//
+//go:nosplit
+func (l *neighborEntryList) PushBack(e *neighborEntry) {
+ linker := neighborEntryElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
+ if l.tail != nil {
+ neighborEntryElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+//
+//go:nosplit
+func (l *neighborEntryList) PushBackList(m *neighborEntryList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ neighborEntryElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ neighborEntryElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+//
+//go:nosplit
+func (l *neighborEntryList) InsertAfter(b, e *neighborEntry) {
+ bLinker := neighborEntryElementMapper{}.linkerFor(b)
+ eLinker := neighborEntryElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
+
+ if a != nil {
+ neighborEntryElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+//
+//go:nosplit
+func (l *neighborEntryList) InsertBefore(a, e *neighborEntry) {
+ aLinker := neighborEntryElementMapper{}.linkerFor(a)
+ eLinker := neighborEntryElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
+
+ if b != nil {
+ neighborEntryElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+//
+//go:nosplit
+func (l *neighborEntryList) Remove(e *neighborEntry) {
+ linker := neighborEntryElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
+
+ if prev != nil {
+ neighborEntryElementMapper{}.linkerFor(prev).SetNext(next)
+ } else if l.head == e {
+ l.head = next
+ }
+
+ if next != nil {
+ neighborEntryElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else if l.tail == e {
+ l.tail = prev
+ }
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type neighborEntryEntry struct {
+ next *neighborEntry
+ prev *neighborEntry
+}
+
+// Next returns the entry that follows e in the list.
+//
+//go:nosplit
+func (e *neighborEntryEntry) Next() *neighborEntry {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *neighborEntryEntry) Prev() *neighborEntry {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+//
+//go:nosplit
+func (e *neighborEntryEntry) SetNext(elem *neighborEntry) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *neighborEntryEntry) SetPrev(elem *neighborEntry) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
deleted file mode 100644
index 59d86d6d4..000000000
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ /dev/null
@@ -1,2269 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package stack
-
-import (
- "fmt"
- "math"
- "math/rand"
- "sync"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
-)
-
-const (
- entryTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
-
- entryTestNICID tcpip.NICID = 1
-
- entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01")
- entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02")
-)
-
-var (
- entryTestAddr1 = testutil.MustParse6("a::1")
- entryTestAddr2 = testutil.MustParse6("a::2")
-)
-
-// runImmediatelyScheduledJobs runs all jobs scheduled to run at the current
-// time.
-func runImmediatelyScheduledJobs(clock *faketime.ManualClock) {
- clock.Advance(immediateDuration)
-}
-
-// The following unit tests exercise every state transition and verify its
-// behavior with RFC 4681 and RFC 7048.
-//
-// | From | To | Cause | Update | Action | Event |
-// | =========== | =========== | ========================================== | ======== | ===========| ======= |
-// | Unknown | Unknown | Confirmation w/ unknown address | | | Added |
-// | Unknown | Incomplete | Packet queued to unknown address | | Send probe | Added |
-// | Unknown | Stale | Probe | | | Added |
-// | Incomplete | Incomplete | Retransmit timer expired | | Send probe | Changed |
-// | Incomplete | Reachable | Solicited confirmation | LinkAddr | Notify | Changed |
-// | Incomplete | Stale | Unsolicited confirmation | LinkAddr | Notify | Changed |
-// | Incomplete | Stale | Probe | LinkAddr | Notify | Changed |
-// | Incomplete | Unreachable | Max probes sent without reply | | Notify | Changed |
-// | Reachable | Reachable | Confirmation w/ different isRouter flag | IsRouter | | |
-// | Reachable | Stale | Reachable timer expired | | | Changed |
-// | Reachable | Stale | Probe or confirmation w/ different address | | | Changed |
-// | Stale | Reachable | Solicited override confirmation | LinkAddr | | Changed |
-// | Stale | Reachable | Solicited confirmation w/o address | | Notify | Changed |
-// | Stale | Stale | Override confirmation | LinkAddr | | Changed |
-// | Stale | Stale | Probe w/ different address | LinkAddr | | Changed |
-// | Stale | Delay | Packet sent | | | Changed |
-// | Delay | Reachable | Upper-layer confirmation | | | Changed |
-// | Delay | Reachable | Solicited override confirmation | LinkAddr | | Changed |
-// | Delay | Reachable | Solicited confirmation w/o address | | Notify | Changed |
-// | Delay | Stale | Probe or confirmation w/ different address | | | Changed |
-// | Delay | Probe | Delay timer expired | | Send probe | Changed |
-// | Probe | Reachable | Solicited override confirmation | LinkAddr | | Changed |
-// | Probe | Reachable | Solicited confirmation w/ same address | | Notify | Changed |
-// | Probe | Reachable | Solicited confirmation w/o address | | Notify | Changed |
-// | Probe | Stale | Probe or confirmation w/ different address | | | Changed |
-// | Probe | Probe | Retransmit timer expired | | | Changed |
-// | Probe | Unreachable | Max probes sent without reply | | Notify | Changed |
-// | Unreachable | Incomplete | Packet queued | | Send probe | Changed |
-// | Unreachable | Stale | Probe w/ different address | LinkAddr | | Changed |
-
-type testEntryEventType uint8
-
-const (
- entryTestAdded testEntryEventType = iota
- entryTestChanged
- entryTestRemoved
-)
-
-func (t testEntryEventType) String() string {
- switch t {
- case entryTestAdded:
- return "add"
- case entryTestChanged:
- return "change"
- case entryTestRemoved:
- return "remove"
- default:
- return fmt.Sprintf("unknown (%d)", t)
- }
-}
-
-// Fields are exported for use with cmp.Diff.
-type testEntryEventInfo struct {
- EventType testEntryEventType
- NICID tcpip.NICID
- Entry NeighborEntry
-}
-
-func (e testEntryEventInfo) String() string {
- return fmt.Sprintf("%s event for NIC #%d, %#v", e.EventType, e.NICID, e.Entry)
-}
-
-// testNUDDispatcher implements NUDDispatcher to validate the dispatching of
-// events upon certain NUD state machine events.
-type testNUDDispatcher struct {
- mu struct {
- sync.Mutex
- events []testEntryEventInfo
- }
-}
-
-var _ NUDDispatcher = (*testNUDDispatcher)(nil)
-
-func (d *testNUDDispatcher) queueEvent(e testEntryEventInfo) {
- d.mu.Lock()
- defer d.mu.Unlock()
- d.mu.events = append(d.mu.events, e)
-}
-
-func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry NeighborEntry) {
- d.queueEvent(testEntryEventInfo{
- EventType: entryTestAdded,
- NICID: nicID,
- Entry: entry,
- })
-}
-
-func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry NeighborEntry) {
- d.queueEvent(testEntryEventInfo{
- EventType: entryTestChanged,
- NICID: nicID,
- Entry: entry,
- })
-}
-
-func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry NeighborEntry) {
- d.queueEvent(testEntryEventInfo{
- EventType: entryTestRemoved,
- NICID: nicID,
- Entry: entry,
- })
-}
-
-type entryTestLinkResolver struct {
- mu struct {
- sync.Mutex
- probes []entryTestProbeInfo
- }
-}
-
-var _ LinkAddressResolver = (*entryTestLinkResolver)(nil)
-
-type entryTestProbeInfo struct {
- RemoteAddress tcpip.Address
- RemoteLinkAddress tcpip.LinkAddress
- LocalAddress tcpip.Address
-}
-
-func (p entryTestProbeInfo) String() string {
- return fmt.Sprintf("probe with RemoteAddress=%q, RemoteLinkAddress=%q, LocalAddress=%q", p.RemoteAddress, p.RemoteLinkAddress, p.LocalAddress)
-}
-
-// LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts
-// to the local network if linkAddr is the zero value.
-func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error {
- r.mu.Lock()
- defer r.mu.Unlock()
- r.mu.probes = append(r.mu.probes, entryTestProbeInfo{
- RemoteAddress: targetAddr,
- RemoteLinkAddress: linkAddr,
- LocalAddress: localAddr,
- })
- return nil
-}
-
-// ResolveStaticAddress attempts to resolve address without sending requests.
-// It either resolves the name immediately or returns the empty LinkAddress.
-func (*entryTestLinkResolver) ResolveStaticAddress(tcpip.Address) (tcpip.LinkAddress, bool) {
- return "", false
-}
-
-// LinkAddressProtocol returns the network protocol of the addresses this
-// resolver can resolve.
-func (*entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
- return entryTestNetNumber
-}
-
-func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *faketime.ManualClock) {
- clock := faketime.NewManualClock()
- disp := testNUDDispatcher{}
- nic := nic{
- LinkEndpoint: nil, // entryTestLinkResolver doesn't use a LinkEndpoint
-
- id: entryTestNICID,
- stack: &Stack{
- clock: clock,
- nudDisp: &disp,
- nudConfigs: c,
- randomGenerator: rand.New(rand.NewSource(time.Now().UnixNano())),
- },
- stats: makeNICStats(tcpip.NICStats{}.FillIn()),
- }
- netEP := (&testIPv6Protocol{}).NewEndpoint(&nic, nil)
- nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{
- header.IPv6ProtocolNumber: netEP,
- }
-
- var linkRes entryTestLinkResolver
- // Stub out the neighbor cache to verify deletion from the cache.
- l := &linkResolver{
- resolver: &linkRes,
- }
- l.neigh.init(&nic, &linkRes)
-
- entry := newNeighborEntry(&l.neigh, entryTestAddr1 /* remoteAddr */, l.neigh.state)
- l.neigh.mu.Lock()
- l.neigh.mu.cache[entryTestAddr1] = entry
- l.neigh.mu.Unlock()
- nic.linkAddrResolvers = map[tcpip.NetworkProtocolNumber]*linkResolver{
- header.IPv6ProtocolNumber: l,
- }
-
- return entry, &disp, &linkRes, clock
-}
-
-// TestEntryInitiallyUnknown verifies that the state of a newly created
-// neighborEntry is Unknown.
-func TestEntryInitiallyUnknown(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- e.mu.Lock()
- if e.mu.neigh.State != Unknown {
- t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unknown)
- }
- e.mu.Unlock()
-
- clock.Advance(c.RetransmitTimer)
-
- // No probes should have been sent.
- linkRes.mu.Lock()
- diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
-
- // No events should have been dispatched.
- nudDisp.mu.Lock()
- if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- nudDisp.mu.Unlock()
-}
-
-func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- e.mu.Lock()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if e.mu.neigh.State != Unknown {
- t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unknown)
- }
- e.mu.Unlock()
-
- clock.Advance(time.Hour)
-
- // No probes should have been sent.
- linkRes.mu.Lock()
- diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
-
- // No events should have been dispatched.
- nudDisp.mu.Lock()
- if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- nudDisp.mu.Unlock()
-}
-
-func TestEntryUnknownToIncomplete(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToIncomplete(...) = %s", err)
- }
-}
-
-func unknownToIncomplete(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error {
- if err := func() error {
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if e.mu.neigh.State != Unknown {
- return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unknown)
- }
- e.handlePacketQueuedLocked(entryTestAddr2)
- if e.mu.neigh.State != Incomplete {
- return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete)
- }
- return nil
- }(); err != nil {
- return err
- }
-
- runImmediatelyScheduledJobs(clock)
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(wantProbes, linkRes.mu.probes)
- linkRes.mu.probes = nil
- linkRes.mu.Unlock()
- if diff != "" {
- return fmt.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
- UpdatedAt: clock.Now(),
- },
- },
- }
- {
- nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, nudDisp.mu.events)
- nudDisp.mu.events = nil
- nudDisp.mu.Unlock()
- if diff != "" {
- return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- }
- return nil
-}
-
-func TestEntryUnknownToStale(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToStale(...) = %s", err)
- }
-}
-
-func unknownToStale(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error {
- if err := func() error {
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if e.mu.neigh.State != Unknown {
- return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unknown)
- }
- e.handleProbeLocked(entryTestLinkAddr1)
- if e.mu.neigh.State != Stale {
- return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
- }
- return nil
- }(); err != nil {
- return err
- }
-
- // No probes should have been sent.
- runImmediatelyScheduledJobs(clock)
- {
- linkRes.mu.Lock()
- diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes)
- linkRes.mu.Unlock()
- if diff != "" {
- return fmt.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestAdded,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
- UpdatedAt: clock.Now(),
- },
- },
- }
- {
- nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, nudDisp.mu.events)
- nudDisp.mu.events = nil
- nudDisp.mu.Unlock()
- if diff != "" {
- return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- }
-
- return nil
-}
-
-func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
- c := DefaultNUDConfigurations()
- c.MaxMulticastProbes = 3
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToIncomplete(...) = %s", err)
- }
-
- // UpdatedAt should remain the same during address resolution.
- e.mu.Lock()
- startedAt := e.mu.neigh.UpdatedAt
- e.mu.Unlock()
-
- // Wait for the rest of the reachability probe transmissions, signifying
- // Incomplete to Incomplete transitions.
- for i := uint32(1); i < c.MaxMulticastProbes; i++ {
- clock.Advance(c.RetransmitTimer)
-
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(wantProbes, linkRes.mu.probes)
- linkRes.mu.probes = nil
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
-
- e.mu.Lock()
- if got, want := e.mu.neigh.UpdatedAt, startedAt; got != want {
- t.Errorf("got e.mu.neigh.UpdatedAt = %q, want = %q", got, want)
- }
- e.mu.Unlock()
- }
-
- // UpdatedAt should change after failing address resolution. Timing out after
- // sending the last probe transitions the entry to Unreachable.
- clock.Advance(c.RetransmitTimer)
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Unreachable,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- nudDisp.mu.Unlock()
-}
-
-func TestEntryIncompleteToReachable(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToIncomplete(...) = %s", err)
- }
- if err := incompleteToReachable(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("incompleteToReachable(...) = %s", err)
- }
-}
-
-func incompleteToReachableWithFlags(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock, flags ReachabilityConfirmationFlags) error {
- if err := func() error {
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if e.mu.neigh.State != Incomplete {
- return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete)
- }
- e.handleConfirmationLocked(entryTestLinkAddr1, flags)
- if e.mu.neigh.State != Reachable {
- return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
- }
- if e.mu.neigh.LinkAddr != entryTestLinkAddr1 {
- return fmt.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1)
- }
- return nil
- }(); err != nil {
- return err
- }
-
- // No probes should have been sent.
- runImmediatelyScheduledJobs(clock)
- {
- linkRes.mu.Lock()
- diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes)
- linkRes.mu.Unlock()
- if diff != "" {
- return fmt.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
- UpdatedAt: clock.Now(),
- },
- },
- }
- {
- nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, nudDisp.mu.events)
- nudDisp.mu.events = nil
- nudDisp.mu.Unlock()
- if diff != "" {
- return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- }
-
- return nil
-}
-
-func incompleteToReachable(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error {
- if err := incompleteToReachableWithFlags(e, nudDisp, linkRes, clock, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- }); err != nil {
- return err
- }
-
- e.mu.Lock()
- isRouter := e.mu.isRouter
- e.mu.Unlock()
- if isRouter {
- return fmt.Errorf("got e.mu.isRouter = %t, want = false", isRouter)
- }
-
- return nil
-}
-
-func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToIncomplete(...) = %s", err)
- }
- if err := incompleteToReachableWithRouterFlag(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("incompleteToReachableWithRouterFlag(...) = %s", err)
- }
-}
-
-func incompleteToReachableWithRouterFlag(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error {
- if err := incompleteToReachableWithFlags(e, nudDisp, linkRes, clock, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: true,
- }); err != nil {
- return err
- }
-
- e.mu.Lock()
- isRouter := e.mu.isRouter
- e.mu.Unlock()
- if !isRouter {
- return fmt.Errorf("got e.mu.isRouter = %t, want = true", isRouter)
- }
-
- return nil
-}
-
-func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToIncomplete(...) = %s", err)
- }
-
- e.mu.Lock()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if e.mu.neigh.State != Stale {
- t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
- }
- if e.mu.isRouter {
- t.Errorf("got e.mu.isRouter = %t, want = false", e.mu.isRouter)
- }
- e.mu.Unlock()
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- nudDisp.mu.Unlock()
-}
-
-func TestEntryIncompleteToStaleWhenProbe(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToIncomplete(...) = %s", err)
- }
-
- e.mu.Lock()
- e.handleProbeLocked(entryTestLinkAddr1)
- if e.mu.neigh.State != Stale {
- t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
- }
- e.mu.Unlock()
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- nudDisp.mu.Unlock()
-}
-
-func TestEntryIncompleteToUnreachable(t *testing.T) {
- c := DefaultNUDConfigurations()
- c.MaxMulticastProbes = 3
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToIncomplete(...) = %s", err)
- }
- if err := incompleteToUnreachable(c, e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("incompleteToUnreachable(...) = %s", err)
- }
-}
-
-func incompleteToUnreachable(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error {
- {
- e.mu.Lock()
- state := e.mu.neigh.State
- e.mu.Unlock()
- if state != Incomplete {
- return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Incomplete)
- }
- }
-
- // The first probe was sent in the transition from Unknown to Incomplete.
- clock.Advance(c.RetransmitTimer)
-
- // Observe each subsequent multicast probe transmitted.
- for i := uint32(1); i < c.MaxMulticastProbes; i++ {
- wantProbes := []entryTestProbeInfo{{
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: "",
- LocalAddress: entryTestAddr2,
- }}
- linkRes.mu.Lock()
- diff := cmp.Diff(wantProbes, linkRes.mu.probes)
- linkRes.mu.probes = nil
- linkRes.mu.Unlock()
- if diff != "" {
- return fmt.Errorf("link address resolver probe #%d mismatch (-want, +got):\n%s", i+1, diff)
- }
-
- e.mu.Lock()
- state := e.mu.neigh.State
- e.mu.Unlock()
- if state != Incomplete {
- return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Incomplete)
- }
-
- clock.Advance(c.RetransmitTimer)
- }
-
- {
- e.mu.Lock()
- state := e.mu.neigh.State
- e.mu.Unlock()
- if state != Unreachable {
- return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Unreachable)
- }
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Unreachable,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, nudDisp.mu.events)
- nudDisp.mu.events = nil
- nudDisp.mu.Unlock()
- if diff != "" {
- return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
-
- return nil
-}
-
-type testLocker struct{}
-
-var _ sync.Locker = (*testLocker)(nil)
-
-func (*testLocker) Lock() {}
-func (*testLocker) Unlock() {}
-
-func TestEntryReachableToReachableClearsRouterWhenConfirmationWithoutRouter(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToIncomplete(...) = %s", err)
- }
- if err := incompleteToReachableWithRouterFlag(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("incompleteToReachableWithRouterFlag(...) = %s", err)
- }
-
- e.mu.Lock()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if e.mu.neigh.State != Reachable {
- t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
- }
- if got, want := e.mu.isRouter, false; got != want {
- t.Errorf("got e.mu.isRouter = %t, want = %t", got, want)
- }
- ipv6EP := e.cache.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint)
- if ipv6EP.invalidatedRtr != e.mu.neigh.Addr {
- t.Errorf("got ipv6EP.invalidatedRtr = %s, want = %s", ipv6EP.invalidatedRtr, e.mu.neigh.Addr)
- }
- e.mu.Unlock()
-
- // No probes should have been sent.
- runImmediatelyScheduledJobs(clock)
- {
- linkRes.mu.Lock()
- diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
- }
-
- // No events should have been dispatched.
- nudDisp.mu.Lock()
- diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events)
- nudDisp.mu.Unlock()
- if diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
-}
-
-func TestEntryReachableToReachableWhenProbeWithSameAddress(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToIncomplete(...) = %s", err)
- }
- if err := incompleteToReachable(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("incompleteToReachable(...) = %s", err)
- }
-
- e.mu.Lock()
- e.handleProbeLocked(entryTestLinkAddr1)
- if e.mu.neigh.State != Reachable {
- t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
- }
- if e.mu.neigh.LinkAddr != entryTestLinkAddr1 {
- t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1)
- }
- e.mu.Unlock()
-
- // No probes should have been sent.
- runImmediatelyScheduledJobs(clock)
- {
- linkRes.mu.Lock()
- diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
- }
-
- // No events should have been dispatched.
- nudDisp.mu.Lock()
- diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events)
- nudDisp.mu.Unlock()
- if diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
-}
-
-func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
- c := DefaultNUDConfigurations()
- // Eliminate random factors from ReachableTime computation so the transition
- // from Stale to Reachable will only take BaseReachableTime duration.
- c.MinRandomFactor = 1
- c.MaxRandomFactor = 1
-
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToIncomplete(...) = %s", err)
- }
- if err := incompleteToReachable(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("incompleteToReachable(...) = %s", err)
- }
- if err := reachableToStale(c, e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("reachableToStale(...) = %s", err)
- }
-}
-
-// reachableToStale transitions a neighborEntry in Reachable state to Stale
-// state. Depends on the elimination of random factors in the ReachableTime
-// computation.
-//
-// c.MinRandomFactor = 1
-// c.MaxRandomFactor = 1
-func reachableToStale(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error {
- // Ensure there are no random factors in the ReachableTime computation.
- if c.MinRandomFactor != 1 {
- return fmt.Errorf("got c.MinRandomFactor = %f, want = 1", c.MinRandomFactor)
- }
- if c.MaxRandomFactor != 1 {
- return fmt.Errorf("got c.MaxRandomFactor = %f, want = 1", c.MaxRandomFactor)
- }
-
- {
- e.mu.Lock()
- state := e.mu.neigh.State
- e.mu.Unlock()
- if state != Reachable {
- return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Reachable)
- }
- }
-
- clock.Advance(c.BaseReachableTime)
-
- {
- e.mu.Lock()
- state := e.mu.neigh.State
- e.mu.Unlock()
- if state != Stale {
- return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Stale)
- }
- }
-
- // No probes should have been sent.
- runImmediatelyScheduledJobs(clock)
- {
- linkRes.mu.Lock()
- diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes)
- linkRes.mu.Unlock()
- if diff != "" {
- return fmt.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
- UpdatedAt: clock.Now(),
- },
- },
- }
- {
-
- nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, nudDisp.mu.events)
- nudDisp.mu.events = nil
- nudDisp.mu.Unlock()
- if diff != "" {
- return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- }
-
- return nil
-}
-
-func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToIncomplete(...) = %s", err)
- }
- if err := incompleteToReachable(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("incompleteToReachable(...) = %s", err)
- }
-
- e.mu.Lock()
- e.handleProbeLocked(entryTestLinkAddr2)
- if e.mu.neigh.State != Stale {
- t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
- }
- e.mu.Unlock()
-
- // No probes should have been sent.
- runImmediatelyScheduledJobs(clock)
- {
- linkRes.mu.Lock()
- diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- nudDisp.mu.Unlock()
-}
-
-func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToIncomplete(...) = %s", err)
- }
- if err := incompleteToReachable(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("incompleteToReachable(...) = %s", err)
- }
-
- e.mu.Lock()
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if e.mu.neigh.State != Stale {
- t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
- }
- e.mu.Unlock()
-
- // No probes should have been sent.
- runImmediatelyScheduledJobs(clock)
- {
- linkRes.mu.Lock()
- diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- nudDisp.mu.Unlock()
-}
-
-func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToIncomplete(...) = %s", err)
- }
- if err := incompleteToReachable(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("incompleteToReachable(...) = %s", err)
- }
-
- e.mu.Lock()
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
- if e.mu.neigh.State != Stale {
- t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
- }
- e.mu.Unlock()
-
- // No probes should have been sent.
- runImmediatelyScheduledJobs(clock)
- {
- linkRes.mu.Lock()
- diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- nudDisp.mu.Unlock()
-}
-
-func TestEntryStaleToStaleWhenProbeWithSameAddress(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToStale(...) = %s", err)
- }
-
- e.mu.Lock()
- e.handleProbeLocked(entryTestLinkAddr1)
- if e.mu.neigh.State != Stale {
- t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
- }
- if e.mu.neigh.LinkAddr != entryTestLinkAddr1 {
- t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1)
- }
- e.mu.Unlock()
-
- // No probes should have been sent.
- runImmediatelyScheduledJobs(clock)
- {
- linkRes.mu.Lock()
- diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
- }
-
- // No events should have been dispatched.
- nudDisp.mu.Lock()
- if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- nudDisp.mu.Unlock()
-}
-
-func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToStale(...) = %s", err)
- }
-
- e.mu.Lock()
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: true,
- IsRouter: false,
- })
- if e.mu.neigh.State != Reachable {
- t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
- }
- if e.mu.neigh.LinkAddr != entryTestLinkAddr2 {
- t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2)
- }
- e.mu.Unlock()
-
- // No probes should have been sent.
- runImmediatelyScheduledJobs(clock)
- {
- linkRes.mu.Lock()
- diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Reachable,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- nudDisp.mu.Unlock()
-}
-
-func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToStale(...) = %s", err)
- }
-
- e.mu.Lock()
- e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if e.mu.neigh.State != Reachable {
- t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
- }
- if e.mu.neigh.LinkAddr != entryTestLinkAddr1 {
- t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1)
- }
- e.mu.Unlock()
-
- // No probes should have been sent.
- runImmediatelyScheduledJobs(clock)
- {
- linkRes.mu.Lock()
- diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- nudDisp.mu.Unlock()
-}
-
-func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToStale(...) = %s", err)
- }
-
- e.mu.Lock()
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
- if e.mu.neigh.State != Stale {
- t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
- }
- if e.mu.neigh.LinkAddr != entryTestLinkAddr2 {
- t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2)
- }
- e.mu.Unlock()
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- nudDisp.mu.Unlock()
-}
-
-func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToStale(...) = %s", err)
- }
-
- e.mu.Lock()
- e.handleProbeLocked(entryTestLinkAddr2)
- if e.mu.neigh.State != Stale {
- t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
- }
- if e.mu.neigh.LinkAddr != entryTestLinkAddr2 {
- t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2)
- }
- e.mu.Unlock()
-
- // No probes should have been sent.
- runImmediatelyScheduledJobs(clock)
- {
- linkRes.mu.Lock()
- diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- nudDisp.mu.Unlock()
-}
-
-func TestEntryStaleToDelay(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToStale(...) = %s", err)
- }
- if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("staleToDelay(...) = %s", err)
- }
-}
-
-func staleToDelay(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error {
- if err := func() error {
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if e.mu.neigh.State != Stale {
- return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
- }
- e.handlePacketQueuedLocked(entryTestAddr2)
- if e.mu.neigh.State != Delay {
- return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay)
- }
- return nil
- }(); err != nil {
- return err
- }
-
- // No probes should have been sent.
- runImmediatelyScheduledJobs(clock)
- {
- linkRes.mu.Lock()
- diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes)
- linkRes.mu.Unlock()
- if diff != "" {
- return fmt.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, nudDisp.mu.events)
- nudDisp.mu.events = nil
- nudDisp.mu.Unlock()
- if diff != "" {
- return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
-
- return nil
-}
-
-func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToStale(...) = %s", err)
- }
- if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("staleToDelay(...) = %s", err)
- }
-
- e.mu.Lock()
- e.handleUpperLevelConfirmationLocked()
- if e.mu.neigh.State != Reachable {
- t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
- }
- e.mu.Unlock()
-
- // No probes should have been sent.
- runImmediatelyScheduledJobs(clock)
- {
- linkRes.mu.Lock()
- diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- nudDisp.mu.Unlock()
-}
-
-func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToStale(...) = %s", err)
- }
- if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("staleToDelay(...) = %s", err)
- }
-
- e.mu.Lock()
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: true,
- IsRouter: false,
- })
- if e.mu.neigh.State != Reachable {
- t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
- }
- if e.mu.neigh.LinkAddr != entryTestLinkAddr2 {
- t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2)
- }
- e.mu.Unlock()
-
- // No probes should have been sent.
- runImmediatelyScheduledJobs(clock)
- {
- linkRes.mu.Lock()
- diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Reachable,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- nudDisp.mu.Unlock()
-}
-
-func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToStale(...) = %s", err)
- }
- if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("staleToDelay(...) = %s", err)
- }
-
- e.mu.Lock()
- e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if e.mu.neigh.State != Reachable {
- t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
- }
- if e.mu.neigh.LinkAddr != entryTestLinkAddr1 {
- t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1)
- }
- e.mu.Unlock()
-
- // No probes should have been sent.
- runImmediatelyScheduledJobs(clock)
- {
- linkRes.mu.Lock()
- diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- nudDisp.mu.Unlock()
-}
-
-func TestEntryDelayToDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToStale(...) = %s", err)
- }
- if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("staleToDelay(...) = %s", err)
- }
-
- e.mu.Lock()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
- if e.mu.neigh.State != Delay {
- t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay)
- }
- if e.mu.neigh.LinkAddr != entryTestLinkAddr1 {
- t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1)
- }
- e.mu.Unlock()
-
- // No probes should have been sent.
- runImmediatelyScheduledJobs(clock)
- {
- linkRes.mu.Lock()
- diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
- }
-
- // No events should have been dispatched.
- nudDisp.mu.Lock()
- if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- nudDisp.mu.Unlock()
-}
-
-func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToStale(...) = %s", err)
- }
- if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("staleToDelay(...) = %s", err)
- }
-
- e.mu.Lock()
- e.handleProbeLocked(entryTestLinkAddr2)
- if e.mu.neigh.State != Stale {
- t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
- }
- e.mu.Unlock()
-
- // No probes should have been sent.
- runImmediatelyScheduledJobs(clock)
- {
- linkRes.mu.Lock()
- diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- nudDisp.mu.Unlock()
-}
-
-func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToStale(...) = %s", err)
- }
- if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("staleToDelay(...) = %s", err)
- }
-
- e.mu.Lock()
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
- if e.mu.neigh.State != Stale {
- t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
- }
- e.mu.Unlock()
-
- // No probes should have been sent.
- runImmediatelyScheduledJobs(clock)
- {
- linkRes.mu.Lock()
- diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- nudDisp.mu.Unlock()
-}
-
-func TestEntryDelayToProbe(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToStale(...) = %s", err)
- }
- if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("staleToDelay(...) = %s", err)
- }
- if err := delayToProbe(c, e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("delayToProbe(...) = %s", err)
- }
-}
-
-func delayToProbe(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error {
- {
- e.mu.Lock()
- state := e.mu.neigh.State
- e.mu.Unlock()
- if state != Delay {
- return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Delay)
- }
- }
-
- // Wait for the first unicast probe to be transmitted, marking the
- // transition from Delay to Probe.
- clock.Advance(c.DelayFirstProbeTime)
-
- {
- e.mu.Lock()
- state := e.mu.neigh.State
- e.mu.Unlock()
- if state != Probe {
- return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Probe)
- }
- }
-
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- }
- {
- linkRes.mu.Lock()
- diff := cmp.Diff(wantProbes, linkRes.mu.probes)
- linkRes.mu.probes = nil
- linkRes.mu.Unlock()
- if diff != "" {
- return fmt.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
- UpdatedAt: clock.Now(),
- },
- },
- }
- {
- nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, nudDisp.mu.events)
- nudDisp.mu.events = nil
- nudDisp.mu.Unlock()
- if diff != "" {
- return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- }
-
- return nil
-}
-
-func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToStale(...) = %s", err)
- }
- if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("staleToDelay(...) = %s", err)
- }
- if err := delayToProbe(c, e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("delayToProbe(...) = %s", err)
- }
-
- e.mu.Lock()
- e.handleProbeLocked(entryTestLinkAddr2)
- if e.mu.neigh.State != Stale {
- t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
- }
- e.mu.Unlock()
-
- // No probes should have been sent.
- runImmediatelyScheduledJobs(clock)
- {
- linkRes.mu.Lock()
- diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- nudDisp.mu.Unlock()
-}
-
-func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToStale(...) = %s", err)
- }
- if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("staleToDelay(...) = %s", err)
- }
- if err := delayToProbe(c, e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("delayToProbe(...) = %s", err)
- }
-
- e.mu.Lock()
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
- if e.mu.neigh.State != Stale {
- t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
- }
- e.mu.Unlock()
-
- // No probes should have been sent.
- runImmediatelyScheduledJobs(clock)
- {
- linkRes.mu.Lock()
- diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- if diff := cmp.Diff(wantEvents, nudDisp.mu.events); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- nudDisp.mu.Unlock()
-}
-
-func TestEntryProbeToProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToStale(...) = %s", err)
- }
- if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("staleToDelay(...) = %s", err)
- }
- if err := delayToProbe(c, e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("delayToProbe(...) = %s", err)
- }
-
- e.mu.Lock()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
- if e.mu.neigh.State != Probe {
- t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe)
- }
- if got, want := e.mu.neigh.LinkAddr, entryTestLinkAddr1; got != want {
- t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", got, want)
- }
- e.mu.Unlock()
-
- // No probes should have been sent.
- runImmediatelyScheduledJobs(clock)
- {
- linkRes.mu.Lock()
- diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
- }
-
- // No events should have been dispatched.
- nudDisp.mu.Lock()
- diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.mu.events)
- nudDisp.mu.Unlock()
- if diff != "" {
- t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
-}
-
-func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToStale(...) = %s", err)
- }
- if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("staleToDelay(...) = %s", err)
- }
- if err := delayToProbe(c, e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("delayToProbe(...) = %s", err)
- }
- if err := probeToReachableWithOverride(e, nudDisp, linkRes, clock, entryTestLinkAddr2); err != nil {
- t.Fatalf("probeToReachableWithOverride(...) = %s", err)
- }
-}
-
-func probeToReachableWithFlags(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) error {
- if err := func() error {
- e.mu.Lock()
- defer e.mu.Unlock()
-
- prevLinkAddr := e.mu.neigh.LinkAddr
- if e.mu.neigh.State != Probe {
- return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe)
- }
- e.handleConfirmationLocked(linkAddr, flags)
-
- if e.mu.neigh.State != Reachable {
- return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
- }
- if linkAddr == "" {
- linkAddr = prevLinkAddr
- }
- if e.mu.neigh.LinkAddr != linkAddr {
- return fmt.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, linkAddr)
- }
- return nil
- }(); err != nil {
- return err
- }
-
- // No probes should have been sent.
- runImmediatelyScheduledJobs(clock)
- {
- linkRes.mu.Lock()
- diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.mu.probes)
- linkRes.mu.Unlock()
- if diff != "" {
- return fmt.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: linkAddr,
- State: Reachable,
- UpdatedAt: clock.Now(),
- },
- },
- }
- {
- nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, nudDisp.mu.events)
- nudDisp.mu.events = nil
- nudDisp.mu.Unlock()
- if diff != "" {
- return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- }
-
- return nil
-}
-
-func probeToReachableWithOverride(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock, linkAddr tcpip.LinkAddress) error {
- return probeToReachableWithFlags(e, nudDisp, linkRes, clock, linkAddr, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: true,
- IsRouter: false,
- })
-}
-
-func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToStale(...) = %s", err)
- }
- if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("staleToDelay(...) = %s", err)
- }
- if err := delayToProbe(c, e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("delayToProbe(...) = %s", err)
- }
- if err := probeToReachable(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("probeToReachable(...) = %s", err)
- }
-}
-
-func probeToReachable(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error {
- return probeToReachableWithFlags(e, nudDisp, linkRes, clock, entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
-}
-
-func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) {
- c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToStale(...) = %s", err)
- }
- if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("staleToDelay(...) = %s", err)
- }
- if err := delayToProbe(c, e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("delayToProbe(...) = %s", err)
- }
- if err := probeToReachableWithoutAddress(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("probeToReachableWithoutAddress(...) = %s", err)
- }
-}
-
-func probeToReachableWithoutAddress(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error {
- return probeToReachableWithFlags(e, nudDisp, linkRes, clock, "" /* linkAddr */, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
-}
-
-func TestEntryProbeToUnreachable(t *testing.T) {
- c := DefaultNUDConfigurations()
- c.MaxMulticastProbes = 3
- c.MaxUnicastProbes = 3
- c.DelayFirstProbeTime = c.RetransmitTimer
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToStale(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToStale(...) = %s", err)
- }
- if err := staleToDelay(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("staleToDelay(...) = %s", err)
- }
- if err := delayToProbe(c, e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("delayToProbe(...) = %s", err)
- }
- if err := probeToUnreachable(c, e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("probeToUnreachable(...) = %s", err)
- }
-}
-
-func probeToUnreachable(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error {
- {
- e.mu.Lock()
- state := e.mu.neigh.State
- e.mu.Unlock()
- if state != Probe {
- return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Probe)
- }
- }
-
- // The first probe was sent in the transition from Delay to Probe.
- clock.Advance(c.RetransmitTimer)
-
- // Observe each subsequent unicast probe transmitted.
- for i := uint32(1); i < c.MaxUnicastProbes; i++ {
- wantProbes := []entryTestProbeInfo{{
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- }}
- linkRes.mu.Lock()
- diff := cmp.Diff(wantProbes, linkRes.mu.probes)
- linkRes.mu.probes = nil
- linkRes.mu.Unlock()
- if diff != "" {
- return fmt.Errorf("link address resolver probe #%d mismatch (-want, +got):\n%s", i+1, diff)
- }
-
- e.mu.Lock()
- state := e.mu.neigh.State
- e.mu.Unlock()
- if state != Probe {
- return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Probe)
- }
-
- clock.Advance(c.RetransmitTimer)
- }
-
- {
- e.mu.Lock()
- state := e.mu.neigh.State
- e.mu.Unlock()
- if state != Unreachable {
- return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", state, Unreachable)
- }
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Unreachable,
- UpdatedAt: clock.Now(),
- },
- },
- }
- nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, nudDisp.mu.events)
- nudDisp.mu.events = nil
- nudDisp.mu.Unlock()
- if diff != "" {
- return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
-
- return nil
-}
-
-func TestEntryUnreachableToIncomplete(t *testing.T) {
- c := DefaultNUDConfigurations()
- c.MaxMulticastProbes = 3
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToIncomplete(...) = %s", err)
- }
- if err := incompleteToUnreachable(c, e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("incompleteToUnreachable(...) = %s", err)
- }
- if err := unreachableToIncomplete(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unreachableToIncomplete(...) = %s", err)
- }
-}
-
-func unreachableToIncomplete(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTestLinkResolver, clock *faketime.ManualClock) error {
- if err := func() error {
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if e.mu.neigh.State != Unreachable {
- return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unreachable)
- }
- e.handlePacketQueuedLocked(entryTestAddr2)
- if e.mu.neigh.State != Incomplete {
- return fmt.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete)
- }
- return nil
- }(); err != nil {
- return err
- }
-
- runImmediatelyScheduledJobs(clock)
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(wantProbes, linkRes.mu.probes)
- linkRes.mu.probes = nil
- linkRes.mu.Unlock()
- if diff != "" {
- return fmt.Errorf("link address resolver probes mismatch (-want, +got):\n%s", diff)
- }
-
- wantEvents := []testEntryEventInfo{
- {
- EventType: entryTestChanged,
- NICID: entryTestNICID,
- Entry: NeighborEntry{
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
- UpdatedAt: clock.Now(),
- },
- },
- }
- {
- nudDisp.mu.Lock()
- diff := cmp.Diff(wantEvents, nudDisp.mu.events)
- nudDisp.mu.events = nil
- nudDisp.mu.Unlock()
- if diff != "" {
- return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
- }
- }
- return nil
-}
-
-func TestEntryUnreachableToStale(t *testing.T) {
- c := DefaultNUDConfigurations()
- c.MaxMulticastProbes = 3
- // Eliminate random factors from ReachableTime computation so the transition
- // from Stale to Reachable will only take BaseReachableTime duration.
- c.MinRandomFactor = 1
- c.MaxRandomFactor = 1
-
- e, nudDisp, linkRes, clock := entryTestSetup(c)
-
- if err := unknownToIncomplete(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unknownToIncomplete(...) = %s", err)
- }
- if err := incompleteToUnreachable(c, e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("incompleteToUnreachable(...) = %s", err)
- }
- if err := unreachableToIncomplete(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("unreachableToIncomplete(...) = %s", err)
- }
- if err := incompleteToReachable(e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("incompleteToReachable(...) = %s", err)
- }
- if err := reachableToStale(c, e, nudDisp, linkRes, clock); err != nil {
- t.Fatalf("reachableToStale(...) = %s", err)
- }
-}
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
deleted file mode 100644
index c8ad93f29..000000000
--- a/pkg/tcpip/stack/nic_test.go
+++ /dev/null
@@ -1,219 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package stack
-
-import (
- "reflect"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
-)
-
-var _ AddressableEndpoint = (*testIPv6Endpoint)(nil)
-var _ NetworkEndpoint = (*testIPv6Endpoint)(nil)
-var _ NDPEndpoint = (*testIPv6Endpoint)(nil)
-
-// An IPv6 NetworkEndpoint that throws away outgoing packets.
-//
-// We use this instead of ipv6.endpoint because the ipv6 package depends on
-// the stack package which this test lives in, causing a cyclic dependency.
-type testIPv6Endpoint struct {
- AddressableEndpointState
-
- nic NetworkInterface
- protocol *testIPv6Protocol
-
- invalidatedRtr tcpip.Address
-}
-
-func (*testIPv6Endpoint) Enable() tcpip.Error {
- return nil
-}
-
-func (*testIPv6Endpoint) Enabled() bool {
- return true
-}
-
-func (*testIPv6Endpoint) Disable() {}
-
-// DefaultTTL implements NetworkEndpoint.DefaultTTL.
-func (*testIPv6Endpoint) DefaultTTL() uint8 {
- return 0
-}
-
-// MTU implements NetworkEndpoint.MTU.
-func (e *testIPv6Endpoint) MTU() uint32 {
- return e.nic.MTU() - header.IPv6MinimumSize
-}
-
-// MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength.
-func (e *testIPv6Endpoint) MaxHeaderLength() uint16 {
- return e.nic.MaxHeaderLength() + header.IPv6MinimumSize
-}
-
-// WritePacket implements NetworkEndpoint.WritePacket.
-func (*testIPv6Endpoint) WritePacket(*Route, NetworkHeaderParams, *PacketBuffer) tcpip.Error {
- return nil
-}
-
-// WritePackets implements NetworkEndpoint.WritePackets.
-func (*testIPv6Endpoint) WritePackets(*Route, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) {
- // Our tests don't use this so we don't support it.
- return 0, &tcpip.ErrNotSupported{}
-}
-
-// WriteHeaderIncludedPacket implements
-// NetworkEndpoint.WriteHeaderIncludedPacket.
-func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) tcpip.Error {
- // Our tests don't use this so we don't support it.
- return &tcpip.ErrNotSupported{}
-}
-
-// HandlePacket implements NetworkEndpoint.HandlePacket.
-func (*testIPv6Endpoint) HandlePacket(*PacketBuffer) {}
-
-// Close implements NetworkEndpoint.Close.
-func (e *testIPv6Endpoint) Close() {
- e.AddressableEndpointState.Cleanup()
-}
-
-// NetworkProtocolNumber implements NetworkEndpoint.NetworkProtocolNumber.
-func (*testIPv6Endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
- return header.IPv6ProtocolNumber
-}
-
-func (e *testIPv6Endpoint) InvalidateDefaultRouter(rtr tcpip.Address) {
- e.invalidatedRtr = rtr
-}
-
-// Stats implements NetworkEndpoint.
-func (*testIPv6Endpoint) Stats() NetworkEndpointStats {
- return &testIPv6EndpointStats{}
-}
-
-var _ NetworkEndpointStats = (*testIPv6EndpointStats)(nil)
-
-type testIPv6EndpointStats struct{}
-
-// IsNetworkEndpointStats implements stack.NetworkEndpointStats.
-func (*testIPv6EndpointStats) IsNetworkEndpointStats() {}
-
-// We use this instead of ipv6.protocol because the ipv6 package depends on
-// the stack package which this test lives in, causing a cyclic dependency.
-type testIPv6Protocol struct{}
-
-// Number implements NetworkProtocol.Number.
-func (*testIPv6Protocol) Number() tcpip.NetworkProtocolNumber {
- return header.IPv6ProtocolNumber
-}
-
-// MinimumPacketSize implements NetworkProtocol.MinimumPacketSize.
-func (*testIPv6Protocol) MinimumPacketSize() int {
- return header.IPv6MinimumSize
-}
-
-// ParseAddresses implements NetworkProtocol.ParseAddresses.
-func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
- h := header.IPv6(v)
- return h.SourceAddress(), h.DestinationAddress()
-}
-
-// NewEndpoint implements NetworkProtocol.NewEndpoint.
-func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ TransportDispatcher) NetworkEndpoint {
- e := &testIPv6Endpoint{
- nic: nic,
- protocol: p,
- }
- e.AddressableEndpointState.Init(e)
- return e
-}
-
-// SetOption implements NetworkProtocol.SetOption.
-func (*testIPv6Protocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error {
- return nil
-}
-
-// Option implements NetworkProtocol.Option.
-func (*testIPv6Protocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error {
- return nil
-}
-
-// Close implements NetworkProtocol.Close.
-func (*testIPv6Protocol) Close() {}
-
-// Wait implements NetworkProtocol.Wait.
-func (*testIPv6Protocol) Wait() {}
-
-// Parse implements NetworkProtocol.Parse.
-func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
- return 0, false, false
-}
-
-func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
- // When the NIC is disabled, the only field that matters is the stats field.
- // This test is limited to stats counter checks.
- nic := nic{
- stats: makeNICStats(tcpip.NICStats{}.FillIn()),
- }
-
- if got := nic.stats.local.DisabledRx.Packets.Value(); got != 0 {
- t.Errorf("got DisabledRx.Packets = %d, want = 0", got)
- }
- if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 0 {
- t.Errorf("got DisabledRx.Bytes = %d, want = 0", got)
- }
- if got := nic.stats.local.Rx.Packets.Value(); got != 0 {
- t.Errorf("got Rx.Packets = %d, want = 0", got)
- }
- if got := nic.stats.local.Rx.Bytes.Value(); got != 0 {
- t.Errorf("got Rx.Bytes = %d, want = 0", got)
- }
-
- if t.Failed() {
- t.FailNow()
- }
-
- nic.DeliverNetworkPacket("", "", 0, NewPacketBuffer(PacketBufferOptions{
- Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView(),
- }))
-
- if got := nic.stats.local.DisabledRx.Packets.Value(); got != 1 {
- t.Errorf("got DisabledRx.Packets = %d, want = 1", got)
- }
- if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 4 {
- t.Errorf("got DisabledRx.Bytes = %d, want = 4", got)
- }
- if got := nic.stats.local.Rx.Packets.Value(); got != 0 {
- t.Errorf("got Rx.Packets = %d, want = 0", got)
- }
- if got := nic.stats.local.Rx.Bytes.Value(); got != 0 {
- t.Errorf("got Rx.Bytes = %d, want = 0", got)
- }
-}
-
-func TestMultiCounterStatsInitialization(t *testing.T) {
- global := tcpip.NICStats{}.FillIn()
- nic := nic{
- stats: makeNICStats(global),
- }
- multi := nic.stats.multiCounterNICStats
- local := nic.stats.local
- if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&multi).Elem(), []reflect.Value{reflect.ValueOf(&local).Elem(), reflect.ValueOf(&global).Elem()}); err != nil {
- t.Error(err)
- }
-}
diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go
deleted file mode 100644
index 1aeb2f8a5..000000000
--- a/pkg/tcpip/stack/nud_test.go
+++ /dev/null
@@ -1,816 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package stack_test
-
-import (
- "math"
- "math/rand"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-const (
- defaultBaseReachableTime = 30 * time.Second
- minimumBaseReachableTime = time.Millisecond
- defaultMinRandomFactor = 0.5
- defaultMaxRandomFactor = 1.5
- defaultRetransmitTimer = time.Second
- minimumRetransmitTimer = time.Millisecond
- defaultDelayFirstProbeTime = 5 * time.Second
- defaultMaxMulticastProbes = 3
- defaultMaxUnicastProbes = 3
-
- defaultFakeRandomNum = 0.5
-)
-
-// fakeRand is a deterministic random number generator.
-type fakeRand struct {
- num float32
-}
-
-var _ rand.Source = (*fakeRand)(nil)
-
-func (f *fakeRand) Int63() int64 {
- return int64(f.num * float32(1<<63))
-}
-
-func (*fakeRand) Seed(int64) {}
-
-func TestNUDFunctions(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- nicID tcpip.NICID
- netProtoFactory []stack.NetworkProtocolFactory
- extraLinkCapabilities stack.LinkEndpointCapabilities
- expectedErr tcpip.Error
- }{
- {
- name: "Invalid NICID",
- nicID: nicID + 1,
- netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- extraLinkCapabilities: stack.CapabilityResolutionRequired,
- expectedErr: &tcpip.ErrUnknownNICID{},
- },
- {
- name: "No network protocol",
- nicID: nicID,
- expectedErr: &tcpip.ErrNotSupported{},
- },
- {
- name: "With IPv6",
- nicID: nicID,
- netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- expectedErr: &tcpip.ErrNotSupported{},
- },
- {
- name: "With resolution capability",
- nicID: nicID,
- extraLinkCapabilities: stack.CapabilityResolutionRequired,
- expectedErr: &tcpip.ErrNotSupported{},
- },
- {
- name: "With IPv6 and resolution capability",
- nicID: nicID,
- netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- extraLinkCapabilities: stack.CapabilityResolutionRequired,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NUDConfigs: stack.DefaultNUDConfigurations(),
- NetworkProtocols: test.netProtoFactory,
- Clock: clock,
- })
-
- e := channel.New(0, 0, linkAddr1)
- e.LinkEPCapabilities &^= stack.CapabilityResolutionRequired
- e.LinkEPCapabilities |= test.extraLinkCapabilities
-
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- configs := stack.DefaultNUDConfigurations()
- configs.BaseReachableTime = time.Hour
-
- {
- err := s.SetNUDConfigurations(test.nicID, ipv6.ProtocolNumber, configs)
- if diff := cmp.Diff(test.expectedErr, err); diff != "" {
- t.Errorf("s.SetNUDConfigurations(%d, %d, _) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff)
- }
- }
-
- {
- gotConfigs, err := s.NUDConfigurations(test.nicID, ipv6.ProtocolNumber)
- if diff := cmp.Diff(test.expectedErr, err); diff != "" {
- t.Errorf("s.NUDConfigurations(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff)
- } else if test.expectedErr == nil {
- if diff := cmp.Diff(configs, gotConfigs); diff != "" {
- t.Errorf("got configs mismatch (-want +got):\n%s", diff)
- }
- }
- }
-
- for _, addr := range []tcpip.Address{llAddr1, llAddr2} {
- {
- err := s.AddStaticNeighbor(test.nicID, ipv6.ProtocolNumber, addr, linkAddr1)
- if diff := cmp.Diff(test.expectedErr, err); diff != "" {
- t.Errorf("s.AddStaticNeighbor(%d, %d, %s, %s) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, addr, linkAddr1, diff)
- }
- }
- }
-
- {
- wantErr := test.expectedErr
- for i := 0; i < 2; i++ {
- {
- err := s.RemoveNeighbor(test.nicID, ipv6.ProtocolNumber, llAddr1)
- if diff := cmp.Diff(wantErr, err); diff != "" {
- t.Errorf("s.RemoveNeighbor(%d, %d, '') error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff)
- }
- }
-
- if test.expectedErr != nil {
- break
- }
-
- // Removing a neighbor that does not exist should give us a bad address
- // error.
- wantErr = &tcpip.ErrBadAddress{}
- }
- }
-
- {
- neighbors, err := s.Neighbors(test.nicID, ipv6.ProtocolNumber)
- if diff := cmp.Diff(test.expectedErr, err); diff != "" {
- t.Errorf("s.Neigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff)
- } else if test.expectedErr == nil {
- if diff := cmp.Diff(
- []stack.NeighborEntry{{Addr: llAddr2, LinkAddr: linkAddr1, State: stack.Static, UpdatedAt: clock.Now()}},
- neighbors,
- ); diff != "" {
- t.Errorf("neighbors mismatch (-want +got):\n%s", diff)
- }
- }
- }
-
- {
- err := s.ClearNeighbors(test.nicID, ipv6.ProtocolNumber)
- if diff := cmp.Diff(test.expectedErr, err); diff != "" {
- t.Errorf("s.ClearNeigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff)
- } else if test.expectedErr == nil {
- if neighbors, err := s.Neighbors(test.nicID, ipv6.ProtocolNumber); err != nil {
- t.Errorf("s.Neighbors(%d, %d): %s", test.nicID, ipv6.ProtocolNumber, err)
- } else if len(neighbors) != 0 {
- t.Errorf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors)
- }
- }
- }
- })
- }
-}
-
-func TestDefaultNUDConfigurations(t *testing.T) {
- const nicID = 1
-
- e := channel.New(0, 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
-
- s := stack.New(stack.Options{
- // A neighbor cache is required to store NUDConfigurations. The networking
- // stack will only allocate neighbor caches if a protocol providing link
- // address resolution is specified (e.g. ARP or IPv6).
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NUDConfigs: stack.DefaultNUDConfigurations(),
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- c, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err)
- }
- if got, want := c, stack.DefaultNUDConfigurations(); got != want {
- t.Errorf("got stack.NUDConfigurations(%d, %d) = %+v, want = %+v", nicID, ipv6.ProtocolNumber, got, want)
- }
-}
-
-func TestNUDConfigurationsBaseReachableTime(t *testing.T) {
- tests := []struct {
- name string
- baseReachableTime time.Duration
- want time.Duration
- }{
- // Invalid cases
- {
- name: "EqualToZero",
- baseReachableTime: 0,
- want: defaultBaseReachableTime,
- },
- // Valid cases
- {
- name: "MoreThanZero",
- baseReachableTime: time.Millisecond,
- want: time.Millisecond,
- },
- {
- name: "MoreThanDefaultBaseReachableTime",
- baseReachableTime: 2 * defaultBaseReachableTime,
- want: 2 * defaultBaseReachableTime,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- const nicID = 1
-
- c := stack.DefaultNUDConfigurations()
- c.BaseReachableTime = test.baseReachableTime
-
- e := channel.New(0, 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
-
- s := stack.New(stack.Options{
- // A neighbor cache is required to store NUDConfigurations. The
- // networking stack will only allocate neighbor caches if a protocol
- // providing link address resolution is specified (e.g. ARP or IPv6).
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NUDConfigs: c,
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err)
- }
- if got := sc.BaseReachableTime; got != test.want {
- t.Errorf("got BaseReachableTime = %q, want = %q", got, test.want)
- }
- })
- }
-}
-
-func TestNUDConfigurationsMinRandomFactor(t *testing.T) {
- tests := []struct {
- name string
- minRandomFactor float32
- want float32
- }{
- // Invalid cases
- {
- name: "LessThanZero",
- minRandomFactor: -1,
- want: defaultMinRandomFactor,
- },
- {
- name: "EqualToZero",
- minRandomFactor: 0,
- want: defaultMinRandomFactor,
- },
- // Valid cases
- {
- name: "MoreThanZero",
- minRandomFactor: 1,
- want: 1,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- const nicID = 1
-
- c := stack.DefaultNUDConfigurations()
- c.MinRandomFactor = test.minRandomFactor
-
- e := channel.New(0, 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
-
- s := stack.New(stack.Options{
- // A neighbor cache is required to store NUDConfigurations. The
- // networking stack will only allocate neighbor caches if a protocol
- // providing link address resolution is specified (e.g. ARP or IPv6).
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NUDConfigs: c,
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err)
- }
- if got := sc.MinRandomFactor; got != test.want {
- t.Errorf("got MinRandomFactor = %f, want = %f", got, test.want)
- }
- })
- }
-}
-
-func TestNUDConfigurationsMaxRandomFactor(t *testing.T) {
- tests := []struct {
- name string
- minRandomFactor float32
- maxRandomFactor float32
- want float32
- }{
- // Invalid cases
- {
- name: "LessThanZero",
- minRandomFactor: defaultMinRandomFactor,
- maxRandomFactor: -1,
- want: defaultMaxRandomFactor,
- },
- {
- name: "EqualToZero",
- minRandomFactor: defaultMinRandomFactor,
- maxRandomFactor: 0,
- want: defaultMaxRandomFactor,
- },
- {
- name: "LessThanMinRandomFactor",
- minRandomFactor: defaultMinRandomFactor,
- maxRandomFactor: defaultMinRandomFactor * 0.99,
- want: defaultMaxRandomFactor,
- },
- {
- name: "MoreThanMinRandomFactorWhenMinRandomFactorIsLargerThanMaxRandomFactorDefault",
- minRandomFactor: defaultMaxRandomFactor * 2,
- maxRandomFactor: defaultMaxRandomFactor,
- want: defaultMaxRandomFactor * 6,
- },
- // Valid cases
- {
- name: "EqualToMinRandomFactor",
- minRandomFactor: defaultMinRandomFactor,
- maxRandomFactor: defaultMinRandomFactor,
- want: defaultMinRandomFactor,
- },
- {
- name: "MoreThanMinRandomFactor",
- minRandomFactor: defaultMinRandomFactor,
- maxRandomFactor: defaultMinRandomFactor * 1.1,
- want: defaultMinRandomFactor * 1.1,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- const nicID = 1
-
- c := stack.DefaultNUDConfigurations()
- c.MinRandomFactor = test.minRandomFactor
- c.MaxRandomFactor = test.maxRandomFactor
-
- e := channel.New(0, 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
-
- s := stack.New(stack.Options{
- // A neighbor cache is required to store NUDConfigurations. The
- // networking stack will only allocate neighbor caches if a protocol
- // providing link address resolution is specified (e.g. ARP or IPv6).
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NUDConfigs: c,
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err)
- }
- if got := sc.MaxRandomFactor; got != test.want {
- t.Errorf("got MaxRandomFactor = %f, want = %f", got, test.want)
- }
- })
- }
-}
-
-func TestNUDConfigurationsRetransmitTimer(t *testing.T) {
- tests := []struct {
- name string
- retransmitTimer time.Duration
- want time.Duration
- }{
- // Invalid cases
- {
- name: "EqualToZero",
- retransmitTimer: 0,
- want: defaultRetransmitTimer,
- },
- {
- name: "LessThanMinimumRetransmitTimer",
- retransmitTimer: minimumRetransmitTimer - time.Nanosecond,
- want: defaultRetransmitTimer,
- },
- // Valid cases
- {
- name: "EqualToMinimumRetransmitTimer",
- retransmitTimer: minimumRetransmitTimer,
- want: minimumBaseReachableTime,
- },
- {
- name: "LargetThanMinimumRetransmitTimer",
- retransmitTimer: 2 * minimumBaseReachableTime,
- want: 2 * minimumBaseReachableTime,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- const nicID = 1
-
- c := stack.DefaultNUDConfigurations()
- c.RetransmitTimer = test.retransmitTimer
-
- e := channel.New(0, 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
-
- s := stack.New(stack.Options{
- // A neighbor cache is required to store NUDConfigurations. The
- // networking stack will only allocate neighbor caches if a protocol
- // providing link address resolution is specified (e.g. ARP or IPv6).
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NUDConfigs: c,
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err)
- }
- if got := sc.RetransmitTimer; got != test.want {
- t.Errorf("got RetransmitTimer = %q, want = %q", got, test.want)
- }
- })
- }
-}
-
-func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) {
- tests := []struct {
- name string
- delayFirstProbeTime time.Duration
- want time.Duration
- }{
- // Invalid cases
- {
- name: "EqualToZero",
- delayFirstProbeTime: 0,
- want: defaultDelayFirstProbeTime,
- },
- // Valid cases
- {
- name: "MoreThanZero",
- delayFirstProbeTime: time.Millisecond,
- want: time.Millisecond,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- const nicID = 1
-
- c := stack.DefaultNUDConfigurations()
- c.DelayFirstProbeTime = test.delayFirstProbeTime
-
- e := channel.New(0, 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
-
- s := stack.New(stack.Options{
- // A neighbor cache is required to store NUDConfigurations. The
- // networking stack will only allocate neighbor caches if a protocol
- // providing link address resolution is specified (e.g. ARP or IPv6).
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NUDConfigs: c,
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err)
- }
- if got := sc.DelayFirstProbeTime; got != test.want {
- t.Errorf("got DelayFirstProbeTime = %q, want = %q", got, test.want)
- }
- })
- }
-}
-
-func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) {
- tests := []struct {
- name string
- maxMulticastProbes uint32
- want uint32
- }{
- // Invalid cases
- {
- name: "EqualToZero",
- maxMulticastProbes: 0,
- want: defaultMaxMulticastProbes,
- },
- // Valid cases
- {
- name: "MoreThanZero",
- maxMulticastProbes: 1,
- want: 1,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- const nicID = 1
-
- c := stack.DefaultNUDConfigurations()
- c.MaxMulticastProbes = test.maxMulticastProbes
-
- e := channel.New(0, 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
-
- s := stack.New(stack.Options{
- // A neighbor cache is required to store NUDConfigurations. The
- // networking stack will only allocate neighbor caches if a protocol
- // providing link address resolution is specified (e.g. ARP or IPv6).
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NUDConfigs: c,
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err)
- }
- if got := sc.MaxMulticastProbes; got != test.want {
- t.Errorf("got MaxMulticastProbes = %q, want = %q", got, test.want)
- }
- })
- }
-}
-
-func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) {
- tests := []struct {
- name string
- maxUnicastProbes uint32
- want uint32
- }{
- // Invalid cases
- {
- name: "EqualToZero",
- maxUnicastProbes: 0,
- want: defaultMaxUnicastProbes,
- },
- // Valid cases
- {
- name: "MoreThanZero",
- maxUnicastProbes: 1,
- want: 1,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- const nicID = 1
-
- c := stack.DefaultNUDConfigurations()
- c.MaxUnicastProbes = test.maxUnicastProbes
-
- e := channel.New(0, 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
-
- s := stack.New(stack.Options{
- // A neighbor cache is required to store NUDConfigurations. The
- // networking stack will only allocate neighbor caches if a protocol
- // providing link address resolution is specified (e.g. ARP or IPv6).
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- NUDConfigs: c,
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err)
- }
- if got := sc.MaxUnicastProbes; got != test.want {
- t.Errorf("got MaxUnicastProbes = %q, want = %q", got, test.want)
- }
- })
- }
-}
-
-// TestNUDStateReachableTime verifies the correctness of the ReachableTime
-// computation.
-func TestNUDStateReachableTime(t *testing.T) {
- tests := []struct {
- name string
- baseReachableTime time.Duration
- minRandomFactor float32
- maxRandomFactor float32
- want time.Duration
- }{
- {
- name: "AllZeros",
- baseReachableTime: 0,
- minRandomFactor: 0,
- maxRandomFactor: 0,
- want: 0,
- },
- {
- name: "ZeroMaxRandomFactor",
- baseReachableTime: time.Second,
- minRandomFactor: 0,
- maxRandomFactor: 0,
- want: 0,
- },
- {
- name: "ZeroMinRandomFactor",
- baseReachableTime: time.Second,
- minRandomFactor: 0,
- maxRandomFactor: 1,
- want: time.Duration(defaultFakeRandomNum * float32(time.Second)),
- },
- {
- name: "FractionalRandomFactor",
- baseReachableTime: time.Duration(math.MaxInt64),
- minRandomFactor: 0.001,
- maxRandomFactor: 0.002,
- want: time.Duration((0.001 + (0.001 * defaultFakeRandomNum)) * float32(math.MaxInt64)),
- },
- {
- name: "MinAndMaxRandomFactorsEqual",
- baseReachableTime: time.Second,
- minRandomFactor: 1,
- maxRandomFactor: 1,
- want: time.Second,
- },
- {
- name: "MinAndMaxRandomFactorsDifferent",
- baseReachableTime: time.Second,
- minRandomFactor: 1,
- maxRandomFactor: 2,
- want: time.Duration((1.0 + defaultFakeRandomNum) * float32(time.Second)),
- },
- {
- name: "MaxInt64",
- baseReachableTime: time.Duration(math.MaxInt64),
- minRandomFactor: 1,
- maxRandomFactor: 1,
- want: time.Duration(math.MaxInt64),
- },
- {
- name: "Overflow",
- baseReachableTime: time.Duration(math.MaxInt64),
- minRandomFactor: 1.5,
- maxRandomFactor: 1.5,
- want: time.Duration(math.MaxInt64),
- },
- {
- name: "DoubleOverflow",
- baseReachableTime: time.Duration(math.MaxInt64),
- minRandomFactor: 2.5,
- maxRandomFactor: 2.5,
- want: time.Duration(math.MaxInt64),
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := stack.NUDConfigurations{
- BaseReachableTime: test.baseReachableTime,
- MinRandomFactor: test.minRandomFactor,
- MaxRandomFactor: test.maxRandomFactor,
- }
- // A fake random number generator is used to ensure deterministic
- // results.
- rng := fakeRand{
- num: defaultFakeRandomNum,
- }
- var clock faketime.NullClock
- s := stack.NewNUDState(c, &clock, rand.New(&rng))
- if got, want := s.ReachableTime(), test.want; got != want {
- t.Errorf("got ReachableTime = %q, want = %q", got, want)
- }
- })
- }
-}
-
-// TestNUDStateRecomputeReachableTime exercises the ReachableTime function
-// twice to verify recomputation of reachable time when the min random factor,
-// max random factor, or base reachable time changes.
-func TestNUDStateRecomputeReachableTime(t *testing.T) {
- const defaultBase = time.Second
- const defaultMin = 2.0 * defaultMaxRandomFactor
- const defaultMax = 3.0 * defaultMaxRandomFactor
-
- tests := []struct {
- name string
- baseReachableTime time.Duration
- minRandomFactor float32
- maxRandomFactor float32
- want time.Duration
- }{
- {
- name: "BaseReachableTime",
- baseReachableTime: 2 * defaultBase,
- minRandomFactor: defaultMin,
- maxRandomFactor: defaultMax,
- want: time.Duration((defaultMin + (defaultMax-defaultMin)*defaultFakeRandomNum) * float32(2*defaultBase)),
- },
- {
- name: "MinRandomFactor",
- baseReachableTime: defaultBase,
- minRandomFactor: defaultMax,
- maxRandomFactor: defaultMax,
- want: time.Duration(defaultMax * float32(defaultBase)),
- },
- {
- name: "MaxRandomFactor",
- baseReachableTime: defaultBase,
- minRandomFactor: defaultMin,
- maxRandomFactor: defaultMin,
- want: time.Duration(defaultMin * float32(defaultBase)),
- },
- {
- name: "BothRandomFactor",
- baseReachableTime: defaultBase,
- minRandomFactor: 2 * defaultMin,
- maxRandomFactor: 2 * defaultMax,
- want: time.Duration((2*defaultMin + (2*defaultMax-2*defaultMin)*defaultFakeRandomNum) * float32(defaultBase)),
- },
- {
- name: "BaseReachableTimeAndBothRandomFactors",
- baseReachableTime: 2 * defaultBase,
- minRandomFactor: 2 * defaultMin,
- maxRandomFactor: 2 * defaultMax,
- want: time.Duration((2*defaultMin + (2*defaultMax-2*defaultMin)*defaultFakeRandomNum) * float32(2*defaultBase)),
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := stack.DefaultNUDConfigurations()
- c.BaseReachableTime = defaultBase
- c.MinRandomFactor = defaultMin
- c.MaxRandomFactor = defaultMax
-
- // A fake random number generator is used to ensure deterministic
- // results.
- rng := fakeRand{
- num: defaultFakeRandomNum,
- }
- var clock faketime.NullClock
- s := stack.NewNUDState(c, &clock, rand.New(&rng))
- old := s.ReachableTime()
-
- if got, want := s.ReachableTime(), old; got != want {
- t.Errorf("got ReachableTime = %q, want = %q", got, want)
- }
-
- // Check for recomputation when changing the min random factor, the max
- // random factor, the base reachability time, or any permutation of those
- // three options.
- c.BaseReachableTime = test.baseReachableTime
- c.MinRandomFactor = test.minRandomFactor
- c.MaxRandomFactor = test.maxRandomFactor
- s.SetConfig(c)
-
- if got, want := s.ReachableTime(), test.want; got != want {
- t.Errorf("got ReachableTime = %q, want = %q", got, want)
- }
-
- // Verify that ReachableTime isn't recomputed when none of the
- // configuration options change. The random factor is changed so that if
- // a recompution were to occur, ReachableTime would change.
- rng.num = defaultFakeRandomNum / 2.0
- if got, want := s.ReachableTime(), test.want; got != want {
- t.Errorf("got ReachableTime = %q, want = %q", got, want)
- }
- })
- }
-}
diff --git a/pkg/tcpip/stack/packet_buffer_list.go b/pkg/tcpip/stack/packet_buffer_list.go
new file mode 100644
index 000000000..ce7057d6b
--- /dev/null
+++ b/pkg/tcpip/stack/packet_buffer_list.go
@@ -0,0 +1,221 @@
+package stack
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type PacketBufferElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (PacketBufferElementMapper) linkerFor(elem *PacketBuffer) *PacketBuffer { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type PacketBufferList struct {
+ head *PacketBuffer
+ tail *PacketBuffer
+}
+
+// Reset resets list l to the empty state.
+func (l *PacketBufferList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+//
+//go:nosplit
+func (l *PacketBufferList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+//
+//go:nosplit
+func (l *PacketBufferList) Front() *PacketBuffer {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+//
+//go:nosplit
+func (l *PacketBufferList) Back() *PacketBuffer {
+ return l.tail
+}
+
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+//
+//go:nosplit
+func (l *PacketBufferList) Len() (count int) {
+ for e := l.Front(); e != nil; e = (PacketBufferElementMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
+// PushFront inserts the element e at the front of list l.
+//
+//go:nosplit
+func (l *PacketBufferList) PushFront(e *PacketBuffer) {
+ linker := PacketBufferElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
+ if l.head != nil {
+ PacketBufferElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+//
+//go:nosplit
+func (l *PacketBufferList) PushBack(e *PacketBuffer) {
+ linker := PacketBufferElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
+ if l.tail != nil {
+ PacketBufferElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+//
+//go:nosplit
+func (l *PacketBufferList) PushBackList(m *PacketBufferList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ PacketBufferElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ PacketBufferElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+//
+//go:nosplit
+func (l *PacketBufferList) InsertAfter(b, e *PacketBuffer) {
+ bLinker := PacketBufferElementMapper{}.linkerFor(b)
+ eLinker := PacketBufferElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
+
+ if a != nil {
+ PacketBufferElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+//
+//go:nosplit
+func (l *PacketBufferList) InsertBefore(a, e *PacketBuffer) {
+ aLinker := PacketBufferElementMapper{}.linkerFor(a)
+ eLinker := PacketBufferElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
+
+ if b != nil {
+ PacketBufferElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+//
+//go:nosplit
+func (l *PacketBufferList) Remove(e *PacketBuffer) {
+ linker := PacketBufferElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
+
+ if prev != nil {
+ PacketBufferElementMapper{}.linkerFor(prev).SetNext(next)
+ } else if l.head == e {
+ l.head = next
+ }
+
+ if next != nil {
+ PacketBufferElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else if l.tail == e {
+ l.tail = prev
+ }
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type PacketBufferEntry struct {
+ next *PacketBuffer
+ prev *PacketBuffer
+}
+
+// Next returns the entry that follows e in the list.
+//
+//go:nosplit
+func (e *PacketBufferEntry) Next() *PacketBuffer {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *PacketBufferEntry) Prev() *PacketBuffer {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+//
+//go:nosplit
+func (e *PacketBufferEntry) SetNext(elem *PacketBuffer) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *PacketBufferEntry) SetPrev(elem *PacketBuffer) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go
deleted file mode 100644
index 87b023445..000000000
--- a/pkg/tcpip/stack/packet_buffer_test.go
+++ /dev/null
@@ -1,675 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at //
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package stack
-
-import (
- "bytes"
- "fmt"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
-)
-
-func TestPacketHeaderPush(t *testing.T) {
- for _, test := range []struct {
- name string
- reserved int
- link []byte
- network []byte
- transport []byte
- data []byte
- }{
- {
- name: "construct empty packet",
- },
- {
- name: "construct link header only packet",
- reserved: 60,
- link: makeView(10),
- },
- {
- name: "construct link and network header only packet",
- reserved: 60,
- link: makeView(10),
- network: makeView(20),
- },
- {
- name: "construct header only packet",
- reserved: 60,
- link: makeView(10),
- network: makeView(20),
- transport: makeView(30),
- },
- {
- name: "construct data only packet",
- data: makeView(40),
- },
- {
- name: "construct L3 packet",
- reserved: 60,
- network: makeView(20),
- transport: makeView(30),
- data: makeView(40),
- },
- {
- name: "construct L2 packet",
- reserved: 60,
- link: makeView(10),
- network: makeView(20),
- transport: makeView(30),
- data: makeView(40),
- },
- } {
- t.Run(test.name, func(t *testing.T) {
- pk := NewPacketBuffer(PacketBufferOptions{
- ReserveHeaderBytes: test.reserved,
- // Make a copy of data to make sure our truth data won't be taint by
- // PacketBuffer.
- Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(),
- })
-
- allHdrSize := len(test.link) + len(test.network) + len(test.transport)
-
- // Check the initial values for packet.
- checkInitialPacketBuffer(t, pk, PacketBufferOptions{
- ReserveHeaderBytes: test.reserved,
- Data: buffer.View(test.data).ToVectorisedView(),
- })
-
- // Push headers.
- if v := test.transport; len(v) > 0 {
- copy(pk.TransportHeader().Push(len(v)), v)
- }
- if v := test.network; len(v) > 0 {
- copy(pk.NetworkHeader().Push(len(v)), v)
- }
- if v := test.link; len(v) > 0 {
- copy(pk.LinkHeader().Push(len(v)), v)
- }
-
- // Check the after values for packet.
- if got, want := pk.ReservedHeaderBytes(), test.reserved; got != want {
- t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want)
- }
- if got, want := pk.AvailableHeaderBytes(), test.reserved-allHdrSize; got != want {
- t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want)
- }
- if got, want := pk.HeaderSize(), allHdrSize; got != want {
- t.Errorf("After pk.HeaderSize() = %d, want %d", got, want)
- }
- if got, want := pk.Size(), allHdrSize+len(test.data); got != want {
- t.Errorf("After pk.Size() = %d, want %d", got, want)
- }
- // Check the after state.
- checkPacketContents(t, "After ", pk, packetContents{
- link: test.link,
- network: test.network,
- transport: test.transport,
- data: test.data,
- })
- })
- }
-}
-
-func TestPacketBufferClone(t *testing.T) {
- data := concatViews(makeView(20), makeView(30), makeView(40))
- pk := NewPacketBuffer(PacketBufferOptions{
- // Make a copy of data to make sure our truth data won't be taint by
- // PacketBuffer.
- Data: buffer.NewViewFromBytes(data).ToVectorisedView(),
- })
-
- bytesToDelete := 30
- originalSize := data.Size()
-
- clonedPks := []*PacketBuffer{
- pk.Clone(),
- pk.CloneToInbound(),
- }
- pk.Data().DeleteFront(bytesToDelete)
- if got, want := pk.Data().Size(), originalSize-bytesToDelete; got != want {
- t.Errorf("original packet was not changed: size expected = %d, got = %d", want, got)
- }
- for _, clonedPk := range clonedPks {
- if got := clonedPk.Data().Size(); got != originalSize {
- t.Errorf("cloned packet should not be modified: expected size = %d, got = %d", originalSize, got)
- }
- }
-}
-
-func TestPacketHeaderConsume(t *testing.T) {
- for _, test := range []struct {
- name string
- data []byte
- link int
- network int
- transport int
- }{
- {
- name: "parse L2 packet",
- data: concatViews(makeView(10), makeView(20), makeView(30), makeView(40)),
- link: 10,
- network: 20,
- transport: 30,
- },
- {
- name: "parse L3 packet",
- data: concatViews(makeView(20), makeView(30), makeView(40)),
- network: 20,
- transport: 30,
- },
- } {
- t.Run(test.name, func(t *testing.T) {
- pk := NewPacketBuffer(PacketBufferOptions{
- // Make a copy of data to make sure our truth data won't be taint by
- // PacketBuffer.
- Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(),
- })
-
- // Check the initial values for packet.
- checkInitialPacketBuffer(t, pk, PacketBufferOptions{
- Data: buffer.View(test.data).ToVectorisedView(),
- })
-
- // Consume headers.
- if size := test.link; size > 0 {
- if _, ok := pk.LinkHeader().Consume(size); !ok {
- t.Fatalf("pk.LinkHeader().Consume() = false, want true")
- }
- }
- if size := test.network; size > 0 {
- if _, ok := pk.NetworkHeader().Consume(size); !ok {
- t.Fatalf("pk.NetworkHeader().Consume() = false, want true")
- }
- }
- if size := test.transport; size > 0 {
- if _, ok := pk.TransportHeader().Consume(size); !ok {
- t.Fatalf("pk.TransportHeader().Consume() = false, want true")
- }
- }
-
- allHdrSize := test.link + test.network + test.transport
-
- // Check the after values for packet.
- if got, want := pk.ReservedHeaderBytes(), 0; got != want {
- t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want)
- }
- if got, want := pk.AvailableHeaderBytes(), 0; got != want {
- t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want)
- }
- if got, want := pk.HeaderSize(), allHdrSize; got != want {
- t.Errorf("After pk.HeaderSize() = %d, want %d", got, want)
- }
- if got, want := pk.Size(), len(test.data); got != want {
- t.Errorf("After pk.Size() = %d, want %d", got, want)
- }
- // Check the after state of pk.
- checkPacketContents(t, "After ", pk, packetContents{
- link: test.data[:test.link],
- network: test.data[test.link:][:test.network],
- transport: test.data[test.link+test.network:][:test.transport],
- data: test.data[allHdrSize:],
- })
- })
- }
-}
-
-func TestPacketHeaderConsumeDataTooShort(t *testing.T) {
- data := makeView(10)
-
- pk := NewPacketBuffer(PacketBufferOptions{
- // Make a copy of data to make sure our truth data won't be taint by
- // PacketBuffer.
- Data: buffer.NewViewFromBytes(data).ToVectorisedView(),
- })
-
- // Consume should fail if pkt.Data is too short.
- if _, ok := pk.LinkHeader().Consume(11); ok {
- t.Fatalf("pk.LinkHeader().Consume() = _, true; want _, false")
- }
- if _, ok := pk.NetworkHeader().Consume(11); ok {
- t.Fatalf("pk.NetworkHeader().Consume() = _, true; want _, false")
- }
- if _, ok := pk.TransportHeader().Consume(11); ok {
- t.Fatalf("pk.TransportHeader().Consume() = _, true; want _, false")
- }
-
- // Check packet should look the same as initial packet.
- checkInitialPacketBuffer(t, pk, PacketBufferOptions{
- Data: buffer.View(data).ToVectorisedView(),
- })
-}
-
-// This is a very obscure use-case seen in the code that verifies packets
-// before sending them out. It tries to parse the headers to verify.
-// PacketHeader was initially not designed to mix Push() and Consume(), but it
-// works and it's been relied upon. Include a test here.
-func TestPacketHeaderPushConsumeMixed(t *testing.T) {
- link := makeView(10)
- network := makeView(20)
- data := makeView(30)
-
- initData := append([]byte(nil), network...)
- initData = append(initData, data...)
- pk := NewPacketBuffer(PacketBufferOptions{
- ReserveHeaderBytes: len(link),
- Data: buffer.NewViewFromBytes(initData).ToVectorisedView(),
- })
-
- // 1. Consume network header
- gotNetwork, ok := pk.NetworkHeader().Consume(len(network))
- if !ok {
- t.Fatalf("pk.NetworkHeader().Consume(%d) = _, false; want _, true", len(network))
- }
- checkViewEqual(t, "gotNetwork", gotNetwork, network)
-
- // 2. Push link header
- copy(pk.LinkHeader().Push(len(link)), link)
-
- checkPacketContents(t, "" /* prefix */, pk, packetContents{
- link: link,
- network: network,
- data: data,
- })
-}
-
-func TestPacketHeaderPushConsumeMixedTooLong(t *testing.T) {
- link := makeView(10)
- network := makeView(20)
- data := makeView(30)
-
- initData := concatViews(network, data)
- pk := NewPacketBuffer(PacketBufferOptions{
- ReserveHeaderBytes: len(link),
- Data: buffer.NewViewFromBytes(initData).ToVectorisedView(),
- })
-
- // 1. Push link header
- copy(pk.LinkHeader().Push(len(link)), link)
-
- checkPacketContents(t, "" /* prefix */, pk, packetContents{
- link: link,
- data: initData,
- })
-
- // 2. Consume network header, with a number of bytes too large.
- gotNetwork, ok := pk.NetworkHeader().Consume(len(initData) + 1)
- if ok {
- t.Fatalf("pk.NetworkHeader().Consume(%d) = %q, true; want _, false", len(initData)+1, gotNetwork)
- }
-
- checkPacketContents(t, "" /* prefix */, pk, packetContents{
- link: link,
- data: initData,
- })
-}
-
-func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) {
- const headerSize = 10
-
- pk := NewPacketBuffer(PacketBufferOptions{
- ReserveHeaderBytes: headerSize * int(numHeaderType),
- })
-
- for _, h := range []PacketHeader{
- pk.TransportHeader(),
- pk.NetworkHeader(),
- pk.LinkHeader(),
- } {
- t.Run("PushedTwice/"+h.typ.String(), func(t *testing.T) {
- h.Push(headerSize)
-
- defer func() { recover() }()
- h.Push(headerSize)
- t.Fatal("Second push should have panicked")
- })
- }
-}
-
-func TestPacketHeaderConsumeCalledAtMostOnce(t *testing.T) {
- const headerSize = 10
-
- pk := NewPacketBuffer(PacketBufferOptions{
- Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(),
- })
-
- for _, h := range []PacketHeader{
- pk.LinkHeader(),
- pk.NetworkHeader(),
- pk.TransportHeader(),
- } {
- t.Run("ConsumedTwice/"+h.typ.String(), func(t *testing.T) {
- if _, ok := h.Consume(headerSize); !ok {
- t.Fatal("First consume should succeed")
- }
-
- defer func() { recover() }()
- h.Consume(headerSize)
- t.Fatal("Second consume should have panicked")
- })
- }
-}
-
-func TestPacketHeaderPushThenConsumePanics(t *testing.T) {
- const headerSize = 10
-
- pk := NewPacketBuffer(PacketBufferOptions{
- ReserveHeaderBytes: headerSize * int(numHeaderType),
- })
-
- for _, h := range []PacketHeader{
- pk.TransportHeader(),
- pk.NetworkHeader(),
- pk.LinkHeader(),
- } {
- t.Run(h.typ.String(), func(t *testing.T) {
- h.Push(headerSize)
-
- defer func() { recover() }()
- h.Consume(headerSize)
- t.Fatal("Consume should have panicked")
- })
- }
-}
-
-func TestPacketHeaderConsumeThenPushPanics(t *testing.T) {
- const headerSize = 10
-
- pk := NewPacketBuffer(PacketBufferOptions{
- Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(),
- })
-
- for _, h := range []PacketHeader{
- pk.LinkHeader(),
- pk.NetworkHeader(),
- pk.TransportHeader(),
- } {
- t.Run(h.typ.String(), func(t *testing.T) {
- h.Consume(headerSize)
-
- defer func() { recover() }()
- h.Push(headerSize)
- t.Fatal("Push should have panicked")
- })
- }
-}
-
-func TestPacketBufferData(t *testing.T) {
- for _, tc := range []struct {
- name string
- makePkt func(*testing.T) *PacketBuffer
- data string
- }{
- {
- name: "inbound packet",
- makePkt: func(*testing.T) *PacketBuffer {
- pkt := NewPacketBuffer(PacketBufferOptions{
- Data: vv("aabbbbccccccDATA"),
- })
- pkt.LinkHeader().Consume(2)
- pkt.NetworkHeader().Consume(4)
- pkt.TransportHeader().Consume(6)
- return pkt
- },
- data: "DATA",
- },
- {
- name: "outbound packet",
- makePkt: func(*testing.T) *PacketBuffer {
- pkt := NewPacketBuffer(PacketBufferOptions{
- ReserveHeaderBytes: 12,
- Data: vv("DATA"),
- })
- copy(pkt.TransportHeader().Push(6), []byte("cccccc"))
- copy(pkt.NetworkHeader().Push(4), []byte("bbbb"))
- copy(pkt.LinkHeader().Push(2), []byte("aa"))
- return pkt
- },
- data: "DATA",
- },
- } {
- t.Run(tc.name, func(t *testing.T) {
- // PullUp
- for _, n := range []int{1, len(tc.data)} {
- t.Run(fmt.Sprintf("PullUp%d", n), func(t *testing.T) {
- pkt := tc.makePkt(t)
- v, ok := pkt.Data().PullUp(n)
- wantV := []byte(tc.data)[:n]
- if !ok || !bytes.Equal(v, wantV) {
- t.Errorf("pkt.Data().PullUp(%d) = %q, %t; want %q, true", n, v, ok, wantV)
- }
- })
- }
- t.Run("PullUpOutOfBounds", func(t *testing.T) {
- n := len(tc.data) + 1
- pkt := tc.makePkt(t)
- v, ok := pkt.Data().PullUp(n)
- if ok || v != nil {
- t.Errorf("pkt.Data().PullUp(%d) = %q, %t; want nil, false", n, v, ok)
- }
- })
-
- // DeleteFront
- for _, n := range []int{1, len(tc.data)} {
- t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) {
- pkt := tc.makePkt(t)
- pkt.Data().DeleteFront(n)
-
- checkData(t, pkt, []byte(tc.data)[n:])
- })
- }
-
- // CapLength
- for _, n := range []int{0, 1, len(tc.data)} {
- t.Run(fmt.Sprintf("CapLength%d", n), func(t *testing.T) {
- pkt := tc.makePkt(t)
- pkt.Data().CapLength(n)
-
- want := []byte(tc.data)
- if n < len(want) {
- want = want[:n]
- }
- checkData(t, pkt, want)
- })
- }
-
- // Views
- t.Run("Views", func(t *testing.T) {
- pkt := tc.makePkt(t)
- checkData(t, pkt, []byte(tc.data))
- })
-
- // AppendView
- t.Run("AppendView", func(t *testing.T) {
- s := "APPEND"
-
- pkt := tc.makePkt(t)
- pkt.Data().AppendView(buffer.View(s))
-
- checkData(t, pkt, []byte(tc.data+s))
- })
-
- // ReadFromVV
- for _, n := range []int{0, 1, 2, 7, 10, 14, 20} {
- t.Run(fmt.Sprintf("ReadFromVV%d", n), func(t *testing.T) {
- s := "TO READ"
- srcVV := vv(s, s)
- s += s
-
- pkt := tc.makePkt(t)
- pkt.Data().ReadFromVV(&srcVV, n)
-
- if n < len(s) {
- s = s[:n]
- }
- checkData(t, pkt, []byte(tc.data+s))
- })
- }
-
- // ExtractVV
- t.Run("ExtractVV", func(t *testing.T) {
- pkt := tc.makePkt(t)
- extractedVV := pkt.Data().ExtractVV()
-
- got := extractedVV.ToOwnedView()
- want := []byte(tc.data)
- if !bytes.Equal(got, want) {
- t.Errorf("pkt.Data().ExtractVV().ToOwnedView() = %q, want %q", got, want)
- }
- })
- })
- }
-}
-
-type packetContents struct {
- link buffer.View
- network buffer.View
- transport buffer.View
- data buffer.View
-}
-
-func checkPacketContents(t *testing.T, prefix string, pk *PacketBuffer, want packetContents) {
- t.Helper()
- // Headers.
- checkPacketHeader(t, prefix+"pk.LinkHeader", pk.LinkHeader(), want.link)
- checkPacketHeader(t, prefix+"pk.NetworkHeader", pk.NetworkHeader(), want.network)
- checkPacketHeader(t, prefix+"pk.TransportHeader", pk.TransportHeader(), want.transport)
- // Data.
- checkData(t, pk, want.data)
- // Whole packet.
- checkViewEqual(t, prefix+"pk.Views()",
- concatViews(pk.Views()...),
- concatViews(want.link, want.network, want.transport, want.data))
- // PayloadSince.
- checkViewEqual(t, prefix+"PayloadSince(LinkHeader)",
- PayloadSince(pk.LinkHeader()),
- concatViews(want.link, want.network, want.transport, want.data))
- checkViewEqual(t, prefix+"PayloadSince(NetworkHeader)",
- PayloadSince(pk.NetworkHeader()),
- concatViews(want.network, want.transport, want.data))
- checkViewEqual(t, prefix+"PayloadSince(TransportHeader)",
- PayloadSince(pk.TransportHeader()),
- concatViews(want.transport, want.data))
-}
-
-func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) {
- t.Helper()
- reserved := opts.ReserveHeaderBytes
- if got, want := pk.ReservedHeaderBytes(), reserved; got != want {
- t.Errorf("Initial pk.ReservedHeaderBytes() = %d, want %d", got, want)
- }
- if got, want := pk.AvailableHeaderBytes(), reserved; got != want {
- t.Errorf("Initial pk.AvailableHeaderBytes() = %d, want %d", got, want)
- }
- if got, want := pk.HeaderSize(), 0; got != want {
- t.Errorf("Initial pk.HeaderSize() = %d, want %d", got, want)
- }
- data := opts.Data.ToView()
- if got, want := pk.Size(), len(data); got != want {
- t.Errorf("Initial pk.Size() = %d, want %d", got, want)
- }
- checkPacketContents(t, "Initial ", pk, packetContents{
- data: data,
- })
-}
-
-func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) {
- t.Helper()
- checkViewEqual(t, name+".View()", h.View(), want)
-}
-
-func checkViewEqual(t *testing.T, what string, got, want buffer.View) {
- t.Helper()
- if !bytes.Equal(got, want) {
- t.Errorf("%s = %x, want %x", what, got, want)
- }
-}
-
-func checkData(t *testing.T, pkt *PacketBuffer, want []byte) {
- t.Helper()
- if got := concatViews(pkt.Data().Views()...); !bytes.Equal(got, want) {
- t.Errorf("pkt.Data().Views() = 0x%x, want 0x%x", got, want)
- }
- if got := pkt.Data().Size(); got != len(want) {
- t.Errorf("pkt.Data().Size() = %d, want %d", got, len(want))
- }
-
- t.Run("AsRange", func(t *testing.T) {
- // Full range
- checkRange(t, pkt.Data().AsRange(), want)
-
- // SubRange
- for _, off := range []int{0, 1, len(want), len(want) + 1} {
- t.Run(fmt.Sprintf("SubRange%d", off), func(t *testing.T) {
- // Empty when off is greater than the size of range.
- var sub []byte
- if off < len(want) {
- sub = want[off:]
- }
- checkRange(t, pkt.Data().AsRange().SubRange(off), sub)
- })
- }
-
- // Capped
- for _, n := range []int{0, 1, len(want), len(want) + 1} {
- t.Run(fmt.Sprintf("Capped%d", n), func(t *testing.T) {
- sub := want
- if n < len(sub) {
- sub = sub[:n]
- }
- checkRange(t, pkt.Data().AsRange().Capped(n), sub)
- })
- }
- })
-}
-
-func checkRange(t *testing.T, r Range, data []byte) {
- if got, want := r.Size(), len(data); got != want {
- t.Errorf("r.Size() = %d, want %d", got, want)
- }
- if got := r.AsView(); !bytes.Equal(got, data) {
- t.Errorf("r.AsView() = %x, want %x", got, data)
- }
- if got := r.ToOwnedView(); !bytes.Equal(got, data) {
- t.Errorf("r.ToOwnedView() = %x, want %x", got, data)
- }
- if got, want := r.Checksum(), header.Checksum(data, 0 /* initial */); got != want {
- t.Errorf("r.Checksum() = %x, want %x", got, want)
- }
-}
-
-func vv(pieces ...string) buffer.VectorisedView {
- var views []buffer.View
- var size int
- for _, p := range pieces {
- v := buffer.View([]byte(p))
- size += len(v)
- views = append(views, v)
- }
- return buffer.NewVectorisedView(size, views)
-}
-
-func makeView(size int) buffer.View {
- b := byte(size)
- return bytes.Repeat([]byte{b}, size)
-}
-
-func concatViews(views ...buffer.View) buffer.View {
- var all buffer.View
- for _, v := range views {
- all = append(all, v...)
- }
- return all
-}
diff --git a/pkg/tcpip/stack/stack_state_autogen.go b/pkg/tcpip/stack/stack_state_autogen.go
new file mode 100644
index 000000000..cfcb6488b
--- /dev/null
+++ b/pkg/tcpip/stack/stack_state_autogen.go
@@ -0,0 +1,1288 @@
+// automatically generated by stateify.
+
+package stack
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (t *tuple) StateTypeName() string {
+ return "pkg/tcpip/stack.tuple"
+}
+
+func (t *tuple) StateFields() []string {
+ return []string{
+ "tupleEntry",
+ "tupleID",
+ "conn",
+ "direction",
+ }
+}
+
+func (t *tuple) beforeSave() {}
+
+// +checklocksignore
+func (t *tuple) StateSave(stateSinkObject state.Sink) {
+ t.beforeSave()
+ stateSinkObject.Save(0, &t.tupleEntry)
+ stateSinkObject.Save(1, &t.tupleID)
+ stateSinkObject.Save(2, &t.conn)
+ stateSinkObject.Save(3, &t.direction)
+}
+
+func (t *tuple) afterLoad() {}
+
+// +checklocksignore
+func (t *tuple) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &t.tupleEntry)
+ stateSourceObject.Load(1, &t.tupleID)
+ stateSourceObject.Load(2, &t.conn)
+ stateSourceObject.Load(3, &t.direction)
+}
+
+func (ti *tupleID) StateTypeName() string {
+ return "pkg/tcpip/stack.tupleID"
+}
+
+func (ti *tupleID) StateFields() []string {
+ return []string{
+ "srcAddr",
+ "srcPort",
+ "dstAddr",
+ "dstPort",
+ "transProto",
+ "netProto",
+ }
+}
+
+func (ti *tupleID) beforeSave() {}
+
+// +checklocksignore
+func (ti *tupleID) StateSave(stateSinkObject state.Sink) {
+ ti.beforeSave()
+ stateSinkObject.Save(0, &ti.srcAddr)
+ stateSinkObject.Save(1, &ti.srcPort)
+ stateSinkObject.Save(2, &ti.dstAddr)
+ stateSinkObject.Save(3, &ti.dstPort)
+ stateSinkObject.Save(4, &ti.transProto)
+ stateSinkObject.Save(5, &ti.netProto)
+}
+
+func (ti *tupleID) afterLoad() {}
+
+// +checklocksignore
+func (ti *tupleID) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &ti.srcAddr)
+ stateSourceObject.Load(1, &ti.srcPort)
+ stateSourceObject.Load(2, &ti.dstAddr)
+ stateSourceObject.Load(3, &ti.dstPort)
+ stateSourceObject.Load(4, &ti.transProto)
+ stateSourceObject.Load(5, &ti.netProto)
+}
+
+func (cn *conn) StateTypeName() string {
+ return "pkg/tcpip/stack.conn"
+}
+
+func (cn *conn) StateFields() []string {
+ return []string{
+ "original",
+ "reply",
+ "manip",
+ "tcbHook",
+ "tcb",
+ "lastUsed",
+ }
+}
+
+func (cn *conn) beforeSave() {}
+
+// +checklocksignore
+func (cn *conn) StateSave(stateSinkObject state.Sink) {
+ cn.beforeSave()
+ var lastUsedValue unixTime
+ lastUsedValue = cn.saveLastUsed()
+ stateSinkObject.SaveValue(5, lastUsedValue)
+ stateSinkObject.Save(0, &cn.original)
+ stateSinkObject.Save(1, &cn.reply)
+ stateSinkObject.Save(2, &cn.manip)
+ stateSinkObject.Save(3, &cn.tcbHook)
+ stateSinkObject.Save(4, &cn.tcb)
+}
+
+func (cn *conn) afterLoad() {}
+
+// +checklocksignore
+func (cn *conn) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &cn.original)
+ stateSourceObject.Load(1, &cn.reply)
+ stateSourceObject.Load(2, &cn.manip)
+ stateSourceObject.Load(3, &cn.tcbHook)
+ stateSourceObject.Load(4, &cn.tcb)
+ stateSourceObject.LoadValue(5, new(unixTime), func(y interface{}) { cn.loadLastUsed(y.(unixTime)) })
+}
+
+func (ct *ConnTrack) StateTypeName() string {
+ return "pkg/tcpip/stack.ConnTrack"
+}
+
+func (ct *ConnTrack) StateFields() []string {
+ return []string{
+ "seed",
+ "buckets",
+ }
+}
+
+// +checklocksignore
+func (ct *ConnTrack) StateSave(stateSinkObject state.Sink) {
+ ct.beforeSave()
+ stateSinkObject.Save(0, &ct.seed)
+ stateSinkObject.Save(1, &ct.buckets)
+}
+
+func (ct *ConnTrack) afterLoad() {}
+
+// +checklocksignore
+func (ct *ConnTrack) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &ct.seed)
+ stateSourceObject.Load(1, &ct.buckets)
+}
+
+func (b *bucket) StateTypeName() string {
+ return "pkg/tcpip/stack.bucket"
+}
+
+func (b *bucket) StateFields() []string {
+ return []string{
+ "tuples",
+ }
+}
+
+func (b *bucket) beforeSave() {}
+
+// +checklocksignore
+func (b *bucket) StateSave(stateSinkObject state.Sink) {
+ b.beforeSave()
+ stateSinkObject.Save(0, &b.tuples)
+}
+
+func (b *bucket) afterLoad() {}
+
+// +checklocksignore
+func (b *bucket) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &b.tuples)
+}
+
+func (u *unixTime) StateTypeName() string {
+ return "pkg/tcpip/stack.unixTime"
+}
+
+func (u *unixTime) StateFields() []string {
+ return []string{
+ "second",
+ "nano",
+ }
+}
+
+func (u *unixTime) beforeSave() {}
+
+// +checklocksignore
+func (u *unixTime) StateSave(stateSinkObject state.Sink) {
+ u.beforeSave()
+ stateSinkObject.Save(0, &u.second)
+ stateSinkObject.Save(1, &u.nano)
+}
+
+func (u *unixTime) afterLoad() {}
+
+// +checklocksignore
+func (u *unixTime) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &u.second)
+ stateSourceObject.Load(1, &u.nano)
+}
+
+func (it *IPTables) StateTypeName() string {
+ return "pkg/tcpip/stack.IPTables"
+}
+
+func (it *IPTables) StateFields() []string {
+ return []string{
+ "mu",
+ "v4Tables",
+ "v6Tables",
+ "modified",
+ "priorities",
+ "connections",
+ "reaperDone",
+ }
+}
+
+// +checklocksignore
+func (it *IPTables) StateSave(stateSinkObject state.Sink) {
+ it.beforeSave()
+ stateSinkObject.Save(0, &it.mu)
+ stateSinkObject.Save(1, &it.v4Tables)
+ stateSinkObject.Save(2, &it.v6Tables)
+ stateSinkObject.Save(3, &it.modified)
+ stateSinkObject.Save(4, &it.priorities)
+ stateSinkObject.Save(5, &it.connections)
+ stateSinkObject.Save(6, &it.reaperDone)
+}
+
+// +checklocksignore
+func (it *IPTables) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &it.mu)
+ stateSourceObject.Load(1, &it.v4Tables)
+ stateSourceObject.Load(2, &it.v6Tables)
+ stateSourceObject.Load(3, &it.modified)
+ stateSourceObject.Load(4, &it.priorities)
+ stateSourceObject.Load(5, &it.connections)
+ stateSourceObject.Load(6, &it.reaperDone)
+ stateSourceObject.AfterLoad(it.afterLoad)
+}
+
+func (table *Table) StateTypeName() string {
+ return "pkg/tcpip/stack.Table"
+}
+
+func (table *Table) StateFields() []string {
+ return []string{
+ "Rules",
+ "BuiltinChains",
+ "Underflows",
+ }
+}
+
+func (table *Table) beforeSave() {}
+
+// +checklocksignore
+func (table *Table) StateSave(stateSinkObject state.Sink) {
+ table.beforeSave()
+ stateSinkObject.Save(0, &table.Rules)
+ stateSinkObject.Save(1, &table.BuiltinChains)
+ stateSinkObject.Save(2, &table.Underflows)
+}
+
+func (table *Table) afterLoad() {}
+
+// +checklocksignore
+func (table *Table) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &table.Rules)
+ stateSourceObject.Load(1, &table.BuiltinChains)
+ stateSourceObject.Load(2, &table.Underflows)
+}
+
+func (r *Rule) StateTypeName() string {
+ return "pkg/tcpip/stack.Rule"
+}
+
+func (r *Rule) StateFields() []string {
+ return []string{
+ "Filter",
+ "Matchers",
+ "Target",
+ }
+}
+
+func (r *Rule) beforeSave() {}
+
+// +checklocksignore
+func (r *Rule) StateSave(stateSinkObject state.Sink) {
+ r.beforeSave()
+ stateSinkObject.Save(0, &r.Filter)
+ stateSinkObject.Save(1, &r.Matchers)
+ stateSinkObject.Save(2, &r.Target)
+}
+
+func (r *Rule) afterLoad() {}
+
+// +checklocksignore
+func (r *Rule) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &r.Filter)
+ stateSourceObject.Load(1, &r.Matchers)
+ stateSourceObject.Load(2, &r.Target)
+}
+
+func (fl *IPHeaderFilter) StateTypeName() string {
+ return "pkg/tcpip/stack.IPHeaderFilter"
+}
+
+func (fl *IPHeaderFilter) StateFields() []string {
+ return []string{
+ "Protocol",
+ "CheckProtocol",
+ "Dst",
+ "DstMask",
+ "DstInvert",
+ "Src",
+ "SrcMask",
+ "SrcInvert",
+ "InputInterface",
+ "InputInterfaceMask",
+ "InputInterfaceInvert",
+ "OutputInterface",
+ "OutputInterfaceMask",
+ "OutputInterfaceInvert",
+ }
+}
+
+func (fl *IPHeaderFilter) beforeSave() {}
+
+// +checklocksignore
+func (fl *IPHeaderFilter) StateSave(stateSinkObject state.Sink) {
+ fl.beforeSave()
+ stateSinkObject.Save(0, &fl.Protocol)
+ stateSinkObject.Save(1, &fl.CheckProtocol)
+ stateSinkObject.Save(2, &fl.Dst)
+ stateSinkObject.Save(3, &fl.DstMask)
+ stateSinkObject.Save(4, &fl.DstInvert)
+ stateSinkObject.Save(5, &fl.Src)
+ stateSinkObject.Save(6, &fl.SrcMask)
+ stateSinkObject.Save(7, &fl.SrcInvert)
+ stateSinkObject.Save(8, &fl.InputInterface)
+ stateSinkObject.Save(9, &fl.InputInterfaceMask)
+ stateSinkObject.Save(10, &fl.InputInterfaceInvert)
+ stateSinkObject.Save(11, &fl.OutputInterface)
+ stateSinkObject.Save(12, &fl.OutputInterfaceMask)
+ stateSinkObject.Save(13, &fl.OutputInterfaceInvert)
+}
+
+func (fl *IPHeaderFilter) afterLoad() {}
+
+// +checklocksignore
+func (fl *IPHeaderFilter) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &fl.Protocol)
+ stateSourceObject.Load(1, &fl.CheckProtocol)
+ stateSourceObject.Load(2, &fl.Dst)
+ stateSourceObject.Load(3, &fl.DstMask)
+ stateSourceObject.Load(4, &fl.DstInvert)
+ stateSourceObject.Load(5, &fl.Src)
+ stateSourceObject.Load(6, &fl.SrcMask)
+ stateSourceObject.Load(7, &fl.SrcInvert)
+ stateSourceObject.Load(8, &fl.InputInterface)
+ stateSourceObject.Load(9, &fl.InputInterfaceMask)
+ stateSourceObject.Load(10, &fl.InputInterfaceInvert)
+ stateSourceObject.Load(11, &fl.OutputInterface)
+ stateSourceObject.Load(12, &fl.OutputInterfaceMask)
+ stateSourceObject.Load(13, &fl.OutputInterfaceInvert)
+}
+
+func (l *neighborEntryList) StateTypeName() string {
+ return "pkg/tcpip/stack.neighborEntryList"
+}
+
+func (l *neighborEntryList) StateFields() []string {
+ return []string{
+ "head",
+ "tail",
+ }
+}
+
+func (l *neighborEntryList) beforeSave() {}
+
+// +checklocksignore
+func (l *neighborEntryList) StateSave(stateSinkObject state.Sink) {
+ l.beforeSave()
+ stateSinkObject.Save(0, &l.head)
+ stateSinkObject.Save(1, &l.tail)
+}
+
+func (l *neighborEntryList) afterLoad() {}
+
+// +checklocksignore
+func (l *neighborEntryList) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &l.head)
+ stateSourceObject.Load(1, &l.tail)
+}
+
+func (e *neighborEntryEntry) StateTypeName() string {
+ return "pkg/tcpip/stack.neighborEntryEntry"
+}
+
+func (e *neighborEntryEntry) StateFields() []string {
+ return []string{
+ "next",
+ "prev",
+ }
+}
+
+func (e *neighborEntryEntry) beforeSave() {}
+
+// +checklocksignore
+func (e *neighborEntryEntry) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.next)
+ stateSinkObject.Save(1, &e.prev)
+}
+
+func (e *neighborEntryEntry) afterLoad() {}
+
+// +checklocksignore
+func (e *neighborEntryEntry) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.next)
+ stateSourceObject.Load(1, &e.prev)
+}
+
+func (p *PacketBufferList) StateTypeName() string {
+ return "pkg/tcpip/stack.PacketBufferList"
+}
+
+func (p *PacketBufferList) StateFields() []string {
+ return []string{
+ "head",
+ "tail",
+ }
+}
+
+func (p *PacketBufferList) beforeSave() {}
+
+// +checklocksignore
+func (p *PacketBufferList) StateSave(stateSinkObject state.Sink) {
+ p.beforeSave()
+ stateSinkObject.Save(0, &p.head)
+ stateSinkObject.Save(1, &p.tail)
+}
+
+func (p *PacketBufferList) afterLoad() {}
+
+// +checklocksignore
+func (p *PacketBufferList) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &p.head)
+ stateSourceObject.Load(1, &p.tail)
+}
+
+func (e *PacketBufferEntry) StateTypeName() string {
+ return "pkg/tcpip/stack.PacketBufferEntry"
+}
+
+func (e *PacketBufferEntry) StateFields() []string {
+ return []string{
+ "next",
+ "prev",
+ }
+}
+
+func (e *PacketBufferEntry) beforeSave() {}
+
+// +checklocksignore
+func (e *PacketBufferEntry) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.next)
+ stateSinkObject.Save(1, &e.prev)
+}
+
+func (e *PacketBufferEntry) afterLoad() {}
+
+// +checklocksignore
+func (e *PacketBufferEntry) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.next)
+ stateSourceObject.Load(1, &e.prev)
+}
+
+func (t *TransportEndpointID) StateTypeName() string {
+ return "pkg/tcpip/stack.TransportEndpointID"
+}
+
+func (t *TransportEndpointID) StateFields() []string {
+ return []string{
+ "LocalPort",
+ "LocalAddress",
+ "RemotePort",
+ "RemoteAddress",
+ }
+}
+
+func (t *TransportEndpointID) beforeSave() {}
+
+// +checklocksignore
+func (t *TransportEndpointID) StateSave(stateSinkObject state.Sink) {
+ t.beforeSave()
+ stateSinkObject.Save(0, &t.LocalPort)
+ stateSinkObject.Save(1, &t.LocalAddress)
+ stateSinkObject.Save(2, &t.RemotePort)
+ stateSinkObject.Save(3, &t.RemoteAddress)
+}
+
+func (t *TransportEndpointID) afterLoad() {}
+
+// +checklocksignore
+func (t *TransportEndpointID) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &t.LocalPort)
+ stateSourceObject.Load(1, &t.LocalAddress)
+ stateSourceObject.Load(2, &t.RemotePort)
+ stateSourceObject.Load(3, &t.RemoteAddress)
+}
+
+func (g *GSOType) StateTypeName() string {
+ return "pkg/tcpip/stack.GSOType"
+}
+
+func (g *GSOType) StateFields() []string {
+ return nil
+}
+
+func (g *GSO) StateTypeName() string {
+ return "pkg/tcpip/stack.GSO"
+}
+
+func (g *GSO) StateFields() []string {
+ return []string{
+ "Type",
+ "NeedsCsum",
+ "CsumOffset",
+ "MSS",
+ "L3HdrLen",
+ "MaxSize",
+ }
+}
+
+func (g *GSO) beforeSave() {}
+
+// +checklocksignore
+func (g *GSO) StateSave(stateSinkObject state.Sink) {
+ g.beforeSave()
+ stateSinkObject.Save(0, &g.Type)
+ stateSinkObject.Save(1, &g.NeedsCsum)
+ stateSinkObject.Save(2, &g.CsumOffset)
+ stateSinkObject.Save(3, &g.MSS)
+ stateSinkObject.Save(4, &g.L3HdrLen)
+ stateSinkObject.Save(5, &g.MaxSize)
+}
+
+func (g *GSO) afterLoad() {}
+
+// +checklocksignore
+func (g *GSO) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &g.Type)
+ stateSourceObject.Load(1, &g.NeedsCsum)
+ stateSourceObject.Load(2, &g.CsumOffset)
+ stateSourceObject.Load(3, &g.MSS)
+ stateSourceObject.Load(4, &g.L3HdrLen)
+ stateSourceObject.Load(5, &g.MaxSize)
+}
+
+func (t *TransportEndpointInfo) StateTypeName() string {
+ return "pkg/tcpip/stack.TransportEndpointInfo"
+}
+
+func (t *TransportEndpointInfo) StateFields() []string {
+ return []string{
+ "NetProto",
+ "TransProto",
+ "ID",
+ "BindNICID",
+ "BindAddr",
+ "RegisterNICID",
+ }
+}
+
+func (t *TransportEndpointInfo) beforeSave() {}
+
+// +checklocksignore
+func (t *TransportEndpointInfo) StateSave(stateSinkObject state.Sink) {
+ t.beforeSave()
+ stateSinkObject.Save(0, &t.NetProto)
+ stateSinkObject.Save(1, &t.TransProto)
+ stateSinkObject.Save(2, &t.ID)
+ stateSinkObject.Save(3, &t.BindNICID)
+ stateSinkObject.Save(4, &t.BindAddr)
+ stateSinkObject.Save(5, &t.RegisterNICID)
+}
+
+func (t *TransportEndpointInfo) afterLoad() {}
+
+// +checklocksignore
+func (t *TransportEndpointInfo) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &t.NetProto)
+ stateSourceObject.Load(1, &t.TransProto)
+ stateSourceObject.Load(2, &t.ID)
+ stateSourceObject.Load(3, &t.BindNICID)
+ stateSourceObject.Load(4, &t.BindAddr)
+ stateSourceObject.Load(5, &t.RegisterNICID)
+}
+
+func (t *TCPCubicState) StateTypeName() string {
+ return "pkg/tcpip/stack.TCPCubicState"
+}
+
+func (t *TCPCubicState) StateFields() []string {
+ return []string{
+ "WLastMax",
+ "WMax",
+ "T",
+ "TimeSinceLastCongestion",
+ "C",
+ "K",
+ "Beta",
+ "WC",
+ "WEst",
+ }
+}
+
+func (t *TCPCubicState) beforeSave() {}
+
+// +checklocksignore
+func (t *TCPCubicState) StateSave(stateSinkObject state.Sink) {
+ t.beforeSave()
+ stateSinkObject.Save(0, &t.WLastMax)
+ stateSinkObject.Save(1, &t.WMax)
+ stateSinkObject.Save(2, &t.T)
+ stateSinkObject.Save(3, &t.TimeSinceLastCongestion)
+ stateSinkObject.Save(4, &t.C)
+ stateSinkObject.Save(5, &t.K)
+ stateSinkObject.Save(6, &t.Beta)
+ stateSinkObject.Save(7, &t.WC)
+ stateSinkObject.Save(8, &t.WEst)
+}
+
+func (t *TCPCubicState) afterLoad() {}
+
+// +checklocksignore
+func (t *TCPCubicState) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &t.WLastMax)
+ stateSourceObject.Load(1, &t.WMax)
+ stateSourceObject.Load(2, &t.T)
+ stateSourceObject.Load(3, &t.TimeSinceLastCongestion)
+ stateSourceObject.Load(4, &t.C)
+ stateSourceObject.Load(5, &t.K)
+ stateSourceObject.Load(6, &t.Beta)
+ stateSourceObject.Load(7, &t.WC)
+ stateSourceObject.Load(8, &t.WEst)
+}
+
+func (t *TCPRACKState) StateTypeName() string {
+ return "pkg/tcpip/stack.TCPRACKState"
+}
+
+func (t *TCPRACKState) StateFields() []string {
+ return []string{
+ "XmitTime",
+ "EndSequence",
+ "FACK",
+ "RTT",
+ "Reord",
+ "DSACKSeen",
+ "ReoWnd",
+ "ReoWndIncr",
+ "ReoWndPersist",
+ "RTTSeq",
+ }
+}
+
+func (t *TCPRACKState) beforeSave() {}
+
+// +checklocksignore
+func (t *TCPRACKState) StateSave(stateSinkObject state.Sink) {
+ t.beforeSave()
+ stateSinkObject.Save(0, &t.XmitTime)
+ stateSinkObject.Save(1, &t.EndSequence)
+ stateSinkObject.Save(2, &t.FACK)
+ stateSinkObject.Save(3, &t.RTT)
+ stateSinkObject.Save(4, &t.Reord)
+ stateSinkObject.Save(5, &t.DSACKSeen)
+ stateSinkObject.Save(6, &t.ReoWnd)
+ stateSinkObject.Save(7, &t.ReoWndIncr)
+ stateSinkObject.Save(8, &t.ReoWndPersist)
+ stateSinkObject.Save(9, &t.RTTSeq)
+}
+
+func (t *TCPRACKState) afterLoad() {}
+
+// +checklocksignore
+func (t *TCPRACKState) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &t.XmitTime)
+ stateSourceObject.Load(1, &t.EndSequence)
+ stateSourceObject.Load(2, &t.FACK)
+ stateSourceObject.Load(3, &t.RTT)
+ stateSourceObject.Load(4, &t.Reord)
+ stateSourceObject.Load(5, &t.DSACKSeen)
+ stateSourceObject.Load(6, &t.ReoWnd)
+ stateSourceObject.Load(7, &t.ReoWndIncr)
+ stateSourceObject.Load(8, &t.ReoWndPersist)
+ stateSourceObject.Load(9, &t.RTTSeq)
+}
+
+func (t *TCPEndpointID) StateTypeName() string {
+ return "pkg/tcpip/stack.TCPEndpointID"
+}
+
+func (t *TCPEndpointID) StateFields() []string {
+ return []string{
+ "LocalPort",
+ "LocalAddress",
+ "RemotePort",
+ "RemoteAddress",
+ }
+}
+
+func (t *TCPEndpointID) beforeSave() {}
+
+// +checklocksignore
+func (t *TCPEndpointID) StateSave(stateSinkObject state.Sink) {
+ t.beforeSave()
+ stateSinkObject.Save(0, &t.LocalPort)
+ stateSinkObject.Save(1, &t.LocalAddress)
+ stateSinkObject.Save(2, &t.RemotePort)
+ stateSinkObject.Save(3, &t.RemoteAddress)
+}
+
+func (t *TCPEndpointID) afterLoad() {}
+
+// +checklocksignore
+func (t *TCPEndpointID) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &t.LocalPort)
+ stateSourceObject.Load(1, &t.LocalAddress)
+ stateSourceObject.Load(2, &t.RemotePort)
+ stateSourceObject.Load(3, &t.RemoteAddress)
+}
+
+func (t *TCPFastRecoveryState) StateTypeName() string {
+ return "pkg/tcpip/stack.TCPFastRecoveryState"
+}
+
+func (t *TCPFastRecoveryState) StateFields() []string {
+ return []string{
+ "Active",
+ "First",
+ "Last",
+ "MaxCwnd",
+ "HighRxt",
+ "RescueRxt",
+ }
+}
+
+func (t *TCPFastRecoveryState) beforeSave() {}
+
+// +checklocksignore
+func (t *TCPFastRecoveryState) StateSave(stateSinkObject state.Sink) {
+ t.beforeSave()
+ stateSinkObject.Save(0, &t.Active)
+ stateSinkObject.Save(1, &t.First)
+ stateSinkObject.Save(2, &t.Last)
+ stateSinkObject.Save(3, &t.MaxCwnd)
+ stateSinkObject.Save(4, &t.HighRxt)
+ stateSinkObject.Save(5, &t.RescueRxt)
+}
+
+func (t *TCPFastRecoveryState) afterLoad() {}
+
+// +checklocksignore
+func (t *TCPFastRecoveryState) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &t.Active)
+ stateSourceObject.Load(1, &t.First)
+ stateSourceObject.Load(2, &t.Last)
+ stateSourceObject.Load(3, &t.MaxCwnd)
+ stateSourceObject.Load(4, &t.HighRxt)
+ stateSourceObject.Load(5, &t.RescueRxt)
+}
+
+func (t *TCPReceiverState) StateTypeName() string {
+ return "pkg/tcpip/stack.TCPReceiverState"
+}
+
+func (t *TCPReceiverState) StateFields() []string {
+ return []string{
+ "RcvNxt",
+ "RcvAcc",
+ "RcvWndScale",
+ "PendingBufUsed",
+ }
+}
+
+func (t *TCPReceiverState) beforeSave() {}
+
+// +checklocksignore
+func (t *TCPReceiverState) StateSave(stateSinkObject state.Sink) {
+ t.beforeSave()
+ stateSinkObject.Save(0, &t.RcvNxt)
+ stateSinkObject.Save(1, &t.RcvAcc)
+ stateSinkObject.Save(2, &t.RcvWndScale)
+ stateSinkObject.Save(3, &t.PendingBufUsed)
+}
+
+func (t *TCPReceiverState) afterLoad() {}
+
+// +checklocksignore
+func (t *TCPReceiverState) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &t.RcvNxt)
+ stateSourceObject.Load(1, &t.RcvAcc)
+ stateSourceObject.Load(2, &t.RcvWndScale)
+ stateSourceObject.Load(3, &t.PendingBufUsed)
+}
+
+func (t *TCPRTTState) StateTypeName() string {
+ return "pkg/tcpip/stack.TCPRTTState"
+}
+
+func (t *TCPRTTState) StateFields() []string {
+ return []string{
+ "SRTT",
+ "RTTVar",
+ "SRTTInited",
+ }
+}
+
+func (t *TCPRTTState) beforeSave() {}
+
+// +checklocksignore
+func (t *TCPRTTState) StateSave(stateSinkObject state.Sink) {
+ t.beforeSave()
+ stateSinkObject.Save(0, &t.SRTT)
+ stateSinkObject.Save(1, &t.RTTVar)
+ stateSinkObject.Save(2, &t.SRTTInited)
+}
+
+func (t *TCPRTTState) afterLoad() {}
+
+// +checklocksignore
+func (t *TCPRTTState) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &t.SRTT)
+ stateSourceObject.Load(1, &t.RTTVar)
+ stateSourceObject.Load(2, &t.SRTTInited)
+}
+
+func (t *TCPSenderState) StateTypeName() string {
+ return "pkg/tcpip/stack.TCPSenderState"
+}
+
+func (t *TCPSenderState) StateFields() []string {
+ return []string{
+ "LastSendTime",
+ "DupAckCount",
+ "SndCwnd",
+ "Ssthresh",
+ "SndCAAckCount",
+ "Outstanding",
+ "SackedOut",
+ "SndWnd",
+ "SndUna",
+ "SndNxt",
+ "RTTMeasureSeqNum",
+ "RTTMeasureTime",
+ "Closed",
+ "RTO",
+ "RTTState",
+ "MaxPayloadSize",
+ "SndWndScale",
+ "MaxSentAck",
+ "FastRecovery",
+ "Cubic",
+ "RACKState",
+ }
+}
+
+func (t *TCPSenderState) beforeSave() {}
+
+// +checklocksignore
+func (t *TCPSenderState) StateSave(stateSinkObject state.Sink) {
+ t.beforeSave()
+ stateSinkObject.Save(0, &t.LastSendTime)
+ stateSinkObject.Save(1, &t.DupAckCount)
+ stateSinkObject.Save(2, &t.SndCwnd)
+ stateSinkObject.Save(3, &t.Ssthresh)
+ stateSinkObject.Save(4, &t.SndCAAckCount)
+ stateSinkObject.Save(5, &t.Outstanding)
+ stateSinkObject.Save(6, &t.SackedOut)
+ stateSinkObject.Save(7, &t.SndWnd)
+ stateSinkObject.Save(8, &t.SndUna)
+ stateSinkObject.Save(9, &t.SndNxt)
+ stateSinkObject.Save(10, &t.RTTMeasureSeqNum)
+ stateSinkObject.Save(11, &t.RTTMeasureTime)
+ stateSinkObject.Save(12, &t.Closed)
+ stateSinkObject.Save(13, &t.RTO)
+ stateSinkObject.Save(14, &t.RTTState)
+ stateSinkObject.Save(15, &t.MaxPayloadSize)
+ stateSinkObject.Save(16, &t.SndWndScale)
+ stateSinkObject.Save(17, &t.MaxSentAck)
+ stateSinkObject.Save(18, &t.FastRecovery)
+ stateSinkObject.Save(19, &t.Cubic)
+ stateSinkObject.Save(20, &t.RACKState)
+}
+
+func (t *TCPSenderState) afterLoad() {}
+
+// +checklocksignore
+func (t *TCPSenderState) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &t.LastSendTime)
+ stateSourceObject.Load(1, &t.DupAckCount)
+ stateSourceObject.Load(2, &t.SndCwnd)
+ stateSourceObject.Load(3, &t.Ssthresh)
+ stateSourceObject.Load(4, &t.SndCAAckCount)
+ stateSourceObject.Load(5, &t.Outstanding)
+ stateSourceObject.Load(6, &t.SackedOut)
+ stateSourceObject.Load(7, &t.SndWnd)
+ stateSourceObject.Load(8, &t.SndUna)
+ stateSourceObject.Load(9, &t.SndNxt)
+ stateSourceObject.Load(10, &t.RTTMeasureSeqNum)
+ stateSourceObject.Load(11, &t.RTTMeasureTime)
+ stateSourceObject.Load(12, &t.Closed)
+ stateSourceObject.Load(13, &t.RTO)
+ stateSourceObject.Load(14, &t.RTTState)
+ stateSourceObject.Load(15, &t.MaxPayloadSize)
+ stateSourceObject.Load(16, &t.SndWndScale)
+ stateSourceObject.Load(17, &t.MaxSentAck)
+ stateSourceObject.Load(18, &t.FastRecovery)
+ stateSourceObject.Load(19, &t.Cubic)
+ stateSourceObject.Load(20, &t.RACKState)
+}
+
+func (t *TCPSACKInfo) StateTypeName() string {
+ return "pkg/tcpip/stack.TCPSACKInfo"
+}
+
+func (t *TCPSACKInfo) StateFields() []string {
+ return []string{
+ "Blocks",
+ "ReceivedBlocks",
+ "MaxSACKED",
+ }
+}
+
+func (t *TCPSACKInfo) beforeSave() {}
+
+// +checklocksignore
+func (t *TCPSACKInfo) StateSave(stateSinkObject state.Sink) {
+ t.beforeSave()
+ stateSinkObject.Save(0, &t.Blocks)
+ stateSinkObject.Save(1, &t.ReceivedBlocks)
+ stateSinkObject.Save(2, &t.MaxSACKED)
+}
+
+func (t *TCPSACKInfo) afterLoad() {}
+
+// +checklocksignore
+func (t *TCPSACKInfo) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &t.Blocks)
+ stateSourceObject.Load(1, &t.ReceivedBlocks)
+ stateSourceObject.Load(2, &t.MaxSACKED)
+}
+
+func (r *RcvBufAutoTuneParams) StateTypeName() string {
+ return "pkg/tcpip/stack.RcvBufAutoTuneParams"
+}
+
+func (r *RcvBufAutoTuneParams) StateFields() []string {
+ return []string{
+ "MeasureTime",
+ "CopiedBytes",
+ "PrevCopiedBytes",
+ "RcvBufSize",
+ "RTT",
+ "RTTVar",
+ "RTTMeasureSeqNumber",
+ "RTTMeasureTime",
+ "Disabled",
+ }
+}
+
+func (r *RcvBufAutoTuneParams) beforeSave() {}
+
+// +checklocksignore
+func (r *RcvBufAutoTuneParams) StateSave(stateSinkObject state.Sink) {
+ r.beforeSave()
+ stateSinkObject.Save(0, &r.MeasureTime)
+ stateSinkObject.Save(1, &r.CopiedBytes)
+ stateSinkObject.Save(2, &r.PrevCopiedBytes)
+ stateSinkObject.Save(3, &r.RcvBufSize)
+ stateSinkObject.Save(4, &r.RTT)
+ stateSinkObject.Save(5, &r.RTTVar)
+ stateSinkObject.Save(6, &r.RTTMeasureSeqNumber)
+ stateSinkObject.Save(7, &r.RTTMeasureTime)
+ stateSinkObject.Save(8, &r.Disabled)
+}
+
+func (r *RcvBufAutoTuneParams) afterLoad() {}
+
+// +checklocksignore
+func (r *RcvBufAutoTuneParams) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &r.MeasureTime)
+ stateSourceObject.Load(1, &r.CopiedBytes)
+ stateSourceObject.Load(2, &r.PrevCopiedBytes)
+ stateSourceObject.Load(3, &r.RcvBufSize)
+ stateSourceObject.Load(4, &r.RTT)
+ stateSourceObject.Load(5, &r.RTTVar)
+ stateSourceObject.Load(6, &r.RTTMeasureSeqNumber)
+ stateSourceObject.Load(7, &r.RTTMeasureTime)
+ stateSourceObject.Load(8, &r.Disabled)
+}
+
+func (t *TCPRcvBufState) StateTypeName() string {
+ return "pkg/tcpip/stack.TCPRcvBufState"
+}
+
+func (t *TCPRcvBufState) StateFields() []string {
+ return []string{
+ "RcvBufUsed",
+ "RcvAutoParams",
+ "RcvClosed",
+ }
+}
+
+func (t *TCPRcvBufState) beforeSave() {}
+
+// +checklocksignore
+func (t *TCPRcvBufState) StateSave(stateSinkObject state.Sink) {
+ t.beforeSave()
+ stateSinkObject.Save(0, &t.RcvBufUsed)
+ stateSinkObject.Save(1, &t.RcvAutoParams)
+ stateSinkObject.Save(2, &t.RcvClosed)
+}
+
+func (t *TCPRcvBufState) afterLoad() {}
+
+// +checklocksignore
+func (t *TCPRcvBufState) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &t.RcvBufUsed)
+ stateSourceObject.Load(1, &t.RcvAutoParams)
+ stateSourceObject.Load(2, &t.RcvClosed)
+}
+
+func (t *TCPSndBufState) StateTypeName() string {
+ return "pkg/tcpip/stack.TCPSndBufState"
+}
+
+func (t *TCPSndBufState) StateFields() []string {
+ return []string{
+ "SndBufSize",
+ "SndBufUsed",
+ "SndClosed",
+ "PacketTooBigCount",
+ "SndMTU",
+ "AutoTuneSndBufDisabled",
+ }
+}
+
+func (t *TCPSndBufState) beforeSave() {}
+
+// +checklocksignore
+func (t *TCPSndBufState) StateSave(stateSinkObject state.Sink) {
+ t.beforeSave()
+ stateSinkObject.Save(0, &t.SndBufSize)
+ stateSinkObject.Save(1, &t.SndBufUsed)
+ stateSinkObject.Save(2, &t.SndClosed)
+ stateSinkObject.Save(3, &t.PacketTooBigCount)
+ stateSinkObject.Save(4, &t.SndMTU)
+ stateSinkObject.Save(5, &t.AutoTuneSndBufDisabled)
+}
+
+func (t *TCPSndBufState) afterLoad() {}
+
+// +checklocksignore
+func (t *TCPSndBufState) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &t.SndBufSize)
+ stateSourceObject.Load(1, &t.SndBufUsed)
+ stateSourceObject.Load(2, &t.SndClosed)
+ stateSourceObject.Load(3, &t.PacketTooBigCount)
+ stateSourceObject.Load(4, &t.SndMTU)
+ stateSourceObject.Load(5, &t.AutoTuneSndBufDisabled)
+}
+
+func (t *TCPEndpointStateInner) StateTypeName() string {
+ return "pkg/tcpip/stack.TCPEndpointStateInner"
+}
+
+func (t *TCPEndpointStateInner) StateFields() []string {
+ return []string{
+ "TSOffset",
+ "SACKPermitted",
+ "SendTSOk",
+ "RecentTS",
+ }
+}
+
+func (t *TCPEndpointStateInner) beforeSave() {}
+
+// +checklocksignore
+func (t *TCPEndpointStateInner) StateSave(stateSinkObject state.Sink) {
+ t.beforeSave()
+ stateSinkObject.Save(0, &t.TSOffset)
+ stateSinkObject.Save(1, &t.SACKPermitted)
+ stateSinkObject.Save(2, &t.SendTSOk)
+ stateSinkObject.Save(3, &t.RecentTS)
+}
+
+func (t *TCPEndpointStateInner) afterLoad() {}
+
+// +checklocksignore
+func (t *TCPEndpointStateInner) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &t.TSOffset)
+ stateSourceObject.Load(1, &t.SACKPermitted)
+ stateSourceObject.Load(2, &t.SendTSOk)
+ stateSourceObject.Load(3, &t.RecentTS)
+}
+
+func (t *TCPEndpointState) StateTypeName() string {
+ return "pkg/tcpip/stack.TCPEndpointState"
+}
+
+func (t *TCPEndpointState) StateFields() []string {
+ return []string{
+ "TCPEndpointStateInner",
+ "ID",
+ "SegTime",
+ "RcvBufState",
+ "SndBufState",
+ "SACK",
+ "Receiver",
+ "Sender",
+ }
+}
+
+func (t *TCPEndpointState) beforeSave() {}
+
+// +checklocksignore
+func (t *TCPEndpointState) StateSave(stateSinkObject state.Sink) {
+ t.beforeSave()
+ stateSinkObject.Save(0, &t.TCPEndpointStateInner)
+ stateSinkObject.Save(1, &t.ID)
+ stateSinkObject.Save(2, &t.SegTime)
+ stateSinkObject.Save(3, &t.RcvBufState)
+ stateSinkObject.Save(4, &t.SndBufState)
+ stateSinkObject.Save(5, &t.SACK)
+ stateSinkObject.Save(6, &t.Receiver)
+ stateSinkObject.Save(7, &t.Sender)
+}
+
+func (t *TCPEndpointState) afterLoad() {}
+
+// +checklocksignore
+func (t *TCPEndpointState) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &t.TCPEndpointStateInner)
+ stateSourceObject.Load(1, &t.ID)
+ stateSourceObject.Load(2, &t.SegTime)
+ stateSourceObject.Load(3, &t.RcvBufState)
+ stateSourceObject.Load(4, &t.SndBufState)
+ stateSourceObject.Load(5, &t.SACK)
+ stateSourceObject.Load(6, &t.Receiver)
+ stateSourceObject.Load(7, &t.Sender)
+}
+
+func (ep *multiPortEndpoint) StateTypeName() string {
+ return "pkg/tcpip/stack.multiPortEndpoint"
+}
+
+func (ep *multiPortEndpoint) StateFields() []string {
+ return []string{
+ "demux",
+ "netProto",
+ "transProto",
+ "flags",
+ "endpoints",
+ }
+}
+
+func (ep *multiPortEndpoint) beforeSave() {}
+
+// +checklocksignore
+func (ep *multiPortEndpoint) StateSave(stateSinkObject state.Sink) {
+ ep.beforeSave()
+ stateSinkObject.Save(0, &ep.demux)
+ stateSinkObject.Save(1, &ep.netProto)
+ stateSinkObject.Save(2, &ep.transProto)
+ stateSinkObject.Save(3, &ep.flags)
+ stateSinkObject.Save(4, &ep.endpoints)
+}
+
+func (ep *multiPortEndpoint) afterLoad() {}
+
+// +checklocksignore
+func (ep *multiPortEndpoint) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &ep.demux)
+ stateSourceObject.Load(1, &ep.netProto)
+ stateSourceObject.Load(2, &ep.transProto)
+ stateSourceObject.Load(3, &ep.flags)
+ stateSourceObject.Load(4, &ep.endpoints)
+}
+
+func (l *tupleList) StateTypeName() string {
+ return "pkg/tcpip/stack.tupleList"
+}
+
+func (l *tupleList) StateFields() []string {
+ return []string{
+ "head",
+ "tail",
+ }
+}
+
+func (l *tupleList) beforeSave() {}
+
+// +checklocksignore
+func (l *tupleList) StateSave(stateSinkObject state.Sink) {
+ l.beforeSave()
+ stateSinkObject.Save(0, &l.head)
+ stateSinkObject.Save(1, &l.tail)
+}
+
+func (l *tupleList) afterLoad() {}
+
+// +checklocksignore
+func (l *tupleList) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &l.head)
+ stateSourceObject.Load(1, &l.tail)
+}
+
+func (e *tupleEntry) StateTypeName() string {
+ return "pkg/tcpip/stack.tupleEntry"
+}
+
+func (e *tupleEntry) StateFields() []string {
+ return []string{
+ "next",
+ "prev",
+ }
+}
+
+func (e *tupleEntry) beforeSave() {}
+
+// +checklocksignore
+func (e *tupleEntry) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.next)
+ stateSinkObject.Save(1, &e.prev)
+}
+
+func (e *tupleEntry) afterLoad() {}
+
+// +checklocksignore
+func (e *tupleEntry) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.next)
+ stateSourceObject.Load(1, &e.prev)
+}
+
+func init() {
+ state.Register((*tuple)(nil))
+ state.Register((*tupleID)(nil))
+ state.Register((*conn)(nil))
+ state.Register((*ConnTrack)(nil))
+ state.Register((*bucket)(nil))
+ state.Register((*unixTime)(nil))
+ state.Register((*IPTables)(nil))
+ state.Register((*Table)(nil))
+ state.Register((*Rule)(nil))
+ state.Register((*IPHeaderFilter)(nil))
+ state.Register((*neighborEntryList)(nil))
+ state.Register((*neighborEntryEntry)(nil))
+ state.Register((*PacketBufferList)(nil))
+ state.Register((*PacketBufferEntry)(nil))
+ state.Register((*TransportEndpointID)(nil))
+ state.Register((*GSOType)(nil))
+ state.Register((*GSO)(nil))
+ state.Register((*TransportEndpointInfo)(nil))
+ state.Register((*TCPCubicState)(nil))
+ state.Register((*TCPRACKState)(nil))
+ state.Register((*TCPEndpointID)(nil))
+ state.Register((*TCPFastRecoveryState)(nil))
+ state.Register((*TCPReceiverState)(nil))
+ state.Register((*TCPRTTState)(nil))
+ state.Register((*TCPSenderState)(nil))
+ state.Register((*TCPSACKInfo)(nil))
+ state.Register((*RcvBufAutoTuneParams)(nil))
+ state.Register((*TCPRcvBufState)(nil))
+ state.Register((*TCPSndBufState)(nil))
+ state.Register((*TCPEndpointStateInner)(nil))
+ state.Register((*TCPEndpointState)(nil))
+ state.Register((*multiPortEndpoint)(nil))
+ state.Register((*tupleList)(nil))
+ state.Register((*tupleEntry)(nil))
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
deleted file mode 100644
index cd4137794..000000000
--- a/pkg/tcpip/stack/stack_test.go
+++ /dev/null
@@ -1,4671 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package stack_test contains tests for the stack. It is in its own package so
-// that the tests can also validate that all definitions needed to implement
-// transport and network protocols are properly exported by the stack package.
-package stack_test
-
-import (
- "bytes"
- "fmt"
- "math"
- "net"
- "sort"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/rand"
- "gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/network/arp"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
-)
-
-const (
- fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
- fakeNetHeaderLen = 12
- fakeDefaultPrefixLen = 8
-
- // fakeControlProtocol is used for control packets that represent
- // destination port unreachable.
- fakeControlProtocol tcpip.TransportProtocolNumber = 2
-
- // defaultMTU is the MTU, in bytes, used throughout the tests, except
- // where another value is explicitly used. It is chosen to match the MTU
- // of loopback interfaces on linux systems.
- defaultMTU = 65536
-
- dstAddrOffset = 0
- srcAddrOffset = 1
- protocolNumberOffset = 2
-)
-
-func checkGetMainNICAddress(s *stack.Stack, nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber, want tcpip.AddressWithPrefix) error {
- if addr, err := s.GetMainNICAddress(nicID, proto); err != nil {
- return fmt.Errorf("stack.GetMainNICAddress(%d, %d): %s", nicID, proto, err)
- } else if addr != want {
- return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, proto, addr, want)
- }
- return nil
-}
-
-// fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and
-// received packets; the counts of all endpoints are aggregated in the protocol
-// descriptor.
-//
-// Headers of this protocol are fakeNetHeaderLen bytes, but we currently only
-// use the first three: destination address, source address, and transport
-// protocol. They're all one byte fields to simplify parsing.
-type fakeNetworkEndpoint struct {
- stack.AddressableEndpointState
-
- mu struct {
- sync.RWMutex
-
- enabled bool
- forwarding bool
- }
-
- nic stack.NetworkInterface
- proto *fakeNetworkProtocol
- dispatcher stack.TransportDispatcher
-}
-
-func (f *fakeNetworkEndpoint) Enable() tcpip.Error {
- f.mu.Lock()
- defer f.mu.Unlock()
- f.mu.enabled = true
- return nil
-}
-
-func (f *fakeNetworkEndpoint) Enabled() bool {
- f.mu.RLock()
- defer f.mu.RUnlock()
- return f.mu.enabled
-}
-
-func (f *fakeNetworkEndpoint) Disable() {
- f.mu.Lock()
- defer f.mu.Unlock()
- f.mu.enabled = false
-}
-
-func (f *fakeNetworkEndpoint) MTU() uint32 {
- return f.nic.MTU() - uint32(f.MaxHeaderLength())
-}
-
-func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
- return 123
-}
-
-func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
- if _, _, ok := f.proto.Parse(pkt); !ok {
- return
- }
-
- // Increment the received packet count in the protocol descriptor.
- netHdr := pkt.NetworkHeader().View()
-
- dst := tcpip.Address(netHdr[dstAddrOffset:][:1])
- addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
- if addressEndpoint == nil {
- return
- }
- addressEndpoint.DecRef()
-
- f.proto.packetCount[int(dst[0])%len(f.proto.packetCount)]++
-
- // Handle control packets.
- if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) {
- hdr, ok := pkt.Data().PullUp(fakeNetHeaderLen)
- if !ok {
- return
- }
- // DeleteFront invalidates slices. Make a copy before trimming.
- nb := append([]byte(nil), hdr...)
- pkt.Data().DeleteFront(fakeNetHeaderLen)
- f.dispatcher.DeliverTransportError(
- tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]),
- tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]),
- fakeNetNumber,
- tcpip.TransportProtocolNumber(nb[protocolNumberOffset]),
- // Nothing checks the error.
- nil, /* transport error */
- pkt,
- )
- return
- }
-
- // Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
-}
-
-func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
- return f.nic.MaxHeaderLength() + fakeNetHeaderLen
-}
-
-func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
- return f.proto.Number()
-}
-
-func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error {
- // Increment the sent packet count in the protocol descriptor.
- f.proto.sendPacketCount[int(r.RemoteAddress()[0])%len(f.proto.sendPacketCount)]++
-
- // Add the protocol's header to the packet and send it to the link
- // endpoint.
- hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen)
- pkt.NetworkProtocolNumber = fakeNetNumber
- hdr[dstAddrOffset] = r.RemoteAddress()[0]
- hdr[srcAddrOffset] = r.LocalAddress()[0]
- hdr[protocolNumberOffset] = byte(params.Protocol)
-
- if r.Loop()&stack.PacketLoop != 0 {
- f.HandlePacket(pkt.Clone())
- }
- if r.Loop()&stack.PacketOut == 0 {
- return nil
- }
-
- return f.nic.WritePacket(r, fakeNetNumber, pkt)
-}
-
-// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (*fakeNetworkEndpoint) WritePackets(*stack.Route, stack.PacketBufferList, stack.NetworkHeaderParams) (int, tcpip.Error) {
- panic("not implemented")
-}
-
-func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) tcpip.Error {
- return &tcpip.ErrNotSupported{}
-}
-
-func (f *fakeNetworkEndpoint) Close() {
- f.AddressableEndpointState.Cleanup()
-}
-
-// Stats implements NetworkEndpoint.
-func (*fakeNetworkEndpoint) Stats() stack.NetworkEndpointStats {
- return &fakeNetworkEndpointStats{}
-}
-
-var _ stack.NetworkEndpointStats = (*fakeNetworkEndpointStats)(nil)
-
-type fakeNetworkEndpointStats struct{}
-
-// IsNetworkEndpointStats implements stack.NetworkEndpointStats.
-func (*fakeNetworkEndpointStats) IsNetworkEndpointStats() {}
-
-// fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the
-// number of packets sent and received via endpoints of this protocol. The index
-// where packets are added is given by the packet's destination address MOD 10.
-type fakeNetworkProtocol struct {
- packetCount [10]int
- sendPacketCount [10]int
- defaultTTL uint8
-}
-
-func (*fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
- return fakeNetNumber
-}
-
-func (*fakeNetworkProtocol) MinimumPacketSize() int {
- return fakeNetHeaderLen
-}
-
-func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int {
- return f.packetCount[int(intfAddr)%len(f.packetCount)]
-}
-
-func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
- return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
-}
-
-func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
- e := &fakeNetworkEndpoint{
- nic: nic,
- proto: f,
- dispatcher: dispatcher,
- }
- e.AddressableEndpointState.Init(e)
- return e
-}
-
-func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error {
- switch v := option.(type) {
- case *tcpip.DefaultTTLOption:
- f.defaultTTL = uint8(*v)
- return nil
- default:
- return &tcpip.ErrUnknownProtocolOption{}
- }
-}
-
-func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error {
- switch v := option.(type) {
- case *tcpip.DefaultTTLOption:
- *v = tcpip.DefaultTTLOption(f.defaultTTL)
- return nil
- default:
- return &tcpip.ErrUnknownProtocolOption{}
- }
-}
-
-// Close implements NetworkProtocol.Close.
-func (*fakeNetworkProtocol) Close() {}
-
-// Wait implements NetworkProtocol.Wait.
-func (*fakeNetworkProtocol) Wait() {}
-
-// Parse implements NetworkProtocol.Parse.
-func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
- hdr, ok := pkt.NetworkHeader().Consume(fakeNetHeaderLen)
- if !ok {
- return 0, false, false
- }
- pkt.NetworkProtocolNumber = fakeNetNumber
- return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true
-}
-
-// Forwarding implements stack.ForwardingNetworkEndpoint.
-func (f *fakeNetworkEndpoint) Forwarding() bool {
- f.mu.RLock()
- defer f.mu.RUnlock()
- return f.mu.forwarding
-}
-
-// SetForwarding implements stack.ForwardingNetworkEndpoint.
-func (f *fakeNetworkEndpoint) SetForwarding(v bool) {
- f.mu.Lock()
- defer f.mu.Unlock()
- f.mu.forwarding = v
-}
-
-func fakeNetFactory(*stack.Stack) stack.NetworkProtocol {
- return &fakeNetworkProtocol{}
-}
-
-// linkEPWithMockedAttach is a stack.LinkEndpoint that tests can use to verify
-// that LinkEndpoint.Attach was called.
-type linkEPWithMockedAttach struct {
- stack.LinkEndpoint
- attached bool
-}
-
-// Attach implements stack.LinkEndpoint.Attach.
-func (l *linkEPWithMockedAttach) Attach(d stack.NetworkDispatcher) {
- l.LinkEndpoint.Attach(d)
- l.attached = d != nil
-}
-
-func (l *linkEPWithMockedAttach) isAttached() bool {
- return l.attached
-}
-
-// Checks to see if list contains an address.
-func containsAddr(list []tcpip.ProtocolAddress, item tcpip.ProtocolAddress) bool {
- for _, i := range list {
- if i == item {
- return true
- }
- }
-
- return false
-}
-
-func TestNetworkReceive(t *testing.T) {
- // Create a stack with the fake network protocol, one nic, and two
- // addresses attached to it: 1 & 2.
- ep := channel.New(10, defaultMTU, "")
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- protocolAddr1 := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: "\x01",
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
- }
-
- protocolAddr2 := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: "\x02",
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(1, protocolAddr2, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr2, err)
- }
-
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
- buf := buffer.NewView(30)
-
- // Make sure packet with wrong address is not delivered.
- buf[dstAddrOffset] = 3
- ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- if fakeNet.packetCount[1] != 0 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
- }
- if fakeNet.packetCount[2] != 0 {
- t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0)
- }
-
- // Make sure packet is delivered to first endpoint.
- buf[dstAddrOffset] = 1
- ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
- if fakeNet.packetCount[2] != 0 {
- t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0)
- }
-
- // Make sure packet is delivered to second endpoint.
- buf[dstAddrOffset] = 2
- ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
- if fakeNet.packetCount[2] != 1 {
- t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1)
- }
-
- // Make sure packet is not delivered if protocol number is wrong.
- ep.InjectInbound(fakeNetNumber-1, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
- if fakeNet.packetCount[2] != 1 {
- t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1)
- }
-
- // Make sure packet that is too small is dropped.
- buf.CapLength(2)
- ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
- if fakeNet.packetCount[2] != 1 {
- t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1)
- }
-}
-
-func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) tcpip.Error {
- r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- return err
- }
- defer r.Release()
- return send(r, payload)
-}
-
-func send(r *stack.Route, payload buffer.View) tcpip.Error {
- return r.WritePacket(stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: payload.ToVectorisedView(),
- }))
-}
-
-func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) {
- t.Helper()
- ep.Drain()
- if err := sendTo(s, addr, payload); err != nil {
- t.Error("sendTo failed:", err)
- }
- if got, want := ep.Drain(), 1; got != want {
- t.Errorf("sendTo packet count: got = %d, want %d", got, want)
- }
-}
-
-func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View) {
- t.Helper()
- ep.Drain()
- if err := send(r, payload); err != nil {
- t.Error("send failed:", err)
- }
- if got, want := ep.Drain(), 1; got != want {
- t.Errorf("send packet count: got = %d, want %d", got, want)
- }
-}
-
-func testFailingSend(t *testing.T, r *stack.Route, payload buffer.View, wantErr tcpip.Error) {
- t.Helper()
- if gotErr := send(r, payload); gotErr != wantErr {
- t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr)
- }
-}
-
-func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View, wantErr tcpip.Error) {
- t.Helper()
- if gotErr := sendTo(s, addr, payload); gotErr != wantErr {
- t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr)
- }
-}
-
-func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) {
- t.Helper()
- // testRecvInternal injects one packet, and we expect to receive it.
- want := fakeNet.PacketCount(localAddrByte) + 1
- testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want)
-}
-
-func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) {
- t.Helper()
- // testRecvInternal injects one packet, and we do NOT expect to receive it.
- want := fakeNet.PacketCount(localAddrByte)
- testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want)
-}
-
-func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) {
- t.Helper()
- ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- if got := fakeNet.PacketCount(localAddrByte); got != want {
- t.Errorf("receive packet count: got = %d, want %d", got, want)
- }
-}
-
-func TestNetworkSend(t *testing.T) {
- // Create a stack with the fake network protocol, one nic, and one
- // address: 1. The route table sends all packets through the only
- // existing nic.
- ep := channel.New(10, defaultMTU, "")
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("NewNIC failed:", err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: "\x01",
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
- }
-
- // Make sure that the link-layer endpoint received the outbound packet.
- testSendTo(t, s, "\x03", ep, nil)
-}
-
-func TestNetworkSendMultiRoute(t *testing.T) {
- // Create a stack with the fake network protocol, two nics, and two
- // addresses per nic, the first nic has odd address, the second one has
- // even addresses.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
-
- ep1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep1); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- protocolAddr1 := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: "\x01",
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
- }
-
- protocolAddr3 := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: "\x03",
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(1, protocolAddr3, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr3, err)
- }
-
- ep2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, ep2); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- protocolAddr2 := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: "\x02",
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(2, protocolAddr2, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err)
- }
-
- protocolAddr4 := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: "\x04",
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(2, protocolAddr4, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr4, err)
- }
-
- // Set a route table that sends all packets with odd destination
- // addresses through the first NIC, and all even destination address
- // through the second one.
- {
- subnet0, err := tcpip.NewSubnet("\x00", "\x01")
- if err != nil {
- t.Fatal(err)
- }
- subnet1, err := tcpip.NewSubnet("\x01", "\x01")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{
- {Destination: subnet1, Gateway: "\x00", NIC: 1},
- {Destination: subnet0, Gateway: "\x00", NIC: 2},
- })
- }
-
- // Send a packet to an odd destination.
- testSendTo(t, s, "\x05", ep1, nil)
-
- // Send a packet to an even destination.
- testSendTo(t, s, "\x06", ep2, nil)
-}
-
-func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) {
- r, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatal("FindRoute failed:", err)
- }
-
- defer r.Release()
-
- if r.LocalAddress() != expectedSrcAddr {
- t.Fatalf("got Route.LocalAddress() = %s, want = %s", expectedSrcAddr, r.LocalAddress())
- }
-
- if r.RemoteAddress() != dstAddr {
- t.Fatalf("got Route.RemoteAddress() = %s, want = %s", dstAddr, r.RemoteAddress())
- }
-}
-
-func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) {
- _, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
- if _, ok := err.(*tcpip.ErrNoRoute); !ok {
- t.Fatalf("FindRoute returned unexpected error, got = %v, want = %s", err, &tcpip.ErrNoRoute{})
- }
-}
-
-// TestAttachToLinkEndpointImmediately tests that a LinkEndpoint is attached to
-// a NetworkDispatcher when the NIC is created.
-func TestAttachToLinkEndpointImmediately(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- nicOpts stack.NICOptions
- }{
- {
- name: "Create enabled NIC",
- nicOpts: stack.NICOptions{Disabled: false},
- },
- {
- name: "Create disabled NIC",
- nicOpts: stack.NICOptions{Disabled: true},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
-
- e := linkEPWithMockedAttach{
- LinkEndpoint: loopback.New(),
- }
-
- if err := s.CreateNICWithOptions(nicID, &e, test.nicOpts); err != nil {
- t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, test.nicOpts, err)
- }
- if !e.isAttached() {
- t.Fatal("link endpoint not attached to a network dispatcher")
- }
- })
- }
-}
-
-func TestDisableUnknownNIC(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
-
- err := s.DisableNIC(1)
- if _, ok := err.(*tcpip.ErrUnknownNICID); !ok {
- t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, &tcpip.ErrUnknownNICID{})
- }
-}
-
-func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) {
- const nicID = 1
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
-
- e := loopback.New()
- nicOpts := stack.NICOptions{Disabled: true}
- if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
- t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err)
- }
-
- checkNIC := func(enabled bool) {
- t.Helper()
-
- allNICInfo := s.NICInfo()
- nicInfo, ok := allNICInfo[nicID]
- if !ok {
- t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo)
- } else if nicInfo.Flags.Running != enabled {
- t.Errorf("got nicInfo.Flags.Running = %t, want = %t", nicInfo.Flags.Running, enabled)
- }
-
- if got := s.CheckNIC(nicID); got != enabled {
- t.Errorf("got s.CheckNIC(%d) = %t, want = %t", nicID, got, enabled)
- }
- }
-
- // NIC should initially report itself as disabled.
- checkNIC(false)
-
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
- }
- checkNIC(true)
-
- // If the NIC is not reporting a correct enabled status, we cannot trust the
- // next check so end the test here.
- if t.Failed() {
- t.FailNow()
- }
-
- if err := s.DisableNIC(nicID); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
- }
- checkNIC(false)
-}
-
-func TestRemoveUnknownNIC(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
-
- err := s.RemoveNIC(1)
- if _, ok := err.(*tcpip.ErrUnknownNICID); !ok {
- t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, &tcpip.ErrUnknownNICID{})
- }
-}
-
-func TestRemoveNIC(t *testing.T) {
- for _, tt := range []struct {
- name string
- linkep stack.LinkEndpoint
- expectErr tcpip.Error
- }{
- {
- name: "loopback",
- linkep: loopback.New(),
- expectErr: &tcpip.ErrNotSupported{},
- },
- {
- name: "channel",
- linkep: channel.New(0, defaultMTU, ""),
- expectErr: nil,
- },
- } {
- t.Run(tt.name, func(t *testing.T) {
- const nicID = 1
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
-
- e := linkEPWithMockedAttach{
- LinkEndpoint: tt.linkep,
- }
- if err := s.CreateNIC(nicID, &e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- // NIC should be present in NICInfo and attached to a NetworkDispatcher.
- allNICInfo := s.NICInfo()
- if _, ok := allNICInfo[nicID]; !ok {
- t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo)
- }
- if !e.isAttached() {
- t.Fatal("link endpoint not attached to a network dispatcher")
- }
-
- // Removing a NIC should remove it from NICInfo and e should be detached from
- // the NetworkDispatcher.
- if got, want := s.RemoveNIC(nicID), tt.expectErr; got != want {
- t.Fatalf("got s.RemoveNIC(%d) = %s, want %s", nicID, got, want)
- }
- if tt.expectErr == nil {
- if nicInfo, ok := s.NICInfo()[nicID]; ok {
- t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo)
- }
- if e.isAttached() {
- t.Error("link endpoint for removed NIC still attached to a network dispatcher")
- }
- }
- })
- }
-}
-
-func TestRouteWithDownNIC(t *testing.T) {
- tests := []struct {
- name string
- downFn func(s *stack.Stack, nicID tcpip.NICID) tcpip.Error
- upFn func(s *stack.Stack, nicID tcpip.NICID) tcpip.Error
- }{
- {
- name: "Disabled NIC",
- downFn: (*stack.Stack).DisableNIC,
- upFn: (*stack.Stack).EnableNIC,
- },
-
- // Once a NIC is removed, it cannot be brought up.
- {
- name: "Removed NIC",
- downFn: (*stack.Stack).RemoveNIC,
- },
- }
-
- const unspecifiedNIC = 0
- const nicID1 = 1
- const nicID2 = 2
- const addr1 = tcpip.Address("\x01")
- const addr2 = tcpip.Address("\x02")
- const nic1Dst = tcpip.Address("\x05")
- const nic2Dst = tcpip.Address("\x06")
-
- setup := func(t *testing.T) (*stack.Stack, *channel.Endpoint, *channel.Endpoint) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
-
- ep1 := channel.New(1, defaultMTU, "")
- if err := s.CreateNIC(nicID1, ep1); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
- }
-
- protocolAddr1 := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addr1,
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr1, err)
- }
-
- ep2 := channel.New(1, defaultMTU, "")
- if err := s.CreateNIC(nicID2, ep2); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
- }
-
- protocolAddr2 := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addr2,
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddr2, err)
- }
-
- // Set a route table that sends all packets with odd destination
- // addresses through the first NIC, and all even destination address
- // through the second one.
- {
- subnet0, err := tcpip.NewSubnet("\x00", "\x01")
- if err != nil {
- t.Fatal(err)
- }
- subnet1, err := tcpip.NewSubnet("\x01", "\x01")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{
- {Destination: subnet1, Gateway: "\x00", NIC: nicID1},
- {Destination: subnet0, Gateway: "\x00", NIC: nicID2},
- })
- }
-
- return s, ep1, ep2
- }
-
- // Tests that routes through a down NIC are not used when looking up a route
- // for a destination.
- t.Run("Find", func(t *testing.T) {
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s, _, _ := setup(t)
-
- // Test routes to odd address.
- testRoute(t, s, unspecifiedNIC, "", "\x05", addr1)
- testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1)
- testRoute(t, s, nicID1, addr1, "\x05", addr1)
-
- // Test routes to even address.
- testRoute(t, s, unspecifiedNIC, "", "\x06", addr2)
- testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2)
- testRoute(t, s, nicID2, addr2, "\x06", addr2)
-
- // Bringing NIC1 down should result in no routes to odd addresses. Routes to
- // even addresses should continue to be available as NIC2 is still up.
- if err := test.downFn(s, nicID1); err != nil {
- t.Fatalf("test.downFn(_, %d): %s", nicID1, err)
- }
- testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
- testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
- testNoRoute(t, s, nicID1, addr1, nic1Dst)
- testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2)
- testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2)
- testRoute(t, s, nicID2, addr2, nic2Dst, addr2)
-
- // Bringing NIC2 down should result in no routes to even addresses. No
- // route should be available to any address as routes to odd addresses
- // were made unavailable by bringing NIC1 down above.
- if err := test.downFn(s, nicID2); err != nil {
- t.Fatalf("test.downFn(_, %d): %s", nicID2, err)
- }
- testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
- testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
- testNoRoute(t, s, nicID1, addr1, nic1Dst)
- testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
- testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
- testNoRoute(t, s, nicID2, addr2, nic2Dst)
-
- if upFn := test.upFn; upFn != nil {
- // Bringing NIC1 up should make routes to odd addresses available
- // again. Routes to even addresses should continue to be unavailable
- // as NIC2 is still down.
- if err := upFn(s, nicID1); err != nil {
- t.Fatalf("test.upFn(_, %d): %s", nicID1, err)
- }
- testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1)
- testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1)
- testRoute(t, s, nicID1, addr1, nic1Dst, addr1)
- testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
- testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
- testNoRoute(t, s, nicID2, addr2, nic2Dst)
- }
- })
- }
- })
-
- // Tests that writing a packet using a Route through a down NIC fails.
- t.Run("WritePacket", func(t *testing.T) {
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s, ep1, ep2 := setup(t)
-
- r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err)
- }
- defer r1.Release()
-
- r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err)
- }
- defer r2.Release()
-
- // If we failed to get routes r1 or r2, we cannot proceed with the test.
- if t.Failed() {
- t.FailNow()
- }
-
- buf := buffer.View([]byte{1})
- testSend(t, r1, ep1, buf)
- testSend(t, r2, ep2, buf)
-
- // Writes with Routes that use NIC1 after being brought down should fail.
- if err := test.downFn(s, nicID1); err != nil {
- t.Fatalf("test.downFn(_, %d): %s", nicID1, err)
- }
- testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{})
- testSend(t, r2, ep2, buf)
-
- // Writes with Routes that use NIC2 after being brought down should fail.
- if err := test.downFn(s, nicID2); err != nil {
- t.Fatalf("test.downFn(_, %d): %s", nicID2, err)
- }
- testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{})
- testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{})
-
- if upFn := test.upFn; upFn != nil {
- // Writes with Routes that use NIC1 after being brought up should
- // succeed.
- //
- // TODO(gvisor.dev/issue/1491): Should we instead completely
- // invalidate all Routes that were bound to a NIC that was brought
- // down at some point?
- if err := upFn(s, nicID1); err != nil {
- t.Fatalf("test.upFn(_, %d): %s", nicID1, err)
- }
- testSend(t, r1, ep1, buf)
- testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{})
- }
- })
- }
- })
-}
-
-func TestRoutes(t *testing.T) {
- // Create a stack with the fake network protocol, two nics, and two
- // addresses per nic, the first nic has odd address, the second one has
- // even addresses.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
-
- ep1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep1); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- protocolAddr1 := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: "\x01",
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
- }
-
- protocolAddr3 := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: "\x03",
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(1, protocolAddr3, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr3, err)
- }
-
- ep2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, ep2); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- protocolAddr2 := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: "\x02",
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(2, protocolAddr2, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err)
- }
-
- protocolAddr4 := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: "\x04",
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(2, protocolAddr4, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr4, err)
- }
-
- // Set a route table that sends all packets with odd destination
- // addresses through the first NIC, and all even destination address
- // through the second one.
- {
- subnet0, err := tcpip.NewSubnet("\x00", "\x01")
- if err != nil {
- t.Fatal(err)
- }
- subnet1, err := tcpip.NewSubnet("\x01", "\x01")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{
- {Destination: subnet1, Gateway: "\x00", NIC: 1},
- {Destination: subnet0, Gateway: "\x00", NIC: 2},
- })
- }
-
- // Test routes to odd address.
- testRoute(t, s, 0, "", "\x05", "\x01")
- testRoute(t, s, 0, "\x01", "\x05", "\x01")
- testRoute(t, s, 1, "\x01", "\x05", "\x01")
- testRoute(t, s, 0, "\x03", "\x05", "\x03")
- testRoute(t, s, 1, "\x03", "\x05", "\x03")
-
- // Test routes to even address.
- testRoute(t, s, 0, "", "\x06", "\x02")
- testRoute(t, s, 0, "\x02", "\x06", "\x02")
- testRoute(t, s, 2, "\x02", "\x06", "\x02")
- testRoute(t, s, 0, "\x04", "\x06", "\x04")
- testRoute(t, s, 2, "\x04", "\x06", "\x04")
-
- // Try to send to odd numbered address from even numbered ones, then
- // vice-versa.
- testNoRoute(t, s, 0, "\x02", "\x05")
- testNoRoute(t, s, 2, "\x02", "\x05")
- testNoRoute(t, s, 0, "\x04", "\x05")
- testNoRoute(t, s, 2, "\x04", "\x05")
-
- testNoRoute(t, s, 0, "\x01", "\x06")
- testNoRoute(t, s, 1, "\x01", "\x06")
- testNoRoute(t, s, 0, "\x03", "\x06")
- testNoRoute(t, s, 1, "\x03", "\x06")
-}
-
-func TestAddressRemoval(t *testing.T) {
- const localAddrByte byte = 0x01
- localAddr := tcpip.Address([]byte{localAddrByte})
- remoteAddr := tcpip.Address("\x02")
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
-
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: localAddr,
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
- }
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
- buf := buffer.NewView(30)
-
- // Send and receive packets, and verify they are received.
- buf[dstAddrOffset] = localAddrByte
- testRecv(t, fakeNet, localAddrByte, ep, buf)
- testSendTo(t, s, remoteAddr, ep, nil)
-
- // Remove the address, then check that send/receive doesn't work anymore.
- if err := s.RemoveAddress(1, localAddr); err != nil {
- t.Fatal("RemoveAddress failed:", err)
- }
- testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
- testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{})
-
- // Check that removing the same address fails.
- err := s.RemoveAddress(1, localAddr)
- if _, ok := err.(*tcpip.ErrBadLocalAddress); !ok {
- t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, &tcpip.ErrBadLocalAddress{})
- }
-}
-
-func TestAddressRemovalWithRouteHeld(t *testing.T) {
- const localAddrByte byte = 0x01
- localAddr := tcpip.Address([]byte{localAddrByte})
- remoteAddr := tcpip.Address("\x02")
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
-
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
- buf := buffer.NewView(30)
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: localAddr,
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
- }
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatal("FindRoute failed:", err)
- }
-
- // Send and receive packets, and verify they are received.
- buf[dstAddrOffset] = localAddrByte
- testRecv(t, fakeNet, localAddrByte, ep, buf)
- testSend(t, r, ep, nil)
- testSendTo(t, s, remoteAddr, ep, nil)
-
- // Remove the address, then check that send/receive doesn't work anymore.
- if err := s.RemoveAddress(1, localAddr); err != nil {
- t.Fatal("RemoveAddress failed:", err)
- }
- testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
- testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{})
- testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{})
-
- // Check that removing the same address fails.
- {
- err := s.RemoveAddress(1, localAddr)
- if _, ok := err.(*tcpip.ErrBadLocalAddress); !ok {
- t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, &tcpip.ErrBadLocalAddress{})
- }
- }
-}
-
-func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.Address) {
- t.Helper()
- info, ok := s.NICInfo()[nicID]
- if !ok {
- t.Fatalf("NICInfo() failed to find nicID=%d", nicID)
- }
- if len(addr) == 0 {
- // No address given, verify that there is no address assigned to the NIC.
- for _, a := range info.ProtocolAddresses {
- if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) {
- t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, tcpip.AddressWithPrefix{})
- }
- }
- return
- }
- // Address given, verify the address is assigned to the NIC and no other
- // address is.
- found := false
- for _, a := range info.ProtocolAddresses {
- if a.Protocol == fakeNetNumber {
- if a.AddressWithPrefix.Address == addr {
- found = true
- } else {
- t.Errorf("verify address: got = %s, want = %s", a.AddressWithPrefix.Address, addr)
- }
- }
- }
- if !found {
- t.Errorf("verify address: couldn't find %s on the NIC", addr)
- }
-}
-
-func TestEndpointExpiration(t *testing.T) {
- const (
- localAddrByte byte = 0x01
- remoteAddr tcpip.Address = "\x03"
- noAddr tcpip.Address = ""
- nicID tcpip.NICID = 1
- )
- localAddr := tcpip.Address([]byte{localAddrByte})
-
- for _, promiscuous := range []bool{true, false} {
- for _, spoofing := range []bool{true, false} {
- t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
-
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = localAddrByte
-
- if promiscuous {
- if err := s.SetPromiscuousMode(nicID, true); err != nil {
- t.Fatal("SetPromiscuousMode failed:", err)
- }
- }
-
- if spoofing {
- if err := s.SetSpoofing(nicID, true); err != nil {
- t.Fatal("SetSpoofing failed:", err)
- }
- }
-
- // 1. No Address yet, send should only work for spoofing, receive for
- // promiscuous mode.
- //-----------------------
- verifyAddress(t, s, nicID, noAddr)
- if promiscuous {
- testRecv(t, fakeNet, localAddrByte, ep, buf)
- } else {
- testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
- }
- if spoofing {
- // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
- // testSendTo(t, s, remoteAddr, ep, nil)
- } else {
- testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{})
- }
-
- // 2. Add Address, everything should work.
- //-----------------------
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: localAddr,
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- verifyAddress(t, s, nicID, localAddr)
- testRecv(t, fakeNet, localAddrByte, ep, buf)
- testSendTo(t, s, remoteAddr, ep, nil)
-
- // 3. Remove the address, send should only work for spoofing, receive
- // for promiscuous mode.
- //-----------------------
- if err := s.RemoveAddress(nicID, localAddr); err != nil {
- t.Fatal("RemoveAddress failed:", err)
- }
- verifyAddress(t, s, nicID, noAddr)
- if promiscuous {
- testRecv(t, fakeNet, localAddrByte, ep, buf)
- } else {
- testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
- }
- if spoofing {
- // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
- // testSendTo(t, s, remoteAddr, ep, nil)
- } else {
- testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{})
- }
-
- // 4. Add Address back, everything should work again.
- //-----------------------
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- verifyAddress(t, s, nicID, localAddr)
- testRecv(t, fakeNet, localAddrByte, ep, buf)
- testSendTo(t, s, remoteAddr, ep, nil)
-
- // 5. Take a reference to the endpoint by getting a route. Verify that
- // we can still send/receive, including sending using the route.
- //-----------------------
- r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatal("FindRoute failed:", err)
- }
- testRecv(t, fakeNet, localAddrByte, ep, buf)
- testSendTo(t, s, remoteAddr, ep, nil)
- testSend(t, r, ep, nil)
-
- // 6. Remove the address. Send should only work for spoofing, receive
- // for promiscuous mode.
- //-----------------------
- if err := s.RemoveAddress(nicID, localAddr); err != nil {
- t.Fatal("RemoveAddress failed:", err)
- }
- verifyAddress(t, s, nicID, noAddr)
- if promiscuous {
- testRecv(t, fakeNet, localAddrByte, ep, buf)
- } else {
- testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
- }
- if spoofing {
- testSend(t, r, ep, nil)
- testSendTo(t, s, remoteAddr, ep, nil)
- } else {
- testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{})
- testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{})
- }
-
- // 7. Add Address back, everything should work again.
- //-----------------------
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- verifyAddress(t, s, nicID, localAddr)
- testRecv(t, fakeNet, localAddrByte, ep, buf)
- testSendTo(t, s, remoteAddr, ep, nil)
- testSend(t, r, ep, nil)
-
- // 8. Remove the route, sendTo/recv should still work.
- //-----------------------
- r.Release()
- verifyAddress(t, s, nicID, localAddr)
- testRecv(t, fakeNet, localAddrByte, ep, buf)
- testSendTo(t, s, remoteAddr, ep, nil)
-
- // 9. Remove the address. Send should only work for spoofing, receive
- // for promiscuous mode.
- //-----------------------
- if err := s.RemoveAddress(nicID, localAddr); err != nil {
- t.Fatal("RemoveAddress failed:", err)
- }
- verifyAddress(t, s, nicID, noAddr)
- if promiscuous {
- testRecv(t, fakeNet, localAddrByte, ep, buf)
- } else {
- testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
- }
- if spoofing {
- // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
- // testSendTo(t, s, remoteAddr, ep, nil)
- } else {
- testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{})
- }
- })
- }
- }
-}
-
-func TestPromiscuousMode(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
-
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
- buf := buffer.NewView(30)
-
- // Write a packet, and check that it doesn't get delivered as we don't
- // have a matching endpoint.
- const localAddrByte byte = 0x01
- buf[dstAddrOffset] = localAddrByte
- testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
-
- // Set promiscuous mode, then check that packet is delivered.
- if err := s.SetPromiscuousMode(1, true); err != nil {
- t.Fatal("SetPromiscuousMode failed:", err)
- }
- testRecv(t, fakeNet, localAddrByte, ep, buf)
-
- // Check that we can't get a route as there is no local address.
- _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */)
- if _, ok := err.(*tcpip.ErrNoRoute); !ok {
- t.Fatalf("FindRoute returned unexpected error: got = %v, want = %s", err, &tcpip.ErrNoRoute{})
- }
-
- // Set promiscuous mode to false, then check that packet can't be
- // delivered anymore.
- if err := s.SetPromiscuousMode(1, false); err != nil {
- t.Fatal("SetPromiscuousMode failed:", err)
- }
- testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
-}
-
-// TestExternalSendWithHandleLocal tests that the stack creates a non-local
-// route when spoofing or promiscuous mode are enabled.
-//
-// This test makes sure that packets are transmitted from the stack.
-func TestExternalSendWithHandleLocal(t *testing.T) {
- const (
- unspecifiedNICID = 0
- nicID = 1
-
- localAddr = tcpip.Address("\x01")
- dstAddr = tcpip.Address("\x03")
- )
-
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
-
- tests := []struct {
- name string
- configureStack func(*testing.T, *stack.Stack)
- }{
- {
- name: "Default",
- configureStack: func(*testing.T, *stack.Stack) {},
- },
- {
- name: "Spoofing",
- configureStack: func(t *testing.T, s *stack.Stack) {
- if err := s.SetSpoofing(nicID, true); err != nil {
- t.Fatalf("s.SetSpoofing(%d, true): %s", nicID, err)
- }
- },
- },
- {
- name: "Promiscuous",
- configureStack: func(t *testing.T, s *stack.Stack) {
- if err := s.SetPromiscuousMode(nicID, true); err != nil {
- t.Fatalf("s.SetPromiscuousMode(%d, true): %s", nicID, err)
- }
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- for _, handleLocal := range []bool{true, false} {
- t.Run(fmt.Sprintf("HandleLocal=%t", handleLocal), func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- HandleLocal: handleLocal,
- })
-
- ep := channel.New(1, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: localAddr,
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
-
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: nicID}})
-
- test.configureStack(t, s)
-
- r, err := s.FindRoute(unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, err)
- }
- defer r.Release()
-
- if r.LocalAddress() != localAddr {
- t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), localAddr)
- }
- if r.RemoteAddress() != dstAddr {
- t.Errorf("got r.RemoteAddress() = %s, want = %s", r.RemoteAddress(), dstAddr)
- }
-
- if n := ep.Drain(); n != 0 {
- t.Fatalf("got ep.Drain() = %d, want = 0", n)
- }
- if err := r.WritePacket(stack.NetworkHeaderParams{
- Protocol: fakeTransNumber,
- TTL: 123,
- TOS: stack.DefaultTOS,
- }, stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: buffer.NewView(10).ToVectorisedView(),
- })); err != nil {
- t.Fatalf("r.WritePacket(nil, _, _): %s", err)
- }
- if n := ep.Drain(); n != 1 {
- t.Fatalf("got ep.Drain() = %d, want = 1", n)
- }
- })
- }
- })
- }
-}
-
-func TestSpoofingWithAddress(t *testing.T) {
- localAddr := tcpip.Address("\x01")
- nonExistentLocalAddr := tcpip.Address("\x02")
- dstAddr := tcpip.Address("\x03")
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
-
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: localAddr,
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- // With address spoofing disabled, FindRoute does not permit an address
- // that was not added to the NIC to be used as the source.
- r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
- if err == nil {
- t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
- }
-
- // With address spoofing enabled, FindRoute permits any address to be used
- // as the source.
- if err := s.SetSpoofing(1, true); err != nil {
- t.Fatal("SetSpoofing failed:", err)
- }
- r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatal("FindRoute failed:", err)
- }
- if r.LocalAddress() != nonExistentLocalAddr {
- t.Errorf("got Route.LocalAddress() = %s, want = %s", r.LocalAddress(), nonExistentLocalAddr)
- }
- if r.RemoteAddress() != dstAddr {
- t.Errorf("got Route.RemoteAddress() = %s, want = %s", r.RemoteAddress(), dstAddr)
- }
- // Sending a packet works.
- testSendTo(t, s, dstAddr, ep, nil)
- testSend(t, r, ep, nil)
-
- // FindRoute should also work with a local address that exists on the NIC.
- r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatal("FindRoute failed:", err)
- }
- if r.LocalAddress() != localAddr {
- t.Errorf("got Route.LocalAddress() = %s, want = %s", r.LocalAddress(), nonExistentLocalAddr)
- }
- if r.RemoteAddress() != dstAddr {
- t.Errorf("got Route.RemoteAddress() = %s, want = %s", r.RemoteAddress(), dstAddr)
- }
- // Sending a packet using the route works.
- testSend(t, r, ep, nil)
-}
-
-func TestSpoofingNoAddress(t *testing.T) {
- nonExistentLocalAddr := tcpip.Address("\x01")
- dstAddr := tcpip.Address("\x02")
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
-
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- // With address spoofing disabled, FindRoute does not permit an address
- // that was not added to the NIC to be used as the source.
- r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
- if err == nil {
- t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
- }
- // Sending a packet fails.
- testFailingSendTo(t, s, dstAddr, nil, &tcpip.ErrNoRoute{})
-
- // With address spoofing enabled, FindRoute permits any address to be used
- // as the source.
- if err := s.SetSpoofing(1, true); err != nil {
- t.Fatal("SetSpoofing failed:", err)
- }
- r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatal("FindRoute failed:", err)
- }
- if r.LocalAddress() != nonExistentLocalAddr {
- t.Errorf("got Route.LocalAddress() = %s, want = %s", r.LocalAddress(), nonExistentLocalAddr)
- }
- if r.RemoteAddress() != dstAddr {
- t.Errorf("got Route.RemoteAddress() = %s, want = %s", r.RemoteAddress(), dstAddr)
- }
- // Sending a packet works.
- // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
- // testSendTo(t, s, remoteAddr, ep, nil)
-}
-
-func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
-
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
- s.SetRouteTable([]tcpip.Route{})
-
- // If there is no endpoint, it won't work.
- {
- _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
- if _, ok := err.(*tcpip.ErrNetworkUnreachable); !ok {
- t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, &tcpip.ErrNetworkUnreachable{})
- }
- }
-
- protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{Address: header.IPv4Any}}
- if err := s.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(1, %+v, {}) failed: %s", protoAddr, err)
- }
- r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
- }
- if r.LocalAddress() != header.IPv4Any {
- t.Errorf("got Route.LocalAddress() = %s, want = %s", r.LocalAddress(), header.IPv4Any)
- }
-
- if r.RemoteAddress() != header.IPv4Broadcast {
- t.Errorf("got Route.RemoteAddress() = %s, want = %s", r.RemoteAddress(), header.IPv4Broadcast)
- }
-
- // If the NIC doesn't exist, it won't work.
- {
- _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
- if _, ok := err.(*tcpip.ErrNetworkUnreachable); !ok {
- t.Fatalf("got FindRoute(2, %v, %v, %d) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, &tcpip.ErrNetworkUnreachable{})
- }
- }
-}
-
-func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
- defaultAddr := tcpip.AddressWithPrefix{Address: header.IPv4Any}
- // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1.
- nic1Addr := tcpip.AddressWithPrefix{Address: "\xc0\xa8\x01\x3a", PrefixLen: 24}
- nic1Gateway := testutil.MustParse4("192.168.1.1")
- // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1.
- nic2Addr := tcpip.AddressWithPrefix{Address: "\x0a\x0a\x0a\x05", PrefixLen: 24}
- nic2Gateway := testutil.MustParse4("10.10.10.1")
-
- // Create a new stack with two NICs.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatalf("CreateNIC failed: %s", err)
- }
- if err := s.CreateNIC(2, ep); err != nil {
- t.Fatalf("CreateNIC failed: %s", err)
- }
- nic1ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic1Addr}
- if err := s.AddProtocolAddress(1, nic1ProtoAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(1, %+v, {}) failed: %s", nic1ProtoAddr, err)
- }
-
- nic2ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic2Addr}
- if err := s.AddProtocolAddress(2, nic2ProtoAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(2, %+v, {}) failed: %s", nic2ProtoAddr, err)
- }
-
- // Set the initial route table.
- rt := []tcpip.Route{
- {Destination: nic1Addr.Subnet(), NIC: 1},
- {Destination: nic2Addr.Subnet(), NIC: 2},
- {Destination: defaultAddr.Subnet(), Gateway: nic2Gateway, NIC: 2},
- {Destination: defaultAddr.Subnet(), Gateway: nic1Gateway, NIC: 1},
- }
- s.SetRouteTable(rt)
-
- // When an interface is given, the route for a broadcast goes through it.
- r, err := s.FindRoute(1, nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
- }
- if r.LocalAddress() != nic1Addr.Address {
- t.Errorf("got Route.LocalAddress() = %s, want = %s", r.LocalAddress(), nic1Addr.Address)
- }
-
- if r.RemoteAddress() != header.IPv4Broadcast {
- t.Errorf("got Route.RemoteAddress() = %s, want = %s", r.RemoteAddress(), header.IPv4Broadcast)
- }
-
- // When an interface is not given, it consults the route table.
- // 1. Case: Using the default route.
- r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
- }
- if r.LocalAddress() != nic2Addr.Address {
- t.Errorf("got Route.LocalAddress() = %s, want = %s", r.LocalAddress(), nic2Addr.Address)
- }
-
- if r.RemoteAddress() != header.IPv4Broadcast {
- t.Errorf("got Route.RemoteAddress() = %s, want = %s", r.RemoteAddress(), header.IPv4Broadcast)
- }
-
- // 2. Case: Having an explicit route for broadcast will select that one.
- rt = append(
- []tcpip.Route{
- {Destination: header.IPv4Broadcast.WithPrefix().Subnet(), NIC: 1},
- },
- rt...,
- )
- s.SetRouteTable(rt)
- r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
- }
- if r.LocalAddress() != nic1Addr.Address {
- t.Errorf("got Route.LocalAddress() = %s, want = %s", r.LocalAddress(), nic1Addr.Address)
- }
-
- if r.RemoteAddress() != header.IPv4Broadcast {
- t.Errorf("got Route.RemoteAddress() = %s, want = %s", r.RemoteAddress(), header.IPv4Broadcast)
- }
-}
-
-func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
- for _, tc := range []struct {
- name string
- routeNeeded bool
- address tcpip.Address
- }{
- // IPv4 multicast address range: 224.0.0.0 - 239.255.255.255
- // <=> 0xe0.0x00.0x00.0x00 - 0xef.0xff.0xff.0xff
- {"IPv4 Multicast 1", false, "\xe0\x00\x00\x00"},
- {"IPv4 Multicast 2", false, "\xef\xff\xff\xff"},
- {"IPv4 Unicast 1", true, "\xdf\xff\xff\xff"},
- {"IPv4 Unicast 2", true, "\xf0\x00\x00\x00"},
- {"IPv4 Unicast 3", true, "\x00\x00\x00\x00"},
-
- // IPv6 multicast address is 0xff[8] + flags[4] + scope[4] + groupId[112]
- {"IPv6 Multicast 1", false, "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Multicast 2", false, "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Multicast 3", false, "\xff\x0f\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"},
-
- // IPv6 link-local address starts with fe80::/10.
- {"IPv6 Unicast Link-Local 1", false, "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Unicast Link-Local 2", false, "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"},
- {"IPv6 Unicast Link-Local 3", false, "\xfe\x80\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff"},
- {"IPv6 Unicast Link-Local 4", false, "\xfe\xbf\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Unicast Link-Local 5", false, "\xfe\xbf\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"},
-
- // IPv6 addresses that are neither multicast nor link-local.
- {"IPv6 Unicast Not Link-Local 1", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Unicast Not Link-Local 2", true, "\xf0\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"},
- {"IPv6 Unicast Not Link-local 3", true, "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Unicast Not Link-Local 4", true, "\xfe\xc0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Unicast Not Link-Local 5", true, "\xfe\xdf\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Unicast Not Link-Local 6", true, "\xfd\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Unicast Not Link-Local 7", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- } {
- t.Run(tc.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
-
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- s.SetRouteTable([]tcpip.Route{})
-
- var anyAddr tcpip.Address
- if len(tc.address) == header.IPv4AddressSize {
- anyAddr = header.IPv4Any
- } else {
- anyAddr = header.IPv6Any
- }
-
- var want tcpip.Error = &tcpip.ErrNetworkUnreachable{}
- if tc.routeNeeded {
- want = &tcpip.ErrNoRoute{}
- }
-
- // If there is no endpoint, it won't work.
- if _, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); err != want {
- t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, want)
- }
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: anyAddr,
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
- }
-
- if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded {
- // Route table is empty but we need a route, this should cause an error.
- if _, ok := err.(*tcpip.ErrNoRoute); !ok {
- t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, &tcpip.ErrNoRoute{})
- }
- } else {
- if err != nil {
- t.Fatalf("FindRoute(1, %v, %v, %v) failed: %v", anyAddr, tc.address, fakeNetNumber, err)
- }
- if r.LocalAddress() != anyAddr {
- t.Errorf("Bad local address: got %v, want = %v", r.LocalAddress(), anyAddr)
- }
- if r.RemoteAddress() != tc.address {
- t.Errorf("Bad remote address: got %v, want = %v", r.RemoteAddress(), tc.address)
- }
- }
- // If the NIC doesn't exist, it won't work.
- if _, err := s.FindRoute(2, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); err != want {
- t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", anyAddr, tc.address, fakeNetNumber, err, want)
- }
- })
- }
-}
-
-func TestNetworkOption(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- TransportProtocols: []stack.TransportProtocolFactory{},
- })
-
- opt := tcpip.DefaultTTLOption(5)
- if err := s.SetNetworkProtocolOption(fakeNetNumber, &opt); err != nil {
- t.Fatalf("s.SetNetworkProtocolOption(%d, &%T(%d)): %s", fakeNetNumber, opt, opt, err)
- }
-
- var optGot tcpip.DefaultTTLOption
- if err := s.NetworkProtocolOption(fakeNetNumber, &optGot); err != nil {
- t.Fatalf("s.NetworkProtocolOption(%d, &%T): %s", fakeNetNumber, optGot, err)
- }
-
- if opt != optGot {
- t.Errorf("got optGot = %d, want = %d", optGot, opt)
- }
-}
-
-func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
- const nicID = 1
-
- for _, addrLen := range []int{4, 16} {
- t.Run(fmt.Sprintf("addrLen=%d", addrLen), func(t *testing.T) {
- for canBe := 0; canBe < 3; canBe++ {
- t.Run(fmt.Sprintf("canBe=%d", canBe), func(t *testing.T) {
- for never := 0; never < 3; never++ {
- t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
- // Insert <canBe> primary and <never> never-primary addresses.
- // Each one will add a network endpoint to the NIC.
- primaryAddrAdded := make(map[tcpip.AddressWithPrefix]struct{})
- for i := 0; i < canBe+never; i++ {
- var behavior stack.PrimaryEndpointBehavior
- if i < canBe {
- behavior = stack.CanBePrimaryEndpoint
- } else {
- behavior = stack.NeverPrimaryEndpoint
- }
- // Add an address and in case of a primary one include a
- // prefixLen.
- address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen))
- properties := stack.AddressProperties{PEB: behavior}
- if behavior == stack.CanBePrimaryEndpoint {
- protocolAddress := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: address.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddress, properties); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddress, properties, err)
- }
- // Remember the address/prefix.
- primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{}
- } else {
- protocolAddress := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: address,
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(nicID, protocolAddress, properties); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddress, properties, err)
- }
- }
- }
- // Check that GetMainNICAddress returns an address if at least
- // one primary address was added. In that case make sure the
- // address/prefixLen matches what we added.
- gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber)
- if err != nil {
- t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
- }
- if len(primaryAddrAdded) == 0 {
- // No primary addresses present.
- if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr {
- t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, wantAddr)
- }
- } else {
- // At least one primary address was added, verify the returned
- // address is in the list of primary addresses we added.
- if _, ok := primaryAddrAdded[gotAddr]; !ok {
- t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, primaryAddrAdded)
- }
- }
- })
- }
- })
- }
- })
- }
-}
-
-func TestGetMainNICAddressErrors(t *testing.T) {
- const nicID = 1
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol},
- })
- if err := s.CreateNIC(nicID, loopback.New()); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
-
- // Sanity check with a successful call.
- if addr, err := s.GetMainNICAddress(nicID, ipv4.ProtocolNumber); err != nil {
- t.Errorf("s.GetMainNICAddress(%d, %d): %s", nicID, ipv4.ProtocolNumber, err)
- } else if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, ipv4.ProtocolNumber, addr, want)
- }
-
- const unknownNICID = nicID + 1
- switch addr, err := s.GetMainNICAddress(unknownNICID, ipv4.ProtocolNumber); err.(type) {
- case *tcpip.ErrUnknownNICID:
- default:
- t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, %T), want = (_, tcpip.ErrUnknownNICID)", unknownNICID, ipv4.ProtocolNumber, addr, err)
- }
-
- // ARP is not an addressable network endpoint.
- switch addr, err := s.GetMainNICAddress(nicID, arp.ProtocolNumber); err.(type) {
- case *tcpip.ErrNotSupported:
- default:
- t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, %T), want = (_, tcpip.ErrNotSupported)", nicID, arp.ProtocolNumber, addr, err)
- }
-
- const unknownProtocolNumber = 1234
- switch addr, err := s.GetMainNICAddress(nicID, unknownProtocolNumber); err.(type) {
- case *tcpip.ErrUnknownProtocol:
- default:
- t.Errorf("got s.GetMainNICAddress(%d, %d) = (%s, %T), want = (_, tcpip.ErrUnknownProtocol)", nicID, unknownProtocolNumber, addr, err)
- }
-}
-
-func TestGetMainNICAddressAddRemove(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- for _, tc := range []struct {
- name string
- address tcpip.Address
- prefixLen int
- }{
- {"IPv4", "\x01\x01\x01\x01", 24},
- {"IPv6", "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", 116},
- } {
- t.Run(tc.name, func(t *testing.T) {
- protocolAddress := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tc.address,
- PrefixLen: tc.prefixLen,
- },
- }
- if err := s.AddProtocolAddress(1, protocolAddress, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", protocolAddress, err)
- }
-
- // Check that we get the right initial address and prefix length.
- if err := checkGetMainNICAddress(s, 1, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil {
- t.Fatal(err)
- }
-
- if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil {
- t.Fatal("RemoveAddress failed:", err)
- }
-
- // Check that we get no address after removal.
- if err := checkGetMainNICAddress(s, 1, fakeNetNumber, tcpip.AddressWithPrefix{}); err != nil {
- t.Fatal(err)
- }
- })
- }
-}
-
-// Simple network address generator. Good for 255 addresses.
-type addressGenerator struct{ cnt byte }
-
-func (g *addressGenerator) next(addrLen int) tcpip.Address {
- g.cnt++
- return tcpip.Address(bytes.Repeat([]byte{g.cnt}, addrLen))
-}
-
-func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.ProtocolAddress) {
- t.Helper()
-
- if len(gotAddresses) != len(expectedAddresses) {
- t.Fatalf("got len(addresses) = %d, want = %d", len(gotAddresses), len(expectedAddresses))
- }
-
- sort.Slice(gotAddresses, func(i, j int) bool {
- return gotAddresses[i].AddressWithPrefix.Address < gotAddresses[j].AddressWithPrefix.Address
- })
- sort.Slice(expectedAddresses, func(i, j int) bool {
- return expectedAddresses[i].AddressWithPrefix.Address < expectedAddresses[j].AddressWithPrefix.Address
- })
-
- for i, gotAddr := range gotAddresses {
- expectedAddr := expectedAddresses[i]
- if gotAddr != expectedAddr {
- t.Errorf("got address = %+v, wanted = %+v", gotAddr, expectedAddr)
- }
- }
-}
-
-func TestAddProtocolAddress(t *testing.T) {
- const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- addrLenRange := []int{4, 16}
- behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint}
- configTypeRange := []stack.AddressConfigType{stack.AddressConfigStatic, stack.AddressConfigSlaac, stack.AddressConfigSlaacTemp}
- deprecatedRange := []bool{false, true}
- wantAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange)*len(configTypeRange)*len(deprecatedRange))
- var addrGen addressGenerator
- for _, addrLen := range addrLenRange {
- for _, behavior := range behaviorRange {
- for _, configType := range configTypeRange {
- for _, deprecated := range deprecatedRange {
- address := addrGen.next(addrLen)
- properties := stack.AddressProperties{
- PEB: behavior,
- ConfigType: configType,
- Deprecated: deprecated,
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: address,
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, properties); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, %+v) failed: %s", nicID, protocolAddr, properties, err)
- }
- wantAddresses = append(wantAddresses, tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen},
- })
- }
- }
- }
- }
-
- gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, wantAddresses, gotAddresses)
-}
-
-func TestCreateNICWithOptions(t *testing.T) {
- type callArgsAndExpect struct {
- nicID tcpip.NICID
- opts stack.NICOptions
- err tcpip.Error
- }
-
- tests := []struct {
- desc string
- calls []callArgsAndExpect
- }{
- {
- desc: "DuplicateNICID",
- calls: []callArgsAndExpect{
- {
- nicID: tcpip.NICID(1),
- opts: stack.NICOptions{Name: "eth1"},
- err: nil,
- },
- {
- nicID: tcpip.NICID(1),
- opts: stack.NICOptions{Name: "eth2"},
- err: &tcpip.ErrDuplicateNICID{},
- },
- },
- },
- {
- desc: "DuplicateName",
- calls: []callArgsAndExpect{
- {
- nicID: tcpip.NICID(1),
- opts: stack.NICOptions{Name: "lo"},
- err: nil,
- },
- {
- nicID: tcpip.NICID(2),
- opts: stack.NICOptions{Name: "lo"},
- err: &tcpip.ErrDuplicateNICID{},
- },
- },
- },
- {
- desc: "Unnamed",
- calls: []callArgsAndExpect{
- {
- nicID: tcpip.NICID(1),
- opts: stack.NICOptions{},
- err: nil,
- },
- {
- nicID: tcpip.NICID(2),
- opts: stack.NICOptions{},
- err: nil,
- },
- },
- },
- {
- desc: "UnnamedDuplicateNICID",
- calls: []callArgsAndExpect{
- {
- nicID: tcpip.NICID(1),
- opts: stack.NICOptions{},
- err: nil,
- },
- {
- nicID: tcpip.NICID(1),
- opts: stack.NICOptions{},
- err: &tcpip.ErrDuplicateNICID{},
- },
- },
- },
- }
- for _, test := range tests {
- t.Run(test.desc, func(t *testing.T) {
- s := stack.New(stack.Options{})
- ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00")
- for _, call := range test.calls {
- if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want {
- t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want)
- }
- }
- })
- }
-}
-
-func TestNICStats(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
-
- nics := []struct {
- addr tcpip.Address
- txByteCount int
- rxByteCount int
- }{
- {
- addr: "\x01",
- txByteCount: 30,
- rxByteCount: 10,
- },
- {
- addr: "\x02",
- txByteCount: 50,
- rxByteCount: 20,
- },
- }
-
- var txBytesTotal, rxBytesTotal, txPacketsTotal, rxPacketsTotal int
- for i, nic := range nics {
- nicid := tcpip.NICID(i)
- ep := channel.New(1, defaultMTU, "")
- if err := s.CreateNIC(nicid, ep); err != nil {
- t.Fatal("CreateNIC failed: ", err)
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: nic.addr,
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(nicid, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicid, protocolAddr, err)
- }
-
- {
- subnet, err := tcpip.NewSubnet(nic.addr, "\xff")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicid}})
- }
-
- nicStats := s.NICInfo()[nicid].Stats
-
- // Inbound packet.
- rxBuffer := buffer.NewView(nic.rxByteCount)
- ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: rxBuffer.ToVectorisedView(),
- }))
- if got, want := nicStats.Rx.Packets.Value(), uint64(1); got != want {
- t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want)
- }
- if got, want := nicStats.Rx.Bytes.Value(), uint64(nic.rxByteCount); got != want {
- t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want)
- }
- rxPacketsTotal++
- rxBytesTotal += nic.rxByteCount
-
- // Outbound packet.
- txBuffer := buffer.NewView(nic.txByteCount)
- actualTxLength := nic.txByteCount + fakeNetHeaderLen
- if err := sendTo(s, nic.addr, txBuffer); err != nil {
- t.Fatal("sendTo failed: ", err)
- }
- want := ep.Drain()
- if got := nicStats.Tx.Packets.Value(); got != uint64(want) {
- t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want)
- }
- if got, want := nicStats.Tx.Bytes.Value(), uint64(actualTxLength); got != want {
- t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
- }
- txPacketsTotal += want
- txBytesTotal += actualTxLength
- }
-
- // Now verify that each NIC stats was correctly aggregated at the stack level.
- if got, want := s.Stats().NICs.Rx.Packets.Value(), uint64(rxPacketsTotal); got != want {
- t.Errorf("got s.Stats().NIC.Rx.Packets.Value() = %d, want = %d", got, want)
- }
- if got, want := s.Stats().NICs.Rx.Bytes.Value(), uint64(rxBytesTotal); got != want {
- t.Errorf("got s.Stats().Rx.Bytes.Value() = %d, want = %d", got, want)
- }
- if got, want := s.Stats().NICs.Tx.Packets.Value(), uint64(txPacketsTotal); got != want {
- t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want)
- }
- if got, want := s.Stats().NICs.Tx.Bytes.Value(), uint64(txBytesTotal); got != want {
- t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
- }
-}
-
-// TestNICContextPreservation tests that you can read out via stack.NICInfo the
-// Context data you pass via NICContext.Context in stack.CreateNICWithOptions.
-func TestNICContextPreservation(t *testing.T) {
- var ctx *int
- tests := []struct {
- name string
- opts stack.NICOptions
- want stack.NICContext
- }{
- {
- "context_set",
- stack.NICOptions{Context: ctx},
- ctx,
- },
- {
- "context_not_set",
- stack.NICOptions{},
- nil,
- },
- }
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{})
- id := tcpip.NICID(1)
- ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00")
- if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil {
- t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err)
- }
- nicinfos := s.NICInfo()
- nicinfo, ok := nicinfos[id]
- if !ok {
- t.Fatalf("got nicinfos[%d] = _, %t, want _, true; nicinfos = %+v", id, ok, nicinfos)
- }
- if got, want := nicinfo.Context == test.want, true; got != want {
- t.Fatalf("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want)
- }
- })
- }
-}
-
-// TestNICAutoGenLinkLocalAddr tests the auto-generation of IPv6 link-local
-// addresses.
-func TestNICAutoGenLinkLocalAddr(t *testing.T) {
- const nicID = 1
-
- var secretKey [header.OpaqueIIDSecretKeyMinBytes]byte
- n, err := rand.Read(secretKey[:])
- if err != nil {
- t.Fatalf("rand.Read(_): %s", err)
- }
- if n != header.OpaqueIIDSecretKeyMinBytes {
- t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", header.OpaqueIIDSecretKeyMinBytes, n)
- }
-
- nicNameFunc := func(_ tcpip.NICID, name string) string {
- return name
- }
-
- tests := []struct {
- name string
- nicName string
- autoGen bool
- linkAddr tcpip.LinkAddress
- iidOpts ipv6.OpaqueInterfaceIdentifierOptions
- shouldGen bool
- expectedAddr tcpip.Address
- }{
- {
- name: "Disabled",
- nicName: "nic1",
- autoGen: false,
- linkAddr: linkAddr1,
- shouldGen: false,
- },
- {
- name: "Disabled without OIID options",
- nicName: "nic1",
- autoGen: false,
- linkAddr: linkAddr1,
- iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: nicNameFunc,
- SecretKey: secretKey[:],
- },
- shouldGen: false,
- },
-
- // Tests for EUI64 based addresses.
- {
- name: "EUI64 Enabled",
- autoGen: true,
- linkAddr: linkAddr1,
- shouldGen: true,
- expectedAddr: header.LinkLocalAddr(linkAddr1),
- },
- {
- name: "EUI64 Empty MAC",
- autoGen: true,
- shouldGen: false,
- },
- {
- name: "EUI64 Invalid MAC",
- autoGen: true,
- linkAddr: "\x01\x02\x03",
- shouldGen: false,
- },
- {
- name: "EUI64 Multicast MAC",
- autoGen: true,
- linkAddr: "\x01\x02\x03\x04\x05\x06",
- shouldGen: false,
- },
- {
- name: "EUI64 Unspecified MAC",
- autoGen: true,
- linkAddr: "\x00\x00\x00\x00\x00\x00",
- shouldGen: false,
- },
-
- // Tests for Opaque IID based addresses.
- {
- name: "OIID Enabled",
- nicName: "nic1",
- autoGen: true,
- linkAddr: linkAddr1,
- iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: nicNameFunc,
- SecretKey: secretKey[:],
- },
- shouldGen: true,
- expectedAddr: header.LinkLocalAddrWithOpaqueIID("nic1", 0, secretKey[:]),
- },
- // These are all cases where we would not have generated a
- // link-local address if opaque IIDs were disabled.
- {
- name: "OIID Empty MAC and empty nicName",
- autoGen: true,
- iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: nicNameFunc,
- SecretKey: secretKey[:1],
- },
- shouldGen: true,
- expectedAddr: header.LinkLocalAddrWithOpaqueIID("", 0, secretKey[:1]),
- },
- {
- name: "OIID Invalid MAC",
- nicName: "test",
- autoGen: true,
- linkAddr: "\x01\x02\x03",
- iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: nicNameFunc,
- SecretKey: secretKey[:2],
- },
- shouldGen: true,
- expectedAddr: header.LinkLocalAddrWithOpaqueIID("test", 0, secretKey[:2]),
- },
- {
- name: "OIID Multicast MAC",
- nicName: "test2",
- autoGen: true,
- linkAddr: "\x01\x02\x03\x04\x05\x06",
- iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: nicNameFunc,
- SecretKey: secretKey[:3],
- },
- shouldGen: true,
- expectedAddr: header.LinkLocalAddrWithOpaqueIID("test2", 0, secretKey[:3]),
- },
- {
- name: "OIID Unspecified MAC and nil SecretKey",
- nicName: "test3",
- autoGen: true,
- linkAddr: "\x00\x00\x00\x00\x00\x00",
- iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: nicNameFunc,
- },
- shouldGen: true,
- expectedAddr: header.LinkLocalAddrWithOpaqueIID("test3", 0, nil),
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenLinkLocal: test.autoGen,
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: test.iidOpts,
- })},
- }
-
- e := channel.New(0, 1280, test.linkAddr)
- s := stack.New(opts)
- nicOpts := stack.NICOptions{Name: test.nicName, Disabled: true}
- if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
- t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err)
- }
-
- // A new disabled NIC should not have any address, even if auto generation
- // was enabled.
- allStackAddrs := s.AllAddresses()
- allNICAddrs, ok := allStackAddrs[nicID]
- if !ok {
- t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
- }
- if l := len(allNICAddrs); l != 0 {
- t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
- }
-
- // Enabling the NIC should attempt auto-generation of a link-local
- // address.
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
- }
-
- var expectedMainAddr tcpip.AddressWithPrefix
- if test.shouldGen {
- expectedMainAddr = tcpip.AddressWithPrefix{
- Address: test.expectedAddr,
- PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen,
- }
-
- // Should have auto-generated an address and resolved immediately (DAD
- // is disabled).
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, expectedMainAddr, newAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- } else {
- // Should not have auto-generated an address.
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly auto-generated an address")
- default:
- }
- }
-
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, expectedMainAddr); err != nil {
- t.Fatal(err)
- }
-
- // Disabling the NIC should remove the auto-generated address.
- if err := s.DisableNIC(nicID); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
- }
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
- t.Fatal(err)
- }
- })
- }
-}
-
-// TestNoLinkLocalAutoGenForLoopbackNIC tests that IPv6 link-local addresses are
-// not auto-generated for loopback NICs.
-func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) {
- const nicID = 1
- const nicName = "nicName"
-
- tests := []struct {
- name string
- opaqueIIDOpts ipv6.OpaqueInterfaceIdentifierOptions
- }{
- {
- name: "IID From MAC",
- opaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{},
- },
- {
- name: "Opaque IID",
- opaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: func(_ tcpip.NICID, nicName string) string {
- return nicName
- },
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenLinkLocal: true,
- OpaqueIIDOpts: test.opaqueIIDOpts,
- })},
- }
-
- e := loopback.New()
- s := stack.New(opts)
- nicOpts := stack.NICOptions{Name: nicName}
- if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
- t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err)
- }
-
- if err := checkGetMainNICAddress(s, 1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
- t.Fatal(err)
- }
- })
- }
-}
-
-// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6
-// link-local addresses will only be assigned after the DAD process resolves.
-func TestNICAutoGenAddrDoesDAD(t *testing.T) {
- const nicID = 1
-
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 1),
- }
- dadConfigs := stack.DefaultDADConfigurations()
- clock := faketime.NewManualClock()
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- AutoGenLinkLocal: true,
- NDPDisp: &ndpDisp,
- DADConfigs: dadConfigs,
- })},
- Clock: clock,
- }
-
- e := channel.New(int(dadConfigs.DupAddrDetectTransmits), 1280, linkAddr1)
- s := stack.New(opts)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- // Address should not be considered bound to the
- // NIC yet (DAD ongoing).
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
- t.Fatal(err)
- }
-
- linkLocalAddr := header.LinkLocalAddr(linkAddr1)
-
- // Wait for DAD to resolve.
- clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer)
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, linkLocalAddr, &stack.DADSucceeded{}); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
- }
- default:
- // We should get a resolution event after 1s (default time to
- // resolve as per default NDP configurations). Waiting for that
- // resolution time + an extra 1s without a resolution event
- // means something is wrong.
- t.Fatal("timed out waiting for DAD resolution")
- }
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); err != nil {
- t.Fatal(err)
- }
-}
-
-// TestNewPEB tests that a new PrimaryEndpointBehavior value (peb) is respected
-// when an address's kind gets "promoted" to permanent from permanentExpired.
-func TestNewPEBOnPromotionToPermanent(t *testing.T) {
- const nicID = 1
-
- pebs := []stack.PrimaryEndpointBehavior{
- stack.NeverPrimaryEndpoint,
- stack.CanBePrimaryEndpoint,
- stack.FirstPrimaryEndpoint,
- }
-
- for _, pi := range pebs {
- for _, ps := range pebs {
- t.Run(fmt.Sprintf("%d-to-%d", pi, ps), func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- ep1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep1); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
-
- // Add a permanent address with initial
- // PrimaryEndpointBehavior (peb), pi. If pi is
- // NeverPrimaryEndpoint, the address should not
- // be returned by a call to GetMainNICAddress;
- // else, it should.
- const address1 = tcpip.Address("\x01")
- properties := stack.AddressProperties{PEB: pi}
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: address1,
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, properties); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr, properties, err)
- }
- addr, err := s.GetMainNICAddress(nicID, fakeNetNumber)
- if err != nil {
- t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
- }
- if pi == stack.NeverPrimaryEndpoint {
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr, want)
-
- }
- } else if addr.Address != address1 {
- t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr.Address, address1)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatalf("NewSubnet failed: %v", err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- // Take a route through the address so its ref
- // count gets incremented and does not actually
- // get deleted when RemoveAddress is called
- // below. This is because we want to test that a
- // new peb is respected when an address gets
- // "promoted" to permanent from a
- // permanentExpired kind.
- const address2 = tcpip.Address("\x02")
- r, err := s.FindRoute(nicID, address1, address2, fakeNetNumber, false)
- if err != nil {
- t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, address1, address2, fakeNetNumber, err)
- }
- defer r.Release()
- if err := s.RemoveAddress(nicID, address1); err != nil {
- t.Fatalf("RemoveAddress(%d, %s): %s", nicID, address1, err)
- }
-
- //
- // At this point, the address should still be
- // known by the NIC, but have its
- // kind = permanentExpired.
- //
-
- // Add some other address with peb set to
- // FirstPrimaryEndpoint.
- const address3 = tcpip.Address("\x03")
- protocolAddr3 := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: address3,
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- properties = stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
- if err := s.AddProtocolAddress(nicID, protocolAddr3, properties); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr3, properties, err)
- }
-
- // Add back the address we removed earlier and
- // make sure the new peb was respected.
- // (The address should just be promoted now).
- protocolAddr1 := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: address1,
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- properties = stack.AddressProperties{PEB: ps}
- if err := s.AddProtocolAddress(nicID, protocolAddr1, properties); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr1, properties, err)
- }
- var primaryAddrs []tcpip.Address
- for _, pa := range s.NICInfo()[nicID].ProtocolAddresses {
- primaryAddrs = append(primaryAddrs, pa.AddressWithPrefix.Address)
- }
- var expectedList []tcpip.Address
- switch ps {
- case stack.FirstPrimaryEndpoint:
- expectedList = []tcpip.Address{
- "\x01",
- "\x03",
- }
- case stack.CanBePrimaryEndpoint:
- expectedList = []tcpip.Address{
- "\x03",
- "\x01",
- }
- case stack.NeverPrimaryEndpoint:
- expectedList = []tcpip.Address{
- "\x03",
- }
- }
- if !cmp.Equal(primaryAddrs, expectedList) {
- t.Fatalf("got NIC's primary addresses = %v, want = %v", primaryAddrs, expectedList)
- }
-
- // Once we remove the other address, if the new
- // peb, ps, was NeverPrimaryEndpoint, no address
- // should be returned by a call to
- // GetMainNICAddress; else, our original address
- // should be returned.
- if err := s.RemoveAddress(nicID, address3); err != nil {
- t.Fatalf("RemoveAddress(%d, %s): %s", nicID, address3, err)
- }
- addr, err = s.GetMainNICAddress(nicID, fakeNetNumber)
- if err != nil {
- t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
- }
- if ps == stack.NeverPrimaryEndpoint {
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr, want)
- }
- } else {
- if addr.Address != address1 {
- t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, addr.Address, address1)
- }
- }
- })
- }
- }
-}
-
-func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
- const (
- nicID = 1
- lifetimeSeconds = 9999
- )
-
- var (
- linkLocalAddr1 = testutil.MustParse6("fe80::1")
- linkLocalAddr2 = testutil.MustParse6("fe80::2")
- linkLocalMulticastAddr = testutil.MustParse6("ff02::1")
- uniqueLocalAddr1 = testutil.MustParse6("fc00::1")
- uniqueLocalAddr2 = testutil.MustParse6("fd00::2")
- globalAddr1 = testutil.MustParse6("a000::1")
- globalAddr2 = testutil.MustParse6("a000::2")
- globalAddr3 = testutil.MustParse6("a000::3")
- ipv4MappedIPv6Addr1 = testutil.MustParse6("::ffff:0.0.0.1")
- ipv4MappedIPv6Addr2 = testutil.MustParse6("::ffff:0.0.0.2")
- toredoAddr1 = testutil.MustParse6("2001::1")
- toredoAddr2 = testutil.MustParse6("2001::2")
- ipv6ToIPv4Addr1 = testutil.MustParse6("2002::1")
- ipv6ToIPv4Addr2 = testutil.MustParse6("2002::2")
- )
-
- prefix1, _, stableGlobalAddr1 := prefixSubnetAddr(0, linkAddr1)
- prefix2, _, stableGlobalAddr2 := prefixSubnetAddr(1, linkAddr1)
-
- var tempIIDHistory [header.IIDSize]byte
- header.InitialTempIID(tempIIDHistory[:], nil, nicID)
- tempGlobalAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableGlobalAddr1.Address).Address
- tempGlobalAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableGlobalAddr2.Address).Address
-
- // Rule 3 is not tested here, and is instead tested by NDP's AutoGenAddr test.
- tests := []struct {
- name string
- slaacPrefixForTempAddrBeforeNICAddrAdd tcpip.AddressWithPrefix
- nicAddrs []tcpip.Address
- slaacPrefixForTempAddrAfterNICAddrAdd tcpip.AddressWithPrefix
- remoteAddr tcpip.Address
- expectedLocalAddr tcpip.Address
- }{
- // Test Rule 1 of RFC 6724 section 5 (prefer same address).
- {
- name: "Same Global most preferred (last address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1},
- remoteAddr: globalAddr1,
- expectedLocalAddr: globalAddr1,
- },
- {
- name: "Same Global most preferred (first address)",
- nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1},
- remoteAddr: globalAddr1,
- expectedLocalAddr: globalAddr1,
- },
- {
- name: "Same Link Local most preferred (last address)",
- nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1},
- remoteAddr: linkLocalAddr1,
- expectedLocalAddr: linkLocalAddr1,
- },
- {
- name: "Same Link Local most preferred (first address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1},
- remoteAddr: linkLocalAddr1,
- expectedLocalAddr: linkLocalAddr1,
- },
- {
- name: "Same Unique Local most preferred (last address)",
- nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1},
- remoteAddr: uniqueLocalAddr1,
- expectedLocalAddr: uniqueLocalAddr1,
- },
- {
- name: "Same Unique Local most preferred (first address)",
- nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1},
- remoteAddr: uniqueLocalAddr1,
- expectedLocalAddr: uniqueLocalAddr1,
- },
-
- // Test Rule 2 of RFC 6724 section 5 (prefer appropriate scope).
- {
- name: "Global most preferred (last address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1},
- remoteAddr: globalAddr2,
- expectedLocalAddr: globalAddr1,
- },
- {
- name: "Global most preferred (first address)",
- nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1},
- remoteAddr: globalAddr2,
- expectedLocalAddr: globalAddr1,
- },
- {
- name: "Link Local most preferred (last address)",
- nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1},
- remoteAddr: linkLocalAddr2,
- expectedLocalAddr: linkLocalAddr1,
- },
- {
- name: "Link Local most preferred (first address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1},
- remoteAddr: linkLocalAddr2,
- expectedLocalAddr: linkLocalAddr1,
- },
- {
- name: "Link Local most preferred for link local multicast (last address)",
- nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1},
- remoteAddr: linkLocalMulticastAddr,
- expectedLocalAddr: linkLocalAddr1,
- },
- {
- name: "Link Local most preferred for link local multicast (first address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1},
- remoteAddr: linkLocalMulticastAddr,
- expectedLocalAddr: linkLocalAddr1,
- },
-
- // Test Rule 6 of 6724 section 5 (prefer matching label).
- {
- name: "Unique Local most preferred (last address)",
- nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1, toredoAddr1, ipv6ToIPv4Addr1},
- remoteAddr: uniqueLocalAddr2,
- expectedLocalAddr: uniqueLocalAddr1,
- },
- {
- name: "Unique Local most preferred (first address)",
- nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, toredoAddr1, ipv6ToIPv4Addr1, uniqueLocalAddr1},
- remoteAddr: uniqueLocalAddr2,
- expectedLocalAddr: uniqueLocalAddr1,
- },
- {
- name: "Toredo most preferred (first address)",
- nicAddrs: []tcpip.Address{toredoAddr1, uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1},
- remoteAddr: toredoAddr2,
- expectedLocalAddr: toredoAddr1,
- },
- {
- name: "Toredo most preferred (last address)",
- nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1, uniqueLocalAddr1, toredoAddr1},
- remoteAddr: toredoAddr2,
- expectedLocalAddr: toredoAddr1,
- },
- {
- name: "6To4 most preferred (first address)",
- nicAddrs: []tcpip.Address{ipv6ToIPv4Addr1, toredoAddr1, uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1},
- remoteAddr: ipv6ToIPv4Addr2,
- expectedLocalAddr: ipv6ToIPv4Addr1,
- },
- {
- name: "6To4 most preferred (last address)",
- nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, uniqueLocalAddr1, toredoAddr1, ipv6ToIPv4Addr1},
- remoteAddr: ipv6ToIPv4Addr2,
- expectedLocalAddr: ipv6ToIPv4Addr1,
- },
- {
- name: "IPv4 mapped IPv6 most preferred (first address)",
- nicAddrs: []tcpip.Address{ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1, toredoAddr1, uniqueLocalAddr1, globalAddr1},
- remoteAddr: ipv4MappedIPv6Addr2,
- expectedLocalAddr: ipv4MappedIPv6Addr1,
- },
- {
- name: "IPv4 mapped IPv6 most preferred (last address)",
- nicAddrs: []tcpip.Address{globalAddr1, ipv6ToIPv4Addr1, uniqueLocalAddr1, toredoAddr1, ipv4MappedIPv6Addr1},
- remoteAddr: ipv4MappedIPv6Addr2,
- expectedLocalAddr: ipv4MappedIPv6Addr1,
- },
-
- // Test Rule 7 of RFC 6724 section 5 (prefer temporary addresses).
- {
- name: "Temp Global most preferred (last address)",
- slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1,
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- remoteAddr: globalAddr2,
- expectedLocalAddr: tempGlobalAddr1,
- },
- {
- name: "Temp Global most preferred (first address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- slaacPrefixForTempAddrAfterNICAddrAdd: prefix1,
- remoteAddr: globalAddr2,
- expectedLocalAddr: tempGlobalAddr1,
- },
-
- // Test Rule 8 of RFC 6724 section 5 (use longest matching prefix).
- {
- name: "Longest prefix matched most preferred (first address)",
- nicAddrs: []tcpip.Address{globalAddr2, globalAddr1},
- remoteAddr: globalAddr3,
- expectedLocalAddr: globalAddr2,
- },
- {
- name: "Longest prefix matched most preferred (last address)",
- nicAddrs: []tcpip.Address{globalAddr1, globalAddr2},
- remoteAddr: globalAddr3,
- expectedLocalAddr: globalAddr2,
- },
-
- // Test returning the endpoint that is closest to the front when
- // candidate addresses are "equal" from the perspective of RFC 6724
- // section 5.
- {
- name: "Unique Local for Global",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, uniqueLocalAddr2},
- remoteAddr: globalAddr2,
- expectedLocalAddr: uniqueLocalAddr1,
- },
- {
- name: "Link Local for Global",
- nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2},
- remoteAddr: globalAddr2,
- expectedLocalAddr: linkLocalAddr1,
- },
- {
- name: "Link Local for Unique Local",
- nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2},
- remoteAddr: uniqueLocalAddr2,
- expectedLocalAddr: linkLocalAddr1,
- },
- {
- name: "Temp Global for Global",
- slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1,
- slaacPrefixForTempAddrAfterNICAddrAdd: prefix2,
- remoteAddr: globalAddr1,
- expectedLocalAddr: tempGlobalAddr2,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- AutoGenTempGlobalAddresses: true,
- },
- NDPDisp: &ndpDispatcher{},
- })},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- if test.slaacPrefixForTempAddrBeforeNICAddrAdd != (tcpip.AddressWithPrefix{}) {
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, test.slaacPrefixForTempAddrBeforeNICAddrAdd, true, true, lifetimeSeconds, lifetimeSeconds))
- }
-
- for _, a := range test.nicAddrs {
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: a.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- }
-
- if test.slaacPrefixForTempAddrAfterNICAddrAdd != (tcpip.AddressWithPrefix{}) {
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, test.slaacPrefixForTempAddrAfterNICAddrAdd, true, true, lifetimeSeconds, lifetimeSeconds))
- }
-
- if t.Failed() {
- t.FailNow()
- }
-
- netEP, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- }
-
- addressableEndpoint, ok := netEP.(stack.AddressableEndpoint)
- if !ok {
- t.Fatal("network endpoint is not addressable")
- }
-
- addressEP := addressableEndpoint.AcquireOutgoingPrimaryAddress(test.remoteAddr, false /* allowExpired */)
- if addressEP == nil {
- t.Fatal("expected a non-nil address endpoint")
- }
- defer addressEP.DecRef()
-
- if got := addressEP.AddressWithPrefix().Address; got != test.expectedLocalAddr {
- t.Errorf("got local address = %s, want = %s", got, test.expectedLocalAddr)
- }
- })
- }
-}
-
-func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) {
- const nicID = 1
- broadcastAddr := tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: header.IPv4Broadcast,
- PrefixLen: 32,
- },
- }
-
- e := loopback.New()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- })
- nicOpts := stack.NICOptions{Disabled: true}
- if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
- t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
- }
-
- {
- allStackAddrs := s.AllAddresses()
- if allNICAddrs, ok := allStackAddrs[nicID]; !ok {
- t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
- } else if containsAddr(allNICAddrs, broadcastAddr) {
- t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr)
- }
- }
-
- // Enabling the NIC should add the IPv4 broadcast address.
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
- }
-
- {
- allStackAddrs := s.AllAddresses()
- if allNICAddrs, ok := allStackAddrs[nicID]; !ok {
- t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
- } else if !containsAddr(allNICAddrs, broadcastAddr) {
- t.Fatalf("got allNICAddrs = %+v, want = %+v", allNICAddrs, broadcastAddr)
- }
- }
-
- // Disabling the NIC should remove the IPv4 broadcast address.
- if err := s.DisableNIC(nicID); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
- }
-
- {
- allStackAddrs := s.AllAddresses()
- if allNICAddrs, ok := allStackAddrs[nicID]; !ok {
- t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
- } else if containsAddr(allNICAddrs, broadcastAddr) {
- t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr)
- }
- }
-}
-
-// TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval tests that removing an IPv6
-// address after leaving its solicited node multicast address does not result in
-// an error.
-func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) {
- const nicID = 1
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- })
- e := channel.New(10, 1280, linkAddr1)
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: addr1.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
-
- // The NIC should have joined addr1's solicited node multicast address.
- snmc := header.SolicitedNodeAddr(addr1)
- in, err := s.IsInGroup(nicID, snmc)
- if err != nil {
- t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err)
- }
- if !in {
- t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, snmc)
- }
-
- if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, snmc); err != nil {
- t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, snmc, err)
- }
- in, err = s.IsInGroup(nicID, snmc)
- if err != nil {
- t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err)
- }
- if in {
- t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, snmc)
- }
-
- if err := s.RemoveAddress(nicID, addr1); err != nil {
- t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err)
- }
-}
-
-func TestJoinLeaveMulticastOnNICEnableDisable(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- proto tcpip.NetworkProtocolNumber
- addr tcpip.Address
- }{
- {
- name: "IPv6 All-Nodes",
- proto: header.IPv6ProtocolNumber,
- addr: header.IPv6AllNodesMulticastAddress,
- },
- {
- name: "IPv4 All-Systems",
- proto: header.IPv4ProtocolNumber,
- addr: header.IPv4AllSystems,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e := loopback.New()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- })
- nicOpts := stack.NICOptions{Disabled: true}
- if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
- t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
- }
-
- // Should not be in the multicast group yet because the NIC has not been
- // enabled yet.
- if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil {
- t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err)
- } else if isInGroup {
- t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr)
- }
-
- // The all-nodes multicast group should be joined when the NIC is enabled.
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
- }
-
- if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil {
- t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err)
- } else if !isInGroup {
- t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr)
- }
-
- // The multicast group should be left when the NIC is disabled.
- if err := s.DisableNIC(nicID); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
- }
-
- if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil {
- t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err)
- } else if isInGroup {
- t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr)
- }
-
- // The all-nodes multicast group should be joined when the NIC is enabled.
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
- }
-
- if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil {
- t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err)
- } else if !isInGroup {
- t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr)
- }
-
- // Leaving the group before disabling the NIC should not cause an error.
- if err := s.LeaveGroup(test.proto, nicID, test.addr); err != nil {
- t.Fatalf("s.LeaveGroup(%d, %d, %s): %s", test.proto, nicID, test.addr, err)
- }
-
- if err := s.DisableNIC(nicID); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
- }
-
- if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil {
- t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err)
- } else if isInGroup {
- t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr)
- }
- })
- }
-}
-
-// TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC
-// was disabled have DAD performed on them when the NIC is enabled.
-func TestDoDADWhenNICEnabled(t *testing.T) {
- const dadTransmits = 1
- const retransmitTimer = time.Second
- const nicID = 1
-
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 1),
- }
- clock := faketime.NewManualClock()
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- DADConfigs: stack.DADConfigurations{
- DupAddrDetectTransmits: dadTransmits,
- RetransmitTimer: retransmitTimer,
- },
- NDPDisp: &ndpDisp,
- })},
- Clock: clock,
- }
-
- e := channel.New(dadTransmits, 1280, linkAddr1)
- s := stack.New(opts)
- nicOpts := stack.NICOptions{Disabled: true}
- if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
- t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
- }
-
- addr := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: llAddr1,
- PrefixLen: 128,
- },
- }
- if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err)
- }
-
- // Address should be in the list of all addresses.
- if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
- t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
- }
-
- // Address should be tentative so it should not be a main address.
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
- t.Fatal(err)
- }
-
- // Enabling the NIC should start DAD for the address.
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
- }
- if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
- t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
- }
-
- // Address should not be considered bound to the NIC yet (DAD ongoing).
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
- t.Fatal(err)
- }
-
- // Wait for DAD to resolve.
- clock.Advance(dadTransmits * retransmitTimer)
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, &stack.DADSucceeded{}); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("timed out waiting for DAD resolution")
- }
- if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
- t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
- }
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr.AddressWithPrefix); err != nil {
- t.Fatal(err)
- }
-
- // Enabling the NIC again should be a no-op.
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
- }
- if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
- t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
- }
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr.AddressWithPrefix); err != nil {
- t.Fatal(err)
- }
-}
-
-func TestStackReceiveBufferSizeOption(t *testing.T) {
- const sMin = stack.MinBufferSize
- testCases := []struct {
- name string
- rs tcpip.ReceiveBufferSizeOption
- err tcpip.Error
- }{
- // Invalid configurations.
- {"min_below_zero", tcpip.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
- {"min_zero", tcpip.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
- {"default_below_min", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}},
- {"default_above_max", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
- {"max_below_min", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}},
-
- // Valid Configurations
- {"in_ascending_order", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil},
- {"all_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil},
- {"min_default_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil},
- {"default_max_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil},
- }
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- s := stack.New(stack.Options{})
- defer s.Close()
- if err := s.SetOption(tc.rs); err != tc.err {
- t.Fatalf("s.SetOption(%#v) = %v, want: %v", tc.rs, err, tc.err)
- }
- var rs tcpip.ReceiveBufferSizeOption
- if tc.err == nil {
- if err := s.Option(&rs); err != nil {
- t.Fatalf("s.Option(%#v) = %v, want: nil", rs, err)
- }
- if got, want := rs, tc.rs; got != want {
- t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want)
- }
- }
- })
- }
-}
-
-func TestStackSendBufferSizeOption(t *testing.T) {
- const sMin = stack.MinBufferSize
- testCases := []struct {
- name string
- ss tcpip.SendBufferSizeOption
- err tcpip.Error
- }{
- // Invalid configurations.
- {"min_below_zero", tcpip.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
- {"min_zero", tcpip.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
- {"default_below_min", tcpip.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}},
- {"default_above_max", tcpip.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
- {"max_below_min", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}},
-
- // Valid Configurations
- {"in_ascending_order", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil},
- {"all_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil},
- {"min_default_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil},
- {"default_max_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil},
- }
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- s := stack.New(stack.Options{})
- defer s.Close()
- err := s.SetOption(tc.ss)
- if diff := cmp.Diff(tc.err, err); diff != "" {
- t.Fatalf("unexpected error from s.SetOption(%+v), (-want, +got):\n%s", tc.ss, diff)
- }
- if tc.err == nil {
- var ss tcpip.SendBufferSizeOption
- if err := s.Option(&ss); err != nil {
- t.Fatalf("s.Option(%+v) = %v, want: nil", ss, err)
- }
- if got, want := ss, tc.ss; got != want {
- t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want)
- }
- }
- })
- }
-}
-
-func TestOutgoingSubnetBroadcast(t *testing.T) {
- const (
- unspecifiedNICID = 0
- nicID1 = 1
- )
-
- defaultAddr := tcpip.AddressWithPrefix{
- Address: header.IPv4Any,
- PrefixLen: 0,
- }
- defaultSubnet := defaultAddr.Subnet()
- ipv4Addr := tcpip.AddressWithPrefix{
- Address: "\xc0\xa8\x01\x3a",
- PrefixLen: 24,
- }
- ipv4Subnet := ipv4Addr.Subnet()
- ipv4SubnetBcast := ipv4Subnet.Broadcast()
- ipv4Gateway := testutil.MustParse4("192.168.1.1")
- ipv4AddrPrefix31 := tcpip.AddressWithPrefix{
- Address: "\xc0\xa8\x01\x3a",
- PrefixLen: 31,
- }
- ipv4Subnet31 := ipv4AddrPrefix31.Subnet()
- ipv4Subnet31Bcast := ipv4Subnet31.Broadcast()
- ipv4AddrPrefix32 := tcpip.AddressWithPrefix{
- Address: "\xc0\xa8\x01\x3a",
- PrefixLen: 32,
- }
- ipv4Subnet32 := ipv4AddrPrefix32.Subnet()
- ipv4Subnet32Bcast := ipv4Subnet32.Broadcast()
- ipv6Addr := tcpip.AddressWithPrefix{
- Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- PrefixLen: 64,
- }
- ipv6Subnet := ipv6Addr.Subnet()
- ipv6SubnetBcast := ipv6Subnet.Broadcast()
- remNetAddr := tcpip.AddressWithPrefix{
- Address: "\x64\x0a\x7b\x18",
- PrefixLen: 24,
- }
- remNetSubnet := remNetAddr.Subnet()
- remNetSubnetBcast := remNetSubnet.Broadcast()
-
- tests := []struct {
- name string
- nicAddr tcpip.ProtocolAddress
- routes []tcpip.Route
- remoteAddr tcpip.Address
- expectedLocalAddress tcpip.Address
- expectedRemoteAddress tcpip.Address
- expectedRemoteLinkAddress tcpip.LinkAddress
- expectedNextHop tcpip.Address
- expectedNetProto tcpip.NetworkProtocolNumber
- expectedLoop stack.PacketLooping
- }{
- // Broadcast to a locally attached subnet populates the broadcast MAC.
- {
- name: "IPv4 Broadcast to local subnet",
- nicAddr: tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: ipv4Addr,
- },
- routes: []tcpip.Route{
- {
- Destination: ipv4Subnet,
- NIC: nicID1,
- },
- },
- remoteAddr: ipv4SubnetBcast,
- expectedLocalAddress: ipv4Addr.Address,
- expectedRemoteAddress: ipv4SubnetBcast,
- expectedRemoteLinkAddress: header.EthernetBroadcastAddress,
- expectedNetProto: header.IPv4ProtocolNumber,
- expectedLoop: stack.PacketOut | stack.PacketLoop,
- },
- // Broadcast to a locally attached /31 subnet does not populate the
- // broadcast MAC.
- {
- name: "IPv4 Broadcast to local /31 subnet",
- nicAddr: tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: ipv4AddrPrefix31,
- },
- routes: []tcpip.Route{
- {
- Destination: ipv4Subnet31,
- NIC: nicID1,
- },
- },
- remoteAddr: ipv4Subnet31Bcast,
- expectedLocalAddress: ipv4AddrPrefix31.Address,
- expectedRemoteAddress: ipv4Subnet31Bcast,
- expectedNetProto: header.IPv4ProtocolNumber,
- expectedLoop: stack.PacketOut,
- },
- // Broadcast to a locally attached /32 subnet does not populate the
- // broadcast MAC.
- {
- name: "IPv4 Broadcast to local /32 subnet",
- nicAddr: tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: ipv4AddrPrefix32,
- },
- routes: []tcpip.Route{
- {
- Destination: ipv4Subnet32,
- NIC: nicID1,
- },
- },
- remoteAddr: ipv4Subnet32Bcast,
- expectedLocalAddress: ipv4AddrPrefix32.Address,
- expectedRemoteAddress: ipv4Subnet32Bcast,
- expectedNetProto: header.IPv4ProtocolNumber,
- expectedLoop: stack.PacketOut,
- },
- // IPv6 has no notion of a broadcast.
- {
- name: "IPv6 'Broadcast' to local subnet",
- nicAddr: tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: ipv6Addr,
- },
- routes: []tcpip.Route{
- {
- Destination: ipv6Subnet,
- NIC: nicID1,
- },
- },
- remoteAddr: ipv6SubnetBcast,
- expectedLocalAddress: ipv6Addr.Address,
- expectedRemoteAddress: ipv6SubnetBcast,
- expectedNetProto: header.IPv6ProtocolNumber,
- expectedLoop: stack.PacketOut,
- },
- // Broadcast to a remote subnet in the route table is send to the next-hop
- // gateway.
- {
- name: "IPv4 Broadcast to remote subnet",
- nicAddr: tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: ipv4Addr,
- },
- routes: []tcpip.Route{
- {
- Destination: remNetSubnet,
- Gateway: ipv4Gateway,
- NIC: nicID1,
- },
- },
- remoteAddr: remNetSubnetBcast,
- expectedLocalAddress: ipv4Addr.Address,
- expectedRemoteAddress: remNetSubnetBcast,
- expectedNextHop: ipv4Gateway,
- expectedNetProto: header.IPv4ProtocolNumber,
- expectedLoop: stack.PacketOut,
- },
- // Broadcast to an unknown subnet follows the default route. Note that this
- // is essentially just routing an unknown destination IP, because w/o any
- // subnet prefix information a subnet broadcast address is just a normal IP.
- {
- name: "IPv4 Broadcast to unknown subnet",
- nicAddr: tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: ipv4Addr,
- },
- routes: []tcpip.Route{
- {
- Destination: defaultSubnet,
- Gateway: ipv4Gateway,
- NIC: nicID1,
- },
- },
- remoteAddr: remNetSubnetBcast,
- expectedLocalAddress: ipv4Addr.Address,
- expectedRemoteAddress: remNetSubnetBcast,
- expectedNextHop: ipv4Gateway,
- expectedNetProto: header.IPv4ProtocolNumber,
- expectedLoop: stack.PacketOut,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
- })
- ep := channel.New(0, defaultMTU, "")
- ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- if err := s.CreateNIC(nicID1, ep); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
- }
- if err := s.AddProtocolAddress(nicID1, test.nicAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, test.nicAddr, err)
- }
-
- s.SetRouteTable(test.routes)
-
- var netProto tcpip.NetworkProtocolNumber
- switch l := len(test.remoteAddr); l {
- case header.IPv4AddressSize:
- netProto = header.IPv4ProtocolNumber
- case header.IPv6AddressSize:
- netProto = header.IPv6ProtocolNumber
- default:
- t.Fatalf("got unexpected address length = %d bytes", l)
- }
-
- r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, test.remoteAddr, netProto, err)
- }
- if r.LocalAddress() != test.expectedLocalAddress {
- t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.expectedLocalAddress)
- }
- if r.RemoteAddress() != test.expectedRemoteAddress {
- t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress(), test.expectedRemoteAddress)
- }
- if got := r.RemoteLinkAddress(); got != test.expectedRemoteLinkAddress {
- t.Errorf("got r.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddress)
- }
- if r.NextHop() != test.expectedNextHop {
- t.Errorf("got r.NextHop() = %s, want = %s", r.NextHop(), test.expectedNextHop)
- }
- if r.NetProto() != test.expectedNetProto {
- t.Errorf("got r.NetProto() = %d, want = %d", r.NetProto(), test.expectedNetProto)
- }
- if r.Loop() != test.expectedLoop {
- t.Errorf("got r.Loop() = %x, want = %x", r.Loop(), test.expectedLoop)
- }
- })
- }
-}
-
-func TestResolveWith(t *testing.T) {
- const (
- unspecifiedNICID = 0
- nicID = 1
- )
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol},
- })
- ep := channel.New(0, defaultMTU, "")
- ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
- addr := tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address([]byte{192, 168, 1, 58}),
- PrefixLen: 24,
- },
- }
- if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err)
- }
-
- s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}})
-
- remoteAddr := tcpip.Address([]byte{192, 168, 1, 59})
- r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, remoteAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, remoteAddr, header.IPv4ProtocolNumber, err)
- }
- defer r.Release()
-
- // Should initially require resolution.
- if !r.IsResolutionRequired() {
- t.Fatal("got r.IsResolutionRequired() = false, want = true")
- }
-
- // Manually resolving the route should no longer require resolution.
- r.ResolveWith("\x01")
- if r.IsResolutionRequired() {
- t.Fatal("got r.IsResolutionRequired() = true, want = false")
- }
-}
-
-// TestRouteReleaseAfterAddrRemoval tests that releasing a Route after its
-// associated address is removed should not cause a panic.
-func TestRouteReleaseAfterAddrRemoval(t *testing.T) {
- const (
- nicID = 1
- localAddr = tcpip.Address("\x01")
- remoteAddr = tcpip.Address("\x02")
- )
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
-
- ep := channel.New(0, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: localAddr,
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- r, err := s.FindRoute(nicID, localAddr, remoteAddr, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, localAddr, remoteAddr, fakeNetNumber, err)
- }
- // Should not panic.
- defer r.Release()
-
- // Check that removing the same address fails.
- if err := s.RemoveAddress(nicID, localAddr); err != nil {
- t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, localAddr, err)
- }
-}
-
-func TestGetNetworkEndpoint(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- protoFactory stack.NetworkProtocolFactory
- protoNum tcpip.NetworkProtocolNumber
- }{
- {
- name: "IPv4",
- protoFactory: ipv4.NewProtocol,
- protoNum: ipv4.ProtocolNumber,
- },
- {
- name: "IPv6",
- protoFactory: ipv6.NewProtocol,
- protoNum: ipv6.ProtocolNumber,
- },
- }
-
- factories := make([]stack.NetworkProtocolFactory, 0, len(tests))
- for _, test := range tests {
- factories = append(factories, test.protoFactory)
- }
-
- s := stack.New(stack.Options{
- NetworkProtocols: factories,
- })
-
- if err := s.CreateNIC(nicID, channel.New(0, defaultMTU, "")); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- ep, err := s.GetNetworkEndpoint(nicID, test.protoNum)
- if err != nil {
- t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, test.protoNum, err)
- }
-
- if got := ep.NetworkProtocolNumber(); got != test.protoNum {
- t.Fatalf("got ep.NetworkProtocolNumber() = %d, want = %d", got, test.protoNum)
- }
- })
- }
-}
-
-func TestGetMainNICAddressWhenNICDisabled(t *testing.T) {
- const nicID = 1
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
-
- if err := s.CreateNIC(nicID, channel.New(0, defaultMTU, "")); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
-
- protocolAddress := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: "\x01",
- PrefixLen: 8,
- },
- }
- if err := s.AddProtocolAddress(nicID, protocolAddress, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddress, err)
- }
-
- // Check that we get the right initial address and prefix length.
- if err := checkGetMainNICAddress(s, nicID, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil {
- t.Fatal(err)
- }
-
- // Should still get the address when the NIC is diabled.
- if err := s.DisableNIC(nicID); err != nil {
- t.Fatalf("DisableNIC(%d): %s", nicID, err)
- }
- if err := checkGetMainNICAddress(s, nicID, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil {
- t.Fatal(err)
- }
-}
-
-// TestAddRoute tests Stack.AddRoute
-func TestAddRoute(t *testing.T) {
- s := stack.New(stack.Options{})
-
- subnet1, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
-
- subnet2, err := tcpip.NewSubnet("\x01", "\x01")
- if err != nil {
- t.Fatal(err)
- }
-
- expected := []tcpip.Route{
- {Destination: subnet1, Gateway: "\x00", NIC: 1},
- {Destination: subnet2, Gateway: "\x00", NIC: 1},
- }
-
- // Initialize the route table with one route.
- s.SetRouteTable([]tcpip.Route{expected[0]})
-
- // Add another route.
- s.AddRoute(expected[1])
-
- rt := s.GetRouteTable()
- if got, want := len(rt), len(expected); got != want {
- t.Fatalf("Unexpected route table length got = %d, want = %d", got, want)
- }
- for i, route := range rt {
- if got, want := route, expected[i]; got != want {
- t.Fatalf("Unexpected route got = %#v, want = %#v", got, want)
- }
- }
-}
-
-// TestRemoveRoutes tests Stack.RemoveRoutes
-func TestRemoveRoutes(t *testing.T) {
- s := stack.New(stack.Options{})
-
- addressToRemove := tcpip.Address("\x01")
- subnet1, err := tcpip.NewSubnet(addressToRemove, "\x01")
- if err != nil {
- t.Fatal(err)
- }
-
- subnet2, err := tcpip.NewSubnet(addressToRemove, "\x01")
- if err != nil {
- t.Fatal(err)
- }
-
- subnet3, err := tcpip.NewSubnet("\x02", "\x02")
- if err != nil {
- t.Fatal(err)
- }
-
- // Initialize the route table with three routes.
- s.SetRouteTable([]tcpip.Route{
- {Destination: subnet1, Gateway: "\x00", NIC: 1},
- {Destination: subnet2, Gateway: "\x00", NIC: 1},
- {Destination: subnet3, Gateway: "\x00", NIC: 1},
- })
-
- // Remove routes with the specific address.
- s.RemoveRoutes(func(r tcpip.Route) bool {
- return r.Destination.ID() == addressToRemove
- })
-
- expected := []tcpip.Route{{Destination: subnet3, Gateway: "\x00", NIC: 1}}
- rt := s.GetRouteTable()
- if got, want := len(rt), len(expected); got != want {
- t.Fatalf("Unexpected route table length got = %d, want = %d", got, want)
- }
- for i, route := range rt {
- if got, want := route, expected[i]; got != want {
- t.Fatalf("Unexpected route got = %#v, want = %#v", got, want)
- }
- }
-}
-
-func TestFindRouteWithForwarding(t *testing.T) {
- const (
- nicID1 = 1
- nicID2 = 2
-
- nic1Addr = tcpip.Address("\x01")
- nic2Addr = tcpip.Address("\x02")
- remoteAddr = tcpip.Address("\x03")
- )
-
- type netCfg struct {
- proto tcpip.NetworkProtocolNumber
- factory stack.NetworkProtocolFactory
- nic1AddrWithPrefix tcpip.AddressWithPrefix
- nic2AddrWithPrefix tcpip.AddressWithPrefix
- remoteAddr tcpip.Address
- }
-
- fakeNetCfg := netCfg{
- proto: fakeNetNumber,
- factory: fakeNetFactory,
- nic1AddrWithPrefix: tcpip.AddressWithPrefix{Address: nic1Addr, PrefixLen: fakeDefaultPrefixLen},
- nic2AddrWithPrefix: tcpip.AddressWithPrefix{Address: nic2Addr, PrefixLen: fakeDefaultPrefixLen},
- remoteAddr: remoteAddr,
- }
-
- globalIPv6Addr1 := tcpip.Address(net.ParseIP("a::1").To16())
- globalIPv6Addr2 := tcpip.Address(net.ParseIP("a::2").To16())
-
- ipv6LinkLocalNIC1WithGlobalRemote := netCfg{
- proto: ipv6.ProtocolNumber,
- factory: ipv6.NewProtocol,
- nic1AddrWithPrefix: llAddr1.WithPrefix(),
- nic2AddrWithPrefix: globalIPv6Addr2.WithPrefix(),
- remoteAddr: globalIPv6Addr1,
- }
- ipv6GlobalNIC1WithLinkLocalRemote := netCfg{
- proto: ipv6.ProtocolNumber,
- factory: ipv6.NewProtocol,
- nic1AddrWithPrefix: globalIPv6Addr1.WithPrefix(),
- nic2AddrWithPrefix: llAddr1.WithPrefix(),
- remoteAddr: llAddr2,
- }
- ipv6GlobalNIC1WithLinkLocalMulticastRemote := netCfg{
- proto: ipv6.ProtocolNumber,
- factory: ipv6.NewProtocol,
- nic1AddrWithPrefix: globalIPv6Addr1.WithPrefix(),
- nic2AddrWithPrefix: globalIPv6Addr2.WithPrefix(),
- remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- }
-
- tests := []struct {
- name string
-
- netCfg netCfg
- forwardingEnabled bool
-
- addrNIC tcpip.NICID
- localAddrWithPrefix tcpip.AddressWithPrefix
-
- findRouteErr tcpip.Error
- dependentOnForwarding bool
- }{
- {
- name: "forwarding disabled and localAddr not on specified NIC but route from different NIC",
- netCfg: fakeNetCfg,
- forwardingEnabled: false,
- addrNIC: nicID1,
- localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
- findRouteErr: &tcpip.ErrNoRoute{},
- dependentOnForwarding: false,
- },
- {
- name: "forwarding enabled and localAddr not on specified NIC but route from different NIC",
- netCfg: fakeNetCfg,
- forwardingEnabled: true,
- addrNIC: nicID1,
- localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
- findRouteErr: &tcpip.ErrNoRoute{},
- dependentOnForwarding: false,
- },
- {
- name: "forwarding disabled and localAddr on specified NIC but route from different NIC",
- netCfg: fakeNetCfg,
- forwardingEnabled: false,
- addrNIC: nicID1,
- localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
- findRouteErr: &tcpip.ErrNoRoute{},
- dependentOnForwarding: false,
- },
- {
- name: "forwarding enabled and localAddr on specified NIC but route from different NIC",
- netCfg: fakeNetCfg,
- forwardingEnabled: true,
- addrNIC: nicID1,
- localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
- findRouteErr: nil,
- dependentOnForwarding: true,
- },
- {
- name: "forwarding disabled and localAddr on specified NIC and route from same NIC",
- netCfg: fakeNetCfg,
- forwardingEnabled: false,
- addrNIC: nicID2,
- localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
- findRouteErr: nil,
- dependentOnForwarding: false,
- },
- {
- name: "forwarding enabled and localAddr on specified NIC and route from same NIC",
- netCfg: fakeNetCfg,
- forwardingEnabled: true,
- addrNIC: nicID2,
- localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
- findRouteErr: nil,
- dependentOnForwarding: false,
- },
- {
- name: "forwarding disabled and localAddr not on specified NIC but route from same NIC",
- netCfg: fakeNetCfg,
- forwardingEnabled: false,
- addrNIC: nicID2,
- localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
- findRouteErr: &tcpip.ErrNoRoute{},
- dependentOnForwarding: false,
- },
- {
- name: "forwarding enabled and localAddr not on specified NIC but route from same NIC",
- netCfg: fakeNetCfg,
- forwardingEnabled: true,
- addrNIC: nicID2,
- localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
- findRouteErr: &tcpip.ErrNoRoute{},
- dependentOnForwarding: false,
- },
- {
- name: "forwarding disabled and localAddr on same NIC as route",
- netCfg: fakeNetCfg,
- forwardingEnabled: false,
- localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
- findRouteErr: nil,
- dependentOnForwarding: false,
- },
- {
- name: "forwarding enabled and localAddr on same NIC as route",
- netCfg: fakeNetCfg,
- forwardingEnabled: false,
- localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
- findRouteErr: nil,
- dependentOnForwarding: false,
- },
- {
- name: "forwarding disabled and localAddr on different NIC as route",
- netCfg: fakeNetCfg,
- forwardingEnabled: false,
- localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
- findRouteErr: &tcpip.ErrNoRoute{},
- dependentOnForwarding: false,
- },
- {
- name: "forwarding enabled and localAddr on different NIC as route",
- netCfg: fakeNetCfg,
- forwardingEnabled: true,
- localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
- findRouteErr: nil,
- dependentOnForwarding: true,
- },
- {
- name: "forwarding disabled and specified NIC only has link-local addr with route on different NIC",
- netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
- forwardingEnabled: false,
- addrNIC: nicID1,
- findRouteErr: &tcpip.ErrNoRoute{},
- dependentOnForwarding: false,
- },
- {
- name: "forwarding enabled and specified NIC only has link-local addr with route on different NIC",
- netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
- forwardingEnabled: true,
- addrNIC: nicID1,
- findRouteErr: &tcpip.ErrNoRoute{},
- dependentOnForwarding: false,
- },
- {
- name: "forwarding disabled and link-local local addr with route on different NIC",
- netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
- forwardingEnabled: false,
- localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic1AddrWithPrefix,
- findRouteErr: &tcpip.ErrNoRoute{},
- dependentOnForwarding: false,
- },
- {
- name: "forwarding enabled and link-local local addr with route on same NIC",
- netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
- forwardingEnabled: true,
- localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic1AddrWithPrefix,
- findRouteErr: &tcpip.ErrNoRoute{},
- dependentOnForwarding: false,
- },
- {
- name: "forwarding disabled and global local addr with route on same NIC",
- netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
- forwardingEnabled: true,
- localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic2AddrWithPrefix,
- findRouteErr: nil,
- dependentOnForwarding: false,
- },
- {
- name: "forwarding disabled and link-local local addr with route on same NIC",
- netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
- forwardingEnabled: false,
- localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic2AddrWithPrefix,
- findRouteErr: nil,
- dependentOnForwarding: false,
- },
- {
- name: "forwarding enabled and link-local local addr with route on same NIC",
- netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
- forwardingEnabled: true,
- localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic2AddrWithPrefix,
- findRouteErr: nil,
- dependentOnForwarding: false,
- },
- {
- name: "forwarding disabled and global local addr with link-local remote on different NIC",
- netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
- forwardingEnabled: false,
- localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic1AddrWithPrefix,
- findRouteErr: &tcpip.ErrNetworkUnreachable{},
- dependentOnForwarding: false,
- },
- {
- name: "forwarding enabled and global local addr with link-local remote on different NIC",
- netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
- forwardingEnabled: true,
- localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic1AddrWithPrefix,
- findRouteErr: &tcpip.ErrNetworkUnreachable{},
- dependentOnForwarding: false,
- },
- {
- name: "forwarding disabled and global local addr with link-local multicast remote on different NIC",
- netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
- forwardingEnabled: false,
- localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1AddrWithPrefix,
- findRouteErr: &tcpip.ErrNetworkUnreachable{},
- dependentOnForwarding: false,
- },
- {
- name: "forwarding enabled and global local addr with link-local multicast remote on different NIC",
- netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
- forwardingEnabled: true,
- localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1AddrWithPrefix,
- findRouteErr: &tcpip.ErrNetworkUnreachable{},
- dependentOnForwarding: false,
- },
- {
- name: "forwarding disabled and global local addr with link-local multicast remote on same NIC",
- netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
- forwardingEnabled: false,
- localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2AddrWithPrefix,
- findRouteErr: nil,
- dependentOnForwarding: false,
- },
- {
- name: "forwarding enabled and global local addr with link-local multicast remote on same NIC",
- netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
- forwardingEnabled: true,
- localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2AddrWithPrefix,
- findRouteErr: nil,
- dependentOnForwarding: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{test.netCfg.factory},
- })
-
- ep1 := channel.New(1, defaultMTU, "")
- if err := s.CreateNIC(nicID1, ep1); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s:", nicID1, err)
- }
-
- ep2 := channel.New(1, defaultMTU, "")
- if err := s.CreateNIC(nicID2, ep2); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s:", nicID2, err)
- }
-
- protocolAddr1 := tcpip.ProtocolAddress{
- Protocol: test.netCfg.proto,
- AddressWithPrefix: test.netCfg.nic1AddrWithPrefix,
- }
- if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr1, err)
- }
-
- protocolAddr2 := tcpip.ProtocolAddress{
- Protocol: test.netCfg.proto,
- AddressWithPrefix: test.netCfg.nic2AddrWithPrefix,
- }
- if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddr2, err)
- }
-
- if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, test.forwardingEnabled); err != nil {
- t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err)
- }
-
- s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}})
-
- r, err := s.FindRoute(test.addrNIC, test.localAddrWithPrefix.Address, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
- if err == nil {
- defer r.Release()
- }
- if diff := cmp.Diff(test.findRouteErr, err); diff != "" {
- t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddrWithPrefix.Address, test.netCfg.remoteAddr, test.netCfg.proto, diff)
- }
-
- if test.findRouteErr != nil {
- return
- }
-
- if r.LocalAddress() != test.localAddrWithPrefix.Address {
- t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.localAddrWithPrefix.Address)
- }
- if r.RemoteAddress() != test.netCfg.remoteAddr {
- t.Errorf("got r.RemoteAddress() = %s, want = %s", r.RemoteAddress(), test.netCfg.remoteAddr)
- }
-
- if t.Failed() {
- t.FailNow()
- }
-
- // Sending a packet should always go through NIC2 since we only install a
- // route to test.netCfg.remoteAddr through NIC2.
- data := buffer.View([]byte{1, 2, 3, 4})
- if err := send(r, data); err != nil {
- t.Fatalf("send(_, _): %s", err)
- }
- if n := ep1.Drain(); n != 0 {
- t.Errorf("got %d unexpected packets from ep1", n)
- }
- pkt, ok := ep2.Read()
- if !ok {
- t.Fatal("packet not sent through ep2")
- }
- if pkt.Route.LocalAddress != test.localAddrWithPrefix.Address {
- t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddrWithPrefix.Address)
- }
- if pkt.Route.RemoteAddress != test.netCfg.remoteAddr {
- t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.netCfg.remoteAddr)
- }
-
- if !test.forwardingEnabled || !test.dependentOnForwarding {
- return
- }
-
- // Disabling forwarding when the route is dependent on forwarding being
- // enabled should make the route invalid.
- if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, false); err != nil {
- t.Fatalf("SetForwardingDefaultAndAllNICs(%d, false): %s", test.netCfg.proto, err)
- }
- {
- err := send(r, data)
- if _, ok := err.(*tcpip.ErrInvalidEndpointState); !ok {
- t.Fatalf("got send(_, _) = %s, want = %s", err, &tcpip.ErrInvalidEndpointState{})
- }
- }
- if n := ep1.Drain(); n != 0 {
- t.Errorf("got %d unexpected packets from ep1", n)
- }
- if n := ep2.Drain(); n != 0 {
- t.Errorf("got %d unexpected packets from ep2", n)
- }
- })
- }
-}
-
-func TestWritePacketToRemote(t *testing.T) {
- const nicID = 1
- const MTU = 1280
- e := channel.New(1, MTU, linkAddr1)
- s := stack.New(stack.Options{})
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("CreateNIC(%d) = %s", nicID, err)
- }
- tests := []struct {
- name string
- protocol tcpip.NetworkProtocolNumber
- payload []byte
- }{
- {
- name: "SuccessIPv4",
- protocol: header.IPv4ProtocolNumber,
- payload: []byte{1, 2, 3, 4},
- },
- {
- name: "SuccessIPv6",
- protocol: header.IPv6ProtocolNumber,
- payload: []byte{5, 6, 7, 8},
- },
- }
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- if err := s.WritePacketToRemote(nicID, linkAddr2, test.protocol, buffer.View(test.payload).ToVectorisedView()); err != nil {
- t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s", err)
- }
-
- pkt, ok := e.Read()
- if got, want := ok, true; got != want {
- t.Fatalf("e.Read() = %t, want %t", got, want)
- }
- if got, want := pkt.Proto, test.protocol; got != want {
- t.Fatalf("pkt.Proto = %d, want %d", got, want)
- }
- if pkt.Route.RemoteLinkAddress != linkAddr2 {
- t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", pkt.Route.RemoteLinkAddress, linkAddr2)
- }
- if diff := cmp.Diff(pkt.Pkt.Data().AsRange().ToOwnedView(), buffer.View(test.payload)); diff != "" {
- t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff)
- }
- })
- }
-
- t.Run("InvalidNICID", func(t *testing.T) {
- err := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView())
- if _, ok := err.(*tcpip.ErrUnknownDevice); !ok {
- t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", err, &tcpip.ErrUnknownDevice{})
- }
- pkt, ok := e.Read()
- if got, want := ok, false; got != want {
- t.Fatalf("e.Read() = %t, %v; want %t", got, pkt, want)
- }
- })
-}
-
-func TestClearNeighborCacheOnNICDisable(t *testing.T) {
- const (
- nicID = 1
- linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- )
-
- var (
- ipv4Addr = testutil.MustParse4("1.2.3.4")
- ipv6Addr = testutil.MustParse6("102:304:102:304:102:304:102:304")
- )
-
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
- Clock: clock,
- })
- e := channel.New(0, 0, "")
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- addrs := []struct {
- proto tcpip.NetworkProtocolNumber
- addr tcpip.Address
- }{
- {
- proto: ipv4.ProtocolNumber,
- addr: ipv4Addr,
- },
- {
- proto: ipv6.ProtocolNumber,
- addr: ipv6Addr,
- },
- }
- for _, addr := range addrs {
- if err := s.AddStaticNeighbor(nicID, addr.proto, addr.addr, linkAddr); err != nil {
- t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, addr.proto, addr.addr, linkAddr, err)
- }
-
- if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil {
- t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err)
- } else if diff := cmp.Diff(
- []stack.NeighborEntry{{Addr: addr.addr, LinkAddr: linkAddr, State: stack.Static, UpdatedAt: clock.Now()}},
- neighbors,
- ); diff != "" {
- t.Fatalf("proto=%d neighbors mismatch (-want +got):\n%s", addr.proto, diff)
- }
- }
-
- // Disabling the NIC should clear the neighbor table.
- if err := s.DisableNIC(nicID); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
- }
- for _, addr := range addrs {
- if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil {
- t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err)
- } else if len(neighbors) != 0 {
- t.Fatalf("got proto=%d len(neighbors) = %d, want = 0; neighbors = %#v", addr.proto, len(neighbors), neighbors)
- }
- }
-
- // Enabling the NIC should have an empty neighbor table.
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
- }
- for _, addr := range addrs {
- if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil {
- t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err)
- } else if len(neighbors) != 0 {
- t.Fatalf("got proto=%d len(neighbors) = %d, want = 0; neighbors = %#v", addr.proto, len(neighbors), neighbors)
- }
- }
-}
-
-func TestGetLinkAddressErrors(t *testing.T) {
- const (
- nicID = 1
- unknownNICID = nicID + 1
- )
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- })
- if err := s.CreateNIC(nicID, channel.New(0, 0, "")); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- {
- err := s.GetLinkAddress(unknownNICID, "", "", ipv4.ProtocolNumber, nil)
- if _, ok := err.(*tcpip.ErrUnknownNICID); !ok {
- t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, &tcpip.ErrUnknownNICID{})
- }
- }
- {
- err := s.GetLinkAddress(nicID, "", "", ipv4.ProtocolNumber, nil)
- if _, ok := err.(*tcpip.ErrNotSupported); !ok {
- t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, &tcpip.ErrNotSupported{})
- }
- }
-}
-
-func TestStaticGetLinkAddress(t *testing.T) {
- const (
- nicID = 1
- )
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
- })
- e := channel.New(0, 0, "")
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- tests := []struct {
- name string
- proto tcpip.NetworkProtocolNumber
- addr tcpip.Address
- expectedLinkAddr tcpip.LinkAddress
- }{
- {
- name: "IPv4",
- proto: ipv4.ProtocolNumber,
- addr: header.IPv4Broadcast,
- expectedLinkAddr: header.EthernetBroadcastAddress,
- },
- {
- name: "IPv6",
- proto: ipv6.ProtocolNumber,
- addr: header.IPv6AllNodesMulticastAddress,
- expectedLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress),
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- ch := make(chan stack.LinkResolutionResult, 1)
- if err := s.GetLinkAddress(nicID, test.addr, "", test.proto, func(r stack.LinkResolutionResult) {
- ch <- r
- }); err != nil {
- t.Fatalf("s.GetLinkAddress(%d, %s, '', %d, _): %s", nicID, test.addr, test.proto, err)
- }
-
- if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Err: nil}, <-ch); diff != "" {
- t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
diff --git a/pkg/tcpip/stack/stack_unsafe_state_autogen.go b/pkg/tcpip/stack/stack_unsafe_state_autogen.go
new file mode 100644
index 000000000..758ab3457
--- /dev/null
+++ b/pkg/tcpip/stack/stack_unsafe_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package stack
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
deleted file mode 100644
index cd3a8c25a..000000000
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ /dev/null
@@ -1,454 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package stack_test
-
-import (
- "io/ioutil"
- "math"
- "math/rand"
- "strconv"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/ports"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- testSrcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- testDstAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
-
- testSrcAddrV4 = "\x0a\x00\x00\x01"
- testDstAddrV4 = "\x0a\x00\x00\x02"
-
- testDstPort = 1234
- testSrcPort = 4096
-)
-
-type testContext struct {
- linkEps map[tcpip.NICID]*channel.Endpoint
- s *stack.Stack
- wq waiter.Queue
-}
-
-// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs.
-func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
- linkEps := make(map[tcpip.NICID]*channel.Endpoint)
- for _, linkEpID := range linkEpIDs {
- channelEp := channel.New(256, mtu, "")
- if err := s.CreateNIC(linkEpID, channelEp); err != nil {
- t.Fatalf("CreateNIC failed: %s", err)
- }
- linkEps[linkEpID] = channelEp
-
- protocolAddrV4 := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.Address(testDstAddrV4).WithPrefix(),
- }
- if err := s.AddProtocolAddress(linkEpID, protocolAddrV4, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV4, err)
- }
-
- protocolAddrV6 := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: testDstAddrV6.WithPrefix(),
- }
- if err := s.AddProtocolAddress(linkEpID, protocolAddrV6, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV6, err)
- }
- }
-
- s.SetRouteTable([]tcpip.Route{
- {Destination: header.IPv4EmptySubnet, NIC: 1},
- {Destination: header.IPv6EmptySubnet, NIC: 1},
- })
-
- return &testContext{
- s: s,
- linkEps: linkEps,
- }
-}
-
-type headers struct {
- srcPort uint16
- dstPort uint16
-}
-
-func newPayload() []byte {
- b := make([]byte, 30+rand.Intn(100))
- for i := range b {
- b[i] = byte(rand.Intn(256))
- }
- return b
-}
-
-func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
- buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
- payloadStart := len(buf) - len(payload)
- copy(buf[payloadStart:], payload)
-
- // Initialize the IP header.
- ip := header.IPv4(buf)
- ip.Encode(&header.IPv4Fields{
- TOS: 0x80,
- TotalLength: uint16(len(buf)),
- TTL: 65,
- Protocol: uint8(udp.ProtocolNumber),
- SrcAddr: testSrcAddrV4,
- DstAddr: testDstAddrV4,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- // Initialize the UDP header.
- u := header.UDP(buf[header.IPv4MinimumSize:])
- u.Encode(&header.UDPFields{
- SrcPort: h.srcPort,
- DstPort: h.dstPort,
- Length: uint16(header.UDPMinimumSize + len(payload)),
- })
-
- // Calculate the UDP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV4, testDstAddrV4, uint16(len(u)))
-
- // Calculate the UDP checksum and set it.
- xsum = header.Checksum(payload, xsum)
- u.SetChecksum(^u.CalculateChecksum(xsum))
-
- // Inject packet.
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- })
- c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, pkt)
-}
-
-func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
- // Allocate a buffer for data and headers.
- buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
- copy(buf[len(buf)-len(payload):], payload)
-
- // Initialize the IP header.
- ip := header.IPv6(buf)
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
- TransportProtocol: udp.ProtocolNumber,
- HopLimit: 65,
- SrcAddr: testSrcAddrV6,
- DstAddr: testDstAddrV6,
- })
-
- // Initialize the UDP header.
- u := header.UDP(buf[header.IPv6MinimumSize:])
- u.Encode(&header.UDPFields{
- SrcPort: h.srcPort,
- DstPort: h.dstPort,
- Length: uint16(header.UDPMinimumSize + len(payload)),
- })
-
- // Calculate the UDP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV6, testDstAddrV6, uint16(len(u)))
-
- // Calculate the UDP checksum and set it.
- xsum = header.Checksum(payload, xsum)
- u.SetChecksum(^u.CalculateChecksum(xsum))
-
- // Inject packet.
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- })
- c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, pkt)
-}
-
-func TestTransportDemuxerRegister(t *testing.T) {
- for _, test := range []struct {
- name string
- proto tcpip.NetworkProtocolNumber
- want tcpip.Error
- }{
- {"failure", ipv6.ProtocolNumber, &tcpip.ErrUnknownProtocol{}},
- {"success", ipv4.ProtocolNumber, nil},
- } {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
- var wq waiter.Queue
- ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatal(err)
- }
- tEP, ok := ep.(stack.TransportEndpoint)
- if !ok {
- t.Fatalf("%T does not implement stack.TransportEndpoint", ep)
- }
- if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want {
- t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want)
- }
- })
- }
-}
-
-func TestTransportDemuxerRegisterMultiple(t *testing.T) {
- type test struct {
- flags ports.Flags
- want tcpip.Error
- }
- for _, subtest := range []struct {
- name string
- tests []test
- }{
- {"zeroFlags", []test{
- {ports.Flags{}, nil},
- {ports.Flags{}, &tcpip.ErrPortInUse{}},
- }},
- {"multibindFlags", []test{
- // Allow multiple registrations same TransportEndpointID with multibind flags.
- {ports.Flags{LoadBalanced: true, MostRecent: true}, nil},
- {ports.Flags{LoadBalanced: true, MostRecent: true}, nil},
- // Disallow registration w/same ID for a non-multibindflag.
- {ports.Flags{TupleOnly: true}, &tcpip.ErrPortInUse{}},
- }},
- } {
- t.Run(subtest.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
- var eps []tcpip.Endpoint
- for idx, test := range subtest.tests {
- var wq waiter.Queue
- ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatal(err)
- }
- eps = append(eps, ep)
- tEP, ok := ep.(stack.TransportEndpoint)
- if !ok {
- t.Fatalf("%T does not implement stack.TransportEndpoint", ep)
- }
- id := stack.TransportEndpointID{LocalPort: 1}
- if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber}, udp.ProtocolNumber, id, tEP, test.flags, 0), test.want; got != want {
- t.Fatalf("test index: %d, s.RegisterTransportEndpoint(ipv4.ProtocolNumber, udp.ProtocolNumber, _, _, %+v, 0) = %s, want %s", idx, test.flags, got, want)
- }
- }
- for _, ep := range eps {
- ep.Close()
- }
- })
- }
-}
-
-// TestBindToDeviceDistribution injects varied packets on input devices and checks that
-// the distribution of packets received matches expectations.
-func TestBindToDeviceDistribution(t *testing.T) {
- type endpointSockopts struct {
- reuse bool
- bindToDevice tcpip.NICID
- }
- tcs := []struct {
- name string
- // endpoints will received the inject packets.
- endpoints []endpointSockopts
- // wantDistributions is the want ratio of packets received on each
- // endpoint for each NIC on which packets are injected.
- wantDistributions map[tcpip.NICID][]float64
- }{
- {
- name: "BindPortReuse",
- // 5 endpoints that all have reuse set.
- endpoints: []endpointSockopts{
- {reuse: true, bindToDevice: 0},
- {reuse: true, bindToDevice: 0},
- {reuse: true, bindToDevice: 0},
- {reuse: true, bindToDevice: 0},
- {reuse: true, bindToDevice: 0},
- },
- wantDistributions: map[tcpip.NICID][]float64{
- // Injected packets on dev0 get distributed evenly.
- 1: {0.2, 0.2, 0.2, 0.2, 0.2},
- },
- },
- {
- name: "BindToDevice",
- // 3 endpoints with various bindings.
- endpoints: []endpointSockopts{
- {reuse: false, bindToDevice: 1},
- {reuse: false, bindToDevice: 2},
- {reuse: false, bindToDevice: 3},
- },
- wantDistributions: map[tcpip.NICID][]float64{
- // Injected packets on dev0 go only to the endpoint bound to dev0.
- 1: {1, 0, 0},
- // Injected packets on dev1 go only to the endpoint bound to dev1.
- 2: {0, 1, 0},
- // Injected packets on dev2 go only to the endpoint bound to dev2.
- 3: {0, 0, 1},
- },
- },
- {
- name: "ReuseAndBindToDevice",
- // 6 endpoints with various bindings.
- endpoints: []endpointSockopts{
- {reuse: true, bindToDevice: 1},
- {reuse: true, bindToDevice: 1},
- {reuse: true, bindToDevice: 2},
- {reuse: true, bindToDevice: 2},
- {reuse: true, bindToDevice: 2},
- {reuse: true, bindToDevice: 0},
- },
- wantDistributions: map[tcpip.NICID][]float64{
- // Injected packets on dev0 get distributed among endpoints bound to
- // dev0.
- 1: {0.5, 0.5, 0, 0, 0, 0},
- // Injected packets on dev1 get distributed among endpoints bound to
- // dev1 or unbound.
- 2: {0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
- // Injected packets on dev999 go only to the unbound.
- 1000: {0, 0, 0, 0, 0, 1},
- },
- },
- }
- protos := map[string]tcpip.NetworkProtocolNumber{
- "IPv4": ipv4.ProtocolNumber,
- "IPv6": ipv6.ProtocolNumber,
- }
-
- for _, test := range tcs {
- for protoName, protoNum := range protos {
- for device, wantDistribution := range test.wantDistributions {
- t.Run(test.name+protoName+"-"+strconv.Itoa(int(device)), func(t *testing.T) {
- // Create the NICs.
- var devices []tcpip.NICID
- for d := range test.wantDistributions {
- devices = append(devices, d)
- }
- c := newDualTestContextMultiNIC(t, defaultMTU, devices)
-
- // Create endpoints and bind each to a NIC, sometimes reusing ports.
- eps := make(map[tcpip.Endpoint]int)
- pollChannel := make(chan tcpip.Endpoint)
- for i, endpoint := range test.endpoints {
- // Try to receive the data.
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- t.Cleanup(func() {
- wq.EventUnregister(&we)
- close(ch)
- })
-
- var err tcpip.Error
- ep, err := c.s.NewEndpoint(udp.ProtocolNumber, protoNum, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- t.Cleanup(ep.Close)
- eps[ep] = i
-
- go func(ep tcpip.Endpoint) {
- for range ch {
- pollChannel <- ep
- }
- }(ep)
-
- ep.SocketOptions().SetReusePort(endpoint.reuse)
- if err := ep.SocketOptions().SetBindToDevice(int32(endpoint.bindToDevice)); err != nil {
- t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", endpoint.bindToDevice, endpoint.bindToDevice, i, err)
- }
-
- var dstAddr tcpip.Address
- switch protoNum {
- case ipv4.ProtocolNumber:
- dstAddr = testDstAddrV4
- case ipv6.ProtocolNumber:
- dstAddr = testDstAddrV6
- default:
- t.Fatalf("unexpected protocol number: %d", protoNum)
- }
- if err := ep.Bind(tcpip.FullAddress{Addr: dstAddr, Port: testDstPort}); err != nil {
- t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err)
- }
- }
-
- // Send packets across a range of ports, checking that packets from
- // the same source port are always demultiplexed to the same
- // destination endpoint.
- npackets := 10_000
- nports := 1_000
- if got, want := len(test.endpoints), len(wantDistribution); got != want {
- t.Fatalf("got len(test.endpoints) = %d, want %d", got, want)
- }
- endpoints := make(map[uint16]tcpip.Endpoint)
- stats := make(map[tcpip.Endpoint]int)
- for i := 0; i < npackets; i++ {
- // Send a packet.
- port := uint16(i % nports)
- payload := newPayload()
- hdrs := &headers{
- srcPort: testSrcPort + port,
- dstPort: testDstPort,
- }
- switch protoNum {
- case ipv4.ProtocolNumber:
- c.sendV4Packet(payload, hdrs, device)
- case ipv6.ProtocolNumber:
- c.sendV6Packet(payload, hdrs, device)
- default:
- t.Fatalf("unexpected protocol number: %d", protoNum)
- }
-
- ep := <-pollChannel
- if _, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != nil {
- t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err)
- }
- stats[ep]++
- if i < nports {
- endpoints[uint16(i)] = ep
- } else {
- // Check that all packets from one client are handled by the same
- // socket.
- if want, got := endpoints[port], ep; want != got {
- t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got])
- }
- }
- }
-
- // Check that a packet distribution is as expected.
- for ep, i := range eps {
- wantRatio := wantDistribution[i]
- wantRecv := wantRatio * float64(npackets)
- actualRecv := stats[ep]
- actualRatio := float64(stats[ep]) / float64(npackets)
- // The deviation is less than 10%.
- if math.Abs(actualRatio-wantRatio) > 0.05 {
- t.Errorf("want about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantRatio*100, wantRecv, npackets, i, actualRatio*100, actualRecv, npackets)
- }
- }
- })
- }
- }
- }
-}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
deleted file mode 100644
index 655931715..000000000
--- a/pkg/tcpip/stack/transport_test.go
+++ /dev/null
@@ -1,576 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package stack_test
-
-import (
- "bytes"
- "io"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/ports"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- fakeTransNumber tcpip.TransportProtocolNumber = 1
- fakeTransHeaderLen int = 3
-)
-
-// fakeTransportEndpoint is a transport-layer protocol endpoint. It counts
-// received packets; the counts of all endpoints are aggregated in the protocol
-// descriptor.
-//
-// Headers of this protocol are fakeTransHeaderLen bytes, but we currently don't
-// use it.
-type fakeTransportEndpoint struct {
- stack.TransportEndpointInfo
- tcpip.DefaultSocketOptionsHandler
-
- proto *fakeTransportProtocol
- peerAddr tcpip.Address
- route *stack.Route
- uniqueID uint64
-
- // acceptQueue is non-nil iff bound.
- acceptQueue []*fakeTransportEndpoint
-
- // ops is used to set and get socket options.
- ops tcpip.SocketOptions
-}
-
-func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo {
- return &f.TransportEndpointInfo
-}
-
-func (*fakeTransportEndpoint) Stats() tcpip.EndpointStats {
- return nil
-}
-
-func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {}
-
-func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions {
- return &f.ops
-}
-
-func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, s *stack.Stack) tcpip.Endpoint {
- ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: s.UniqueID()}
- ep.ops.InitHandler(ep, s, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
- return ep
-}
-
-func (f *fakeTransportEndpoint) Abort() {
- f.Close()
-}
-
-func (f *fakeTransportEndpoint) Close() {
- // TODO(gvisor.dev/issue/5153): Consider retaining the route.
- f.route.Release()
-}
-
-func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
- return mask
-}
-
-func (*fakeTransportEndpoint) Read(io.Writer, tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
- return tcpip.ReadResult{}, nil
-}
-
-func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
- if len(f.route.RemoteAddress()) == 0 {
- return 0, &tcpip.ErrNoRoute{}
- }
-
- v := make([]byte, p.Len())
- if _, err := io.ReadFull(p, v); err != nil {
- return 0, &tcpip.ErrBadBuffer{}
- }
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen,
- Data: buffer.View(v).ToVectorisedView(),
- })
- _ = pkt.TransportHeader().Push(fakeTransHeaderLen)
- if err := f.route.WritePacket(stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil {
- return 0, err
- }
-
- return int64(len(v)), nil
-}
-
-// SetSockOpt sets a socket option. Currently not supported.
-func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error {
- return &tcpip.ErrInvalidEndpointState{}
-}
-
-// SetSockOptInt sets a socket option. Currently not supported.
-func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error {
- return &tcpip.ErrInvalidEndpointState{}
-}
-
-// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
- return -1, &tcpip.ErrUnknownProtocolOption{}
-}
-
-// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (*fakeTransportEndpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error {
- return &tcpip.ErrInvalidEndpointState{}
-}
-
-// Disconnect implements tcpip.Endpoint.Disconnect.
-func (*fakeTransportEndpoint) Disconnect() tcpip.Error {
- return &tcpip.ErrNotSupported{}
-}
-
-func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
- f.peerAddr = addr.Addr
-
- // Find the route.
- r, err := f.proto.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- return &tcpip.ErrNoRoute{}
- }
-
- // Try to register so that we can start receiving packets.
- f.ID.RemoteAddress = addr.Addr
- err = f.proto.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */)
- if err != nil {
- r.Release()
- return err
- }
-
- f.route = r
-
- return nil
-}
-
-func (f *fakeTransportEndpoint) UniqueID() uint64 {
- return f.uniqueID
-}
-
-func (*fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) tcpip.Error {
- return nil
-}
-
-func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error {
- return nil
-}
-
-func (*fakeTransportEndpoint) Reset() {
-}
-
-func (*fakeTransportEndpoint) Listen(int) tcpip.Error {
- return nil
-}
-
-func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) {
- if len(f.acceptQueue) == 0 {
- return nil, nil, nil
- }
- a := f.acceptQueue[0]
- f.acceptQueue = f.acceptQueue[1:]
- return a, nil, nil
-}
-
-func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) tcpip.Error {
- if err := f.proto.stack.RegisterTransportEndpoint(
- []tcpip.NetworkProtocolNumber{fakeNetNumber},
- fakeTransNumber,
- stack.TransportEndpointID{LocalAddress: a.Addr},
- f,
- ports.Flags{},
- 0, /* bindtoDevice */
- ); err != nil {
- return err
- }
- f.acceptQueue = []*fakeTransportEndpoint{}
- return nil
-}
-
-func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
- return tcpip.FullAddress{}, nil
-}
-
-func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
- return tcpip.FullAddress{}, nil
-}
-
-func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
- // Increment the number of received packets.
- f.proto.packetCount++
- if f.acceptQueue == nil {
- return
- }
-
- netHdr := pkt.NetworkHeader().View()
- route, err := f.proto.stack.FindRoute(pkt.NICID, tcpip.Address(netHdr[dstAddrOffset]), tcpip.Address(netHdr[srcAddrOffset]), pkt.NetworkProtocolNumber, false /* multicastLoop */)
- if err != nil {
- return
- }
-
- ep := &fakeTransportEndpoint{
- TransportEndpointInfo: stack.TransportEndpointInfo{
- ID: f.ID,
- NetProto: f.NetProto,
- },
- proto: f.proto,
- peerAddr: route.RemoteAddress(),
- route: route,
- }
- ep.ops.InitHandler(ep, f.proto.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
- f.acceptQueue = append(f.acceptQueue, ep)
-}
-
-func (f *fakeTransportEndpoint) HandleError(stack.TransportError, *stack.PacketBuffer) {
- // Increment the number of received control packets.
- f.proto.controlCount++
-}
-
-func (*fakeTransportEndpoint) State() uint32 {
- return 0
-}
-
-func (*fakeTransportEndpoint) ModerateRecvBuf(copied int) {}
-
-func (*fakeTransportEndpoint) Resume(*stack.Stack) {}
-
-func (*fakeTransportEndpoint) Wait() {}
-
-func (*fakeTransportEndpoint) LastError() tcpip.Error {
- return nil
-}
-
-type fakeTransportGoodOption bool
-
-type fakeTransportBadOption bool
-
-type fakeTransportInvalidValueOption int
-
-type fakeTransportProtocolOptions struct {
- good bool
-}
-
-// fakeTransportProtocol is a transport-layer protocol descriptor. It
-// aggregates the number of packets received via endpoints of this protocol.
-type fakeTransportProtocol struct {
- stack *stack.Stack
-
- packetCount int
- controlCount int
- opts fakeTransportProtocolOptions
-}
-
-func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber {
- return fakeTransNumber
-}
-
-func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
- return newFakeTransportEndpoint(f, netProto, f.stack), nil
-}
-
-func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
- return nil, &tcpip.ErrUnknownProtocol{}
-}
-
-func (*fakeTransportProtocol) MinimumPacketSize() int {
- return fakeTransHeaderLen
-}
-
-func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err tcpip.Error) {
- return 0, 0, nil
-}
-
-func (*fakeTransportProtocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
- return stack.UnknownDestinationPacketHandled
-}
-
-func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip.Error {
- switch v := option.(type) {
- case *tcpip.TCPModerateReceiveBufferOption:
- f.opts.good = bool(*v)
- return nil
- default:
- return &tcpip.ErrUnknownProtocolOption{}
- }
-}
-
-func (f *fakeTransportProtocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Error {
- switch v := option.(type) {
- case *tcpip.TCPModerateReceiveBufferOption:
- *v = tcpip.TCPModerateReceiveBufferOption(f.opts.good)
- return nil
- default:
- return &tcpip.ErrUnknownProtocolOption{}
- }
-}
-
-// Abort implements TransportProtocol.Abort.
-func (*fakeTransportProtocol) Abort() {}
-
-// Close implements tcpip.Endpoint.Close.
-func (*fakeTransportProtocol) Close() {}
-
-// Wait implements TransportProtocol.Wait.
-func (*fakeTransportProtocol) Wait() {}
-
-// Parse implements TransportProtocol.Parse.
-func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool {
- _, ok := pkt.TransportHeader().Consume(fakeTransHeaderLen)
- return ok
-}
-
-func fakeTransFactory(s *stack.Stack) stack.TransportProtocol {
- return &fakeTransportProtocol{stack: s}
-}
-
-func TestTransportReceive(t *testing.T) {
- linkEP := channel.New(10, defaultMTU, "")
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory},
- })
- if err := s.CreateNIC(1, linkEP); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: "\x01",
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
- }
-
- // Create endpoint and connect to remote address.
- wq := waiter.Queue{}
- ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil {
- t.Fatalf("Connect failed: %v", err)
- }
-
- fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol)
-
- // Create buffer that will hold the packet.
- buf := buffer.NewView(30)
-
- // Make sure packet with wrong protocol is not delivered.
- buf[0] = 1
- buf[2] = 0
- linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- if fakeTrans.packetCount != 0 {
- t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0)
- }
-
- // Make sure packet from the wrong source is not delivered.
- buf[0] = 1
- buf[1] = 3
- buf[2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- if fakeTrans.packetCount != 0 {
- t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0)
- }
-
- // Make sure packet is delivered.
- buf[0] = 1
- buf[1] = 2
- buf[2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- if fakeTrans.packetCount != 1 {
- t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1)
- }
-}
-
-func TestTransportControlReceive(t *testing.T) {
- linkEP := channel.New(10, defaultMTU, "")
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory},
- })
- if err := s.CreateNIC(1, linkEP); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: "\x01",
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
- }
-
- // Create endpoint and connect to remote address.
- wq := waiter.Queue{}
- ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil {
- t.Fatalf("Connect failed: %v", err)
- }
-
- fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol)
-
- // Create buffer that will hold the control packet.
- buf := buffer.NewView(2*fakeNetHeaderLen + 30)
-
- // Outer packet contains the control protocol number.
- buf[0] = 1
- buf[1] = 0xfe
- buf[2] = uint8(fakeControlProtocol)
-
- // Make sure packet with wrong protocol is not delivered.
- buf[fakeNetHeaderLen+0] = 0
- buf[fakeNetHeaderLen+1] = 1
- buf[fakeNetHeaderLen+2] = 0
- linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- if fakeTrans.controlCount != 0 {
- t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0)
- }
-
- // Make sure packet from the wrong source is not delivered.
- buf[fakeNetHeaderLen+0] = 3
- buf[fakeNetHeaderLen+1] = 1
- buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- if fakeTrans.controlCount != 0 {
- t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0)
- }
-
- // Make sure packet is delivered.
- buf[fakeNetHeaderLen+0] = 2
- buf[fakeNetHeaderLen+1] = 1
- buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- if fakeTrans.controlCount != 1 {
- t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1)
- }
-}
-
-func TestTransportSend(t *testing.T) {
- linkEP := channel.New(10, defaultMTU, "")
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory},
- })
- if err := s.CreateNIC(1, linkEP); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: "\x01",
- PrefixLen: fakeDefaultPrefixLen,
- },
- }
- if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- // Create endpoint and bind it.
- wq := waiter.Queue{}
- ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil {
- t.Fatalf("Connect failed: %v", err)
- }
-
- // Create buffer that will hold the payload.
- b := make([]byte, 30)
- var r bytes.Reader
- r.Reset(b)
- if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("write failed: %v", err)
- }
-
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
- if fakeNet.sendPacketCount[2] != 1 {
- t.Errorf("sendPacketCount = %d, want %d", fakeNet.sendPacketCount[2], 1)
- }
-}
-
-func TestTransportOptions(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory},
- })
-
- v := tcpip.TCPModerateReceiveBufferOption(true)
- if err := s.SetTransportProtocolOption(fakeTransNumber, &v); err != nil {
- t.Errorf("s.SetTransportProtocolOption(fakeTrans, &%T(%t)): %s", v, v, err)
- }
- v = false
- if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil {
- t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &%T): %s", v, err)
- }
- if !v {
- t.Fatalf("got tcpip.TCPModerateReceiveBufferOption = false, want = true")
- }
-}
diff --git a/pkg/tcpip/stack/tuple_list.go b/pkg/tcpip/stack/tuple_list.go
new file mode 100644
index 000000000..31d0feefa
--- /dev/null
+++ b/pkg/tcpip/stack/tuple_list.go
@@ -0,0 +1,221 @@
+package stack
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type tupleElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (tupleElementMapper) linkerFor(elem *tuple) *tuple { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type tupleList struct {
+ head *tuple
+ tail *tuple
+}
+
+// Reset resets list l to the empty state.
+func (l *tupleList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+//
+//go:nosplit
+func (l *tupleList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+//
+//go:nosplit
+func (l *tupleList) Front() *tuple {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+//
+//go:nosplit
+func (l *tupleList) Back() *tuple {
+ return l.tail
+}
+
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+//
+//go:nosplit
+func (l *tupleList) Len() (count int) {
+ for e := l.Front(); e != nil; e = (tupleElementMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
+// PushFront inserts the element e at the front of list l.
+//
+//go:nosplit
+func (l *tupleList) PushFront(e *tuple) {
+ linker := tupleElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
+ if l.head != nil {
+ tupleElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+//
+//go:nosplit
+func (l *tupleList) PushBack(e *tuple) {
+ linker := tupleElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
+ if l.tail != nil {
+ tupleElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+//
+//go:nosplit
+func (l *tupleList) PushBackList(m *tupleList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ tupleElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ tupleElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+//
+//go:nosplit
+func (l *tupleList) InsertAfter(b, e *tuple) {
+ bLinker := tupleElementMapper{}.linkerFor(b)
+ eLinker := tupleElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
+
+ if a != nil {
+ tupleElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+//
+//go:nosplit
+func (l *tupleList) InsertBefore(a, e *tuple) {
+ aLinker := tupleElementMapper{}.linkerFor(a)
+ eLinker := tupleElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
+
+ if b != nil {
+ tupleElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+//
+//go:nosplit
+func (l *tupleList) Remove(e *tuple) {
+ linker := tupleElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
+
+ if prev != nil {
+ tupleElementMapper{}.linkerFor(prev).SetNext(next)
+ } else if l.head == e {
+ l.head = next
+ }
+
+ if next != nil {
+ tupleElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else if l.tail == e {
+ l.tail = prev
+ }
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type tupleEntry struct {
+ next *tuple
+ prev *tuple
+}
+
+// Next returns the entry that follows e in the list.
+//
+//go:nosplit
+func (e *tupleEntry) Next() *tuple {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *tupleEntry) Prev() *tuple {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+//
+//go:nosplit
+func (e *tupleEntry) SetNext(elem *tuple) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *tupleEntry) SetPrev(elem *tuple) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/tcpip_state_autogen.go b/pkg/tcpip/tcpip_state_autogen.go
new file mode 100644
index 000000000..c3df463eb
--- /dev/null
+++ b/pkg/tcpip/tcpip_state_autogen.go
@@ -0,0 +1,1361 @@
+// automatically generated by stateify.
+
+package tcpip
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (e *ErrAborted) StateTypeName() string {
+ return "pkg/tcpip.ErrAborted"
+}
+
+func (e *ErrAborted) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrAborted) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrAborted) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrAborted) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrAborted) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrAddressFamilyNotSupported) StateTypeName() string {
+ return "pkg/tcpip.ErrAddressFamilyNotSupported"
+}
+
+func (e *ErrAddressFamilyNotSupported) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrAddressFamilyNotSupported) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrAddressFamilyNotSupported) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrAddressFamilyNotSupported) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrAddressFamilyNotSupported) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrAlreadyBound) StateTypeName() string {
+ return "pkg/tcpip.ErrAlreadyBound"
+}
+
+func (e *ErrAlreadyBound) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrAlreadyBound) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrAlreadyBound) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrAlreadyBound) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrAlreadyBound) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrAlreadyConnected) StateTypeName() string {
+ return "pkg/tcpip.ErrAlreadyConnected"
+}
+
+func (e *ErrAlreadyConnected) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrAlreadyConnected) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrAlreadyConnected) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrAlreadyConnected) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrAlreadyConnected) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrAlreadyConnecting) StateTypeName() string {
+ return "pkg/tcpip.ErrAlreadyConnecting"
+}
+
+func (e *ErrAlreadyConnecting) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrAlreadyConnecting) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrAlreadyConnecting) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrAlreadyConnecting) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrAlreadyConnecting) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrBadAddress) StateTypeName() string {
+ return "pkg/tcpip.ErrBadAddress"
+}
+
+func (e *ErrBadAddress) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrBadAddress) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrBadAddress) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrBadAddress) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrBadAddress) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrBadBuffer) StateTypeName() string {
+ return "pkg/tcpip.ErrBadBuffer"
+}
+
+func (e *ErrBadBuffer) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrBadBuffer) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrBadBuffer) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrBadBuffer) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrBadBuffer) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrBadLocalAddress) StateTypeName() string {
+ return "pkg/tcpip.ErrBadLocalAddress"
+}
+
+func (e *ErrBadLocalAddress) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrBadLocalAddress) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrBadLocalAddress) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrBadLocalAddress) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrBadLocalAddress) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrBroadcastDisabled) StateTypeName() string {
+ return "pkg/tcpip.ErrBroadcastDisabled"
+}
+
+func (e *ErrBroadcastDisabled) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrBroadcastDisabled) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrBroadcastDisabled) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrBroadcastDisabled) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrBroadcastDisabled) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrClosedForReceive) StateTypeName() string {
+ return "pkg/tcpip.ErrClosedForReceive"
+}
+
+func (e *ErrClosedForReceive) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrClosedForReceive) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrClosedForReceive) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrClosedForReceive) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrClosedForReceive) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrClosedForSend) StateTypeName() string {
+ return "pkg/tcpip.ErrClosedForSend"
+}
+
+func (e *ErrClosedForSend) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrClosedForSend) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrClosedForSend) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrClosedForSend) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrClosedForSend) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrConnectStarted) StateTypeName() string {
+ return "pkg/tcpip.ErrConnectStarted"
+}
+
+func (e *ErrConnectStarted) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrConnectStarted) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrConnectStarted) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrConnectStarted) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrConnectStarted) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrConnectionAborted) StateTypeName() string {
+ return "pkg/tcpip.ErrConnectionAborted"
+}
+
+func (e *ErrConnectionAborted) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrConnectionAborted) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrConnectionAborted) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrConnectionAborted) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrConnectionAborted) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrConnectionRefused) StateTypeName() string {
+ return "pkg/tcpip.ErrConnectionRefused"
+}
+
+func (e *ErrConnectionRefused) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrConnectionRefused) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrConnectionRefused) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrConnectionRefused) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrConnectionRefused) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrConnectionReset) StateTypeName() string {
+ return "pkg/tcpip.ErrConnectionReset"
+}
+
+func (e *ErrConnectionReset) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrConnectionReset) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrConnectionReset) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrConnectionReset) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrConnectionReset) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrDestinationRequired) StateTypeName() string {
+ return "pkg/tcpip.ErrDestinationRequired"
+}
+
+func (e *ErrDestinationRequired) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrDestinationRequired) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrDestinationRequired) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrDestinationRequired) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrDestinationRequired) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrDuplicateAddress) StateTypeName() string {
+ return "pkg/tcpip.ErrDuplicateAddress"
+}
+
+func (e *ErrDuplicateAddress) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrDuplicateAddress) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrDuplicateAddress) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrDuplicateAddress) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrDuplicateAddress) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrDuplicateNICID) StateTypeName() string {
+ return "pkg/tcpip.ErrDuplicateNICID"
+}
+
+func (e *ErrDuplicateNICID) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrDuplicateNICID) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrDuplicateNICID) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrDuplicateNICID) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrDuplicateNICID) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrInvalidEndpointState) StateTypeName() string {
+ return "pkg/tcpip.ErrInvalidEndpointState"
+}
+
+func (e *ErrInvalidEndpointState) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrInvalidEndpointState) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrInvalidEndpointState) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrInvalidEndpointState) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrInvalidEndpointState) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrInvalidOptionValue) StateTypeName() string {
+ return "pkg/tcpip.ErrInvalidOptionValue"
+}
+
+func (e *ErrInvalidOptionValue) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrInvalidOptionValue) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrInvalidOptionValue) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrInvalidOptionValue) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrInvalidOptionValue) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrInvalidPortRange) StateTypeName() string {
+ return "pkg/tcpip.ErrInvalidPortRange"
+}
+
+func (e *ErrInvalidPortRange) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrInvalidPortRange) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrInvalidPortRange) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrInvalidPortRange) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrInvalidPortRange) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrMalformedHeader) StateTypeName() string {
+ return "pkg/tcpip.ErrMalformedHeader"
+}
+
+func (e *ErrMalformedHeader) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrMalformedHeader) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrMalformedHeader) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrMalformedHeader) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrMalformedHeader) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrMessageTooLong) StateTypeName() string {
+ return "pkg/tcpip.ErrMessageTooLong"
+}
+
+func (e *ErrMessageTooLong) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrMessageTooLong) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrMessageTooLong) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrMessageTooLong) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrMessageTooLong) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrNetworkUnreachable) StateTypeName() string {
+ return "pkg/tcpip.ErrNetworkUnreachable"
+}
+
+func (e *ErrNetworkUnreachable) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrNetworkUnreachable) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrNetworkUnreachable) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrNetworkUnreachable) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrNetworkUnreachable) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrNoBufferSpace) StateTypeName() string {
+ return "pkg/tcpip.ErrNoBufferSpace"
+}
+
+func (e *ErrNoBufferSpace) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrNoBufferSpace) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrNoBufferSpace) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrNoBufferSpace) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrNoBufferSpace) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrNoPortAvailable) StateTypeName() string {
+ return "pkg/tcpip.ErrNoPortAvailable"
+}
+
+func (e *ErrNoPortAvailable) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrNoPortAvailable) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrNoPortAvailable) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrNoPortAvailable) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrNoPortAvailable) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrNoRoute) StateTypeName() string {
+ return "pkg/tcpip.ErrNoRoute"
+}
+
+func (e *ErrNoRoute) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrNoRoute) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrNoRoute) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrNoRoute) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrNoRoute) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrNoSuchFile) StateTypeName() string {
+ return "pkg/tcpip.ErrNoSuchFile"
+}
+
+func (e *ErrNoSuchFile) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrNoSuchFile) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrNoSuchFile) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrNoSuchFile) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrNoSuchFile) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrNotConnected) StateTypeName() string {
+ return "pkg/tcpip.ErrNotConnected"
+}
+
+func (e *ErrNotConnected) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrNotConnected) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrNotConnected) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrNotConnected) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrNotConnected) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrNotPermitted) StateTypeName() string {
+ return "pkg/tcpip.ErrNotPermitted"
+}
+
+func (e *ErrNotPermitted) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrNotPermitted) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrNotPermitted) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrNotPermitted) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrNotPermitted) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrNotSupported) StateTypeName() string {
+ return "pkg/tcpip.ErrNotSupported"
+}
+
+func (e *ErrNotSupported) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrNotSupported) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrNotSupported) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrNotSupported) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrNotSupported) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrPortInUse) StateTypeName() string {
+ return "pkg/tcpip.ErrPortInUse"
+}
+
+func (e *ErrPortInUse) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrPortInUse) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrPortInUse) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrPortInUse) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrPortInUse) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrQueueSizeNotSupported) StateTypeName() string {
+ return "pkg/tcpip.ErrQueueSizeNotSupported"
+}
+
+func (e *ErrQueueSizeNotSupported) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrQueueSizeNotSupported) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrQueueSizeNotSupported) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrQueueSizeNotSupported) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrQueueSizeNotSupported) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrTimeout) StateTypeName() string {
+ return "pkg/tcpip.ErrTimeout"
+}
+
+func (e *ErrTimeout) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrTimeout) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrTimeout) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrTimeout) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrTimeout) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrUnknownDevice) StateTypeName() string {
+ return "pkg/tcpip.ErrUnknownDevice"
+}
+
+func (e *ErrUnknownDevice) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrUnknownDevice) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrUnknownDevice) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrUnknownDevice) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrUnknownDevice) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrUnknownNICID) StateTypeName() string {
+ return "pkg/tcpip.ErrUnknownNICID"
+}
+
+func (e *ErrUnknownNICID) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrUnknownNICID) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrUnknownNICID) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrUnknownNICID) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrUnknownNICID) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrUnknownProtocol) StateTypeName() string {
+ return "pkg/tcpip.ErrUnknownProtocol"
+}
+
+func (e *ErrUnknownProtocol) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrUnknownProtocol) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrUnknownProtocol) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrUnknownProtocol) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrUnknownProtocol) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrUnknownProtocolOption) StateTypeName() string {
+ return "pkg/tcpip.ErrUnknownProtocolOption"
+}
+
+func (e *ErrUnknownProtocolOption) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrUnknownProtocolOption) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrUnknownProtocolOption) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrUnknownProtocolOption) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrUnknownProtocolOption) StateLoad(stateSourceObject state.Source) {
+}
+
+func (e *ErrWouldBlock) StateTypeName() string {
+ return "pkg/tcpip.ErrWouldBlock"
+}
+
+func (e *ErrWouldBlock) StateFields() []string {
+ return []string{}
+}
+
+func (e *ErrWouldBlock) beforeSave() {}
+
+// +checklocksignore
+func (e *ErrWouldBlock) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+}
+
+func (e *ErrWouldBlock) afterLoad() {}
+
+// +checklocksignore
+func (e *ErrWouldBlock) StateLoad(stateSourceObject state.Source) {
+}
+
+func (l *sockErrorList) StateTypeName() string {
+ return "pkg/tcpip.sockErrorList"
+}
+
+func (l *sockErrorList) StateFields() []string {
+ return []string{
+ "head",
+ "tail",
+ }
+}
+
+func (l *sockErrorList) beforeSave() {}
+
+// +checklocksignore
+func (l *sockErrorList) StateSave(stateSinkObject state.Sink) {
+ l.beforeSave()
+ stateSinkObject.Save(0, &l.head)
+ stateSinkObject.Save(1, &l.tail)
+}
+
+func (l *sockErrorList) afterLoad() {}
+
+// +checklocksignore
+func (l *sockErrorList) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &l.head)
+ stateSourceObject.Load(1, &l.tail)
+}
+
+func (e *sockErrorEntry) StateTypeName() string {
+ return "pkg/tcpip.sockErrorEntry"
+}
+
+func (e *sockErrorEntry) StateFields() []string {
+ return []string{
+ "next",
+ "prev",
+ }
+}
+
+func (e *sockErrorEntry) beforeSave() {}
+
+// +checklocksignore
+func (e *sockErrorEntry) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.next)
+ stateSinkObject.Save(1, &e.prev)
+}
+
+func (e *sockErrorEntry) afterLoad() {}
+
+// +checklocksignore
+func (e *sockErrorEntry) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.next)
+ stateSourceObject.Load(1, &e.prev)
+}
+
+func (so *SocketOptions) StateTypeName() string {
+ return "pkg/tcpip.SocketOptions"
+}
+
+func (so *SocketOptions) StateFields() []string {
+ return []string{
+ "handler",
+ "broadcastEnabled",
+ "passCredEnabled",
+ "noChecksumEnabled",
+ "reuseAddressEnabled",
+ "reusePortEnabled",
+ "keepAliveEnabled",
+ "multicastLoopEnabled",
+ "receiveTOSEnabled",
+ "receiveTClassEnabled",
+ "receivePacketInfoEnabled",
+ "receiveIPv6PacketInfoEnabled",
+ "hdrIncludedEnabled",
+ "v6OnlyEnabled",
+ "quickAckEnabled",
+ "delayOptionEnabled",
+ "corkOptionEnabled",
+ "receiveOriginalDstAddress",
+ "recvErrEnabled",
+ "errQueue",
+ "bindToDevice",
+ "sendBufferSize",
+ "receiveBufferSize",
+ "linger",
+ }
+}
+
+func (so *SocketOptions) beforeSave() {}
+
+// +checklocksignore
+func (so *SocketOptions) StateSave(stateSinkObject state.Sink) {
+ so.beforeSave()
+ stateSinkObject.Save(0, &so.handler)
+ stateSinkObject.Save(1, &so.broadcastEnabled)
+ stateSinkObject.Save(2, &so.passCredEnabled)
+ stateSinkObject.Save(3, &so.noChecksumEnabled)
+ stateSinkObject.Save(4, &so.reuseAddressEnabled)
+ stateSinkObject.Save(5, &so.reusePortEnabled)
+ stateSinkObject.Save(6, &so.keepAliveEnabled)
+ stateSinkObject.Save(7, &so.multicastLoopEnabled)
+ stateSinkObject.Save(8, &so.receiveTOSEnabled)
+ stateSinkObject.Save(9, &so.receiveTClassEnabled)
+ stateSinkObject.Save(10, &so.receivePacketInfoEnabled)
+ stateSinkObject.Save(11, &so.receiveIPv6PacketInfoEnabled)
+ stateSinkObject.Save(12, &so.hdrIncludedEnabled)
+ stateSinkObject.Save(13, &so.v6OnlyEnabled)
+ stateSinkObject.Save(14, &so.quickAckEnabled)
+ stateSinkObject.Save(15, &so.delayOptionEnabled)
+ stateSinkObject.Save(16, &so.corkOptionEnabled)
+ stateSinkObject.Save(17, &so.receiveOriginalDstAddress)
+ stateSinkObject.Save(18, &so.recvErrEnabled)
+ stateSinkObject.Save(19, &so.errQueue)
+ stateSinkObject.Save(20, &so.bindToDevice)
+ stateSinkObject.Save(21, &so.sendBufferSize)
+ stateSinkObject.Save(22, &so.receiveBufferSize)
+ stateSinkObject.Save(23, &so.linger)
+}
+
+func (so *SocketOptions) afterLoad() {}
+
+// +checklocksignore
+func (so *SocketOptions) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &so.handler)
+ stateSourceObject.Load(1, &so.broadcastEnabled)
+ stateSourceObject.Load(2, &so.passCredEnabled)
+ stateSourceObject.Load(3, &so.noChecksumEnabled)
+ stateSourceObject.Load(4, &so.reuseAddressEnabled)
+ stateSourceObject.Load(5, &so.reusePortEnabled)
+ stateSourceObject.Load(6, &so.keepAliveEnabled)
+ stateSourceObject.Load(7, &so.multicastLoopEnabled)
+ stateSourceObject.Load(8, &so.receiveTOSEnabled)
+ stateSourceObject.Load(9, &so.receiveTClassEnabled)
+ stateSourceObject.Load(10, &so.receivePacketInfoEnabled)
+ stateSourceObject.Load(11, &so.receiveIPv6PacketInfoEnabled)
+ stateSourceObject.Load(12, &so.hdrIncludedEnabled)
+ stateSourceObject.Load(13, &so.v6OnlyEnabled)
+ stateSourceObject.Load(14, &so.quickAckEnabled)
+ stateSourceObject.Load(15, &so.delayOptionEnabled)
+ stateSourceObject.Load(16, &so.corkOptionEnabled)
+ stateSourceObject.Load(17, &so.receiveOriginalDstAddress)
+ stateSourceObject.Load(18, &so.recvErrEnabled)
+ stateSourceObject.Load(19, &so.errQueue)
+ stateSourceObject.Load(20, &so.bindToDevice)
+ stateSourceObject.Load(21, &so.sendBufferSize)
+ stateSourceObject.Load(22, &so.receiveBufferSize)
+ stateSourceObject.Load(23, &so.linger)
+}
+
+func (l *LocalSockError) StateTypeName() string {
+ return "pkg/tcpip.LocalSockError"
+}
+
+func (l *LocalSockError) StateFields() []string {
+ return []string{
+ "info",
+ }
+}
+
+func (l *LocalSockError) beforeSave() {}
+
+// +checklocksignore
+func (l *LocalSockError) StateSave(stateSinkObject state.Sink) {
+ l.beforeSave()
+ stateSinkObject.Save(0, &l.info)
+}
+
+func (l *LocalSockError) afterLoad() {}
+
+// +checklocksignore
+func (l *LocalSockError) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &l.info)
+}
+
+func (s *SockError) StateTypeName() string {
+ return "pkg/tcpip.SockError"
+}
+
+func (s *SockError) StateFields() []string {
+ return []string{
+ "sockErrorEntry",
+ "Err",
+ "Cause",
+ "Payload",
+ "Dst",
+ "Offender",
+ "NetProto",
+ }
+}
+
+func (s *SockError) beforeSave() {}
+
+// +checklocksignore
+func (s *SockError) StateSave(stateSinkObject state.Sink) {
+ s.beforeSave()
+ stateSinkObject.Save(0, &s.sockErrorEntry)
+ stateSinkObject.Save(1, &s.Err)
+ stateSinkObject.Save(2, &s.Cause)
+ stateSinkObject.Save(3, &s.Payload)
+ stateSinkObject.Save(4, &s.Dst)
+ stateSinkObject.Save(5, &s.Offender)
+ stateSinkObject.Save(6, &s.NetProto)
+}
+
+func (s *SockError) afterLoad() {}
+
+// +checklocksignore
+func (s *SockError) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &s.sockErrorEntry)
+ stateSourceObject.Load(1, &s.Err)
+ stateSourceObject.Load(2, &s.Cause)
+ stateSourceObject.Load(3, &s.Payload)
+ stateSourceObject.Load(4, &s.Dst)
+ stateSourceObject.Load(5, &s.Offender)
+ stateSourceObject.Load(6, &s.NetProto)
+}
+
+func (s *stdClock) StateTypeName() string {
+ return "pkg/tcpip.stdClock"
+}
+
+func (s *stdClock) StateFields() []string {
+ return []string{
+ "maxMonotonic",
+ }
+}
+
+func (s *stdClock) beforeSave() {}
+
+// +checklocksignore
+func (s *stdClock) StateSave(stateSinkObject state.Sink) {
+ s.beforeSave()
+ stateSinkObject.Save(0, &s.maxMonotonic)
+}
+
+// +checklocksignore
+func (s *stdClock) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &s.maxMonotonic)
+ stateSourceObject.AfterLoad(s.afterLoad)
+}
+
+func (mt *MonotonicTime) StateTypeName() string {
+ return "pkg/tcpip.MonotonicTime"
+}
+
+func (mt *MonotonicTime) StateFields() []string {
+ return []string{
+ "nanoseconds",
+ }
+}
+
+func (mt *MonotonicTime) beforeSave() {}
+
+// +checklocksignore
+func (mt *MonotonicTime) StateSave(stateSinkObject state.Sink) {
+ mt.beforeSave()
+ stateSinkObject.Save(0, &mt.nanoseconds)
+}
+
+func (mt *MonotonicTime) afterLoad() {}
+
+// +checklocksignore
+func (mt *MonotonicTime) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &mt.nanoseconds)
+}
+
+func (f *FullAddress) StateTypeName() string {
+ return "pkg/tcpip.FullAddress"
+}
+
+func (f *FullAddress) StateFields() []string {
+ return []string{
+ "NIC",
+ "Addr",
+ "Port",
+ }
+}
+
+func (f *FullAddress) beforeSave() {}
+
+// +checklocksignore
+func (f *FullAddress) StateSave(stateSinkObject state.Sink) {
+ f.beforeSave()
+ stateSinkObject.Save(0, &f.NIC)
+ stateSinkObject.Save(1, &f.Addr)
+ stateSinkObject.Save(2, &f.Port)
+}
+
+func (f *FullAddress) afterLoad() {}
+
+// +checklocksignore
+func (f *FullAddress) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &f.NIC)
+ stateSourceObject.Load(1, &f.Addr)
+ stateSourceObject.Load(2, &f.Port)
+}
+
+func (c *ControlMessages) StateTypeName() string {
+ return "pkg/tcpip.ControlMessages"
+}
+
+func (c *ControlMessages) StateFields() []string {
+ return []string{
+ "HasTimestamp",
+ "Timestamp",
+ "HasInq",
+ "Inq",
+ "HasTOS",
+ "TOS",
+ "HasTClass",
+ "TClass",
+ "HasIPPacketInfo",
+ "PacketInfo",
+ "HasIPv6PacketInfo",
+ "IPv6PacketInfo",
+ "HasOriginalDstAddress",
+ "OriginalDstAddress",
+ "SockErr",
+ }
+}
+
+func (c *ControlMessages) beforeSave() {}
+
+// +checklocksignore
+func (c *ControlMessages) StateSave(stateSinkObject state.Sink) {
+ c.beforeSave()
+ stateSinkObject.Save(0, &c.HasTimestamp)
+ stateSinkObject.Save(1, &c.Timestamp)
+ stateSinkObject.Save(2, &c.HasInq)
+ stateSinkObject.Save(3, &c.Inq)
+ stateSinkObject.Save(4, &c.HasTOS)
+ stateSinkObject.Save(5, &c.TOS)
+ stateSinkObject.Save(6, &c.HasTClass)
+ stateSinkObject.Save(7, &c.TClass)
+ stateSinkObject.Save(8, &c.HasIPPacketInfo)
+ stateSinkObject.Save(9, &c.PacketInfo)
+ stateSinkObject.Save(10, &c.HasIPv6PacketInfo)
+ stateSinkObject.Save(11, &c.IPv6PacketInfo)
+ stateSinkObject.Save(12, &c.HasOriginalDstAddress)
+ stateSinkObject.Save(13, &c.OriginalDstAddress)
+ stateSinkObject.Save(14, &c.SockErr)
+}
+
+func (c *ControlMessages) afterLoad() {}
+
+// +checklocksignore
+func (c *ControlMessages) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &c.HasTimestamp)
+ stateSourceObject.Load(1, &c.Timestamp)
+ stateSourceObject.Load(2, &c.HasInq)
+ stateSourceObject.Load(3, &c.Inq)
+ stateSourceObject.Load(4, &c.HasTOS)
+ stateSourceObject.Load(5, &c.TOS)
+ stateSourceObject.Load(6, &c.HasTClass)
+ stateSourceObject.Load(7, &c.TClass)
+ stateSourceObject.Load(8, &c.HasIPPacketInfo)
+ stateSourceObject.Load(9, &c.PacketInfo)
+ stateSourceObject.Load(10, &c.HasIPv6PacketInfo)
+ stateSourceObject.Load(11, &c.IPv6PacketInfo)
+ stateSourceObject.Load(12, &c.HasOriginalDstAddress)
+ stateSourceObject.Load(13, &c.OriginalDstAddress)
+ stateSourceObject.Load(14, &c.SockErr)
+}
+
+func (l *LinkPacketInfo) StateTypeName() string {
+ return "pkg/tcpip.LinkPacketInfo"
+}
+
+func (l *LinkPacketInfo) StateFields() []string {
+ return []string{
+ "Protocol",
+ "PktType",
+ }
+}
+
+func (l *LinkPacketInfo) beforeSave() {}
+
+// +checklocksignore
+func (l *LinkPacketInfo) StateSave(stateSinkObject state.Sink) {
+ l.beforeSave()
+ stateSinkObject.Save(0, &l.Protocol)
+ stateSinkObject.Save(1, &l.PktType)
+}
+
+func (l *LinkPacketInfo) afterLoad() {}
+
+// +checklocksignore
+func (l *LinkPacketInfo) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &l.Protocol)
+ stateSourceObject.Load(1, &l.PktType)
+}
+
+func (l *LingerOption) StateTypeName() string {
+ return "pkg/tcpip.LingerOption"
+}
+
+func (l *LingerOption) StateFields() []string {
+ return []string{
+ "Enabled",
+ "Timeout",
+ }
+}
+
+func (l *LingerOption) beforeSave() {}
+
+// +checklocksignore
+func (l *LingerOption) StateSave(stateSinkObject state.Sink) {
+ l.beforeSave()
+ stateSinkObject.Save(0, &l.Enabled)
+ stateSinkObject.Save(1, &l.Timeout)
+}
+
+func (l *LingerOption) afterLoad() {}
+
+// +checklocksignore
+func (l *LingerOption) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &l.Enabled)
+ stateSourceObject.Load(1, &l.Timeout)
+}
+
+func (i *IPPacketInfo) StateTypeName() string {
+ return "pkg/tcpip.IPPacketInfo"
+}
+
+func (i *IPPacketInfo) StateFields() []string {
+ return []string{
+ "NIC",
+ "LocalAddr",
+ "DestinationAddr",
+ }
+}
+
+func (i *IPPacketInfo) beforeSave() {}
+
+// +checklocksignore
+func (i *IPPacketInfo) StateSave(stateSinkObject state.Sink) {
+ i.beforeSave()
+ stateSinkObject.Save(0, &i.NIC)
+ stateSinkObject.Save(1, &i.LocalAddr)
+ stateSinkObject.Save(2, &i.DestinationAddr)
+}
+
+func (i *IPPacketInfo) afterLoad() {}
+
+// +checklocksignore
+func (i *IPPacketInfo) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &i.NIC)
+ stateSourceObject.Load(1, &i.LocalAddr)
+ stateSourceObject.Load(2, &i.DestinationAddr)
+}
+
+func (i *IPv6PacketInfo) StateTypeName() string {
+ return "pkg/tcpip.IPv6PacketInfo"
+}
+
+func (i *IPv6PacketInfo) StateFields() []string {
+ return []string{
+ "Addr",
+ "NIC",
+ }
+}
+
+func (i *IPv6PacketInfo) beforeSave() {}
+
+// +checklocksignore
+func (i *IPv6PacketInfo) StateSave(stateSinkObject state.Sink) {
+ i.beforeSave()
+ stateSinkObject.Save(0, &i.Addr)
+ stateSinkObject.Save(1, &i.NIC)
+}
+
+func (i *IPv6PacketInfo) afterLoad() {}
+
+// +checklocksignore
+func (i *IPv6PacketInfo) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &i.Addr)
+ stateSourceObject.Load(1, &i.NIC)
+}
+
+func init() {
+ state.Register((*ErrAborted)(nil))
+ state.Register((*ErrAddressFamilyNotSupported)(nil))
+ state.Register((*ErrAlreadyBound)(nil))
+ state.Register((*ErrAlreadyConnected)(nil))
+ state.Register((*ErrAlreadyConnecting)(nil))
+ state.Register((*ErrBadAddress)(nil))
+ state.Register((*ErrBadBuffer)(nil))
+ state.Register((*ErrBadLocalAddress)(nil))
+ state.Register((*ErrBroadcastDisabled)(nil))
+ state.Register((*ErrClosedForReceive)(nil))
+ state.Register((*ErrClosedForSend)(nil))
+ state.Register((*ErrConnectStarted)(nil))
+ state.Register((*ErrConnectionAborted)(nil))
+ state.Register((*ErrConnectionRefused)(nil))
+ state.Register((*ErrConnectionReset)(nil))
+ state.Register((*ErrDestinationRequired)(nil))
+ state.Register((*ErrDuplicateAddress)(nil))
+ state.Register((*ErrDuplicateNICID)(nil))
+ state.Register((*ErrInvalidEndpointState)(nil))
+ state.Register((*ErrInvalidOptionValue)(nil))
+ state.Register((*ErrInvalidPortRange)(nil))
+ state.Register((*ErrMalformedHeader)(nil))
+ state.Register((*ErrMessageTooLong)(nil))
+ state.Register((*ErrNetworkUnreachable)(nil))
+ state.Register((*ErrNoBufferSpace)(nil))
+ state.Register((*ErrNoPortAvailable)(nil))
+ state.Register((*ErrNoRoute)(nil))
+ state.Register((*ErrNoSuchFile)(nil))
+ state.Register((*ErrNotConnected)(nil))
+ state.Register((*ErrNotPermitted)(nil))
+ state.Register((*ErrNotSupported)(nil))
+ state.Register((*ErrPortInUse)(nil))
+ state.Register((*ErrQueueSizeNotSupported)(nil))
+ state.Register((*ErrTimeout)(nil))
+ state.Register((*ErrUnknownDevice)(nil))
+ state.Register((*ErrUnknownNICID)(nil))
+ state.Register((*ErrUnknownProtocol)(nil))
+ state.Register((*ErrUnknownProtocolOption)(nil))
+ state.Register((*ErrWouldBlock)(nil))
+ state.Register((*sockErrorList)(nil))
+ state.Register((*sockErrorEntry)(nil))
+ state.Register((*SocketOptions)(nil))
+ state.Register((*LocalSockError)(nil))
+ state.Register((*SockError)(nil))
+ state.Register((*stdClock)(nil))
+ state.Register((*MonotonicTime)(nil))
+ state.Register((*FullAddress)(nil))
+ state.Register((*ControlMessages)(nil))
+ state.Register((*LinkPacketInfo)(nil))
+ state.Register((*LingerOption)(nil))
+ state.Register((*IPPacketInfo)(nil))
+ state.Register((*IPv6PacketInfo)(nil))
+}
diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go
deleted file mode 100644
index c96ae2f02..000000000
--- a/pkg/tcpip/tcpip_test.go
+++ /dev/null
@@ -1,325 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcpip
-
-import (
- "bytes"
- "fmt"
- "io"
- "net"
- "testing"
-
- "github.com/google/go-cmp/cmp"
-)
-
-func TestLimitedWriter_Write(t *testing.T) {
- var b bytes.Buffer
- l := LimitedWriter{
- W: &b,
- N: 5,
- }
- if n, err := l.Write([]byte{0, 1, 2}); err != nil {
- t.Errorf("got l.Write(3/5) = (_, %s), want nil", err)
- } else if n != 3 {
- t.Errorf("got l.Write(3/5) = (%d, _), want 3", n)
- }
- if n, err := l.Write([]byte{3, 4, 5}); err != io.ErrShortWrite {
- t.Errorf("got l.Write(3/2) = (_, %s), want io.ErrShortWrite", err)
- } else if n != 2 {
- t.Errorf("got l.Write(3/2) = (%d, _), want 2", n)
- }
- if l.N != 0 {
- t.Errorf("got l.N = %d, want 0", l.N)
- }
- l.N = 1
- if n, err := l.Write([]byte{5}); err != nil {
- t.Errorf("got l.Write(1/1) = (_, %s), want nil", err)
- } else if n != 1 {
- t.Errorf("got l.Write(1/1) = (%d, _), want 1", n)
- }
- if diff := cmp.Diff(b.Bytes(), []byte{0, 1, 2, 3, 4, 5}); diff != "" {
- t.Errorf("%T wrote incorrect data: (-want +got):\n%s", l, diff)
- }
-}
-
-func TestSubnetContains(t *testing.T) {
- tests := []struct {
- s Address
- m AddressMask
- a Address
- want bool
- }{
- {"\xa0", "\xf0", "\x90", false},
- {"\xa0", "\xf0", "\xa0", true},
- {"\xa0", "\xf0", "\xa5", true},
- {"\xa0", "\xf0", "\xaf", true},
- {"\xa0", "\xf0", "\xb0", false},
- {"\xa0", "\xf0", "", false},
- {"\xa0", "\xf0", "\xa0\x00", false},
- {"\xc2\x80", "\xff\xf0", "\xc2\x80", true},
- {"\xc2\x80", "\xff\xf0", "\xc2\x00", false},
- {"\xc2\x00", "\xff\xf0", "\xc2\x00", true},
- {"\xc2\x00", "\xff\xf0", "\xc2\x80", false},
- }
- for _, tt := range tests {
- s, err := NewSubnet(tt.s, tt.m)
- if err != nil {
- t.Errorf("NewSubnet(%v, %v) = %v", tt.s, tt.m, err)
- continue
- }
- if got := s.Contains(tt.a); got != tt.want {
- t.Errorf("Subnet(%v).Contains(%v) = %v, want %v", s, tt.a, got, tt.want)
- }
- }
-}
-
-func TestSubnetBits(t *testing.T) {
- tests := []struct {
- a AddressMask
- want1 int
- want0 int
- }{
- {"\x00", 0, 8},
- {"\x00\x00", 0, 16},
- {"\x36", 0, 8},
- {"\x5c", 0, 8},
- {"\x5c\x5c", 0, 16},
- {"\x5c\x36", 0, 16},
- {"\x36\x5c", 0, 16},
- {"\x36\x36", 0, 16},
- {"\xff", 8, 0},
- {"\xff\xff", 16, 0},
- }
- for _, tt := range tests {
- s := &Subnet{mask: tt.a}
- got1, got0 := s.Bits()
- if got1 != tt.want1 || got0 != tt.want0 {
- t.Errorf("Subnet{mask: %x}.Bits() = %d, %d, want %d, %d", tt.a, got1, got0, tt.want1, tt.want0)
- }
- }
-}
-
-func TestSubnetPrefix(t *testing.T) {
- tests := []struct {
- a AddressMask
- want int
- }{
- {"\x00", 0},
- {"\x00\x00", 0},
- {"\x36", 0},
- {"\x86", 1},
- {"\xc5", 2},
- {"\xff\x00", 8},
- {"\xff\x36", 8},
- {"\xff\x8c", 9},
- {"\xff\xc8", 10},
- {"\xff", 8},
- {"\xff\xff", 16},
- }
- for _, tt := range tests {
- s := &Subnet{mask: tt.a}
- got := s.Prefix()
- if got != tt.want {
- t.Errorf("Subnet{mask: %x}.Bits() = %d want %d", tt.a, got, tt.want)
- }
- }
-}
-
-func TestSubnetCreation(t *testing.T) {
- tests := []struct {
- a Address
- m AddressMask
- want error
- }{
- {"\xa0", "\xf0", nil},
- {"\xa0\xa0", "\xf0", errSubnetLengthMismatch},
- {"\xaa", "\xf0", errSubnetAddressMasked},
- {"", "", nil},
- }
- for _, tt := range tests {
- if _, err := NewSubnet(tt.a, tt.m); err != tt.want {
- t.Errorf("NewSubnet(%v, %v) = %v, want %v", tt.a, tt.m, err, tt.want)
- }
- }
-}
-
-func TestAddressString(t *testing.T) {
- for _, want := range []string{
- // Taken from stdlib.
- "2001:db8::123:12:1",
- "2001:db8::1",
- "2001:db8:0:1:0:1:0:1",
- "2001:db8:1:0:1:0:1:0",
- "2001::1:0:0:1",
- "2001:db8:0:0:1::",
- "2001:db8::1:0:0:1",
- "2001:db8::a:b:c:d",
-
- // Leading zeros.
- "::1",
- // Trailing zeros.
- "8::",
- // No zeros.
- "1:1:1:1:1:1:1:1",
- // Longer sequence is after other zeros, but not at the end.
- "1:0:0:1::1",
- // Longer sequence is at the beginning, shorter sequence is at
- // the end.
- "::1:1:1:0:0",
- // Longer sequence is not at the beginning, shorter sequence is
- // at the end.
- "1::1:1:0:0",
- // Longer sequence is at the beginning, shorter sequence is not
- // at the end.
- "::1:1:0:0:1",
- // Neither sequence is at an end, longer is after shorter.
- "1:0:0:1::1",
- // Shorter sequence is at the beginning, longer sequence is not
- // at the end.
- "0:0:1:1::1",
- // Shorter sequence is at the beginning, longer sequence is at
- // the end.
- "0:0:1:1:1::",
- // Short sequences at both ends, longer one in the middle.
- "0:1:1::1:1:0",
- // Short sequences at both ends, longer one in the middle.
- "0:1::1:0:0",
- // Short sequences at both ends, longer one in the middle.
- "0:0:1::1:0",
- // Longer sequence surrounded by shorter sequences, but none at
- // the end.
- "1:0:1::1:0:1",
- } {
- addr := Address(net.ParseIP(want))
- if got := addr.String(); got != want {
- t.Errorf("Address(%x).String() = '%s', want = '%s'", addr, got, want)
- }
- }
-}
-
-func TestAddressWithPrefixSubnet(t *testing.T) {
- tests := []struct {
- addr Address
- prefixLen int
- subnetAddr Address
- subnetMask AddressMask
- }{
- {"\xaa\x55\x33\x42", -1, "\x00\x00\x00\x00", "\x00\x00\x00\x00"},
- {"\xaa\x55\x33\x42", 0, "\x00\x00\x00\x00", "\x00\x00\x00\x00"},
- {"\xaa\x55\x33\x42", 1, "\x80\x00\x00\x00", "\x80\x00\x00\x00"},
- {"\xaa\x55\x33\x42", 7, "\xaa\x00\x00\x00", "\xfe\x00\x00\x00"},
- {"\xaa\x55\x33\x42", 8, "\xaa\x00\x00\x00", "\xff\x00\x00\x00"},
- {"\xaa\x55\x33\x42", 24, "\xaa\x55\x33\x00", "\xff\xff\xff\x00"},
- {"\xaa\x55\x33\x42", 31, "\xaa\x55\x33\x42", "\xff\xff\xff\xfe"},
- {"\xaa\x55\x33\x42", 32, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"},
- {"\xaa\x55\x33\x42", 33, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"},
- }
- for _, tt := range tests {
- ap := AddressWithPrefix{Address: tt.addr, PrefixLen: tt.prefixLen}
- gotSubnet := ap.Subnet()
- wantSubnet, err := NewSubnet(tt.subnetAddr, tt.subnetMask)
- if err != nil {
- t.Errorf("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err)
- continue
- }
- if gotSubnet != wantSubnet {
- t.Errorf("got subnet = %q, want = %q", gotSubnet, wantSubnet)
- }
- }
-}
-
-func TestAddressUnspecified(t *testing.T) {
- tests := []struct {
- addr Address
- unspecified bool
- }{
- {
- addr: "",
- unspecified: true,
- },
- {
- addr: "\x00",
- unspecified: true,
- },
- {
- addr: "\x01",
- unspecified: false,
- },
- {
- addr: "\x00\x00",
- unspecified: true,
- },
- {
- addr: "\x01\x00",
- unspecified: false,
- },
- {
- addr: "\x00\x01",
- unspecified: false,
- },
- {
- addr: "\x01\x01",
- unspecified: false,
- },
- }
-
- for _, test := range tests {
- t.Run(fmt.Sprintf("addr=%s", test.addr), func(t *testing.T) {
- if got := test.addr.Unspecified(); got != test.unspecified {
- t.Fatalf("got addr.Unspecified() = %t, want = %t", got, test.unspecified)
- }
- })
- }
-}
-
-func TestAddressMatchingPrefix(t *testing.T) {
- tests := []struct {
- addrA Address
- addrB Address
- prefix uint8
- }{
- {
- addrA: "\x01\x01",
- addrB: "\x01\x01",
- prefix: 16,
- },
- {
- addrA: "\x01\x01",
- addrB: "\x01\x00",
- prefix: 15,
- },
- {
- addrA: "\x01\x01",
- addrB: "\x81\x00",
- prefix: 0,
- },
- {
- addrA: "\x01\x01",
- addrB: "\x01\x80",
- prefix: 8,
- },
- {
- addrA: "\x01\x01",
- addrB: "\x02\x80",
- prefix: 6,
- },
- }
-
- for _, test := range tests {
- if got := test.addrA.MatchingPrefix(test.addrB); got != test.prefix {
- t.Errorf("got (%s).MatchingPrefix(%s) = %d, want = %d", test.addrA, test.addrB, got, test.prefix)
- }
- }
-}
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
deleted file mode 100644
index 181ef799e..000000000
--- a/pkg/tcpip/tests/integration/BUILD
+++ /dev/null
@@ -1,141 +0,0 @@
-load("//tools:defs.bzl", "go_test")
-
-package(licenses = ["notice"])
-
-go_test(
- name = "forward_test",
- size = "small",
- srcs = ["forward_test.go"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/network/arp",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/tests/utils",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/tcp",
- "//pkg/tcpip/transport/udp",
- "//pkg/waiter",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
-
-go_test(
- name = "iptables_test",
- size = "small",
- srcs = ["iptables_test.go"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/tests/utils",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/udp",
- ],
-)
-
-go_test(
- name = "link_resolution_test",
- size = "small",
- srcs = ["link_resolution_test.go"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/faketime",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/pipe",
- "//pkg/tcpip/network/arp",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/tests/utils",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/icmp",
- "//pkg/tcpip/transport/tcp",
- "//pkg/tcpip/transport/udp",
- "//pkg/waiter",
- "@com_github_google_go_cmp//cmp:go_default_library",
- "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
- ],
-)
-
-go_test(
- name = "loopback_test",
- size = "small",
- srcs = ["loopback_test.go"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/loopback",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/tests/utils",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/icmp",
- "//pkg/tcpip/transport/tcp",
- "//pkg/tcpip/transport/udp",
- "//pkg/waiter",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
-
-go_test(
- name = "multicast_broadcast_test",
- size = "small",
- srcs = ["multicast_broadcast_test.go"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/loopback",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/tests/utils",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/icmp",
- "//pkg/tcpip/transport/udp",
- "//pkg/waiter",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
-
-go_test(
- name = "route_test",
- size = "small",
- srcs = ["route_test.go"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/loopback",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/tests/utils",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/icmp",
- "//pkg/tcpip/transport/udp",
- "//pkg/waiter",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
deleted file mode 100644
index 6e1d4720d..000000000
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ /dev/null
@@ -1,698 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package forward_test
-
-import (
- "bytes"
- "fmt"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/network/arp"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/tests/utils"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const ttl = 64
-
-var (
- ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10")
- ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a")
-)
-
-func rxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
- utils.RxICMPv4EchoRequest(e, src, dst, ttl)
-}
-
-func rxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
- utils.RxICMPv6EchoRequest(e, src, dst, ttl)
-}
-
-func forwardedICMPv4EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
- checker.IPv4(t, b,
- checker.SrcAddr(src),
- checker.DstAddr(dst),
- checker.TTL(ttl-1),
- checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4Echo)))
-}
-
-func forwardedICMPv6EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
- checker.IPv6(t, b,
- checker.SrcAddr(src),
- checker.DstAddr(dst),
- checker.TTL(ttl-1),
- checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6EchoRequest)))
-}
-
-func TestForwarding(t *testing.T) {
- const listenPort = 8080
-
- type endpointAndAddresses struct {
- serverEP tcpip.Endpoint
- serverAddr tcpip.Address
- serverReadableCH chan struct{}
-
- clientEP tcpip.Endpoint
- clientAddr tcpip.Address
- clientReadableCH chan struct{}
- }
-
- newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) {
- t.Helper()
- var wq waiter.Queue
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- ep, err := s.NewEndpoint(transProto, netProto, &wq)
- if err != nil {
- t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err)
- }
-
- t.Cleanup(func() {
- wq.EventUnregister(&we)
- })
-
- return ep, ch
- }
-
- tests := []struct {
- name string
- epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses
- }{
- {
- name: "IPv4 host1 server with host2 client",
- epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
- ep1, ep1WECH := newEP(t, host1Stack, proto, ipv4.ProtocolNumber)
- ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
- return endpointAndAddresses{
- serverEP: ep1,
- serverAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
- serverReadableCH: ep1WECH,
-
- clientEP: ep2,
- clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address,
- clientReadableCH: ep2WECH,
- }
- },
- },
- {
- name: "IPv6 host2 server with host1 client",
- epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
- ep1, ep1WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
- ep2, ep2WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber)
- return endpointAndAddresses{
- serverEP: ep1,
- serverAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address,
- serverReadableCH: ep1WECH,
-
- clientEP: ep2,
- clientAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
- clientReadableCH: ep2WECH,
- }
- },
- },
- {
- name: "IPv4 host2 server with routerNIC1 client",
- epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
- ep1, ep1WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
- ep2, ep2WECH := newEP(t, routerStack, proto, ipv4.ProtocolNumber)
- return endpointAndAddresses{
- serverEP: ep1,
- serverAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address,
- serverReadableCH: ep1WECH,
-
- clientEP: ep2,
- clientAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
- clientReadableCH: ep2WECH,
- }
- },
- },
- {
- name: "IPv6 routerNIC2 server with host1 client",
- epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
- ep1, ep1WECH := newEP(t, routerStack, proto, ipv6.ProtocolNumber)
- ep2, ep2WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber)
- return endpointAndAddresses{
- serverEP: ep1,
- serverAddr: utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
- serverReadableCH: ep1WECH,
-
- clientEP: ep2,
- clientAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
- clientReadableCH: ep2WECH,
- }
- },
- },
- }
-
- subTests := []struct {
- name string
- proto tcpip.TransportProtocolNumber
- expectedConnectErr tcpip.Error
- setupServer func(t *testing.T, ep tcpip.Endpoint)
- setupServerConn func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{})
- needRemoteAddr bool
- }{
- {
- name: "UDP",
- proto: udp.ProtocolNumber,
- expectedConnectErr: nil,
- setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
- t.Helper()
-
- if err := ep.Connect(clientAddr); err != nil {
- t.Fatalf("ep.Connect(%#v): %s", clientAddr, err)
- }
- return nil, nil
- },
- needRemoteAddr: true,
- },
- {
- name: "TCP",
- proto: tcp.ProtocolNumber,
- expectedConnectErr: &tcpip.ErrConnectStarted{},
- setupServer: func(t *testing.T, ep tcpip.Endpoint) {
- t.Helper()
-
- if err := ep.Listen(1); err != nil {
- t.Fatalf("ep.Listen(1): %s", err)
- }
- },
- setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
- t.Helper()
-
- var addr tcpip.FullAddress
- for {
- newEP, wq, err := ep.Accept(&addr)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
- <-ch
- continue
- }
- if err != nil {
- t.Fatalf("ep.Accept(_): %s", err)
- }
- if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath(
- "NIC",
- )); diff != "" {
- t.Errorf("accepted address mismatch (-want +got):\n%s", diff)
- }
-
- we, newCH := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- return newEP, newCH
- }
- },
- needRemoteAddr: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- for _, subTest := range subTests {
- t.Run(subTest.name, func(t *testing.T) {
- stackOpts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
- }
-
- host1Stack := stack.New(stackOpts)
- routerStack := stack.New(stackOpts)
- host2Stack := stack.New(stackOpts)
- utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack)
-
- epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto)
- defer epsAndAddrs.serverEP.Close()
- defer epsAndAddrs.clientEP.Close()
-
- serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort}
- if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil {
- t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err)
- }
- clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr}
- if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil {
- t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
- }
-
- if subTest.setupServer != nil {
- subTest.setupServer(t, epsAndAddrs.serverEP)
- }
- {
- err := epsAndAddrs.clientEP.Connect(serverAddr)
- if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" {
- t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff)
- }
- }
- if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil {
- t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err)
- } else {
- clientAddr = addr
- clientAddr.NIC = 0
- }
-
- serverEP := epsAndAddrs.serverEP
- serverCH := epsAndAddrs.serverReadableCH
- if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, clientAddr); ep != nil {
- defer ep.Close()
- serverEP = ep
- serverCH = ch
- }
-
- write := func(ep tcpip.Endpoint, data []byte) {
- t.Helper()
-
- var r bytes.Reader
- r.Reset(data)
- var wOpts tcpip.WriteOptions
- n, err := ep.Write(&r, wOpts)
- if err != nil {
- t.Fatalf("ep.Write(_, %#v): %s", wOpts, err)
- }
- if want := int64(len(data)); n != want {
- t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want)
- }
- }
-
- data := []byte{1, 2, 3, 4}
- write(epsAndAddrs.clientEP, data)
-
- read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) {
- t.Helper()
-
- var buf bytes.Buffer
- var res tcpip.ReadResult
- for {
- var err tcpip.Error
- opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr}
- res, err = ep.Read(&buf, opts)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
- <-ch
- continue
- }
- if err != nil {
- t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
- }
- break
- }
-
- readResult := tcpip.ReadResult{
- Count: len(data),
- Total: len(data),
- }
- if subTest.needRemoteAddr {
- readResult.RemoteAddr = expectedFrom
- }
- if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath(
- "ControlMessages",
- "RemoteAddr.NIC",
- )); diff != "" {
- t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(buf.Bytes(), data); diff != "" {
- t.Errorf("received data mismatch (-want +got):\n%s", diff)
- }
-
- if t.Failed() {
- t.FailNow()
- }
- }
-
- read(serverCH, serverEP, data, clientAddr)
-
- data = []byte{5, 6, 7, 8, 9, 10, 11, 12}
- write(serverEP, data)
- read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr)
- })
- }
- })
- }
-}
-
-func TestMulticastForwarding(t *testing.T) {
- const (
- nicID1 = 1
- nicID2 = 2
- )
-
- var (
- ipv4LinkLocalUnicastAddr = testutil.MustParse4("169.254.0.10")
- ipv4LinkLocalMulticastAddr = testutil.MustParse4("224.0.0.10")
-
- ipv6LinkLocalUnicastAddr = testutil.MustParse6("fe80::a")
- ipv6LinkLocalMulticastAddr = testutil.MustParse6("ff02::a")
- )
-
- tests := []struct {
- name string
- srcAddr, dstAddr tcpip.Address
- rx func(*channel.Endpoint, tcpip.Address, tcpip.Address)
- expectForward bool
- checker func(*testing.T, []byte)
- }{
- {
- name: "IPv4 link-local multicast destination",
- srcAddr: utils.RemoteIPv4Addr,
- dstAddr: ipv4LinkLocalMulticastAddr,
- rx: rxICMPv4EchoRequest,
- expectForward: false,
- },
- {
- name: "IPv4 link-local source",
- srcAddr: ipv4LinkLocalUnicastAddr,
- dstAddr: utils.RemoteIPv4Addr,
- rx: rxICMPv4EchoRequest,
- expectForward: false,
- },
- {
- name: "IPv4 link-local destination",
- srcAddr: utils.RemoteIPv4Addr,
- dstAddr: ipv4LinkLocalUnicastAddr,
- rx: rxICMPv4EchoRequest,
- expectForward: false,
- },
- {
- name: "IPv4 non-link-local unicast",
- srcAddr: utils.RemoteIPv4Addr,
- dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
- rx: rxICMPv4EchoRequest,
- expectForward: true,
- checker: func(t *testing.T, b []byte) {
- forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
- },
- },
- {
- name: "IPv4 non-link-local multicast",
- srcAddr: utils.RemoteIPv4Addr,
- dstAddr: ipv4GlobalMulticastAddr,
- rx: rxICMPv4EchoRequest,
- expectForward: true,
- checker: func(t *testing.T, b []byte) {
- forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr)
- },
- },
-
- {
- name: "IPv6 link-local multicast destination",
- srcAddr: utils.RemoteIPv6Addr,
- dstAddr: ipv6LinkLocalMulticastAddr,
- rx: rxICMPv6EchoRequest,
- expectForward: false,
- },
- {
- name: "IPv6 link-local source",
- srcAddr: ipv6LinkLocalUnicastAddr,
- dstAddr: utils.RemoteIPv6Addr,
- rx: rxICMPv6EchoRequest,
- expectForward: false,
- },
- {
- name: "IPv6 link-local destination",
- srcAddr: utils.RemoteIPv6Addr,
- dstAddr: ipv6LinkLocalUnicastAddr,
- rx: rxICMPv6EchoRequest,
- expectForward: false,
- },
- {
- name: "IPv6 non-link-local unicast",
- srcAddr: utils.RemoteIPv6Addr,
- dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
- rx: rxICMPv6EchoRequest,
- expectForward: true,
- checker: func(t *testing.T, b []byte) {
- forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
- },
- },
- {
- name: "IPv6 non-link-local multicast",
- srcAddr: utils.RemoteIPv6Addr,
- dstAddr: ipv6GlobalMulticastAddr,
- rx: rxICMPv6EchoRequest,
- expectForward: true,
- checker: func(t *testing.T, b []byte) {
- forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr)
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
-
- e1 := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID1, e1); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err)
- }
-
- e2 := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID2, e2); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err)
- }
-
- protocolAddrV4 := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: utils.Ipv4Addr,
- }
- if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
- }
- protocolAddrV6 := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: utils.Ipv6Addr,
- }
- if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
- }
-
- if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
- t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err)
- }
- if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
- t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: nicID2,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: nicID2,
- },
- })
-
- test.rx(e1, test.srcAddr, test.dstAddr)
-
- p, ok := e2.Read()
- if ok != test.expectForward {
- t.Fatalf("got e2.Read() = (%#v, %t), want = (_, %t)", p, ok, test.expectForward)
- }
-
- if test.expectForward {
- test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
- }
- })
- }
-}
-
-func TestPerInterfaceForwarding(t *testing.T) {
- const (
- nicID1 = 1
- nicID2 = 2
- )
-
- tests := []struct {
- name string
- srcAddr, dstAddr tcpip.Address
- rx func(*channel.Endpoint, tcpip.Address, tcpip.Address)
- checker func(*testing.T, []byte)
- }{
- {
- name: "IPv4 unicast",
- srcAddr: utils.RemoteIPv4Addr,
- dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
- rx: rxICMPv4EchoRequest,
- checker: func(t *testing.T, b []byte) {
- forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
- },
- },
- {
- name: "IPv4 multicast",
- srcAddr: utils.RemoteIPv4Addr,
- dstAddr: ipv4GlobalMulticastAddr,
- rx: rxICMPv4EchoRequest,
- checker: func(t *testing.T, b []byte) {
- forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr)
- },
- },
-
- {
- name: "IPv6 unicast",
- srcAddr: utils.RemoteIPv6Addr,
- dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
- rx: rxICMPv6EchoRequest,
- checker: func(t *testing.T, b []byte) {
- forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
- },
- },
- {
- name: "IPv6 multicast",
- srcAddr: utils.RemoteIPv6Addr,
- dstAddr: ipv6GlobalMulticastAddr,
- rx: rxICMPv6EchoRequest,
- checker: func(t *testing.T, b []byte) {
- forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr)
- },
- },
- }
-
- netProtos := [...]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber, ipv6.ProtocolNumber}
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{
- // ARP is not used in this test but it is a network protocol that does
- // not support forwarding. We install the protocol to make sure that
- // forwarding information for a NIC is only reported for network
- // protocols that support forwarding.
- arp.NewProtocol,
-
- ipv4.NewProtocol,
- ipv6.NewProtocol,
- },
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
-
- e1 := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID1, e1); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err)
- }
-
- e2 := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID2, e2); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err)
- }
-
- for _, add := range [...]struct {
- nicID tcpip.NICID
- addr tcpip.ProtocolAddress
- }{
- {
- nicID: nicID1,
- addr: utils.RouterNIC1IPv4Addr,
- },
- {
- nicID: nicID1,
- addr: utils.RouterNIC1IPv6Addr,
- },
- {
- nicID: nicID2,
- addr: utils.RouterNIC2IPv4Addr,
- },
- {
- nicID: nicID2,
- addr: utils.RouterNIC2IPv6Addr,
- },
- } {
- if err := s.AddProtocolAddress(add.nicID, add.addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", add.nicID, add.addr, err)
- }
- }
-
- // Only enable forwarding on NIC1 and make sure that only packets arriving
- // on NIC1 are forwarded.
- for _, netProto := range netProtos {
- if err := s.SetNICForwarding(nicID1, netProto, true); err != nil {
- t.Fatalf("s.SetNICForwarding(%d, %d, true): %s", nicID1, netProtos, err)
- }
- }
-
- nicsInfo := s.NICInfo()
- for _, subTest := range [...]struct {
- nicID tcpip.NICID
- nicEP *channel.Endpoint
- otherNICID tcpip.NICID
- otherNICEP *channel.Endpoint
- expectForwarding bool
- }{
- {
- nicID: nicID1,
- nicEP: e1,
- otherNICID: nicID2,
- otherNICEP: e2,
- expectForwarding: true,
- },
- {
- nicID: nicID2,
- nicEP: e2,
- otherNICID: nicID2,
- otherNICEP: e1,
- expectForwarding: false,
- },
- } {
- t.Run(fmt.Sprintf("Packet arriving at NIC%d", subTest.nicID), func(t *testing.T) {
- nicInfo, ok := nicsInfo[subTest.nicID]
- if !ok {
- t.Errorf("expected NIC info for NIC %d; got = %#v", subTest.nicID, nicsInfo)
- } else {
- forwarding := make(map[tcpip.NetworkProtocolNumber]bool)
- for _, netProto := range netProtos {
- forwarding[netProto] = subTest.expectForwarding
- }
-
- if diff := cmp.Diff(forwarding, nicInfo.Forwarding); diff != "" {
- t.Errorf("nicsInfo[%d].Forwarding mismatch (-want +got):\n%s", subTest.nicID, diff)
- }
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: subTest.otherNICID,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: subTest.otherNICID,
- },
- })
-
- test.rx(subTest.nicEP, test.srcAddr, test.dstAddr)
- if p, ok := subTest.nicEP.Read(); ok {
- t.Errorf("unexpectedly got a response from the interface the packet arrived on: %#v", p)
- }
- if p, ok := subTest.otherNICEP.Read(); ok != subTest.expectForwarding {
- t.Errorf("got otherNICEP.Read() = (%#v, %t), want = (_, %t)", p, ok, subTest.expectForwarding)
- } else if subTest.expectForwarding {
- test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
- }
- })
- }
- })
- }
-}
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
deleted file mode 100644
index 28b49c6be..000000000
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ /dev/null
@@ -1,1158 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package iptables_test
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/tests/utils"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
-)
-
-type inputIfNameMatcher struct {
- name string
-}
-
-var _ stack.Matcher = (*inputIfNameMatcher)(nil)
-
-func (*inputIfNameMatcher) Name() string {
- return "inputIfNameMatcher"
-}
-
-func (im *inputIfNameMatcher) Match(hook stack.Hook, _ *stack.PacketBuffer, inNicName, _ string) (bool, bool) {
- return (hook == stack.Input && im.name != "" && im.name == inNicName), false
-}
-
-const (
- nicID = 1
- nicName = "nic1"
- anotherNicName = "nic2"
- linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- srcAddrV4 = tcpip.Address("\x0a\x00\x00\x01")
- dstAddrV4 = tcpip.Address("\x0a\x00\x00\x02")
- srcAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- dstAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- payloadSize = 20
-)
-
-func genStackV6(t *testing.T) (*stack.Stack, *channel.Endpoint) {
- t.Helper()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- })
- e := channel.New(0, header.IPv6MinimumMTU, linkAddr)
- nicOpts := stack.NICOptions{Name: nicName}
- if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
- t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err)
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: dstAddrV6.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- return s, e
-}
-
-func genStackV4(t *testing.T) (*stack.Stack, *channel.Endpoint) {
- t.Helper()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- })
- e := channel.New(0, header.IPv4MinimumMTU, linkAddr)
- nicOpts := stack.NICOptions{Name: nicName}
- if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
- t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err)
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: dstAddrV4.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- return s, e
-}
-
-func genPacketV6() *stack.PacketBuffer {
- pktSize := header.IPv6MinimumSize + payloadSize
- hdr := buffer.NewPrependable(pktSize)
- ip := header.IPv6(hdr.Prepend(pktSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: payloadSize,
- TransportProtocol: 99,
- HopLimit: 255,
- SrcAddr: srcAddrV6,
- DstAddr: dstAddrV6,
- })
- vv := hdr.View().ToVectorisedView()
- return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv})
-}
-
-func genPacketV4() *stack.PacketBuffer {
- pktSize := header.IPv4MinimumSize + payloadSize
- hdr := buffer.NewPrependable(pktSize)
- ip := header.IPv4(hdr.Prepend(pktSize))
- ip.Encode(&header.IPv4Fields{
- TOS: 0,
- TotalLength: uint16(pktSize),
- ID: 1,
- Flags: 0,
- FragmentOffset: 16,
- TTL: 48,
- Protocol: 99,
- SrcAddr: srcAddrV4,
- DstAddr: dstAddrV4,
- })
- ip.SetChecksum(0)
- ip.SetChecksum(^ip.CalculateChecksum())
- vv := hdr.View().ToVectorisedView()
- return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv})
-}
-
-func TestIPTablesStatsForInput(t *testing.T) {
- tests := []struct {
- name string
- setupStack func(*testing.T) (*stack.Stack, *channel.Endpoint)
- setupFilter func(*testing.T, *stack.Stack)
- genPacket func() *stack.PacketBuffer
- proto tcpip.NetworkProtocolNumber
- expectReceived int
- expectInputDropped int
- }{
- {
- name: "IPv6 Accept",
- setupStack: genStackV6,
- setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
- genPacket: genPacketV6,
- proto: header.IPv6ProtocolNumber,
- expectReceived: 1,
- expectInputDropped: 0,
- },
- {
- name: "IPv4 Accept",
- setupStack: genStackV4,
- setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
- genPacket: genPacketV4,
- proto: header.IPv4ProtocolNumber,
- expectReceived: 1,
- expectInputDropped: 0,
- },
- {
- name: "IPv6 Drop (input interface matches)",
- setupStack: genStackV6,
- setupFilter: func(t *testing.T, s *stack.Stack) {
- t.Helper()
- ipt := s.IPTables()
- filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Input]
- filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName}
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}}
- // Make sure the packet is not dropped by the next rule.
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
- }
- },
- genPacket: genPacketV6,
- proto: header.IPv6ProtocolNumber,
- expectReceived: 1,
- expectInputDropped: 1,
- },
- {
- name: "IPv4 Drop (input interface matches)",
- setupStack: genStackV4,
- setupFilter: func(t *testing.T, s *stack.Stack) {
- t.Helper()
- ipt := s.IPTables()
- filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Input]
- filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName}
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}}
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
- }
- },
- genPacket: genPacketV4,
- proto: header.IPv4ProtocolNumber,
- expectReceived: 1,
- expectInputDropped: 1,
- },
- {
- name: "IPv6 Accept (input interface does not match)",
- setupStack: genStackV6,
- setupFilter: func(t *testing.T, s *stack.Stack) {
- t.Helper()
- ipt := s.IPTables()
- filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Input]
- filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName}
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
- }
- },
- genPacket: genPacketV6,
- proto: header.IPv6ProtocolNumber,
- expectReceived: 1,
- expectInputDropped: 0,
- },
- {
- name: "IPv4 Accept (input interface does not match)",
- setupStack: genStackV4,
- setupFilter: func(t *testing.T, s *stack.Stack) {
- t.Helper()
- ipt := s.IPTables()
- filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Input]
- filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName}
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
- }
- },
- genPacket: genPacketV4,
- proto: header.IPv4ProtocolNumber,
- expectReceived: 1,
- expectInputDropped: 0,
- },
- {
- name: "IPv6 Drop (input interface does not match but invert is true)",
- setupStack: genStackV6,
- setupFilter: func(t *testing.T, s *stack.Stack) {
- t.Helper()
- ipt := s.IPTables()
- filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Input]
- filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{
- InputInterface: anotherNicName,
- InputInterfaceInvert: true,
- }
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
- }
- },
- genPacket: genPacketV6,
- proto: header.IPv6ProtocolNumber,
- expectReceived: 1,
- expectInputDropped: 1,
- },
- {
- name: "IPv4 Drop (input interface does not match but invert is true)",
- setupStack: genStackV4,
- setupFilter: func(t *testing.T, s *stack.Stack) {
- t.Helper()
- ipt := s.IPTables()
- filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Input]
- filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{
- InputInterface: anotherNicName,
- InputInterfaceInvert: true,
- }
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
- }
- },
- genPacket: genPacketV4,
- proto: header.IPv4ProtocolNumber,
- expectReceived: 1,
- expectInputDropped: 1,
- },
- {
- name: "IPv6 Accept (input interface does not match using a matcher)",
- setupStack: genStackV6,
- setupFilter: func(t *testing.T, s *stack.Stack) {
- t.Helper()
- ipt := s.IPTables()
- filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Input]
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}}
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
- }
- },
- genPacket: genPacketV6,
- proto: header.IPv6ProtocolNumber,
- expectReceived: 1,
- expectInputDropped: 0,
- },
- {
- name: "IPv4 Accept (input interface does not match using a matcher)",
- setupStack: genStackV4,
- setupFilter: func(t *testing.T, s *stack.Stack) {
- t.Helper()
- ipt := s.IPTables()
- filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
- ruleIdx := filter.BuiltinChains[stack.Input]
- filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}}
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
- }
- },
- genPacket: genPacketV4,
- proto: header.IPv4ProtocolNumber,
- expectReceived: 1,
- expectInputDropped: 0,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s, e := test.setupStack(t)
- test.setupFilter(t, s)
- e.InjectInbound(test.proto, test.genPacket())
-
- if got := int(s.Stats().IP.PacketsReceived.Value()); got != test.expectReceived {
- t.Errorf("got PacketReceived = %d, want = %d", got, test.expectReceived)
- }
- if got := int(s.Stats().IP.IPTablesInputDropped.Value()); got != test.expectInputDropped {
- t.Errorf("got IPTablesInputDropped = %d, want = %d", got, test.expectInputDropped)
- }
- })
- }
-}
-
-var _ stack.LinkEndpoint = (*channelEndpointWithoutWritePacket)(nil)
-
-// channelEndpointWithoutWritePacket is a channel endpoint that does not support
-// stack.LinkEndpoint.WritePacket.
-type channelEndpointWithoutWritePacket struct {
- *channel.Endpoint
-
- t *testing.T
-}
-
-func (c *channelEndpointWithoutWritePacket) WritePacket(stack.RouteInfo, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
- c.t.Error("unexpectedly called WritePacket; all writes should go through WritePackets")
- return &tcpip.ErrNotSupported{}
-}
-
-var _ stack.Matcher = (*udpSourcePortMatcher)(nil)
-
-type udpSourcePortMatcher struct {
- port uint16
-}
-
-func (*udpSourcePortMatcher) Name() string {
- return "udpSourcePortMatcher"
-}
-
-func (m *udpSourcePortMatcher) Match(_ stack.Hook, pkt *stack.PacketBuffer, _, _ string) (matches, hotdrop bool) {
- udp := header.UDP(pkt.TransportHeader().View())
- if len(udp) < header.UDPMinimumSize {
- // Drop immediately as the packet is invalid.
- return false, true
- }
-
- return udp.SourcePort() == m.port, false
-}
-
-func TestIPTableWritePackets(t *testing.T) {
- const (
- nicID = 1
-
- dropLocalPort = utils.LocalPort - 1
- acceptPackets = 2
- dropPackets = 3
- )
-
- udpHdr := func(hdr buffer.View, srcAddr, dstAddr tcpip.Address, srcPort, dstPort uint16) {
- u := header.UDP(hdr)
- u.Encode(&header.UDPFields{
- SrcPort: srcPort,
- DstPort: dstPort,
- Length: header.UDPMinimumSize,
- })
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, srcAddr, dstAddr, header.UDPMinimumSize)
- sum = header.Checksum(hdr, sum)
- u.SetChecksum(^u.CalculateChecksum(sum))
- }
-
- tests := []struct {
- name string
- setupFilter func(*testing.T, *stack.Stack)
- genPacket func(*stack.Route) stack.PacketBufferList
- proto tcpip.NetworkProtocolNumber
- remoteAddr tcpip.Address
- expectSent uint64
- expectOutputDropped uint64
- }{
- {
- name: "IPv4 Accept",
- setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
- genPacket: func(r *stack.Route) stack.PacketBufferList {
- var pkts stack.PacketBufferList
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
- })
- hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
- udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort)
- pkts.PushFront(pkt)
-
- return pkts
- },
- proto: header.IPv4ProtocolNumber,
- remoteAddr: dstAddrV4,
- expectSent: 1,
- expectOutputDropped: 0,
- },
- {
- name: "IPv4 Drop Other Port",
- setupFilter: func(t *testing.T, s *stack.Stack) {
- t.Helper()
-
- table := stack.Table{
- Rules: []stack.Rule{
- {
- Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
- },
- {
- Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
- },
- {
- Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}},
- Target: &stack.DropTarget{NetworkProtocol: header.IPv4ProtocolNumber},
- },
- {
- Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
- },
- {
- Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber},
- },
- },
- BuiltinChains: [stack.NumHooks]int{
- stack.Prerouting: stack.HookUnset,
- stack.Input: 0,
- stack.Forward: 1,
- stack.Output: 2,
- stack.Postrouting: stack.HookUnset,
- },
- Underflows: [stack.NumHooks]int{
- stack.Prerouting: stack.HookUnset,
- stack.Input: 0,
- stack.Forward: 1,
- stack.Output: 2,
- stack.Postrouting: stack.HookUnset,
- },
- }
-
- if err := s.IPTables().ReplaceTable(stack.FilterID, table, false /* ipv4 */); err != nil {
- t.Fatalf("ReplaceTable(%d, _, false): %s", stack.FilterID, err)
- }
- },
- genPacket: func(r *stack.Route) stack.PacketBufferList {
- var pkts stack.PacketBufferList
-
- for i := 0; i < acceptPackets; i++ {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
- })
- hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
- udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort)
- pkts.PushFront(pkt)
- }
- for i := 0; i < dropPackets; i++ {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
- })
- hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
- udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), dropLocalPort, utils.RemotePort)
- pkts.PushFront(pkt)
- }
-
- return pkts
- },
- proto: header.IPv4ProtocolNumber,
- remoteAddr: dstAddrV4,
- expectSent: acceptPackets,
- expectOutputDropped: dropPackets,
- },
- {
- name: "IPv6 Accept",
- setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
- genPacket: func(r *stack.Route) stack.PacketBufferList {
- var pkts stack.PacketBufferList
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
- })
- hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
- udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort)
- pkts.PushFront(pkt)
-
- return pkts
- },
- proto: header.IPv6ProtocolNumber,
- remoteAddr: dstAddrV6,
- expectSent: 1,
- expectOutputDropped: 0,
- },
- {
- name: "IPv6 Drop Other Port",
- setupFilter: func(t *testing.T, s *stack.Stack) {
- t.Helper()
-
- table := stack.Table{
- Rules: []stack.Rule{
- {
- Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
- },
- {
- Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
- },
- {
- Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}},
- Target: &stack.DropTarget{NetworkProtocol: header.IPv6ProtocolNumber},
- },
- {
- Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
- },
- {
- Target: &stack.ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber},
- },
- },
- BuiltinChains: [stack.NumHooks]int{
- stack.Prerouting: stack.HookUnset,
- stack.Input: 0,
- stack.Forward: 1,
- stack.Output: 2,
- stack.Postrouting: stack.HookUnset,
- },
- Underflows: [stack.NumHooks]int{
- stack.Prerouting: stack.HookUnset,
- stack.Input: 0,
- stack.Forward: 1,
- stack.Output: 2,
- stack.Postrouting: stack.HookUnset,
- },
- }
-
- if err := s.IPTables().ReplaceTable(stack.FilterID, table, true /* ipv6 */); err != nil {
- t.Fatalf("ReplaceTable(%d, _, true): %s", stack.FilterID, err)
- }
- },
- genPacket: func(r *stack.Route) stack.PacketBufferList {
- var pkts stack.PacketBufferList
-
- for i := 0; i < acceptPackets; i++ {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
- })
- hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
- udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort)
- pkts.PushFront(pkt)
- }
- for i := 0; i < dropPackets; i++ {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
- })
- hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
- udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), dropLocalPort, utils.RemotePort)
- pkts.PushFront(pkt)
- }
-
- return pkts
- },
- proto: header.IPv6ProtocolNumber,
- remoteAddr: dstAddrV6,
- expectSent: acceptPackets,
- expectOutputDropped: dropPackets,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
- e := channelEndpointWithoutWritePacket{
- Endpoint: channel.New(4, header.IPv6MinimumMTU, linkAddr),
- t: t,
- }
- if err := s.CreateNIC(nicID, &e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- protocolAddrV6 := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: srcAddrV6.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err)
- }
- protocolAddrV4 := tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: srcAddrV4.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: nicID,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: nicID,
- },
- })
-
- test.setupFilter(t, s)
-
- r, err := s.FindRoute(nicID, "", test.remoteAddr, test.proto, false)
- if err != nil {
- t.Fatalf("FindRoute(%d, '', %s, %d, false): %s", nicID, test.remoteAddr, test.proto, err)
- }
- defer r.Release()
-
- pkts := test.genPacket(r)
- pktsLen := pkts.Len()
- if n, err := r.WritePackets(pkts, stack.NetworkHeaderParams{
- Protocol: header.UDPProtocolNumber,
- TTL: 64,
- }); err != nil {
- t.Fatalf("WritePackets(...): %s", err)
- } else if n != pktsLen {
- t.Fatalf("got WritePackets(...) = %d, want = %d", n, pktsLen)
- }
-
- if got := s.Stats().IP.PacketsSent.Value(); got != test.expectSent {
- t.Errorf("got PacketSent = %d, want = %d", got, test.expectSent)
- }
- if got := s.Stats().IP.IPTablesOutputDropped.Value(); got != test.expectOutputDropped {
- t.Errorf("got IPTablesOutputDropped = %d, want = %d", got, test.expectOutputDropped)
- }
- })
- }
-}
-
-const ttl = 64
-
-var (
- ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10")
- ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a")
-)
-
-func rxICMPv4EchoReply(e *channel.Endpoint, src, dst tcpip.Address) {
- utils.RxICMPv4EchoReply(e, src, dst, ttl)
-}
-
-func rxICMPv6EchoReply(e *channel.Endpoint, src, dst tcpip.Address) {
- utils.RxICMPv6EchoReply(e, src, dst, ttl)
-}
-
-func forwardedICMPv4EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
- checker.IPv4(t, b,
- checker.SrcAddr(src),
- checker.DstAddr(dst),
- checker.TTL(ttl-1),
- checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4EchoReply)))
-}
-
-func forwardedICMPv6EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Address) {
- checker.IPv6(t, b,
- checker.SrcAddr(src),
- checker.DstAddr(dst),
- checker.TTL(ttl-1),
- checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6EchoReply)))
-}
-
-func boolToInt(v bool) uint64 {
- if v {
- return 1
- }
- return 0
-}
-
-func setupDropFilter(hook stack.Hook, f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) {
- return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) {
- t.Helper()
-
- ipv6 := netProto == ipv6.ProtocolNumber
-
- ipt := s.IPTables()
- filter := ipt.GetTable(stack.FilterID, ipv6)
- ruleIdx := filter.BuiltinChains[hook]
- filter.Rules[ruleIdx].Filter = f
- filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto}
- // Make sure the packet is not dropped by the next rule.
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto}
- if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil {
- t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err)
- }
- }
-}
-
-func TestForwardingHook(t *testing.T) {
- const (
- nicID1 = 1
- nicID2 = 2
-
- nic1Name = "nic1"
- nic2Name = "nic2"
-
- otherNICName = "otherNIC"
- )
-
- tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- local bool
- srcAddr, dstAddr tcpip.Address
- rx func(*channel.Endpoint, tcpip.Address, tcpip.Address)
- checker func(*testing.T, []byte)
- }{
- {
- name: "IPv4 remote",
- netProto: ipv4.ProtocolNumber,
- local: false,
- srcAddr: utils.RemoteIPv4Addr,
- dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
- rx: rxICMPv4EchoReply,
- checker: func(t *testing.T, b []byte) {
- forwardedICMPv4EchoReplyChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
- },
- },
- {
- name: "IPv4 local",
- netProto: ipv4.ProtocolNumber,
- local: true,
- srcAddr: utils.RemoteIPv4Addr,
- dstAddr: utils.Ipv4Addr.Address,
- rx: rxICMPv4EchoReply,
- },
- {
- name: "IPv6 remote",
- netProto: ipv6.ProtocolNumber,
- local: false,
- srcAddr: utils.RemoteIPv6Addr,
- dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
- rx: rxICMPv6EchoReply,
- checker: func(t *testing.T, b []byte) {
- forwardedICMPv6EchoReplyChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
- },
- },
- {
- name: "IPv6 local",
- netProto: ipv6.ProtocolNumber,
- local: true,
- srcAddr: utils.RemoteIPv6Addr,
- dstAddr: utils.Ipv6Addr.Address,
- rx: rxICMPv6EchoReply,
- },
- }
-
- subTests := []struct {
- name string
- setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber)
- expectForward bool
- }{
- {
- name: "Accept",
- setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ },
- expectForward: true,
- },
-
- {
- name: "Drop",
- setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{}),
- expectForward: false,
- },
- {
- name: "Drop with input NIC filtering",
- setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name}),
- expectForward: false,
- },
- {
- name: "Drop with output NIC filtering",
- setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name}),
- expectForward: false,
- },
- {
- name: "Drop with input and output NIC filtering",
- setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}),
- expectForward: false,
- },
-
- {
- name: "Drop with other input NIC filtering",
- setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName}),
- expectForward: true,
- },
- {
- name: "Drop with other output NIC filtering",
- setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: otherNICName}),
- expectForward: true,
- },
- {
- name: "Drop with other input and output NIC filtering",
- setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}),
- expectForward: true,
- },
- {
- name: "Drop with input and other output NIC filtering",
- setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}),
- expectForward: true,
- },
- {
- name: "Drop with other input and other output NIC filtering",
- setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}),
- expectForward: true,
- },
-
- {
- name: "Drop with inverted input NIC filtering",
- setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}),
- expectForward: true,
- },
- {
- name: "Drop with inverted output NIC filtering",
- setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}),
- expectForward: true,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- for _, subTest := range subTests {
- t.Run(subTest.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- })
-
- subTest.setupFilter(t, s, test.netProto)
-
- e1 := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil {
- t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err)
- }
-
- e2 := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil {
- t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
- }
-
- protocolAddrV4 := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: utils.Ipv4Addr.Address.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
- }
- protocolAddrV6 := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: utils.Ipv6Addr.Address.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
- }
-
- if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
- t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err)
- }
- if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
- t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: nicID2,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: nicID2,
- },
- })
-
- test.rx(e1, test.srcAddr, test.dstAddr)
-
- expectTransmitPacket := subTest.expectForward && !test.local
-
- ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto)
- if err != nil {
- t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err)
- }
- ep1Stats := ep1.Stats()
- ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats)
- if !ok {
- t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats)
- }
- ip1Stats := ipEP1Stats.IPStats()
-
- if got := ip1Stats.PacketsReceived.Value(); got != 1 {
- t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got)
- }
- if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 {
- t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got)
- }
- if got, want := ip1Stats.IPTablesForwardDropped.Value(), boolToInt(!subTest.expectForward); got != want {
- t.Errorf("got ip1Stats.IPTablesForwardDropped.Value() = %d, want = %d", got, want)
- }
- if got := ip1Stats.PacketsSent.Value(); got != 0 {
- t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = 0", got)
- }
-
- ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto)
- if err != nil {
- t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err)
- }
- ep2Stats := ep2.Stats()
- ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats)
- if !ok {
- t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats)
- }
- ip2Stats := ipEP2Stats.IPStats()
- if got := ip2Stats.PacketsReceived.Value(); got != 0 {
- t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got)
- }
- if got, want := ip2Stats.ValidPacketsReceived.Value(), boolToInt(subTest.expectForward && test.local); got != want {
- t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = %d", got, want)
- }
- if got, want := ip2Stats.PacketsSent.Value(), boolToInt(expectTransmitPacket); got != want {
- t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = %d", got, want)
- }
-
- p, ok := e2.Read()
- if ok != expectTransmitPacket {
- t.Fatalf("got e2.Read() = (%#v, %t), want = (_, %t)", p, ok, expectTransmitPacket)
- }
- if expectTransmitPacket {
- test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
- }
- })
- }
- })
- }
-}
-
-func TestInputHookWithLocalForwarding(t *testing.T) {
- const (
- nicID1 = 1
- nicID2 = 2
-
- nic1Name = "nic1"
- nic2Name = "nic2"
-
- otherNICName = "otherNIC"
- )
-
- tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- rx func(*channel.Endpoint)
- checker func(*testing.T, []byte)
- }{
- {
- name: "IPv4",
- netProto: ipv4.ProtocolNumber,
- rx: func(e *channel.Endpoint) {
- utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address, ttl)
- },
- checker: func(t *testing.T, b []byte) {
- checker.IPv4(t, b,
- checker.SrcAddr(utils.Ipv4Addr2.AddressWithPrefix.Address),
- checker.DstAddr(utils.RemoteIPv4Addr),
- checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4EchoReply)))
- },
- },
- {
- name: "IPv6",
- netProto: ipv6.ProtocolNumber,
- rx: func(e *channel.Endpoint) {
- utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address, ttl)
- },
- checker: func(t *testing.T, b []byte) {
- checker.IPv6(t, b,
- checker.SrcAddr(utils.Ipv6Addr2.AddressWithPrefix.Address),
- checker.DstAddr(utils.RemoteIPv6Addr),
- checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6EchoReply)))
- },
- },
- }
-
- subTests := []struct {
- name string
- setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber)
- expectDrop bool
- }{
- {
- name: "Accept",
- setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ },
- expectDrop: false,
- },
-
- {
- name: "Drop",
- setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{}),
- expectDrop: true,
- },
- {
- name: "Drop with input NIC filtering on arrival NIC",
- setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic1Name}),
- expectDrop: true,
- },
- {
- name: "Drop with input NIC filtering on delivered NIC",
- setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic2Name}),
- expectDrop: false,
- },
-
- {
- name: "Drop with input NIC filtering on other NIC",
- setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: otherNICName}),
- expectDrop: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- for _, subTest := range subTests {
- t.Run(subTest.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- })
-
- subTest.setupFilter(t, s, test.netProto)
-
- e1 := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil {
- t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err)
- }
- if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv4Addr1, err)
- }
- if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv6Addr1, err)
- }
-
- e2 := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil {
- t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
- }
- if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv4Addr2, err)
- }
- if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv6Addr2, err)
- }
-
- if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
- t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err)
- }
- if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
- t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: nicID1,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: nicID1,
- },
- })
-
- test.rx(e1)
-
- ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto)
- if err != nil {
- t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err)
- }
- ep1Stats := ep1.Stats()
- ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats)
- if !ok {
- t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats)
- }
- ip1Stats := ipEP1Stats.IPStats()
-
- if got := ip1Stats.PacketsReceived.Value(); got != 1 {
- t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got)
- }
- if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 {
- t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got)
- }
- if got, want := ip1Stats.PacketsSent.Value(), boolToInt(!subTest.expectDrop); got != want {
- t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = %d", got, want)
- }
-
- ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto)
- if err != nil {
- t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err)
- }
- ep2Stats := ep2.Stats()
- ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats)
- if !ok {
- t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats)
- }
- ip2Stats := ipEP2Stats.IPStats()
- if got := ip2Stats.PacketsReceived.Value(); got != 0 {
- t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got)
- }
- if got := ip2Stats.ValidPacketsReceived.Value(); got != 1 {
- t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = 1", got)
- }
- if got, want := ip2Stats.IPTablesInputDropped.Value(), boolToInt(subTest.expectDrop); got != want {
- t.Errorf("got ip2Stats.IPTablesInputDropped.Value() = %d, want = %d", got, want)
- }
- if got := ip2Stats.PacketsSent.Value(); got != 0 {
- t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = 0", got)
- }
-
- if p, ok := e1.Read(); ok == subTest.expectDrop {
- t.Errorf("got e1.Read() = (%#v, %t), want = (_, %t)", p, ok, !subTest.expectDrop)
- } else if !subTest.expectDrop {
- test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
- }
- if p, ok := e2.Read(); ok {
- t.Errorf("got e1.Read() = (%#v, true), want = (_, false)", p)
- }
- })
- }
- })
- }
-}
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
deleted file mode 100644
index 95ddd8ec3..000000000
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ /dev/null
@@ -1,1640 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package link_resolution_test
-
-import (
- "bytes"
- "fmt"
- "net"
- "runtime"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "github.com/google/go-cmp/cmp/cmpopts"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/pipe"
- "gvisor.dev/gvisor/pkg/tcpip/network/arp"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/tests/utils"
- tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-func setupStack(t *testing.T, stackOpts stack.Options, host1NICID, host2NICID tcpip.NICID) (*stack.Stack, *stack.Stack) {
- host1Stack := stack.New(stackOpts)
- host2Stack := stack.New(stackOpts)
-
- host1NIC, host2NIC := pipe.New(utils.LinkAddr1, utils.LinkAddr2)
-
- if err := host1Stack.CreateNIC(host1NICID, utils.NewEthernetEndpoint(host1NIC)); err != nil {
- t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
- }
- if err := host2Stack.CreateNIC(host2NICID, utils.NewEthernetEndpoint(host2NIC)); err != nil {
- t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
- }
-
- if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", host1NICID, utils.Ipv4Addr1, err)
- }
- if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", host2NICID, utils.Ipv4Addr2, err)
- }
- if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", host1NICID, utils.Ipv6Addr1, err)
- }
- if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", host2NICID, utils.Ipv6Addr2, err)
- }
-
- host1Stack.SetRouteTable([]tcpip.Route{
- {
- Destination: utils.Ipv4Addr1.AddressWithPrefix.Subnet(),
- NIC: host1NICID,
- },
- {
- Destination: utils.Ipv6Addr1.AddressWithPrefix.Subnet(),
- NIC: host1NICID,
- },
- })
- host2Stack.SetRouteTable([]tcpip.Route{
- {
- Destination: utils.Ipv4Addr2.AddressWithPrefix.Subnet(),
- NIC: host2NICID,
- },
- {
- Destination: utils.Ipv6Addr2.AddressWithPrefix.Subnet(),
- NIC: host2NICID,
- },
- })
-
- return host1Stack, host2Stack
-}
-
-// TestPing tests that two hosts can ping eachother when link resolution is
-// enabled.
-func TestPing(t *testing.T) {
- const (
- host1NICID = 1
- host2NICID = 4
-
- // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo
- // request/reply packets.
- icmpDataOffset = 8
- )
-
- tests := []struct {
- name string
- transProto tcpip.TransportProtocolNumber
- netProto tcpip.NetworkProtocolNumber
- remoteAddr tcpip.Address
- icmpBuf func(*testing.T) []byte
- }{
- {
- name: "IPv4 Ping",
- transProto: icmp.ProtocolNumber4,
- netProto: ipv4.ProtocolNumber,
- remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
- icmpBuf: func(t *testing.T) []byte {
- data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
- hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data)))
- hdr.SetType(header.ICMPv4Echo)
- if n := copy(hdr.Payload(), data[:]); n != len(data) {
- t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
- }
- return hdr
- },
- },
- {
- name: "IPv6 Ping",
- transProto: icmp.ProtocolNumber6,
- netProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
- icmpBuf: func(t *testing.T) []byte {
- data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
- hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data)))
- hdr.SetType(header.ICMPv6EchoRequest)
- if n := copy(hdr.Payload(), data[:]); n != len(data) {
- t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
- }
- return hdr
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- stackOpts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
- }
-
- host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
-
- var wq waiter.Queue
- we, waiterCH := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- ep, err := host1Stack.NewEndpoint(test.transProto, test.netProto, &wq)
- if err != nil {
- t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err)
- }
- defer ep.Close()
-
- icmpBuf := test.icmpBuf(t)
- var r bytes.Reader
- r.Reset(icmpBuf)
- wOpts := tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: test.remoteAddr}}
- if n, err := ep.Write(&r, wOpts); err != nil {
- t.Fatalf("ep.Write(_, _): %s", err)
- } else if want := int64(len(icmpBuf)); n != want {
- t.Fatalf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want)
- }
-
- // Wait for the endpoint to be readable.
- <-waiterCH
-
- var buf bytes.Buffer
- opts := tcpip.ReadOptions{NeedRemoteAddr: true}
- res, err := ep.Read(&buf, opts)
- if err != nil {
- t.Fatalf("ep.Read(_, %d, %#v): %s", len(icmpBuf), opts, err)
- }
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: buf.Len(),
- Total: buf.Len(),
- RemoteAddr: tcpip.FullAddress{Addr: test.remoteAddr},
- }, res, checker.IgnoreCmpPath(
- "ControlMessages",
- "RemoteAddr.NIC",
- "RemoteAddr.Port",
- )); diff != "" {
- t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(buf.Bytes()[icmpDataOffset:], icmpBuf[icmpDataOffset:]); diff != "" {
- t.Errorf("received data mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
-
-type transportError struct {
- origin tcpip.SockErrOrigin
- typ uint8
- code uint8
- info uint32
- kind stack.TransportErrorKind
-}
-
-func TestTCPLinkResolutionFailure(t *testing.T) {
- const (
- host1NICID = 1
- host2NICID = 4
- )
-
- tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- remoteAddr tcpip.Address
- expectedWriteErr tcpip.Error
- sockError tcpip.SockError
- transErr transportError
- }{
- {
- name: "IPv4 with resolvable remote",
- netProto: ipv4.ProtocolNumber,
- remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
- expectedWriteErr: nil,
- },
- {
- name: "IPv6 with resolvable remote",
- netProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
- expectedWriteErr: nil,
- },
- {
- name: "IPv4 without resolvable remote",
- netProto: ipv4.ProtocolNumber,
- remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address,
- expectedWriteErr: &tcpip.ErrNoRoute{},
- sockError: tcpip.SockError{
- Err: &tcpip.ErrNoRoute{},
- Dst: tcpip.FullAddress{
- NIC: host1NICID,
- Addr: utils.Ipv4Addr3.AddressWithPrefix.Address,
- Port: 1234,
- },
- Offender: tcpip.FullAddress{
- NIC: host1NICID,
- Addr: utils.Ipv4Addr1.AddressWithPrefix.Address,
- },
- NetProto: ipv4.ProtocolNumber,
- },
- transErr: transportError{
- origin: tcpip.SockExtErrorOriginICMP,
- typ: uint8(header.ICMPv4DstUnreachable),
- code: uint8(header.ICMPv4HostUnreachable),
- kind: stack.DestinationHostUnreachableTransportError,
- },
- },
- {
- name: "IPv6 without resolvable remote",
- netProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address,
- expectedWriteErr: &tcpip.ErrNoRoute{},
- sockError: tcpip.SockError{
- Err: &tcpip.ErrNoRoute{},
- Dst: tcpip.FullAddress{
- NIC: host1NICID,
- Addr: utils.Ipv6Addr3.AddressWithPrefix.Address,
- Port: 1234,
- },
- Offender: tcpip.FullAddress{
- NIC: host1NICID,
- Addr: utils.Ipv6Addr1.AddressWithPrefix.Address,
- },
- NetProto: ipv6.ProtocolNumber,
- },
- transErr: transportError{
- origin: tcpip.SockExtErrorOriginICMP6,
- typ: uint8(header.ICMPv6DstUnreachable),
- code: uint8(header.ICMPv6AddressUnreachable),
- kind: stack.DestinationHostUnreachableTransportError,
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- stackOpts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
- Clock: clock,
- }
-
- host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID)
-
- var listenerWQ waiter.Queue
- listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, test.netProto, &listenerWQ)
- if err != nil {
- t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, test.netProto, err)
- }
- defer listenerEP.Close()
-
- listenerAddr := tcpip.FullAddress{Port: 1234}
- if err := listenerEP.Bind(listenerAddr); err != nil {
- t.Fatalf("listenerEP.Bind(%#v): %s", listenerAddr, err)
- }
-
- if err := listenerEP.Listen(1); err != nil {
- t.Fatalf("listenerEP.Listen(1): %s", err)
- }
-
- var clientWQ waiter.Queue
- we, ch := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&we, waiter.WritableEvents|waiter.EventErr)
- clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, test.netProto, &clientWQ)
- if err != nil {
- t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, test.netProto, err)
- }
- defer clientEP.Close()
-
- sockOpts := clientEP.SocketOptions()
- sockOpts.SetRecvError(true)
-
- remoteAddr := listenerAddr
- remoteAddr.Addr = test.remoteAddr
- {
- err := clientEP.Connect(remoteAddr)
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", remoteAddr, err, &tcpip.ErrConnectStarted{})
- }
- }
-
- // Wait for an error due to link resolution failing, or the endpoint to be
- // writable.
- if test.expectedWriteErr != nil {
- nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto)
- if err != nil {
- t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err)
- }
- clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer)
- } else {
- clock.RunImmediatelyScheduledJobs()
- }
- <-ch
-
- {
- var r bytes.Reader
- r.Reset([]byte{0})
- var wOpts tcpip.WriteOptions
- _, err := clientEP.Write(&r, wOpts)
- if diff := cmp.Diff(test.expectedWriteErr, err); diff != "" {
- t.Errorf("unexpected error from clientEP.Write(_, %#v), (-want, +got):\n%s", wOpts, diff)
- }
- }
-
- if test.expectedWriteErr == nil {
- return
- }
-
- sockErr := sockOpts.DequeueErr()
- if sockErr == nil {
- t.Fatalf("got sockOpts.DequeueErr() = nil, want = non-nil")
- }
-
- sockErrCmpOpts := []cmp.Option{
- cmpopts.IgnoreUnexported(tcpip.SockError{}),
- cmp.Comparer(func(a, b tcpip.Error) bool {
- // tcpip.Error holds an unexported field but the errors netstack uses
- // are pre defined so we can simply compare pointers.
- return a == b
- }),
- checker.IgnoreCmpPath(
- // Ignore the payload since we do not know the TCP seq/ack numbers.
- "Payload",
- // Ignore the cause since we will compare its properties separately
- // since the concrete type of the cause is unknown.
- "Cause",
- ),
- }
-
- if addr, err := clientEP.GetLocalAddress(); err != nil {
- t.Fatalf("clientEP.GetLocalAddress(): %s", err)
- } else {
- test.sockError.Offender.Port = addr.Port
- }
- if diff := cmp.Diff(&test.sockError, sockErr, sockErrCmpOpts...); diff != "" {
- t.Errorf("socket error mismatch (-want +got):\n%s", diff)
- }
-
- transErr, ok := sockErr.Cause.(stack.TransportError)
- if !ok {
- t.Fatalf("socket error cause is not a transport error; cause = %#v", sockErr.Cause)
- }
- if diff := cmp.Diff(
- test.transErr,
- transportError{
- origin: transErr.Origin(),
- typ: transErr.Type(),
- code: transErr.Code(),
- info: transErr.Info(),
- kind: transErr.Kind(),
- },
- cmp.AllowUnexported(transportError{}),
- ); diff != "" {
- t.Errorf("socket error mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
-
-func TestForwardingWithLinkResolutionFailure(t *testing.T) {
- const (
- incomingNICID = 1
- outgoingNICID = 2
- ttl = 2
- expectedHostUnreachableErrorCount = 1
- )
- outgoingLinkAddr := tcptestutil.MustParseLink("02:03:03:04:05:06")
-
- rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
- utils.RxICMPv4EchoRequest(e, src, dst, ttl)
- }
-
- rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
- utils.RxICMPv6EchoRequest(e, src, dst, ttl)
- }
-
- arpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) {
- if request.Proto != arp.ProtocolNumber {
- t.Errorf("got request.Proto = %d, want = %d", request.Proto, arp.ProtocolNumber)
- }
- if request.Route.RemoteLinkAddress != header.EthernetBroadcastAddress {
- t.Errorf("got request.Route.RemoteLinkAddress = %s, want = %s", request.Route.RemoteLinkAddress, header.EthernetBroadcastAddress)
- }
- rep := header.ARP(request.Pkt.NetworkHeader().View())
- if got := rep.Op(); got != header.ARPRequest {
- t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest)
- }
- if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != outgoingLinkAddr {
- t.Errorf("got HardwareAddressSender = %s, want = %s", got, outgoingLinkAddr)
- }
- if got := tcpip.Address(rep.ProtocolAddressSender()); got != src {
- t.Errorf("got ProtocolAddressSender = %s, want = %s", got, src)
- }
- if got := tcpip.Address(rep.ProtocolAddressTarget()); got != dst {
- t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, dst)
- }
- }
-
- ndpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) {
- if request.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", request.Proto, header.IPv6ProtocolNumber)
- }
-
- snmc := header.SolicitedNodeAddr(dst)
- if want := header.EthernetAddressFromMulticastIPv6Address(snmc); request.Route.RemoteLinkAddress != want {
- t.Errorf("got remote link address = %s, want = %s", request.Route.RemoteLinkAddress, want)
- }
-
- checker.IPv6(t, stack.PayloadSince(request.Pkt.NetworkHeader()),
- checker.SrcAddr(src),
- checker.DstAddr(snmc),
- checker.TTL(header.NDPHopLimit),
- checker.NDPNS(
- checker.NDPNSTargetAddress(dst),
- ))
- }
-
- icmpv4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
- checker.IPv4(t, b,
- checker.SrcAddr(src),
- checker.DstAddr(dst),
- checker.TTL(ipv4.DefaultTTL),
- checker.ICMPv4(
- checker.ICMPv4Checksum(),
- checker.ICMPv4Type(header.ICMPv4DstUnreachable),
- checker.ICMPv4Code(header.ICMPv4HostUnreachable),
- ),
- )
- }
-
- icmpv6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
- checker.IPv6(t, b,
- checker.SrcAddr(src),
- checker.DstAddr(dst),
- checker.TTL(ipv6.DefaultTTL),
- checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6DstUnreachable),
- checker.ICMPv6Code(header.ICMPv6AddressUnreachable),
- ),
- )
- }
-
- tests := []struct {
- name string
- networkProtocolFactory []stack.NetworkProtocolFactory
- networkProtocolNumber tcpip.NetworkProtocolNumber
- sourceAddr tcpip.Address
- destAddr tcpip.Address
- incomingAddr tcpip.AddressWithPrefix
- outgoingAddr tcpip.AddressWithPrefix
- transportProtocol func(*stack.Stack) stack.TransportProtocol
- rx func(*channel.Endpoint, tcpip.Address, tcpip.Address)
- linkResolutionRequestChecker func(*testing.T, channel.PacketInfo, tcpip.Address, tcpip.Address)
- icmpReplyChecker func(*testing.T, []byte, tcpip.Address, tcpip.Address)
- mtu uint32
- }{
- {
- name: "IPv4 Host unreachable",
- networkProtocolFactory: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
- networkProtocolNumber: header.IPv4ProtocolNumber,
- sourceAddr: tcptestutil.MustParse4("10.0.0.2"),
- destAddr: tcptestutil.MustParse4("11.0.0.2"),
- incomingAddr: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()),
- PrefixLen: 8,
- },
- outgoingAddr: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()),
- PrefixLen: 8,
- },
- transportProtocol: icmp.NewProtocol4,
- linkResolutionRequestChecker: arpChecker,
- icmpReplyChecker: icmpv4Checker,
- rx: rxICMPv4EchoRequest,
- mtu: ipv4.MaxTotalSize,
- },
- {
- name: "IPv6 Host unreachable",
- networkProtocolFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
- networkProtocolNumber: header.IPv6ProtocolNumber,
- sourceAddr: tcptestutil.MustParse6("10::2"),
- destAddr: tcptestutil.MustParse6("11::2"),
- incomingAddr: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("10::1").To16()),
- PrefixLen: 64,
- },
- outgoingAddr: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("11::1").To16()),
- PrefixLen: 64,
- },
- transportProtocol: icmp.NewProtocol6,
- linkResolutionRequestChecker: ndpChecker,
- icmpReplyChecker: icmpv6Checker,
- rx: rxICMPv6EchoRequest,
- mtu: header.IPv6MinimumMTU,
- },
- }
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
-
- s := stack.New(stack.Options{
- NetworkProtocols: test.networkProtocolFactory,
- TransportProtocols: []stack.TransportProtocolFactory{test.transportProtocol},
- Clock: clock,
- })
-
- // Set up endpoint through which we will receive packets.
- incomingEndpoint := channel.New(1, test.mtu, "")
- if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err)
- }
- incomingProtoAddr := tcpip.ProtocolAddress{
- Protocol: test.networkProtocolNumber,
- AddressWithPrefix: test.incomingAddr,
- }
- if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingProtoAddr, err)
- }
-
- // Set up endpoint through which we will attempt to forward packets.
- outgoingEndpoint := channel.New(1, test.mtu, outgoingLinkAddr)
- outgoingEndpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err)
- }
- outgoingProtoAddr := tcpip.ProtocolAddress{
- Protocol: test.networkProtocolNumber,
- AddressWithPrefix: test.outgoingAddr,
- }
- if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingProtoAddr, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: test.incomingAddr.Subnet(),
- NIC: incomingNICID,
- },
- {
- Destination: test.outgoingAddr.Subnet(),
- NIC: outgoingNICID,
- },
- })
-
- if err := s.SetForwardingDefaultAndAllNICs(test.networkProtocolNumber, true); err != nil {
- t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", test.networkProtocolNumber, err)
- }
-
- test.rx(incomingEndpoint, test.sourceAddr, test.destAddr)
-
- nudConfigs, err := s.NUDConfigurations(outgoingNICID, test.networkProtocolNumber)
- if err != nil {
- t.Fatalf("s.NUDConfigurations(%d, %d): %s", outgoingNICID, test.networkProtocolNumber, err)
- }
- // Trigger the first packet on the endpoint.
- clock.RunImmediatelyScheduledJobs()
-
- for i := 0; i < int(nudConfigs.MaxMulticastProbes); i++ {
- request, ok := outgoingEndpoint.Read()
- if !ok {
- t.Fatal("expected ARP packet through outgoing NIC")
- }
-
- test.linkResolutionRequestChecker(t, request, test.outgoingAddr.Address, test.destAddr)
-
- // Advance the clock the span of one request timeout.
- clock.Advance(nudConfigs.RetransmitTimer)
- }
-
- // Next, we make a blocking read to retrieve the error packet. This is
- // necessary because outgoing packets are dequeued asynchronously when
- // link resolution fails, and this dequeue is what triggers the ICMP
- // error.
- reply, ok := incomingEndpoint.Read()
- if !ok {
- t.Fatal("expected ICMP packet through incoming NIC")
- }
-
- test.icmpReplyChecker(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), test.incomingAddr.Address, test.sourceAddr)
-
- // Since link resolution failed, we don't expect the packet to be
- // forwarded.
- forwardedPacket, ok := outgoingEndpoint.Read()
- if ok {
- t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", forwardedPacket)
- }
-
- if got, want := s.Stats().IP.Forwarding.HostUnreachable.Value(), expectedHostUnreachableErrorCount; int(got) != want {
- t.Errorf("got rt.Stats().IP.Forwarding.HostUnreachable.Value() = %d, want = %d", got, want)
- }
- })
- }
-}
-
-func TestGetLinkAddress(t *testing.T) {
- const (
- host1NICID = 1
- host2NICID = 4
- )
-
- tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- remoteAddr, localAddr tcpip.Address
- expectedErr tcpip.Error
- }{
- {
- name: "IPv4 resolvable",
- netProto: ipv4.ProtocolNumber,
- remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
- expectedErr: nil,
- },
- {
- name: "IPv6 resolvable",
- netProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
- expectedErr: nil,
- },
- {
- name: "IPv4 not resolvable",
- netProto: ipv4.ProtocolNumber,
- remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address,
- expectedErr: &tcpip.ErrTimeout{},
- },
- {
- name: "IPv6 not resolvable",
- netProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address,
- expectedErr: &tcpip.ErrTimeout{},
- },
- {
- name: "IPv4 bad local address",
- netProto: ipv4.ProtocolNumber,
- remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
- localAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
- expectedErr: &tcpip.ErrBadLocalAddress{},
- },
- {
- name: "IPv6 bad local address",
- netProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
- localAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
- expectedErr: &tcpip.ErrBadLocalAddress{},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- stackOpts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
- Clock: clock,
- }
-
- host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
-
- ch := make(chan stack.LinkResolutionResult, 1)
- err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, test.localAddr, test.netProto, func(r stack.LinkResolutionResult) {
- ch <- r
- })
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, &tcpip.ErrWouldBlock{})
- }
- wantRes := stack.LinkResolutionResult{Err: test.expectedErr}
- if test.expectedErr == nil {
- wantRes.LinkAddress = utils.LinkAddr2
- }
-
- nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto)
- if err != nil {
- t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err)
- }
-
- clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer)
- select {
- case got := <-ch:
- if diff := cmp.Diff(wantRes, got); diff != "" {
- t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("event didn't arrive")
- }
- })
- }
-}
-
-func TestRouteResolvedFields(t *testing.T) {
- const (
- host1NICID = 1
- host2NICID = 4
- )
-
- tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- localAddr tcpip.Address
- remoteAddr tcpip.Address
- immediatelyResolvable bool
- expectedErr tcpip.Error
- expectedLinkAddr tcpip.LinkAddress
- }{
- {
- name: "IPv4 immediately resolvable",
- netProto: ipv4.ProtocolNumber,
- localAddr: utils.Ipv4Addr1.AddressWithPrefix.Address,
- remoteAddr: header.IPv4AllSystems,
- immediatelyResolvable: true,
- expectedErr: nil,
- expectedLinkAddr: header.EthernetAddressFromMulticastIPv4Address(header.IPv4AllSystems),
- },
- {
- name: "IPv6 immediately resolvable",
- netProto: ipv6.ProtocolNumber,
- localAddr: utils.Ipv6Addr1.AddressWithPrefix.Address,
- remoteAddr: header.IPv6AllNodesMulticastAddress,
- immediatelyResolvable: true,
- expectedErr: nil,
- expectedLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress),
- },
- {
- name: "IPv4 resolvable",
- netProto: ipv4.ProtocolNumber,
- localAddr: utils.Ipv4Addr1.AddressWithPrefix.Address,
- remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
- immediatelyResolvable: false,
- expectedErr: nil,
- expectedLinkAddr: utils.LinkAddr2,
- },
- {
- name: "IPv6 resolvable",
- netProto: ipv6.ProtocolNumber,
- localAddr: utils.Ipv6Addr1.AddressWithPrefix.Address,
- remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
- immediatelyResolvable: false,
- expectedErr: nil,
- expectedLinkAddr: utils.LinkAddr2,
- },
- {
- name: "IPv4 not resolvable",
- netProto: ipv4.ProtocolNumber,
- localAddr: utils.Ipv4Addr1.AddressWithPrefix.Address,
- remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address,
- immediatelyResolvable: false,
- expectedErr: &tcpip.ErrTimeout{},
- },
- {
- name: "IPv6 not resolvable",
- netProto: ipv6.ProtocolNumber,
- localAddr: utils.Ipv6Addr1.AddressWithPrefix.Address,
- remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address,
- immediatelyResolvable: false,
- expectedErr: &tcpip.ErrTimeout{},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- stackOpts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
- Clock: clock,
- }
-
- host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
- r, err := host1Stack.FindRoute(host1NICID, test.localAddr, test.remoteAddr, test.netProto, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("host1Stack.FindRoute(%d, %s, %s, %d, false): %s", host1NICID, test.localAddr, test.remoteAddr, test.netProto, err)
- }
- defer r.Release()
-
- var wantRouteInfo stack.RouteInfo
- wantRouteInfo.LocalLinkAddress = utils.LinkAddr1
- wantRouteInfo.LocalAddress = test.localAddr
- wantRouteInfo.RemoteAddress = test.remoteAddr
- wantRouteInfo.NetProto = test.netProto
- wantRouteInfo.Loop = stack.PacketOut
- wantRouteInfo.RemoteLinkAddress = test.expectedLinkAddr
-
- ch := make(chan stack.ResolvedFieldsResult, 1)
-
- if !test.immediatelyResolvable {
- wantUnresolvedRouteInfo := wantRouteInfo
- wantUnresolvedRouteInfo.RemoteLinkAddress = ""
-
- err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) {
- ch <- r
- })
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{})
- }
-
- nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto)
- if err != nil {
- t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err)
- }
- clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer)
-
- select {
- case got := <-ch:
- if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Err: test.expectedErr}, got, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" {
- t.Errorf("route resolve result mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatalf("event didn't arrive")
- }
-
- if test.expectedErr != nil {
- return
- }
-
- // At this point the neighbor table should be populated so the route
- // should be immediately resolvable.
- }
-
- if err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) {
- ch <- r
- }); err != nil {
- t.Errorf("r.ResolvedFields(_): %s", err)
- }
- select {
- case routeResolveRes := <-ch:
- if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Err: nil}, routeResolveRes, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" {
- t.Errorf("route resolve result from resolved route mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected route to be immediately resolvable")
- }
- })
- }
-}
-
-func TestWritePacketsLinkResolution(t *testing.T) {
- const (
- host1NICID = 1
- host2NICID = 4
- )
-
- tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- remoteAddr tcpip.Address
- expectedWriteErr tcpip.Error
- }{
- {
- name: "IPv4",
- netProto: ipv4.ProtocolNumber,
- remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
- expectedWriteErr: nil,
- },
- {
- name: "IPv6",
- netProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
- expectedWriteErr: nil,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- stackOpts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- }
-
- host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID)
-
- var serverWQ waiter.Queue
- serverWE, serverCH := waiter.NewChannelEntry(nil)
- serverWQ.EventRegister(&serverWE, waiter.ReadableEvents)
- serverEP, err := host2Stack.NewEndpoint(udp.ProtocolNumber, test.netProto, &serverWQ)
- if err != nil {
- t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.netProto, err)
- }
- defer serverEP.Close()
-
- serverAddr := tcpip.FullAddress{Port: 1234}
- if err := serverEP.Bind(serverAddr); err != nil {
- t.Fatalf("serverEP.Bind(%#v): %s", serverAddr, err)
- }
-
- r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err)
- }
- defer r.Release()
-
- data := []byte{1, 2}
- var pkts stack.PacketBufferList
- for _, d := range data {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()),
- Data: buffer.View([]byte{d}).ToVectorisedView(),
- })
- pkt.TransportProtocolNumber = udp.ProtocolNumber
- length := uint16(pkt.Size())
- udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
- udpHdr.Encode(&header.UDPFields{
- SrcPort: 5555,
- DstPort: serverAddr.Port,
- Length: length,
- })
- xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length)
- xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
- udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
-
- pkts.PushBack(pkt)
- }
-
- params := stack.NetworkHeaderParams{
- Protocol: udp.ProtocolNumber,
- TTL: 64,
- TOS: stack.DefaultTOS,
- }
-
- if n, err := r.WritePackets(pkts, params); err != nil {
- t.Fatalf("r.WritePackets(_, %#v): %s", params, err)
- } else if want := pkts.Len(); want != n {
- t.Fatalf("got r.WritePackets(_, %#v) = %d, want = %d", params, n, want)
- }
-
- var writer bytes.Buffer
- count := 0
- for {
- var rOpts tcpip.ReadOptions
- res, err := serverEP.Read(&writer, rOpts)
- if err != nil {
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
- // Should not have anymore bytes to read after we read the sent
- // number of bytes.
- if count == len(data) {
- break
- }
-
- <-serverCH
- continue
- }
-
- t.Fatalf("serverEP.Read(_, %#v): %s", rOpts, err)
- }
- count += res.Count
- }
-
- if got, want := host2Stack.Stats().UDP.PacketsReceived.Value(), uint64(len(data)); got != want {
- t.Errorf("got host2Stack.Stats().UDP.PacketsReceived.Value() = %d, want = %d", got, want)
- }
- if diff := cmp.Diff(data, writer.Bytes()); diff != "" {
- t.Errorf("read bytes mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
-
-type eventType int
-
-const (
- entryAdded eventType = iota
- entryChanged
- entryRemoved
-)
-
-func (t eventType) String() string {
- switch t {
- case entryAdded:
- return "add"
- case entryChanged:
- return "change"
- case entryRemoved:
- return "remove"
- default:
- return fmt.Sprintf("unknown (%d)", t)
- }
-}
-
-type eventInfo struct {
- eventType eventType
- nicID tcpip.NICID
- entry stack.NeighborEntry
-}
-
-func (e eventInfo) String() string {
- return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry)
-}
-
-var _ stack.NUDDispatcher = (*nudDispatcher)(nil)
-
-type nudDispatcher struct {
- c chan eventInfo
-}
-
-func (d *nudDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) {
- e := eventInfo{
- eventType: entryAdded,
- nicID: nicID,
- entry: entry,
- }
- d.c <- e
-}
-
-func (d *nudDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) {
- e := eventInfo{
- eventType: entryChanged,
- nicID: nicID,
- entry: entry,
- }
- d.c <- e
-}
-
-func (d *nudDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) {
- e := eventInfo{
- eventType: entryRemoved,
- nicID: nicID,
- entry: entry,
- }
- d.c <- e
-}
-
-func (d *nudDispatcher) expectEvent(want eventInfo) error {
- select {
- case got := <-d.c:
- if diff := cmp.Diff(want, got, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAt")); diff != "" {
- return fmt.Errorf("got invalid event (-want +got):\n%s", diff)
- }
- return nil
- default:
- return fmt.Errorf("event didn't arrive")
- }
-}
-
-// TestTCPConfirmNeighborReachability tests that TCP informs layers beneath it
-// that the neighbor used for a route is reachable.
-func TestTCPConfirmNeighborReachability(t *testing.T) {
- tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- remoteAddr tcpip.Address
- neighborAddr tcpip.Address
- getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{})
- isHost1Listener bool
- }{
- {
- name: "IPv4 active connection through neighbor",
- netProto: ipv4.ProtocolNumber,
- remoteAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address,
- neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
- var listenerWQ waiter.Queue
- listenerWE, listenerCH := waiter.NewChannelEntry(nil)
- listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
- listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
- if err != nil {
- t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
- }
- t.Cleanup(listenerEP.Close)
-
- var clientWQ waiter.Queue
- clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
- clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
- if err != nil {
- t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
- }
-
- return listenerEP, listenerCH, clientEP, clientCH
- },
- },
- {
- name: "IPv6 active connection through neighbor",
- netProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address,
- neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
- var listenerWQ waiter.Queue
- listenerWE, listenerCH := waiter.NewChannelEntry(nil)
- listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
- listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
- if err != nil {
- t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
- }
- t.Cleanup(listenerEP.Close)
-
- var clientWQ waiter.Queue
- clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
- clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
- if err != nil {
- t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
- }
-
- return listenerEP, listenerCH, clientEP, clientCH
- },
- },
- {
- name: "IPv4 active connection to neighbor",
- netProto: ipv4.ProtocolNumber,
- remoteAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
- neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
- var listenerWQ waiter.Queue
- listenerWE, listenerCH := waiter.NewChannelEntry(nil)
- listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
- listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
- if err != nil {
- t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
- }
- t.Cleanup(listenerEP.Close)
-
- var clientWQ waiter.Queue
- clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
- clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
- if err != nil {
- t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
- }
-
- return listenerEP, listenerCH, clientEP, clientCH
- },
- },
- {
- name: "IPv6 active connection to neighbor",
- netProto: ipv6.ProtocolNumber,
- remoteAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
- neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
- var listenerWQ waiter.Queue
- listenerWE, listenerCH := waiter.NewChannelEntry(nil)
- listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
- listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
- if err != nil {
- t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
- }
- t.Cleanup(listenerEP.Close)
-
- var clientWQ waiter.Queue
- clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
- clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
- if err != nil {
- t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
- }
-
- return listenerEP, listenerCH, clientEP, clientCH
- },
- },
- {
- name: "IPv4 passive connection to neighbor",
- netProto: ipv4.ProtocolNumber,
- remoteAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
- neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
- var listenerWQ waiter.Queue
- listenerWE, listenerCH := waiter.NewChannelEntry(nil)
- listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
- listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
- if err != nil {
- t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
- }
- t.Cleanup(listenerEP.Close)
-
- var clientWQ waiter.Queue
- clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
- clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
- if err != nil {
- t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
- }
-
- return listenerEP, listenerCH, clientEP, clientCH
- },
- isHost1Listener: true,
- },
- {
- name: "IPv6 passive connection to neighbor",
- netProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
- neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
- var listenerWQ waiter.Queue
- listenerWE, listenerCH := waiter.NewChannelEntry(nil)
- listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
- listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
- if err != nil {
- t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
- }
- t.Cleanup(listenerEP.Close)
-
- var clientWQ waiter.Queue
- clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
- clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
- if err != nil {
- t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
- }
-
- return listenerEP, listenerCH, clientEP, clientCH
- },
- isHost1Listener: true,
- },
- {
- name: "IPv4 passive connection through neighbor",
- netProto: ipv4.ProtocolNumber,
- remoteAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
- neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
- var listenerWQ waiter.Queue
- listenerWE, listenerCH := waiter.NewChannelEntry(nil)
- listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
- listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
- if err != nil {
- t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
- }
- t.Cleanup(listenerEP.Close)
-
- var clientWQ waiter.Queue
- clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
- clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
- if err != nil {
- t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
- }
-
- return listenerEP, listenerCH, clientEP, clientCH
- },
- isHost1Listener: true,
- },
- {
- name: "IPv6 passive connection through neighbor",
- netProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
- neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
- getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) {
- var listenerWQ waiter.Queue
- listenerWE, listenerCH := waiter.NewChannelEntry(nil)
- listenerWQ.EventRegister(&listenerWE, waiter.EventIn)
- listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
- if err != nil {
- t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
- }
- t.Cleanup(listenerEP.Close)
-
- var clientWQ waiter.Queue
- clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents)
- clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
- if err != nil {
- t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
- }
-
- return listenerEP, listenerCH, clientEP, clientCH
- },
- isHost1Listener: true,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- nudDisp := nudDispatcher{
- c: make(chan eventInfo, 3),
- }
- stackOpts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
- Clock: clock,
- }
- host1StackOpts := stackOpts
- host1StackOpts.NUDDisp = &nudDisp
-
- host1Stack := stack.New(host1StackOpts)
- routerStack := stack.New(stackOpts)
- host2Stack := stack.New(stackOpts)
- utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack)
-
- // Add a reachable dynamic entry to our neighbor table for the remote.
- {
- ch := make(chan stack.LinkResolutionResult, 1)
- err := host1Stack.GetLinkAddress(utils.Host1NICID, test.neighborAddr, "", test.netProto, func(r stack.LinkResolutionResult) {
- ch <- r
- })
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", utils.Host1NICID, test.neighborAddr, test.netProto, err, &tcpip.ErrWouldBlock{})
- }
- if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: utils.LinkAddr2, Err: nil}, <-ch); diff != "" {
- t.Fatalf("link resolution mismatch (-want +got):\n%s", diff)
- }
- }
- if err := nudDisp.expectEvent(eventInfo{
- eventType: entryAdded,
- nicID: utils.Host1NICID,
- entry: stack.NeighborEntry{State: stack.Incomplete, Addr: test.neighborAddr},
- }); err != nil {
- t.Fatalf("error waiting for initial NUD event: %s", err)
- }
- if err := nudDisp.expectEvent(eventInfo{
- eventType: entryChanged,
- nicID: utils.Host1NICID,
- entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
- }); err != nil {
- t.Fatalf("error waiting for reachable NUD event: %s", err)
- }
-
- // Wait for the remote's neighbor entry to be stale before creating a
- // TCP connection from host1 to some remote.
- nudConfigs, err := host1Stack.NUDConfigurations(utils.Host1NICID, test.netProto)
- if err != nil {
- t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", utils.Host1NICID, test.netProto, err)
- }
- // The maximum reachable time for a neighbor is some maximum random factor
- // applied to the base reachable time.
- //
- // See NUDConfigurations.BaseReachableTime for more information.
- maxReachableTime := time.Duration(float32(nudConfigs.BaseReachableTime) * nudConfigs.MaxRandomFactor)
- clock.Advance(maxReachableTime)
- if err := nudDisp.expectEvent(eventInfo{
- eventType: entryChanged,
- nicID: utils.Host1NICID,
- entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
- }); err != nil {
- t.Fatalf("error waiting for stale NUD event: %s", err)
- }
-
- listenerEP, listenerCH, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack)
- defer clientEP.Close()
- listenerAddr := tcpip.FullAddress{Addr: test.remoteAddr, Port: 1234}
- if err := listenerEP.Bind(listenerAddr); err != nil {
- t.Fatalf("listenerEP.Bind(%#v): %s", listenerAddr, err)
- }
- if err := listenerEP.Listen(1); err != nil {
- t.Fatalf("listenerEP.Listen(1): %s", err)
- }
- {
- err := clientEP.Connect(listenerAddr)
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", listenerAddr, err, &tcpip.ErrConnectStarted{})
- }
- }
-
- // Wait for the TCP handshake to complete then make sure the neighbor is
- // reachable without entering the probe state as TCP should provide NUD
- // with confirmation that the neighbor is reachable (indicated by a
- // successful 3-way handshake).
- <-clientCH
- if err := nudDisp.expectEvent(eventInfo{
- eventType: entryChanged,
- nicID: utils.Host1NICID,
- entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
- }); err != nil {
- t.Fatalf("error waiting for delay NUD event: %s", err)
- }
- <-listenerCH
- if err := nudDisp.expectEvent(eventInfo{
- eventType: entryChanged,
- nicID: utils.Host1NICID,
- entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
- }); err != nil {
- t.Fatalf("error waiting for reachable NUD event: %s", err)
- }
-
- peerEP, peerWQ, err := listenerEP.Accept(nil)
- if err != nil {
- t.Fatalf("listenerEP.Accept(): %s", err)
- }
- defer peerEP.Close()
- peerWE, peerCH := waiter.NewChannelEntry(nil)
- peerWQ.EventRegister(&peerWE, waiter.ReadableEvents)
-
- // Wait for the neighbor to be stale again then send data to the remote.
- //
- // On successful transmission, the neighbor should become reachable
- // without probing the neighbor as a TCP ACK would be received which is an
- // indication of the neighbor being reachable.
- clock.Advance(maxReachableTime)
- if err := nudDisp.expectEvent(eventInfo{
- eventType: entryChanged,
- nicID: utils.Host1NICID,
- entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
- }); err != nil {
- t.Fatalf("error waiting for stale NUD event: %s", err)
- }
- {
- var r bytes.Reader
- r.Reset([]byte{0})
- var wOpts tcpip.WriteOptions
- if _, err := clientEP.Write(&r, wOpts); err != nil {
- t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err)
- }
- }
- // Heads up, there is a race here.
- //
- // Incoming TCP segments are handled in
- // tcp.(*endpoint).handleSegmentLocked:
- //
- // - tcp.(*endpoint).rcv.handleRcvdSegment puts the segment on the
- // segment queue and notifies waiting readers (such as this channel)
- //
- // - tcp.(*endpoint).snd.handleRcvdSegment sends an ACK for the segment
- // and notifies the NUD machinery that the peer is reachable
- //
- // Thus we must permit a delay between the readable signal and the
- // expected NUD event.
- //
- // At the time of writing, this race is reliably hit with gotsan.
- <-peerCH
- for len(nudDisp.c) == 0 {
- runtime.Gosched()
- }
- if err := nudDisp.expectEvent(eventInfo{
- eventType: entryChanged,
- nicID: utils.Host1NICID,
- entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
- }); err != nil {
- t.Fatalf("error waiting for delay NUD event: %s", err)
- }
- if test.isHost1Listener {
- // If host1 is not the client, host1 does not send any data so TCP
- // has no way to know it is making forward progress. Because of this,
- // TCP should not mark the route reachable and NUD should go through the
- // probe state.
- clock.Advance(nudConfigs.DelayFirstProbeTime)
- if err := nudDisp.expectEvent(eventInfo{
- eventType: entryChanged,
- nicID: utils.Host1NICID,
- entry: stack.NeighborEntry{State: stack.Probe, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
- }); err != nil {
- t.Fatalf("error waiting for probe NUD event: %s", err)
- }
- }
- {
- var r bytes.Reader
- r.Reset([]byte{0})
- var wOpts tcpip.WriteOptions
- if _, err := peerEP.Write(&r, wOpts); err != nil {
- t.Errorf("peerEP.Write(_, %#v): %s", wOpts, err)
- }
- }
- <-clientCH
- if err := nudDisp.expectEvent(eventInfo{
- eventType: entryChanged,
- nicID: utils.Host1NICID,
- entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2},
- }); err != nil {
- t.Fatalf("error waiting for reachable NUD event: %s", err)
- }
- })
- }
-}
-
-func TestDAD(t *testing.T) {
- dadConfigs := stack.DADConfigurations{
- DupAddrDetectTransmits: 1,
- RetransmitTimer: time.Second,
- }
-
- tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- dadNetProto tcpip.NetworkProtocolNumber
- remoteAddr tcpip.Address
- expectedResult stack.DADResult
- }{
- {
- name: "IPv4 own address",
- netProto: ipv4.ProtocolNumber,
- dadNetProto: arp.ProtocolNumber,
- remoteAddr: utils.Ipv4Addr1.AddressWithPrefix.Address,
- expectedResult: &stack.DADSucceeded{},
- },
- {
- name: "IPv6 own address",
- netProto: ipv6.ProtocolNumber,
- dadNetProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Ipv6Addr1.AddressWithPrefix.Address,
- expectedResult: &stack.DADSucceeded{},
- },
- {
- name: "IPv4 duplicate address",
- netProto: ipv4.ProtocolNumber,
- dadNetProto: arp.ProtocolNumber,
- remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
- expectedResult: &stack.DADDupAddrDetected{HolderLinkAddress: utils.LinkAddr2},
- },
- {
- name: "IPv6 duplicate address",
- netProto: ipv6.ProtocolNumber,
- dadNetProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
- expectedResult: &stack.DADDupAddrDetected{HolderLinkAddress: utils.LinkAddr2},
- },
- {
- name: "IPv4 no duplicate address",
- netProto: ipv4.ProtocolNumber,
- dadNetProto: arp.ProtocolNumber,
- remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address,
- expectedResult: &stack.DADSucceeded{},
- },
- {
- name: "IPv6 no duplicate address",
- netProto: ipv6.ProtocolNumber,
- dadNetProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address,
- expectedResult: &stack.DADSucceeded{},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- stackOpts := stack.Options{
- Clock: clock,
- NetworkProtocols: []stack.NetworkProtocolFactory{
- arp.NewProtocol,
- ipv4.NewProtocol,
- ipv6.NewProtocol,
- },
- }
-
- host1Stack, _ := setupStack(t, stackOpts, utils.Host1NICID, utils.Host2NICID)
-
- // DAD should be disabled by default.
- if res, err := host1Stack.CheckDuplicateAddress(utils.Host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) {
- t.Errorf("unexpectedly called DAD completion handler when DAD was supposed to be disabled")
- }); err != nil {
- t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", utils.Host1NICID, test.netProto, test.remoteAddr, err)
- } else if res != stack.DADDisabled {
- t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", utils.Host1NICID, test.netProto, test.remoteAddr, res, stack.DADDisabled)
- }
-
- // Enable DAD then attempt to check if an address is duplicated.
- netEP, err := host1Stack.GetNetworkEndpoint(utils.Host1NICID, test.dadNetProto)
- if err != nil {
- t.Fatalf("host1Stack.GetNetworkEndpoint(%d, %d): %s", utils.Host1NICID, test.dadNetProto, err)
- }
- dad, ok := netEP.(stack.DuplicateAddressDetector)
- if !ok {
- t.Fatalf("expected %T to implement stack.DuplicateAddressDetector", netEP)
- }
- dad.SetDADConfigurations(dadConfigs)
- ch := make(chan stack.DADResult, 3)
- if res, err := host1Stack.CheckDuplicateAddress(utils.Host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) {
- ch <- r
- }); err != nil {
- t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", utils.Host1NICID, test.netProto, test.remoteAddr, err)
- } else if res != stack.DADStarting {
- t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", utils.Host1NICID, test.netProto, test.remoteAddr, res, stack.DADStarting)
- }
-
- expectResults := 1
- if _, ok := test.expectedResult.(*stack.DADSucceeded); ok {
- const delta = time.Nanosecond
- clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer - delta)
- select {
- case r := <-ch:
- t.Fatalf("unexpectedly got DAD result before the DAD timeout; r = %#v", r)
- default:
- }
-
- // If we expect the resolve to succeed try requesting DAD again on the
- // same address. The handler for the new request should be called once
- // the original DAD request completes.
- expectResults = 2
- if res, err := host1Stack.CheckDuplicateAddress(utils.Host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) {
- ch <- r
- }); err != nil {
- t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", utils.Host1NICID, test.netProto, test.remoteAddr, err)
- } else if res != stack.DADAlreadyRunning {
- t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", utils.Host1NICID, test.netProto, test.remoteAddr, res, stack.DADAlreadyRunning)
- }
-
- clock.Advance(delta)
- }
-
- for i := 0; i < expectResults; i++ {
- if diff := cmp.Diff(test.expectedResult, <-ch); diff != "" {
- t.Errorf("(i=%d) DAD result mismatch (-want +got):\n%s", i, diff)
- }
- }
-
- // Should have no more results.
- select {
- case r := <-ch:
- t.Errorf("unexpectedly got an extra DAD result; r = %#v", r)
- default:
- }
- })
- }
-}
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
deleted file mode 100644
index f33223e79..000000000
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ /dev/null
@@ -1,782 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package loopback_test
-
-import (
- "bytes"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/tests/utils"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil)
-
-type ndpDispatcher struct{}
-
-func (*ndpDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) {
-}
-
-func (*ndpDispatcher) OnOffLinkRouteUpdated(tcpip.NICID, tcpip.Subnet, tcpip.Address, header.NDPRoutePreference) {
-}
-
-func (*ndpDispatcher) OnOffLinkRouteInvalidated(tcpip.NICID, tcpip.Subnet, tcpip.Address) {}
-
-func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) {
-}
-
-func (*ndpDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {}
-
-func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) {
-}
-
-func (*ndpDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {}
-
-func (*ndpDispatcher) OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) {}
-
-func (*ndpDispatcher) OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) {}
-
-func (*ndpDispatcher) OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) {}
-
-func (*ndpDispatcher) OnDHCPv6Configuration(tcpip.NICID, ipv6.DHCPv6ConfigurationFromNDPRA) {}
-
-// TestInitialLoopbackAddresses tests that the loopback interface does not
-// auto-generate a link-local address when it is brought up.
-func TestInitialLoopbackAddresses(t *testing.T) {
- const nicID = 1
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPDisp: &ndpDispatcher{},
- AutoGenLinkLocal: true,
- OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: func(nicID tcpip.NICID, nicName string) string {
- t.Fatalf("should not attempt to get name for NIC with ID = %d; nicName = %s", nicID, nicName)
- return ""
- },
- },
- })},
- })
-
- if err := s.CreateNIC(nicID, loopback.New()); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
-
- nicsInfo := s.NICInfo()
- if nicInfo, ok := nicsInfo[nicID]; !ok {
- t.Fatalf("did not find NIC with ID = %d in s.NICInfo() = %#v", nicID, nicsInfo)
- } else if got := len(nicInfo.ProtocolAddresses); got != 0 {
- t.Fatalf("got len(nicInfo.ProtocolAddresses) = %d, want = 0; nicInfo.ProtocolAddresses = %#v", got, nicInfo.ProtocolAddresses)
- }
-}
-
-// TestLoopbackAcceptAllInSubnetUDP tests that a loopback interface considers
-// itself bound to all addresses in the subnet of an assigned address and UDP
-// traffic is sent/received correctly.
-func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) {
- const (
- nicID = 1
- localPort = 80
- )
-
- data := []byte{1, 2, 3, 4}
-
- ipv4ProtocolAddress := tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: utils.Ipv4Addr,
- }
- ipv4Bytes := []byte(ipv4ProtocolAddress.AddressWithPrefix.Address)
- ipv4Bytes[len(ipv4Bytes)-1]++
- otherIPv4Address := tcpip.Address(ipv4Bytes)
-
- ipv6ProtocolAddress := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: utils.Ipv6Addr,
- }
- ipv6Bytes := []byte(utils.Ipv6Addr.Address)
- ipv6Bytes[len(ipv6Bytes)-1]++
- otherIPv6Address := tcpip.Address(ipv6Bytes)
-
- tests := []struct {
- name string
- addAddress tcpip.ProtocolAddress
- bindAddr tcpip.Address
- dstAddr tcpip.Address
- expectRx bool
- }{
- {
- name: "IPv4 bind to wildcard and send to assigned address",
- addAddress: ipv4ProtocolAddress,
- dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
- expectRx: true,
- },
- {
- name: "IPv4 bind to wildcard and send to other subnet-local address",
- addAddress: ipv4ProtocolAddress,
- dstAddr: otherIPv4Address,
- expectRx: true,
- },
- {
- name: "IPv4 bind to wildcard send to other address",
- addAddress: ipv4ProtocolAddress,
- dstAddr: utils.RemoteIPv4Addr,
- expectRx: false,
- },
- {
- name: "IPv4 bind to other subnet-local address and send to assigned address",
- addAddress: ipv4ProtocolAddress,
- bindAddr: otherIPv4Address,
- dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
- expectRx: false,
- },
- {
- name: "IPv4 bind and send to other subnet-local address",
- addAddress: ipv4ProtocolAddress,
- bindAddr: otherIPv4Address,
- dstAddr: otherIPv4Address,
- expectRx: true,
- },
- {
- name: "IPv4 bind to assigned address and send to other subnet-local address",
- addAddress: ipv4ProtocolAddress,
- bindAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
- dstAddr: otherIPv4Address,
- expectRx: false,
- },
-
- {
- name: "IPv6 bind and send to assigned address",
- addAddress: ipv6ProtocolAddress,
- bindAddr: utils.Ipv6Addr.Address,
- dstAddr: utils.Ipv6Addr.Address,
- expectRx: true,
- },
- {
- name: "IPv6 bind to wildcard and send to other subnet-local address",
- addAddress: ipv6ProtocolAddress,
- dstAddr: otherIPv6Address,
- expectRx: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
- if err := s.CreateNIC(nicID, loopback.New()); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
- if err := s.AddProtocolAddress(nicID, test.addAddress, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, test.addAddress, err)
- }
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: nicID,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: nicID,
- },
- })
-
- var wq waiter.Queue
- rep, err := s.NewEndpoint(udp.ProtocolNumber, test.addAddress.Protocol, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err)
- }
- defer rep.Close()
-
- bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort}
- if err := rep.Bind(bindAddr); err != nil {
- t.Fatalf("rep.Bind(%+v): %s", bindAddr, err)
- }
-
- sep, err := s.NewEndpoint(udp.ProtocolNumber, test.addAddress.Protocol, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err)
- }
- defer sep.Close()
-
- wopts := tcpip.WriteOptions{
- To: &tcpip.FullAddress{
- Addr: test.dstAddr,
- Port: localPort,
- },
- }
- var r bytes.Reader
- r.Reset(data)
- n, err := sep.Write(&r, wopts)
- if err != nil {
- t.Fatalf("sep.Write(_, _): %s", err)
- }
- if want := int64(len(data)); n != want {
- t.Fatalf("got sep.Write(_, _) = (%d, nil), want = (%d, nil)", n, want)
- }
-
- var buf bytes.Buffer
- opts := tcpip.ReadOptions{NeedRemoteAddr: true}
- if res, err := rep.Read(&buf, opts); test.expectRx {
- if err != nil {
- t.Fatalf("rep.Read(_, %#v): %s", opts, err)
- }
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: buf.Len(),
- Total: buf.Len(),
- RemoteAddr: tcpip.FullAddress{
- Addr: test.addAddress.AddressWithPrefix.Address,
- },
- }, res,
- checker.IgnoreCmpPath("ControlMessages", "RemoteAddr.NIC", "RemoteAddr.Port"),
- ); diff != "" {
- t.Errorf("rep.Read: unexpected result (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(data, buf.Bytes()); diff != "" {
- t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
- }
- } else if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got rep.Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), &tcpip.ErrWouldBlock{})
- }
- })
- }
-}
-
-// TestLoopbackSubnetLifetimeBoundToAddr tests that the lifetime of an address
-// in a loopback interface's associated subnet is bound to the permanently bound
-// address.
-func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) {
- const nicID = 1
-
- protoAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: utils.Ipv4Addr,
- }
- addrBytes := []byte(utils.Ipv4Addr.Address)
- addrBytes[len(addrBytes)-1]++
- otherAddr := tcpip.Address(addrBytes)
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- })
- if err := s.CreateNIC(nicID, loopback.New()); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
- if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
- }
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: nicID,
- },
- })
-
- r, err := s.FindRoute(nicID, otherAddr, utils.RemoteIPv4Addr, ipv4.ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, otherAddr, utils.RemoteIPv4Addr, ipv4.ProtocolNumber, err)
- }
- defer r.Release()
-
- params := stack.NetworkHeaderParams{
- Protocol: 111,
- TTL: 64,
- TOS: stack.DefaultTOS,
- }
- data := buffer.View([]byte{1, 2, 3, 4})
- if err := r.WritePacket(params, stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: data.ToVectorisedView(),
- })); err != nil {
- t.Fatalf("r.WritePacket(%#v, _): %s", params, err)
- }
-
- // Removing the address should make the endpoint invalid.
- if err := s.RemoveAddress(nicID, protoAddr.AddressWithPrefix.Address); err != nil {
- t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, protoAddr.AddressWithPrefix.Address, err)
- }
- {
- err := r.WritePacket(params, stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: data.ToVectorisedView(),
- }))
- if _, ok := err.(*tcpip.ErrInvalidEndpointState); !ok {
- t.Fatalf("got r.WritePacket(%#v, _) = %s, want = %s", params, err, &tcpip.ErrInvalidEndpointState{})
- }
- }
-}
-
-// TestLoopbackAcceptAllInSubnetTCP tests that a loopback interface considers
-// itself bound to all addresses in the subnet of an assigned address and TCP
-// traffic is sent/received correctly.
-func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) {
- const (
- nicID = 1
- localPort = 80
- )
-
- ipv4ProtocolAddress := tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: utils.Ipv4Addr,
- }
- ipv4ProtocolAddress.AddressWithPrefix.PrefixLen = 8
- ipv4Bytes := []byte(ipv4ProtocolAddress.AddressWithPrefix.Address)
- ipv4Bytes[len(ipv4Bytes)-1]++
- otherIPv4Address := tcpip.Address(ipv4Bytes)
-
- ipv6ProtocolAddress := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: utils.Ipv6Addr,
- }
- ipv6Bytes := []byte(utils.Ipv6Addr.Address)
- ipv6Bytes[len(ipv6Bytes)-1]++
- otherIPv6Address := tcpip.Address(ipv6Bytes)
-
- tests := []struct {
- name string
- addAddress tcpip.ProtocolAddress
- bindAddr tcpip.Address
- dstAddr tcpip.Address
- expectAccept bool
- }{
- {
- name: "IPv4 bind to wildcard and send to assigned address",
- addAddress: ipv4ProtocolAddress,
- dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
- expectAccept: true,
- },
- {
- name: "IPv4 bind to wildcard and send to other subnet-local address",
- addAddress: ipv4ProtocolAddress,
- dstAddr: otherIPv4Address,
- expectAccept: true,
- },
- {
- name: "IPv4 bind to wildcard send to other address",
- addAddress: ipv4ProtocolAddress,
- dstAddr: utils.RemoteIPv4Addr,
- expectAccept: false,
- },
- {
- name: "IPv4 bind to other subnet-local address and send to assigned address",
- addAddress: ipv4ProtocolAddress,
- bindAddr: otherIPv4Address,
- dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
- expectAccept: false,
- },
- {
- name: "IPv4 bind and send to other subnet-local address",
- addAddress: ipv4ProtocolAddress,
- bindAddr: otherIPv4Address,
- dstAddr: otherIPv4Address,
- expectAccept: true,
- },
- {
- name: "IPv4 bind to assigned address and send to other subnet-local address",
- addAddress: ipv4ProtocolAddress,
- bindAddr: ipv4ProtocolAddress.AddressWithPrefix.Address,
- dstAddr: otherIPv4Address,
- expectAccept: false,
- },
-
- {
- name: "IPv6 bind and send to assigned address",
- addAddress: ipv6ProtocolAddress,
- bindAddr: utils.Ipv6Addr.Address,
- dstAddr: utils.Ipv6Addr.Address,
- expectAccept: true,
- },
- {
- name: "IPv6 bind to wildcard and send to other subnet-local address",
- addAddress: ipv6ProtocolAddress,
- dstAddr: otherIPv6Address,
- expectAccept: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
- })
- if err := s.CreateNIC(nicID, loopback.New()); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
- if err := s.AddProtocolAddress(nicID, test.addAddress, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, test.addAddress, err)
- }
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: nicID,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: nicID,
- },
- })
-
- var wq waiter.Queue
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
- listeningEndpoint, err := s.NewEndpoint(tcp.ProtocolNumber, test.addAddress.Protocol, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err)
- }
- defer listeningEndpoint.Close()
-
- bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort}
- if err := listeningEndpoint.Bind(bindAddr); err != nil {
- t.Fatalf("listeningEndpoint.Bind(%#v): %s", bindAddr, err)
- }
-
- if err := listeningEndpoint.Listen(1); err != nil {
- t.Fatalf("listeningEndpoint.Listen(1): %s", err)
- }
-
- connectingEndpoint, err := s.NewEndpoint(tcp.ProtocolNumber, test.addAddress.Protocol, &wq)
- if err != nil {
- t.Fatalf("s.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err)
- }
- defer connectingEndpoint.Close()
-
- connectAddr := tcpip.FullAddress{
- Addr: test.dstAddr,
- Port: localPort,
- }
- {
- err := connectingEndpoint.Connect(connectAddr)
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- t.Fatalf("connectingEndpoint.Connect(%#v): %s", connectAddr, err)
- }
- }
-
- if !test.expectAccept {
- _, _, err := listeningEndpoint.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got listeningEndpoint.Accept(nil) = %s, want = %s", err, &tcpip.ErrWouldBlock{})
- }
- return
- }
-
- // Wait for the listening endpoint to be "readable". That is, wait for a
- // new connection.
- <-ch
- var addr tcpip.FullAddress
- if _, _, err := listeningEndpoint.Accept(&addr); err != nil {
- t.Fatalf("listeningEndpoint.Accept(nil): %s", err)
- }
- if addr.Addr != test.addAddress.AddressWithPrefix.Address {
- t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.addAddress.AddressWithPrefix.Address)
- }
- })
- }
-}
-
-func TestExternalLoopbackTraffic(t *testing.T) {
- const (
- nicID1 = 1
- nicID2 = 2
-
- numPackets = 1
- ttl = 64
- )
- ipv4Loopback := testutil.MustParse4("127.0.0.1")
-
- loopbackSourcedICMPv4 := func(e *channel.Endpoint) {
- utils.RxICMPv4EchoRequest(e, ipv4Loopback, utils.Ipv4Addr.Address, ttl)
- }
-
- loopbackSourcedICMPv6 := func(e *channel.Endpoint) {
- utils.RxICMPv6EchoRequest(e, header.IPv6Loopback, utils.Ipv6Addr.Address, ttl)
- }
-
- loopbackDestinedICMPv4 := func(e *channel.Endpoint) {
- utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, ipv4Loopback, ttl)
- }
-
- loopbackDestinedICMPv6 := func(e *channel.Endpoint) {
- utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, header.IPv6Loopback, ttl)
- }
-
- invalidSrcAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter {
- return s.InvalidSourceAddressesReceived
- }
-
- invalidDestAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter {
- return s.InvalidDestinationAddressesReceived
- }
-
- tests := []struct {
- name string
- allowExternalLoopback bool
- forwarding bool
- rxICMP func(*channel.Endpoint)
- invalidAddressStat func(tcpip.IPStats) *tcpip.StatCounter
- shouldAccept bool
- }{
- {
- name: "IPv4 external loopback sourced traffic without forwarding and drop external loopback disabled",
- allowExternalLoopback: true,
- forwarding: false,
- rxICMP: loopbackSourcedICMPv4,
- invalidAddressStat: invalidSrcAddrStat,
- shouldAccept: true,
- },
- {
- name: "IPv4 external loopback sourced traffic without forwarding and drop external loopback enabled",
- allowExternalLoopback: false,
- forwarding: false,
- rxICMP: loopbackSourcedICMPv4,
- invalidAddressStat: invalidSrcAddrStat,
- shouldAccept: false,
- },
- {
- name: "IPv4 external loopback sourced traffic with forwarding and drop external loopback disabled",
- allowExternalLoopback: true,
- forwarding: true,
- rxICMP: loopbackSourcedICMPv4,
- invalidAddressStat: invalidSrcAddrStat,
- shouldAccept: true,
- },
- {
- name: "IPv4 external loopback sourced traffic with forwarding and drop external loopback enabled",
- allowExternalLoopback: false,
- forwarding: true,
- rxICMP: loopbackSourcedICMPv4,
- invalidAddressStat: invalidSrcAddrStat,
- shouldAccept: false,
- },
- {
- name: "IPv4 external loopback destined traffic without forwarding and drop external loopback disabled",
- allowExternalLoopback: true,
- forwarding: false,
- rxICMP: loopbackDestinedICMPv4,
- invalidAddressStat: invalidDestAddrStat,
- shouldAccept: false,
- },
- {
- name: "IPv4 external loopback destined traffic without forwarding and drop external loopback enabled",
- allowExternalLoopback: false,
- forwarding: false,
- rxICMP: loopbackDestinedICMPv4,
- invalidAddressStat: invalidDestAddrStat,
- shouldAccept: false,
- },
- {
- name: "IPv4 external loopback destined traffic with forwarding and drop external loopback disabled",
- allowExternalLoopback: true,
- forwarding: true,
- rxICMP: loopbackDestinedICMPv4,
- invalidAddressStat: invalidDestAddrStat,
- shouldAccept: true,
- },
- {
- name: "IPv4 external loopback destined traffic with forwarding and drop external loopback enabled",
- allowExternalLoopback: false,
- forwarding: true,
- rxICMP: loopbackDestinedICMPv4,
- invalidAddressStat: invalidDestAddrStat,
- shouldAccept: false,
- },
-
- {
- name: "IPv6 external loopback sourced traffic without forwarding and drop external loopback disabled",
- allowExternalLoopback: true,
- forwarding: false,
- rxICMP: loopbackSourcedICMPv6,
- invalidAddressStat: invalidSrcAddrStat,
- shouldAccept: true,
- },
- {
- name: "IPv6 external loopback sourced traffic without forwarding and drop external loopback enabled",
- allowExternalLoopback: false,
- forwarding: false,
- rxICMP: loopbackSourcedICMPv6,
- invalidAddressStat: invalidSrcAddrStat,
- shouldAccept: false,
- },
- {
- name: "IPv6 external loopback sourced traffic with forwarding and drop external loopback disabled",
- allowExternalLoopback: true,
- forwarding: true,
- rxICMP: loopbackSourcedICMPv6,
- invalidAddressStat: invalidSrcAddrStat,
- shouldAccept: true,
- },
- {
- name: "IPv6 external loopback sourced traffic with forwarding and drop external loopback enabled",
- allowExternalLoopback: false,
- forwarding: true,
- rxICMP: loopbackSourcedICMPv6,
- invalidAddressStat: invalidSrcAddrStat,
- shouldAccept: false,
- },
- {
- name: "IPv6 external loopback destined traffic without forwarding and drop external loopback disabled",
- allowExternalLoopback: true,
- forwarding: false,
- rxICMP: loopbackDestinedICMPv6,
- invalidAddressStat: invalidDestAddrStat,
- shouldAccept: false,
- },
- {
- name: "IPv6 external loopback destined traffic without forwarding and drop external loopback enabled",
- allowExternalLoopback: false,
- forwarding: false,
- rxICMP: loopbackDestinedICMPv6,
- invalidAddressStat: invalidDestAddrStat,
- shouldAccept: false,
- },
- {
- name: "IPv6 external loopback destined traffic with forwarding and drop external loopback disabled",
- allowExternalLoopback: true,
- forwarding: true,
- rxICMP: loopbackDestinedICMPv6,
- invalidAddressStat: invalidDestAddrStat,
- shouldAccept: true,
- },
- {
- name: "IPv6 external loopback destined traffic with forwarding and drop external loopback enabled",
- allowExternalLoopback: false,
- forwarding: true,
- rxICMP: loopbackDestinedICMPv6,
- invalidAddressStat: invalidDestAddrStat,
- shouldAccept: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{
- ipv4.NewProtocolWithOptions(ipv4.Options{
- AllowExternalLoopbackTraffic: test.allowExternalLoopback,
- }),
- ipv6.NewProtocolWithOptions(ipv6.Options{
- AllowExternalLoopbackTraffic: test.allowExternalLoopback,
- }),
- },
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
- })
- e := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID1, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
- }
- v4Addr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: utils.Ipv4Addr,
- }
- if err := s.AddProtocolAddress(nicID1, v4Addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, v4Addr, err)
- }
- v6Addr := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: utils.Ipv6Addr,
- }
- if err := s.AddProtocolAddress(nicID1, v6Addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, v6Addr, err)
- }
-
- if err := s.CreateNIC(nicID2, loopback.New()); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
- }
- protocolAddrV4 := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: ipv4Loopback,
- PrefixLen: 8,
- },
- }
- if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
- }
- protocolAddrV6 := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: header.IPv6Loopback.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
- }
-
- if test.forwarding {
- if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
- t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err)
- }
- if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
- t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
- }
- }
-
- s.SetRouteTable([]tcpip.Route{
- tcpip.Route{
- Destination: header.IPv4EmptySubnet,
- NIC: nicID1,
- },
- tcpip.Route{
- Destination: header.IPv6EmptySubnet,
- NIC: nicID1,
- },
- tcpip.Route{
- Destination: ipv4Loopback.WithPrefix().Subnet(),
- NIC: nicID2,
- },
- tcpip.Route{
- Destination: header.IPv6Loopback.WithPrefix().Subnet(),
- NIC: nicID2,
- },
- })
-
- stats := s.Stats().IP
- invalidAddressStat := test.invalidAddressStat(stats)
- deliveredPacketsStat := stats.PacketsDelivered
- if got := invalidAddressStat.Value(); got != 0 {
- t.Fatalf("got invalidAddressStat.Value() = %d, want = 0", got)
- }
- if got := deliveredPacketsStat.Value(); got != 0 {
- t.Fatalf("got deliveredPacketsStat.Value() = %d, want = 0", got)
- }
- test.rxICMP(e)
- var expectedInvalidPackets uint64
- if !test.shouldAccept {
- expectedInvalidPackets = numPackets
- }
- if got := invalidAddressStat.Value(); got != expectedInvalidPackets {
- t.Fatalf("got invalidAddressStat.Value() = %d, want = %d", got, expectedInvalidPackets)
- }
- if got, want := deliveredPacketsStat.Value(), numPackets-expectedInvalidPackets; got != want {
- t.Fatalf("got deliveredPacketsStat.Value() = %d, want = %d", got, want)
- }
- })
- }
-}
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
deleted file mode 100644
index 7753e7d6e..000000000
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ /dev/null
@@ -1,723 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package multicast_broadcast_test
-
-import (
- "bytes"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/tests/utils"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- defaultMTU = 1280
- ttl = 255
-)
-
-// TestPingMulticastBroadcast tests that responding to an Echo Request destined
-// to a multicast or broadcast address uses a unicast source address for the
-// reply.
-func TestPingMulticastBroadcast(t *testing.T) {
- const (
- nicID = 1
- ttl = 64
- )
-
- tests := []struct {
- name string
- protoNum tcpip.NetworkProtocolNumber
- rxICMP func(*channel.Endpoint, tcpip.Address, tcpip.Address, uint8)
- srcAddr tcpip.Address
- dstAddr tcpip.Address
- expectedSrc tcpip.Address
- }{
- {
- name: "IPv4 unicast",
- protoNum: header.IPv4ProtocolNumber,
- dstAddr: utils.Ipv4Addr.Address,
- srcAddr: utils.RemoteIPv4Addr,
- rxICMP: utils.RxICMPv4EchoRequest,
- expectedSrc: utils.Ipv4Addr.Address,
- },
- {
- name: "IPv4 directed broadcast",
- protoNum: header.IPv4ProtocolNumber,
- rxICMP: utils.RxICMPv4EchoRequest,
- srcAddr: utils.RemoteIPv4Addr,
- dstAddr: utils.Ipv4SubnetBcast,
- expectedSrc: utils.Ipv4Addr.Address,
- },
- {
- name: "IPv4 broadcast",
- protoNum: header.IPv4ProtocolNumber,
- rxICMP: utils.RxICMPv4EchoRequest,
- srcAddr: utils.RemoteIPv4Addr,
- dstAddr: header.IPv4Broadcast,
- expectedSrc: utils.Ipv4Addr.Address,
- },
- {
- name: "IPv4 all-systems multicast",
- protoNum: header.IPv4ProtocolNumber,
- rxICMP: utils.RxICMPv4EchoRequest,
- srcAddr: utils.RemoteIPv4Addr,
- dstAddr: header.IPv4AllSystems,
- expectedSrc: utils.Ipv4Addr.Address,
- },
- {
- name: "IPv6 unicast",
- protoNum: header.IPv6ProtocolNumber,
- rxICMP: utils.RxICMPv6EchoRequest,
- srcAddr: utils.RemoteIPv6Addr,
- dstAddr: utils.Ipv6Addr.Address,
- expectedSrc: utils.Ipv6Addr.Address,
- },
- {
- name: "IPv6 all-nodes multicast",
- protoNum: header.IPv6ProtocolNumber,
- rxICMP: utils.RxICMPv6EchoRequest,
- srcAddr: utils.RemoteIPv6Addr,
- dstAddr: header.IPv6AllNodesMulticastAddress,
- expectedSrc: utils.Ipv6Addr.Address,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
- })
- // We only expect a single packet in response to our ICMP Echo Request.
- e := channel.New(1, defaultMTU, "")
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
- ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: utils.Ipv4Addr}
- if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtoAddr, err)
- }
- ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: utils.Ipv6Addr}
- if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtoAddr, err)
- }
-
- // Default routes for IPv4 and IPv6 so ICMP can find a route to the remote
- // node when attempting to send the ICMP Echo Reply.
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv6EmptySubnet,
- NIC: nicID,
- },
- {
- Destination: header.IPv4EmptySubnet,
- NIC: nicID,
- },
- })
-
- test.rxICMP(e, test.srcAddr, test.dstAddr, ttl)
- pkt, ok := e.Read()
- if !ok {
- t.Fatal("expected ICMP response")
- }
-
- if pkt.Route.LocalAddress != test.expectedSrc {
- t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.expectedSrc)
- }
- // The destination of the response packet should be the source of the
- // original packet.
- if pkt.Route.RemoteAddress != test.srcAddr {
- t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.srcAddr)
- }
-
- src, dst := s.NetworkProtocolInstance(test.protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
- if src != test.expectedSrc {
- t.Errorf("got pkt source = %s, want = %s", src, test.expectedSrc)
- }
- // The destination of the response packet should be the source of the
- // original packet.
- if dst != test.srcAddr {
- t.Errorf("got pkt destination = %s, want = %s", dst, test.srcAddr)
- }
- })
- }
-
-}
-
-func rxIPv4UDP(e *channel.Endpoint, src, dst tcpip.Address, data []byte) {
- payloadLen := header.UDPMinimumSize + len(data)
- totalLen := header.IPv4MinimumSize + payloadLen
- hdr := buffer.NewPrependable(totalLen)
- u := header.UDP(hdr.Prepend(payloadLen))
- u.Encode(&header.UDPFields{
- SrcPort: utils.RemotePort,
- DstPort: utils.LocalPort,
- Length: uint16(payloadLen),
- })
- copy(u.Payload(), data)
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(payloadLen))
- sum = header.Checksum(data, sum)
- u.SetChecksum(^u.CalculateChecksum(sum))
-
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(totalLen),
- Protocol: uint8(udp.ProtocolNumber),
- TTL: ttl,
- SrcAddr: src,
- DstAddr: dst,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
-}
-
-func rxIPv6UDP(e *channel.Endpoint, src, dst tcpip.Address, data []byte) {
- payloadLen := header.UDPMinimumSize + len(data)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen)
- u := header.UDP(hdr.Prepend(payloadLen))
- u.Encode(&header.UDPFields{
- SrcPort: utils.RemotePort,
- DstPort: utils.LocalPort,
- Length: uint16(payloadLen),
- })
- copy(u.Payload(), data)
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(payloadLen))
- sum = header.Checksum(data, sum)
- u.SetChecksum(^u.CalculateChecksum(sum))
-
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLen),
- TransportProtocol: udp.ProtocolNumber,
- HopLimit: ttl,
- SrcAddr: src,
- DstAddr: dst,
- })
-
- e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
-}
-
-// TestIncomingMulticastAndBroadcast tests receiving a packet destined to some
-// multicast or broadcast address.
-func TestIncomingMulticastAndBroadcast(t *testing.T) {
- const nicID = 1
-
- data := []byte{1, 2, 3, 4}
-
- tests := []struct {
- name string
- proto tcpip.NetworkProtocolNumber
- remoteAddr tcpip.Address
- localAddr tcpip.AddressWithPrefix
- rxUDP func(*channel.Endpoint, tcpip.Address, tcpip.Address, []byte)
- bindAddr tcpip.Address
- dstAddr tcpip.Address
- expectRx bool
- }{
- {
- name: "IPv4 unicast binding to unicast",
- proto: header.IPv4ProtocolNumber,
- remoteAddr: utils.RemoteIPv4Addr,
- localAddr: utils.Ipv4Addr,
- rxUDP: rxIPv4UDP,
- bindAddr: utils.Ipv4Addr.Address,
- dstAddr: utils.Ipv4Addr.Address,
- expectRx: true,
- },
- {
- name: "IPv4 unicast binding to broadcast",
- proto: header.IPv4ProtocolNumber,
- remoteAddr: utils.RemoteIPv4Addr,
- localAddr: utils.Ipv4Addr,
- rxUDP: rxIPv4UDP,
- bindAddr: header.IPv4Broadcast,
- dstAddr: utils.Ipv4Addr.Address,
- expectRx: false,
- },
- {
- name: "IPv4 unicast binding to wildcard",
- proto: header.IPv4ProtocolNumber,
- remoteAddr: utils.RemoteIPv4Addr,
- localAddr: utils.Ipv4Addr,
- rxUDP: rxIPv4UDP,
- dstAddr: utils.Ipv4Addr.Address,
- expectRx: true,
- },
-
- {
- name: "IPv4 directed broadcast binding to subnet broadcast",
- proto: header.IPv4ProtocolNumber,
- remoteAddr: utils.RemoteIPv4Addr,
- localAddr: utils.Ipv4Addr,
- rxUDP: rxIPv4UDP,
- bindAddr: utils.Ipv4SubnetBcast,
- dstAddr: utils.Ipv4SubnetBcast,
- expectRx: true,
- },
- {
- name: "IPv4 directed broadcast binding to broadcast",
- proto: header.IPv4ProtocolNumber,
- remoteAddr: utils.RemoteIPv4Addr,
- localAddr: utils.Ipv4Addr,
- rxUDP: rxIPv4UDP,
- bindAddr: header.IPv4Broadcast,
- dstAddr: utils.Ipv4SubnetBcast,
- expectRx: false,
- },
- {
- name: "IPv4 directed broadcast binding to wildcard",
- proto: header.IPv4ProtocolNumber,
- remoteAddr: utils.RemoteIPv4Addr,
- localAddr: utils.Ipv4Addr,
- rxUDP: rxIPv4UDP,
- dstAddr: utils.Ipv4SubnetBcast,
- expectRx: true,
- },
-
- {
- name: "IPv4 broadcast binding to broadcast",
- proto: header.IPv4ProtocolNumber,
- remoteAddr: utils.RemoteIPv4Addr,
- localAddr: utils.Ipv4Addr,
- rxUDP: rxIPv4UDP,
- bindAddr: header.IPv4Broadcast,
- dstAddr: header.IPv4Broadcast,
- expectRx: true,
- },
- {
- name: "IPv4 broadcast binding to subnet broadcast",
- proto: header.IPv4ProtocolNumber,
- remoteAddr: utils.RemoteIPv4Addr,
- localAddr: utils.Ipv4Addr,
- rxUDP: rxIPv4UDP,
- bindAddr: utils.Ipv4SubnetBcast,
- dstAddr: header.IPv4Broadcast,
- expectRx: false,
- },
- {
- name: "IPv4 broadcast binding to wildcard",
- proto: header.IPv4ProtocolNumber,
- remoteAddr: utils.RemoteIPv4Addr,
- localAddr: utils.Ipv4Addr,
- rxUDP: rxIPv4UDP,
- dstAddr: utils.Ipv4SubnetBcast,
- expectRx: true,
- },
-
- {
- name: "IPv4 all-systems multicast binding to all-systems multicast",
- proto: header.IPv4ProtocolNumber,
- remoteAddr: utils.RemoteIPv4Addr,
- localAddr: utils.Ipv4Addr,
- rxUDP: rxIPv4UDP,
- bindAddr: header.IPv4AllSystems,
- dstAddr: header.IPv4AllSystems,
- expectRx: true,
- },
- {
- name: "IPv4 all-systems multicast binding to wildcard",
- proto: header.IPv4ProtocolNumber,
- remoteAddr: utils.RemoteIPv4Addr,
- localAddr: utils.Ipv4Addr,
- rxUDP: rxIPv4UDP,
- dstAddr: header.IPv4AllSystems,
- expectRx: true,
- },
- {
- name: "IPv4 all-systems multicast binding to unicast",
- proto: header.IPv4ProtocolNumber,
- remoteAddr: utils.RemoteIPv4Addr,
- localAddr: utils.Ipv4Addr,
- rxUDP: rxIPv4UDP,
- bindAddr: utils.Ipv4Addr.Address,
- dstAddr: header.IPv4AllSystems,
- expectRx: false,
- },
-
- // IPv6 has no notion of a broadcast.
- {
- name: "IPv6 unicast binding to wildcard",
- dstAddr: utils.Ipv6Addr.Address,
- proto: header.IPv6ProtocolNumber,
- remoteAddr: utils.RemoteIPv6Addr,
- localAddr: utils.Ipv6Addr,
- rxUDP: rxIPv6UDP,
- expectRx: true,
- },
- {
- name: "IPv6 broadcast-like address binding to wildcard",
- dstAddr: utils.Ipv6SubnetBcast,
- proto: header.IPv6ProtocolNumber,
- remoteAddr: utils.RemoteIPv6Addr,
- localAddr: utils.Ipv6Addr,
- rxUDP: rxIPv6UDP,
- expectRx: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
- e := channel.New(0, defaultMTU, "")
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
- protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr}
- if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
- }
-
- var wq waiter.Queue
- ep, err := s.NewEndpoint(udp.ProtocolNumber, test.proto, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err)
- }
- defer ep.Close()
-
- bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: utils.LocalPort}
- if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
- }
-
- test.rxUDP(e, test.remoteAddr, test.dstAddr, data)
- var buf bytes.Buffer
- var opts tcpip.ReadOptions
- if res, err := ep.Read(&buf, opts); test.expectRx {
- if err != nil {
- t.Fatalf("ep.Read(_, %#v): %s", opts, err)
- }
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: buf.Len(),
- Total: buf.Len(),
- }, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
- t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(data, buf.Bytes()); diff != "" {
- t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
- }
- } else if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), &tcpip.ErrWouldBlock{})
- }
- })
- }
-}
-
-// TestReuseAddrAndBroadcast makes sure broadcast packets are received by all
-// interested endpoints.
-func TestReuseAddrAndBroadcast(t *testing.T) {
- const (
- nicID = 1
- localPort = 9000
- )
- loopbackBroadcast := testutil.MustParse4("127.255.255.255")
-
- tests := []struct {
- name string
- broadcastAddr tcpip.Address
- }{
- {
- name: "Subnet directed broadcast",
- broadcastAddr: loopbackBroadcast,
- },
- {
- name: "IPv4 broadcast",
- broadcastAddr: header.IPv4Broadcast,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
- if err := s.CreateNIC(nicID, loopback.New()); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
- protoAddr := tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: "\x7f\x00\x00\x01",
- PrefixLen: 8,
- },
- }
- if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- // We use the empty subnet instead of just the loopback subnet so we
- // also have a route to the IPv4 Broadcast address.
- Destination: header.IPv4EmptySubnet,
- NIC: nicID,
- },
- })
-
- type endpointAndWaiter struct {
- ep tcpip.Endpoint
- ch chan struct{}
- }
- var eps []endpointAndWaiter
- // We create endpoints that bind to both the wildcard address and the
- // broadcast address to make sure both of these types of "broadcast
- // interested" endpoints receive broadcast packets.
- for _, bindWildcard := range []bool{false, true} {
- // Create multiple endpoints for each type of "broadcast interested"
- // endpoint so we can test that all endpoints receive the broadcast
- // packet.
- for i := 0; i < 2; i++ {
- var wq waiter.Queue
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("(eps[%d]) NewEndpoint(%d, %d, _): %s", len(eps), udp.ProtocolNumber, ipv4.ProtocolNumber, err)
- }
- defer ep.Close()
-
- ep.SocketOptions().SetReuseAddress(true)
- ep.SocketOptions().SetBroadcast(true)
-
- bindAddr := tcpip.FullAddress{Port: localPort}
- if bindWildcard {
- if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("eps[%d].Bind(%#v): %s", len(eps), bindAddr, err)
- }
- } else {
- bindAddr.Addr = test.broadcastAddr
- if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("eps[%d].Bind(%#v): %s", len(eps), bindAddr, err)
- }
- }
-
- eps = append(eps, endpointAndWaiter{ep: ep, ch: ch})
- }
- }
-
- for i, wep := range eps {
- writeOpts := tcpip.WriteOptions{
- To: &tcpip.FullAddress{
- Addr: test.broadcastAddr,
- Port: localPort,
- },
- }
- data := []byte{byte(i), 2, 3, 4}
- var r bytes.Reader
- r.Reset(data)
- if n, err := wep.ep.Write(&r, writeOpts); err != nil {
- t.Fatalf("eps[%d].Write(_, _): %s", i, err)
- } else if want := int64(len(data)); n != want {
- t.Fatalf("got eps[%d].Write(_, _) = (%d, nil), want = (%d, nil)", i, n, want)
- }
-
- for j, rep := range eps {
- // Wait for the endpoint to become readable.
- <-rep.ch
-
- var buf bytes.Buffer
- result, err := rep.ep.Read(&buf, tcpip.ReadOptions{})
- if err != nil {
- t.Errorf("(eps[%d] write) eps[%d].Read: %s", i, j, err)
- continue
- }
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: buf.Len(),
- Total: buf.Len(),
- }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
- t.Errorf("(eps[%d] write) eps[%d].Read: unexpected result (-want +got):\n%s", i, j, diff)
- }
- if diff := cmp.Diff([]byte(data), buf.Bytes()); diff != "" {
- t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff)
- }
- }
- }
- })
- }
-}
-
-func TestUDPAddRemoveMembershipSocketOption(t *testing.T) {
- const (
- nicID = 1
- )
-
- data := []byte{1, 2, 3, 4}
-
- tests := []struct {
- name string
- proto tcpip.NetworkProtocolNumber
- remoteAddr tcpip.Address
- localAddr tcpip.AddressWithPrefix
- rxUDP func(*channel.Endpoint, tcpip.Address, tcpip.Address, []byte)
- multicastAddr tcpip.Address
- }{
- {
- name: "IPv4 unicast binding to unicast",
- multicastAddr: "\xe0\x01\x02\x03",
- proto: header.IPv4ProtocolNumber,
- remoteAddr: utils.RemoteIPv4Addr,
- localAddr: utils.Ipv4Addr,
- rxUDP: rxIPv4UDP,
- },
- {
- name: "IPv6 broadcast-like address binding to wildcard",
- multicastAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04",
- proto: header.IPv6ProtocolNumber,
- remoteAddr: utils.RemoteIPv6Addr,
- localAddr: utils.Ipv6Addr,
- rxUDP: rxIPv6UDP,
- },
- }
-
- subTests := []struct {
- name string
- specifyNICID bool
- specifyNICAddr bool
- }{
- {
- name: "Specify NIC ID and NIC address",
- specifyNICID: true,
- specifyNICAddr: true,
- },
- {
- name: "Don't specify NIC ID or NIC address",
- specifyNICID: false,
- specifyNICAddr: false,
- },
- {
- name: "Specify NIC ID but don't specify NIC address",
- specifyNICID: true,
- specifyNICAddr: false,
- },
- {
- name: "Don't specify NIC ID but specify NIC address",
- specifyNICID: false,
- specifyNICAddr: true,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- for _, subTest := range subTests {
- t.Run(subTest.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- })
- e := channel.New(0, defaultMTU, "")
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
- protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr}
- if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
- }
-
- // Set the route table so that UDP can find a NIC that is
- // routable to the multicast address when the NIC isn't specified.
- if !subTest.specifyNICID && !subTest.specifyNICAddr {
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv6EmptySubnet,
- NIC: nicID,
- },
- {
- Destination: header.IPv4EmptySubnet,
- NIC: nicID,
- },
- })
- }
-
- var wq waiter.Queue
- ep, err := s.NewEndpoint(udp.ProtocolNumber, test.proto, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err)
- }
- defer ep.Close()
-
- bindAddr := tcpip.FullAddress{Port: utils.LocalPort}
- if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
- }
-
- memOpt := tcpip.MembershipOption{MulticastAddr: test.multicastAddr}
- if subTest.specifyNICID {
- memOpt.NIC = nicID
- }
- if subTest.specifyNICAddr {
- memOpt.InterfaceAddr = test.localAddr.Address
- }
-
- // We should receive UDP packets to the group once we join the
- // multicast group.
- addOpt := tcpip.AddMembershipOption(memOpt)
- if err := ep.SetSockOpt(&addOpt); err != nil {
- t.Fatalf("ep.SetSockOpt(&%#v): %s", addOpt, err)
- }
- test.rxUDP(e, test.remoteAddr, test.multicastAddr, data)
- var buf bytes.Buffer
- result, err := ep.Read(&buf, tcpip.ReadOptions{})
- if err != nil {
- t.Fatalf("ep.Read: %s", err)
- } else {
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: buf.Len(),
- Total: buf.Len(),
- }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
- t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(data, buf.Bytes()); diff != "" {
- t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
- }
- }
-
- // We should not receive UDP packets to the group once we leave
- // the multicast group.
- removeOpt := tcpip.RemoveMembershipOption(memOpt)
- if err := ep.SetSockOpt(&removeOpt); err != nil {
- t.Fatalf("ep.SetSockOpt(&%#v): %s", removeOpt, err)
- }
- {
- _, err := ep.Read(&buf, tcpip.ReadOptions{})
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, &tcpip.ErrWouldBlock{})
- }
- }
- })
- }
- })
- }
-}
diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go
deleted file mode 100644
index 422eb8408..000000000
--- a/pkg/tcpip/tests/integration/route_test.go
+++ /dev/null
@@ -1,441 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package route_test
-
-import (
- "bytes"
- "fmt"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/tests/utils"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-// TestLocalPing tests pinging a remote that is local the stack.
-//
-// This tests that a local route is created and packets do not leave the stack.
-func TestLocalPing(t *testing.T) {
- const (
- nicID = 1
-
- // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo
- // request/reply packets.
- icmpDataOffset = 8
- )
- ipv4Loopback := tcpip.AddressWithPrefix{
- Address: testutil.MustParse4("127.0.0.1"),
- PrefixLen: 8,
- }
-
- channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") }
- channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) {
- channelEP := e.(*channel.Endpoint)
- if n := channelEP.Drain(); n != 0 {
- t.Fatalf("got channelEP.Drain() = %d, want = 0", n)
- }
- }
-
- ipv4ICMPBuf := func(t *testing.T) buffer.View {
- data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
- hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data)))
- hdr.SetType(header.ICMPv4Echo)
- if n := copy(hdr.Payload(), data[:]); n != len(data) {
- t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
- }
- return buffer.View(hdr)
- }
-
- ipv6ICMPBuf := func(t *testing.T) buffer.View {
- data := [8]byte{1, 2, 3, 4, 5, 6, 7, 9}
- hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data)))
- hdr.SetType(header.ICMPv6EchoRequest)
- if n := copy(hdr.Payload(), data[:]); n != len(data) {
- t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
- }
- return buffer.View(hdr)
- }
-
- tests := []struct {
- name string
- transProto tcpip.TransportProtocolNumber
- netProto tcpip.NetworkProtocolNumber
- linkEndpoint func() stack.LinkEndpoint
- localAddr tcpip.AddressWithPrefix
- icmpBuf func(*testing.T) buffer.View
- expectedConnectErr tcpip.Error
- checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint)
- }{
- {
- name: "IPv4 loopback",
- transProto: icmp.ProtocolNumber4,
- netProto: ipv4.ProtocolNumber,
- linkEndpoint: loopback.New,
- localAddr: ipv4Loopback,
- icmpBuf: ipv4ICMPBuf,
- checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
- },
- {
- name: "IPv6 loopback",
- transProto: icmp.ProtocolNumber6,
- netProto: ipv6.ProtocolNumber,
- linkEndpoint: loopback.New,
- localAddr: header.IPv6Loopback.WithPrefix(),
- icmpBuf: ipv6ICMPBuf,
- checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
- },
- {
- name: "IPv4 non-loopback",
- transProto: icmp.ProtocolNumber4,
- netProto: ipv4.ProtocolNumber,
- linkEndpoint: channelEP,
- localAddr: utils.Ipv4Addr,
- icmpBuf: ipv4ICMPBuf,
- checkLinkEndpoint: channelEPCheck,
- },
- {
- name: "IPv6 non-loopback",
- transProto: icmp.ProtocolNumber6,
- netProto: ipv6.ProtocolNumber,
- linkEndpoint: channelEP,
- localAddr: utils.Ipv6Addr,
- icmpBuf: ipv6ICMPBuf,
- checkLinkEndpoint: channelEPCheck,
- },
- {
- name: "IPv4 loopback without local address",
- transProto: icmp.ProtocolNumber4,
- netProto: ipv4.ProtocolNumber,
- linkEndpoint: loopback.New,
- icmpBuf: ipv4ICMPBuf,
- expectedConnectErr: &tcpip.ErrNoRoute{},
- checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
- },
- {
- name: "IPv6 loopback without local address",
- transProto: icmp.ProtocolNumber6,
- netProto: ipv6.ProtocolNumber,
- linkEndpoint: loopback.New,
- icmpBuf: ipv6ICMPBuf,
- expectedConnectErr: &tcpip.ErrNoRoute{},
- checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
- },
- {
- name: "IPv4 non-loopback without local address",
- transProto: icmp.ProtocolNumber4,
- netProto: ipv4.ProtocolNumber,
- linkEndpoint: channelEP,
- icmpBuf: ipv4ICMPBuf,
- expectedConnectErr: &tcpip.ErrNoRoute{},
- checkLinkEndpoint: channelEPCheck,
- },
- {
- name: "IPv6 non-loopback without local address",
- transProto: icmp.ProtocolNumber6,
- netProto: ipv6.ProtocolNumber,
- linkEndpoint: channelEP,
- icmpBuf: ipv6ICMPBuf,
- expectedConnectErr: &tcpip.ErrNoRoute{},
- checkLinkEndpoint: channelEPCheck,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- for _, allowExternalLoopback := range []bool{true, false} {
- t.Run(fmt.Sprintf("AllowExternalLoopback=%t", allowExternalLoopback), func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{
- ipv4.NewProtocolWithOptions(ipv4.Options{
- AllowExternalLoopbackTraffic: allowExternalLoopback,
- }),
- ipv6.NewProtocolWithOptions(ipv6.Options{
- AllowExternalLoopbackTraffic: allowExternalLoopback,
- }),
- },
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
- HandleLocal: true,
- })
- e := test.linkEndpoint()
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
-
- if len(test.localAddr.Address) != 0 {
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: test.netProto,
- AddressWithPrefix: test.localAddr,
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
- }
- }
-
- var wq waiter.Queue
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq)
- if err != nil {
- t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err)
- }
- defer ep.Close()
-
- connAddr := tcpip.FullAddress{Addr: test.localAddr.Address}
- if err := ep.Connect(connAddr); err != test.expectedConnectErr {
- t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr)
- }
-
- if test.expectedConnectErr != nil {
- return
- }
-
- var r bytes.Reader
- payload := test.icmpBuf(t)
- r.Reset(payload)
- var wOpts tcpip.WriteOptions
- if n, err := ep.Write(&r, wOpts); err != nil {
- t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err)
- } else if n != int64(len(payload)) {
- t.Fatalf("got ep.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", payload, wOpts, n, len(payload))
- }
-
- // Wait for the endpoint to become readable.
- <-ch
-
- var w bytes.Buffer
- rr, err := ep.Read(&w, tcpip.ReadOptions{
- NeedRemoteAddr: true,
- })
- if err != nil {
- t.Fatalf("ep.Read(...): %s", err)
- }
- if diff := cmp.Diff(buffer.View(w.Bytes()[icmpDataOffset:]), payload[icmpDataOffset:]); diff != "" {
- t.Errorf("received data mismatch (-want +got):\n%s", diff)
- }
- if rr.RemoteAddr.Addr != test.localAddr.Address {
- t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr.Address)
- }
-
- test.checkLinkEndpoint(t, e)
- })
- }
- })
- }
-}
-
-// TestLocalUDP tests sending UDP packets between two endpoints that are local
-// to the stack.
-//
-// This tests that that packets never leave the stack and the addresses
-// used when sending a packet.
-func TestLocalUDP(t *testing.T) {
- const (
- nicID = 1
- )
-
- tests := []struct {
- name string
- canBePrimaryAddr tcpip.ProtocolAddress
- firstPrimaryAddr tcpip.ProtocolAddress
- }{
- {
- name: "IPv4",
- canBePrimaryAddr: utils.Ipv4Addr1,
- firstPrimaryAddr: utils.Ipv4Addr2,
- },
- {
- name: "IPv6",
- canBePrimaryAddr: utils.Ipv6Addr1,
- firstPrimaryAddr: utils.Ipv6Addr2,
- },
- }
-
- subTests := []struct {
- name string
- addAddress bool
- expectedWriteErr tcpip.Error
- }{
- {
- name: "Unassigned local address",
- addAddress: false,
- expectedWriteErr: &tcpip.ErrNoRoute{},
- },
- {
- name: "Assigned local address",
- addAddress: true,
- expectedWriteErr: nil,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- for _, subTest := range subTests {
- t.Run(subTest.name, func(t *testing.T) {
- stackOpts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- HandleLocal: true,
- }
-
- s := stack.New(stackOpts)
- ep := channel.New(1, header.IPv6MinimumMTU, "")
-
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
-
- if subTest.addAddress {
- if err := s.AddProtocolAddress(nicID, test.canBePrimaryAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, test.canBePrimaryAddr, err)
- }
- properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
- if err := s.AddProtocolAddress(nicID, test.firstPrimaryAddr, properties); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %+v, %+v): %s", nicID, test.firstPrimaryAddr, properties, err)
- }
- }
-
- var serverWQ waiter.Queue
- serverWE, serverCH := waiter.NewChannelEntry(nil)
- serverWQ.EventRegister(&serverWE, waiter.ReadableEvents)
- server, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &serverWQ)
- if err != nil {
- t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err)
- }
- defer server.Close()
-
- bindAddr := tcpip.FullAddress{Port: 80}
- if err := server.Bind(bindAddr); err != nil {
- t.Fatalf("server.Bind(%#v): %s", bindAddr, err)
- }
-
- var clientWQ waiter.Queue
- clientWE, clientCH := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientWE, waiter.ReadableEvents)
- client, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &clientWQ)
- if err != nil {
- t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err)
- }
- defer client.Close()
-
- serverAddr := tcpip.FullAddress{
- Addr: test.canBePrimaryAddr.AddressWithPrefix.Address,
- Port: 80,
- }
-
- clientPayload := []byte{1, 2, 3, 4}
- {
- var r bytes.Reader
- r.Reset(clientPayload)
- wOpts := tcpip.WriteOptions{
- To: &serverAddr,
- }
- if n, err := client.Write(&r, wOpts); err != subTest.expectedWriteErr {
- t.Fatalf("got client.Write(%#v, %#v) = (%d, %s), want = (_, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr)
- } else if subTest.expectedWriteErr != nil {
- // Nothing else to test if we expected not to be able to send the
- // UDP packet.
- return
- } else if n != int64(len(clientPayload)) {
- t.Fatalf("got client.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", clientPayload, wOpts, n, len(clientPayload))
- }
- }
-
- // Wait for the server endpoint to become readable.
- <-serverCH
-
- var clientAddr tcpip.FullAddress
- var readBuf bytes.Buffer
- if read, err := server.Read(&readBuf, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil {
- t.Fatalf("server.Read(_): %s", err)
- } else {
- clientAddr = read.RemoteAddr
-
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: readBuf.Len(),
- Total: readBuf.Len(),
- RemoteAddr: tcpip.FullAddress{
- Addr: test.canBePrimaryAddr.AddressWithPrefix.Address,
- },
- }, read, checker.IgnoreCmpPath(
- "ControlMessages",
- "RemoteAddr.NIC",
- "RemoteAddr.Port",
- )); diff != "" {
- t.Errorf("server.Read: unexpected result (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(buffer.View(clientPayload), buffer.View(readBuf.Bytes())); diff != "" {
- t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-
- serverPayload := []byte{1, 2, 3, 4}
- {
- var r bytes.Reader
- r.Reset(serverPayload)
- wOpts := tcpip.WriteOptions{
- To: &clientAddr,
- }
- if n, err := server.Write(&r, wOpts); err != nil {
- t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err)
- } else if n != int64(len(serverPayload)) {
- t.Fatalf("got server.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", serverPayload, wOpts, n, len(serverPayload))
- }
- }
-
- // Wait for the client endpoint to become readable.
- <-clientCH
-
- readBuf.Reset()
- if read, err := client.Read(&readBuf, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil {
- t.Fatalf("client.Read(_): %s", err)
- } else {
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: readBuf.Len(),
- Total: readBuf.Len(),
- RemoteAddr: tcpip.FullAddress{Addr: serverAddr.Addr},
- }, read, checker.IgnoreCmpPath(
- "ControlMessages",
- "RemoteAddr.NIC",
- "RemoteAddr.Port",
- )); diff != "" {
- t.Errorf("client.Read: unexpected result (-want +got):\n%s", diff)
- }
- if diff := cmp.Diff(buffer.View(serverPayload), buffer.View(readBuf.Bytes())); diff != "" {
- t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
- })
- }
- })
- }
-}
diff --git a/pkg/tcpip/tests/utils/BUILD b/pkg/tcpip/tests/utils/BUILD
deleted file mode 100644
index a9699a367..000000000
--- a/pkg/tcpip/tests/utils/BUILD
+++ /dev/null
@@ -1,22 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "utils",
- srcs = ["utils.go"],
- visibility = ["//pkg/tcpip/tests:__subpackages__"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/ethernet",
- "//pkg/tcpip/link/nested",
- "//pkg/tcpip/link/pipe",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/icmp",
- ],
-)
diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go
deleted file mode 100644
index 947bcc7b1..000000000
--- a/pkg/tcpip/tests/utils/utils.go
+++ /dev/null
@@ -1,390 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package utils holds common testing utilities for tcpip.
-package utils
-
-import (
- "net"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/ethernet"
- "gvisor.dev/gvisor/pkg/tcpip/link/nested"
- "gvisor.dev/gvisor/pkg/tcpip/link/pipe"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
-)
-
-// Common NIC IDs used by tests.
-const (
- Host1NICID = 1
- RouterNICID1 = 2
- RouterNICID2 = 3
- Host2NICID = 4
-)
-
-// Common link addresses used by tests.
-const (
- LinkAddr1 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
- LinkAddr2 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x07")
- LinkAddr3 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x08")
- LinkAddr4 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
-)
-
-// Common IP addresses used by tests.
-var (
- Ipv4Addr = tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()),
- PrefixLen: 24,
- }
- Ipv4Subnet = Ipv4Addr.Subnet()
- Ipv4SubnetBcast = Ipv4Subnet.Broadcast()
-
- Ipv6Addr = tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("200a::1").To16()),
- PrefixLen: 64,
- }
- Ipv6Subnet = Ipv6Addr.Subnet()
- Ipv6SubnetBcast = Ipv6Subnet.Broadcast()
-
- Ipv4Addr1 = tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
- PrefixLen: 24,
- },
- }
- Ipv4Addr2 = tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
- PrefixLen: 8,
- },
- }
- Ipv4Addr3 = tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("192.168.0.3").To4()),
- PrefixLen: 8,
- },
- }
- Ipv6Addr1 = tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("a::1").To16()),
- PrefixLen: 64,
- },
- }
- Ipv6Addr2 = tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("a::2").To16()),
- PrefixLen: 64,
- },
- }
- Ipv6Addr3 = tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("a::3").To16()),
- PrefixLen: 64,
- },
- }
-
- // Remote addrs.
- RemoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4())
- RemoteIPv6Addr = tcpip.Address(net.ParseIP("200b::1").To16())
-)
-
-// Common ports for testing.
-const (
- RemotePort = 5555
- LocalPort = 80
-)
-
-// Common IP addresses used for testing.
-var (
- Host1IPv4Addr = tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
- PrefixLen: 24,
- },
- }
- RouterNIC1IPv4Addr = tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
- PrefixLen: 24,
- },
- }
- RouterNIC2IPv4Addr = tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()),
- PrefixLen: 8,
- },
- }
- Host2IPv4Addr = tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("10.0.0.2").To4()),
- PrefixLen: 8,
- },
- }
- Host1IPv6Addr = tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("a::2").To16()),
- PrefixLen: 64,
- },
- }
- RouterNIC1IPv6Addr = tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("a::1").To16()),
- PrefixLen: 64,
- },
- }
- RouterNIC2IPv6Addr = tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("b::1").To16()),
- PrefixLen: 64,
- },
- }
- Host2IPv6Addr = tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("b::2").To16()),
- PrefixLen: 64,
- },
- }
-)
-
-// NewEthernetEndpoint returns an ethernet link endpoint that wraps an inner
-// link endpoint and checks the destination link address before delivering
-// network packets to the network dispatcher.
-//
-// See ethernet.Endpoint for more details.
-func NewEthernetEndpoint(ep stack.LinkEndpoint) *EndpointWithDestinationCheck {
- var e EndpointWithDestinationCheck
- e.Endpoint.Init(ethernet.New(ep), &e)
- return &e
-}
-
-// EndpointWithDestinationCheck is a link endpoint that checks the destination
-// link address before delivering network packets to the network dispatcher.
-type EndpointWithDestinationCheck struct {
- nested.Endpoint
-}
-
-var _ stack.NetworkDispatcher = (*EndpointWithDestinationCheck)(nil)
-var _ stack.LinkEndpoint = (*EndpointWithDestinationCheck)(nil)
-
-// DeliverNetworkPacket implements stack.NetworkDispatcher.
-func (e *EndpointWithDestinationCheck) DeliverNetworkPacket(src, dst tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- if dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) {
- e.Endpoint.DeliverNetworkPacket(src, dst, proto, pkt)
- }
-}
-
-// SetupRoutedStacks creates the NICs, sets forwarding, adds addresses and sets
-// the route tables for the passed stacks.
-func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) {
- host1NIC, routerNIC1 := pipe.New(LinkAddr1, LinkAddr2)
- routerNIC2, host2NIC := pipe.New(LinkAddr3, LinkAddr4)
-
- if err := host1Stack.CreateNIC(Host1NICID, NewEthernetEndpoint(host1NIC)); err != nil {
- t.Fatalf("host1Stack.CreateNIC(%d, _): %s", Host1NICID, err)
- }
- if err := routerStack.CreateNIC(RouterNICID1, NewEthernetEndpoint(routerNIC1)); err != nil {
- t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID1, err)
- }
- if err := routerStack.CreateNIC(RouterNICID2, NewEthernetEndpoint(routerNIC2)); err != nil {
- t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID2, err)
- }
- if err := host2Stack.CreateNIC(Host2NICID, NewEthernetEndpoint(host2NIC)); err != nil {
- t.Fatalf("host2Stack.CreateNIC(%d, _): %s", Host2NICID, err)
- }
-
- if err := routerStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
- t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv4.ProtocolNumber, err)
- }
- if err := routerStack.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
- t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv6.ProtocolNumber, err)
- }
-
- if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", Host1NICID, Host1IPv4Addr, err)
- }
- if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv4Addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID1, RouterNIC1IPv4Addr, err)
- }
- if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv4Addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID2, RouterNIC2IPv4Addr, err)
- }
- if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv4Addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", Host2NICID, Host2IPv4Addr, err)
- }
- if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv6Addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", Host1NICID, Host1IPv6Addr, err)
- }
- if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv6Addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID1, RouterNIC1IPv6Addr, err)
- }
- if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv6Addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID2, RouterNIC2IPv6Addr, err)
- }
- if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv6Addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", Host2NICID, Host2IPv6Addr, err)
- }
-
- host1Stack.SetRouteTable([]tcpip.Route{
- {
- Destination: Host1IPv4Addr.AddressWithPrefix.Subnet(),
- NIC: Host1NICID,
- },
- {
- Destination: Host1IPv6Addr.AddressWithPrefix.Subnet(),
- NIC: Host1NICID,
- },
- {
- Destination: Host2IPv4Addr.AddressWithPrefix.Subnet(),
- Gateway: RouterNIC1IPv4Addr.AddressWithPrefix.Address,
- NIC: Host1NICID,
- },
- {
- Destination: Host2IPv6Addr.AddressWithPrefix.Subnet(),
- Gateway: RouterNIC1IPv6Addr.AddressWithPrefix.Address,
- NIC: Host1NICID,
- },
- })
- routerStack.SetRouteTable([]tcpip.Route{
- {
- Destination: RouterNIC1IPv4Addr.AddressWithPrefix.Subnet(),
- NIC: RouterNICID1,
- },
- {
- Destination: RouterNIC1IPv6Addr.AddressWithPrefix.Subnet(),
- NIC: RouterNICID1,
- },
- {
- Destination: RouterNIC2IPv4Addr.AddressWithPrefix.Subnet(),
- NIC: RouterNICID2,
- },
- {
- Destination: RouterNIC2IPv6Addr.AddressWithPrefix.Subnet(),
- NIC: RouterNICID2,
- },
- })
- host2Stack.SetRouteTable([]tcpip.Route{
- {
- Destination: Host2IPv4Addr.AddressWithPrefix.Subnet(),
- NIC: Host2NICID,
- },
- {
- Destination: Host2IPv6Addr.AddressWithPrefix.Subnet(),
- NIC: Host2NICID,
- },
- {
- Destination: Host1IPv4Addr.AddressWithPrefix.Subnet(),
- Gateway: RouterNIC2IPv4Addr.AddressWithPrefix.Address,
- NIC: Host2NICID,
- },
- {
- Destination: Host1IPv6Addr.AddressWithPrefix.Subnet(),
- Gateway: RouterNIC2IPv6Addr.AddressWithPrefix.Address,
- NIC: Host2NICID,
- },
- })
-}
-
-func rxICMPv4Echo(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8, ty header.ICMPv4Type) {
- totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
- hdr := buffer.NewPrependable(totalLen)
- pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
- pkt.SetType(ty)
- pkt.SetCode(header.ICMPv4UnusedCode)
- pkt.SetChecksum(0)
- pkt.SetChecksum(^header.Checksum(pkt, 0))
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(totalLen),
- Protocol: uint8(icmp.ProtocolNumber4),
- TTL: ttl,
- SrcAddr: src,
- DstAddr: dst,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
-}
-
-// RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on
-// the provided endpoint.
-func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
- rxICMPv4Echo(e, src, dst, ttl, header.ICMPv4Echo)
-}
-
-// RxICMPv4EchoReply constructs and injects an ICMPv4 echo reply packet on
-// the provided endpoint.
-func RxICMPv4EchoReply(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
- rxICMPv4Echo(e, src, dst, ttl, header.ICMPv4EchoReply)
-}
-
-func rxICMPv6Echo(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8, ty header.ICMPv6Type) {
- totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
- hdr := buffer.NewPrependable(totalLen)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
- pkt.SetType(ty)
- pkt.SetCode(header.ICMPv6UnusedCode)
- pkt.SetChecksum(0)
- pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: pkt,
- Src: src,
- Dst: dst,
- }))
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: header.ICMPv6MinimumSize,
- TransportProtocol: icmp.ProtocolNumber6,
- HopLimit: ttl,
- SrcAddr: src,
- DstAddr: dst,
- })
-
- e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
-}
-
-// RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on
-// the provided endpoint.
-func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
- rxICMPv6Echo(e, src, dst, ttl, header.ICMPv6EchoRequest)
-}
-
-// RxICMPv6EchoReply constructs and injects an ICMPv6 echo reply packet on
-// the provided endpoint.
-func RxICMPv6EchoReply(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
- rxICMPv6Echo(e, src, dst, ttl, header.ICMPv6EchoReply)
-}
diff --git a/pkg/tcpip/testutil/BUILD b/pkg/tcpip/testutil/BUILD
deleted file mode 100644
index 02ee86ff1..000000000
--- a/pkg/tcpip/testutil/BUILD
+++ /dev/null
@@ -1,21 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "testutil",
- testonly = True,
- srcs = [
- "testutil.go",
- "testutil_unsafe.go",
- ],
- visibility = ["//visibility:public"],
- deps = ["//pkg/tcpip"],
-)
-
-go_test(
- name = "testutil_test",
- srcs = ["testutil_test.go"],
- library = ":testutil",
- deps = ["//pkg/tcpip"],
-)
diff --git a/pkg/tcpip/testutil/testutil.go b/pkg/tcpip/testutil/testutil.go
deleted file mode 100644
index 94b580a70..000000000
--- a/pkg/tcpip/testutil/testutil.go
+++ /dev/null
@@ -1,123 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package testutil provides helper functions for netstack unit tests.
-package testutil
-
-import (
- "fmt"
- "net"
- "reflect"
- "strings"
-
- "gvisor.dev/gvisor/pkg/tcpip"
-)
-
-// MustParse4 parses an IPv4 string (e.g. "192.168.1.1") into a tcpip.Address.
-// Passing an IPv4-mapped IPv6 address will yield only the 4 IPv4 bytes.
-func MustParse4(addr string) tcpip.Address {
- ip := net.ParseIP(addr).To4()
- if ip == nil {
- panic(fmt.Sprintf("Parse4 expects IPv4 addresses, but was passed %q", addr))
- }
- return tcpip.Address(ip)
-}
-
-// MustParse6 parses an IPv6 string (e.g. "fe80::1") into a tcpip.Address. Passing
-// an IPv4 address will yield an IPv4-mapped IPv6 address.
-func MustParse6(addr string) tcpip.Address {
- ip := net.ParseIP(addr).To16()
- if ip == nil {
- panic(fmt.Sprintf("Parse6 was passed malformed address %q", addr))
- }
- return tcpip.Address(ip)
-}
-
-func checkFieldCounts(ref, multi reflect.Value) error {
- refTypeName := ref.Type().Name()
- multiTypeName := multi.Type().Name()
- refNumField := ref.NumField()
- multiNumField := multi.NumField()
-
- if refNumField != multiNumField {
- return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName)
- }
-
- return nil
-}
-
-func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error {
- s, ok := ref.Addr().Interface().(**tcpip.StatCounter)
- if !ok {
- return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name())
- }
-
- // The field names are expected to match (case insensitive).
- if !strings.EqualFold(refName, multiName) {
- return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName)
- }
-
- base := (*s).Value()
- m.Increment()
- if (*s).Value() != base+1 {
- return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName)
- }
-
- return nil
-}
-
-// ValidateMultiCounterStats verifies that every counter stored in multi is
-// correctly tracking its counterpart in the given counters.
-func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error {
- for _, c := range counters {
- if err := checkFieldCounts(c, multi); err != nil {
- return err
- }
- }
-
- for i := 0; i < multi.NumField(); i++ {
- multiName := multi.Type().Field(i).Name
- multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i))
-
- if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok {
- for _, c := range counters {
- if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil {
- return err
- }
- }
- } else {
- var countersNextField []reflect.Value
- for _, c := range counters {
- countersNextField = append(countersNextField, c.Field(i))
- }
- if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil {
- return err
- }
- }
- }
-
- return nil
-}
-
-// MustParseLink parses a Link string into a tcpip.LinkAddress, panicking on
-// error.
-//
-// The string must be in the format aa:bb:cc:dd:ee:ff or aa-bb-cc-dd-ee-ff.
-func MustParseLink(addr string) tcpip.LinkAddress {
- parsed, err := tcpip.ParseMACAddress(addr)
- if err != nil {
- panic(fmt.Sprintf("tcpip.ParseMACAddress(%s): %s", addr, err))
- }
- return parsed
-}
diff --git a/pkg/tcpip/testutil/testutil_test.go b/pkg/tcpip/testutil/testutil_test.go
deleted file mode 100644
index 6aad9585d..000000000
--- a/pkg/tcpip/testutil/testutil_test.go
+++ /dev/null
@@ -1,103 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package testutil
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
-)
-
-// Who tests the testutils?
-
-func TestMustParse4(t *testing.T) {
- tcs := []struct {
- str string
- addr tcpip.Address
- shouldPanic bool
- }{
- {
- str: "127.0.0.1",
- addr: "\x7f\x00\x00\x01",
- }, {
- str: "",
- shouldPanic: true,
- }, {
- str: "fe80::1",
- shouldPanic: true,
- }, {
- // In an ideal world this panics too, but net.IP
- // doesn't distinguish between IPv4 and IPv4-mapped
- // addresses.
- str: "::ffff:0.0.0.1",
- addr: "\x00\x00\x00\x01",
- },
- }
-
- for _, tc := range tcs {
- t.Run(tc.str, func(t *testing.T) {
- if tc.shouldPanic {
- defer func() {
- if r := recover(); r == nil {
- t.Errorf("panic expected, but did not occur")
- }
- }()
- }
- if got := MustParse4(tc.str); got != tc.addr {
- t.Errorf("got MustParse4(%s) = %s, want = %s", tc.str, got, tc.addr)
- }
- })
- }
-}
-
-func TestMustParse6(t *testing.T) {
- tcs := []struct {
- str string
- addr tcpip.Address
- shouldPanic bool
- }{
- {
- // In an ideal world this panics too, but net.IP
- // doesn't distinguish between IPv4 and IPv4-mapped
- // addresses.
- str: "127.0.0.1",
- addr: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x7f\x00\x00\x01",
- }, {
- str: "",
- shouldPanic: true,
- }, {
- str: "fe80::1",
- addr: "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- }, {
- str: "::ffff:0.0.0.1",
- addr: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x01",
- },
- }
-
- for _, tc := range tcs {
- t.Run(tc.str, func(t *testing.T) {
- if tc.shouldPanic {
- defer func() {
- if r := recover(); r == nil {
- t.Errorf("panic expected, but did not occur")
- }
- }()
- }
- if got := MustParse6(tc.str); got != tc.addr {
- t.Errorf("got MustParse6(%s) = %s, want = %s", tc.str, got, tc.addr)
- }
- })
- }
-}
diff --git a/pkg/tcpip/testutil/testutil_unsafe.go b/pkg/tcpip/testutil/testutil_unsafe.go
deleted file mode 100644
index 5ff764800..000000000
--- a/pkg/tcpip/testutil/testutil_unsafe.go
+++ /dev/null
@@ -1,26 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package testutil
-
-import (
- "reflect"
- "unsafe"
-)
-
-// unsafeExposeUnexportedFields takes a Value and returns a version of it in
-// which even unexported fields can be read and written.
-func unsafeExposeUnexportedFields(a reflect.Value) reflect.Value {
- return reflect.NewAt(a.Type(), unsafe.Pointer(a.UnsafeAddr())).Elem()
-}
diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go
deleted file mode 100644
index ed1ed8ac6..000000000
--- a/pkg/tcpip/timer_test.go
+++ /dev/null
@@ -1,353 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcpip_test
-
-import (
- "math"
- "sync"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
-)
-
-func TestMonotonicTimeBefore(t *testing.T) {
- var mt tcpip.MonotonicTime
- if mt.Before(mt) {
- t.Errorf("%#v.Before(%#v)", mt, mt)
- }
-
- one := mt.Add(1)
- if one.Before(mt) {
- t.Errorf("%#v.Before(%#v)", one, mt)
- }
- if !mt.Before(one) {
- t.Errorf("!%#v.Before(%#v)", mt, one)
- }
-}
-
-func TestMonotonicTimeAfter(t *testing.T) {
- var mt tcpip.MonotonicTime
- if mt.After(mt) {
- t.Errorf("%#v.After(%#v)", mt, mt)
- }
-
- one := mt.Add(1)
- if mt.After(one) {
- t.Errorf("%#v.After(%#v)", mt, one)
- }
- if !one.After(mt) {
- t.Errorf("!%#v.After(%#v)", one, mt)
- }
-}
-
-func TestMonotonicTimeAddSub(t *testing.T) {
- var mt tcpip.MonotonicTime
- if one, two := mt.Add(2), mt.Add(1).Add(1); one != two {
- t.Errorf("mt.Add(2) != mt.Add(1).Add(1) (%#v != %#v)", one, two)
- }
-
- min := mt.Add(math.MinInt64)
- max := mt.Add(math.MaxInt64)
-
- if overflow := mt.Add(1).Add(math.MaxInt64); overflow != max {
- t.Errorf("mt.Add(math.MaxInt64) != mt.Add(1).Add(math.MaxInt64) (%#v != %#v)", max, overflow)
- }
- if underflow := mt.Add(-1).Add(math.MinInt64); underflow != min {
- t.Errorf("mt.Add(math.MinInt64) != mt.Add(-1).Add(math.MinInt64) (%#v != %#v)", min, underflow)
- }
-
- if got, want := min.Sub(min), time.Duration(0); want != got {
- t.Errorf("got min.Sub(min) = %d, want %d", got, want)
- }
- if got, want := max.Sub(max), time.Duration(0); want != got {
- t.Errorf("got max.Sub(max) = %d, want %d", got, want)
- }
-
- if overflow, want := max.Sub(min), time.Duration(math.MaxInt64); overflow != want {
- t.Errorf("mt.Add(math.MaxInt64).Sub(mt.Add(math.MinInt64) != %s (%#v)", want, overflow)
- }
- if underflow, want := min.Sub(max), time.Duration(math.MinInt64); underflow != want {
- t.Errorf("mt.Add(math.MinInt64).Sub(mt.Add(math.MaxInt64) != %s (%#v)", want, underflow)
- }
-}
-
-func TestMonotonicTimeSub(t *testing.T) {
- var mt tcpip.MonotonicTime
-
- if one, two := mt.Add(2), mt.Add(1).Add(1); one != two {
- t.Errorf("mt.Add(2) != mt.Add(1).Add(1) (%#v != %#v)", one, two)
- }
-
- if max, overflow := mt.Add(math.MaxInt64), mt.Add(1).Add(math.MaxInt64); max != overflow {
- t.Errorf("mt.Add(math.MaxInt64) != mt.Add(1).Add(math.MaxInt64) (%#v != %#v)", max, overflow)
- }
- if max, underflow := mt.Add(math.MinInt64), mt.Add(-1).Add(math.MinInt64); max != underflow {
- t.Errorf("mt.Add(math.MinInt64) != mt.Add(-1).Add(math.MinInt64) (%#v != %#v)", max, underflow)
- }
-}
-
-const (
- shortDuration = 1 * time.Nanosecond
- middleDuration = 100 * time.Millisecond
-)
-
-func TestJobReschedule(t *testing.T) {
- clock := tcpip.NewStdClock()
- var wg sync.WaitGroup
- var lock sync.Mutex
-
- for i := 0; i < 2; i++ {
- wg.Add(1)
-
- go func() {
- lock.Lock()
- // Assigning a new timer value updates the timer's locker and function.
- // This test makes sure there is no data race when reassigning a timer
- // that has an active timer (even if it has been stopped as a stopped
- // timer may be blocked on a lock before it can check if it has been
- // stopped while another goroutine holds the same lock).
- job := tcpip.NewJob(clock, &lock, func() {
- wg.Done()
- })
- job.Schedule(shortDuration)
- lock.Unlock()
- }()
- }
- wg.Wait()
-}
-
-func stdClockWithAfter() (tcpip.Clock, func(time.Duration) <-chan time.Time) {
- return tcpip.NewStdClock(), time.After
-}
-
-func TestJobExecution(t *testing.T) {
- t.Parallel()
-
- clock, after := stdClockWithAfter()
- var lock sync.Mutex
- ch := make(chan struct{})
-
- job := tcpip.NewJob(clock, &lock, func() {
- ch <- struct{}{}
- })
- job.Schedule(shortDuration)
-
- // Wait for timer to fire.
- select {
- case <-ch:
- case <-after(middleDuration):
- t.Fatal("timed out waiting for timer to fire")
- }
-
- // The timer should have fired only once.
- select {
- case <-ch:
- t.Fatal("no other timers should have fired")
- case <-after(middleDuration):
- }
-}
-
-func TestCancellableTimerResetFromLongDuration(t *testing.T) {
- t.Parallel()
-
- clock, after := stdClockWithAfter()
- var lock sync.Mutex
- ch := make(chan struct{})
-
- job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} })
- job.Schedule(middleDuration)
-
- lock.Lock()
- job.Cancel()
- lock.Unlock()
-
- job.Schedule(shortDuration)
-
- // Wait for timer to fire.
- select {
- case <-ch:
- case <-after(middleDuration):
- t.Fatal("timed out waiting for timer to fire")
- }
-
- // The timer should have fired only once.
- select {
- case <-ch:
- t.Fatal("no other timers should have fired")
- case <-after(middleDuration):
- }
-}
-
-func TestJobRescheduleFromShortDuration(t *testing.T) {
- t.Parallel()
-
- clock, after := stdClockWithAfter()
- var lock sync.Mutex
- ch := make(chan struct{})
-
- lock.Lock()
- job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} })
- job.Schedule(shortDuration)
- job.Cancel()
- lock.Unlock()
-
- // Wait for timer to fire if it wasn't correctly stopped.
- select {
- case <-ch:
- t.Fatal("timer fired after being stopped")
- case <-after(middleDuration):
- }
-
- job.Schedule(shortDuration)
-
- // Wait for timer to fire.
- select {
- case <-ch:
- case <-after(middleDuration):
- t.Fatal("timed out waiting for timer to fire")
- }
-
- // The timer should have fired only once.
- select {
- case <-ch:
- t.Fatal("no other timers should have fired")
- case <-after(middleDuration):
- }
-}
-
-func TestJobImmediatelyCancel(t *testing.T) {
- t.Parallel()
-
- clock, after := stdClockWithAfter()
- var lock sync.Mutex
- ch := make(chan struct{})
-
- for i := 0; i < 1000; i++ {
- lock.Lock()
- job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} })
- job.Schedule(shortDuration)
- job.Cancel()
- lock.Unlock()
- }
-
- // Wait for timer to fire if it wasn't correctly stopped.
- select {
- case <-ch:
- t.Fatal("timer fired after being stopped")
- case <-after(middleDuration):
- }
-}
-
-func stdClockWithAfterAndSleep() (tcpip.Clock, func(time.Duration) <-chan time.Time, func(time.Duration)) {
- clock, after := stdClockWithAfter()
- return clock, after, time.Sleep
-}
-
-func TestJobCancelledRescheduleWithoutLock(t *testing.T) {
- t.Parallel()
-
- clock, after, sleep := stdClockWithAfterAndSleep()
- var lock sync.Mutex
- ch := make(chan struct{})
-
- lock.Lock()
- job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} })
- job.Schedule(shortDuration)
- job.Cancel()
- lock.Unlock()
-
- for i := 0; i < 10; i++ {
- job.Schedule(middleDuration)
-
- lock.Lock()
- // Sleep until the timer fires and gets blocked trying to take the lock.
- sleep(middleDuration * 2)
- job.Cancel()
- lock.Unlock()
- }
-
- // Wait for double the duration so timers that weren't correctly stopped can
- // fire.
- select {
- case <-ch:
- t.Fatal("timer fired after being stopped")
- case <-after(middleDuration * 2):
- }
-}
-
-func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
- t.Parallel()
-
- clock, after, sleep := stdClockWithAfterAndSleep()
- var lock sync.Mutex
- ch := make(chan struct{})
-
- lock.Lock()
- job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} })
- job.Schedule(shortDuration)
- for i := 0; i < 10; i++ {
- // Sleep until the timer fires and gets blocked trying to take the lock.
- sleep(middleDuration)
- job.Cancel()
- job.Schedule(shortDuration)
- }
- lock.Unlock()
-
- // Wait for double the duration for the last timer to fire.
- select {
- case <-ch:
- case <-after(middleDuration):
- t.Fatal("timed out waiting for timer to fire")
- }
-
- // The timer should have fired only once.
- select {
- case <-ch:
- t.Fatal("no other timers should have fired")
- case <-after(middleDuration):
- }
-}
-
-func TestManyJobReschedulesUnderLock(t *testing.T) {
- t.Parallel()
-
- clock, after := stdClockWithAfter()
- var lock sync.Mutex
- ch := make(chan struct{})
-
- lock.Lock()
- job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} })
- job.Schedule(shortDuration)
- for i := 0; i < 10; i++ {
- job.Cancel()
- job.Schedule(shortDuration)
- }
- lock.Unlock()
-
- // Wait for double the duration for the last timer to fire.
- select {
- case <-ch:
- case <-after(middleDuration):
- t.Fatal("timed out waiting for timer to fire")
- }
-
- // The timer should have fired only once.
- select {
- case <-ch:
- t.Fatal("no other timers should have fired")
- case <-after(middleDuration):
- }
-}
diff --git a/pkg/tcpip/transport/BUILD b/pkg/tcpip/transport/BUILD
deleted file mode 100644
index af332ed91..000000000
--- a/pkg/tcpip/transport/BUILD
+++ /dev/null
@@ -1,13 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "transport",
- srcs = [
- "datagram.go",
- "transport.go",
- ],
- visibility = ["//visibility:public"],
- deps = ["//pkg/tcpip"],
-)
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
deleted file mode 100644
index bbc0e3ecc..000000000
--- a/pkg/tcpip/transport/icmp/BUILD
+++ /dev/null
@@ -1,59 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-package(licenses = ["notice"])
-
-go_template_instance(
- name = "icmp_packet_list",
- out = "icmp_packet_list.go",
- package = "icmp",
- prefix = "icmpPacket",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*icmpPacket",
- "Linker": "*icmpPacket",
- },
-)
-
-go_library(
- name = "icmp",
- srcs = [
- "endpoint.go",
- "endpoint_state.go",
- "icmp_packet_list.go",
- "protocol.go",
- ],
- imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/sleep",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/ports",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/raw",
- "//pkg/tcpip/transport/tcp",
- "//pkg/waiter",
- ],
-)
-
-go_test(
- name = "icmp_x_test",
- size = "small",
- srcs = ["icmp_test.go"],
- deps = [
- ":icmp",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/sniffer",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/testutil",
- "//pkg/waiter",
- ],
-)
diff --git a/pkg/tcpip/transport/icmp/icmp_packet_list.go b/pkg/tcpip/transport/icmp/icmp_packet_list.go
new file mode 100644
index 000000000..0aacdad3f
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/icmp_packet_list.go
@@ -0,0 +1,221 @@
+package icmp
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type icmpPacketElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (icmpPacketElementMapper) linkerFor(elem *icmpPacket) *icmpPacket { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type icmpPacketList struct {
+ head *icmpPacket
+ tail *icmpPacket
+}
+
+// Reset resets list l to the empty state.
+func (l *icmpPacketList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+//
+//go:nosplit
+func (l *icmpPacketList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+//
+//go:nosplit
+func (l *icmpPacketList) Front() *icmpPacket {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+//
+//go:nosplit
+func (l *icmpPacketList) Back() *icmpPacket {
+ return l.tail
+}
+
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+//
+//go:nosplit
+func (l *icmpPacketList) Len() (count int) {
+ for e := l.Front(); e != nil; e = (icmpPacketElementMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
+// PushFront inserts the element e at the front of list l.
+//
+//go:nosplit
+func (l *icmpPacketList) PushFront(e *icmpPacket) {
+ linker := icmpPacketElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
+ if l.head != nil {
+ icmpPacketElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+//
+//go:nosplit
+func (l *icmpPacketList) PushBack(e *icmpPacket) {
+ linker := icmpPacketElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
+ if l.tail != nil {
+ icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+//
+//go:nosplit
+func (l *icmpPacketList) PushBackList(m *icmpPacketList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ icmpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+//
+//go:nosplit
+func (l *icmpPacketList) InsertAfter(b, e *icmpPacket) {
+ bLinker := icmpPacketElementMapper{}.linkerFor(b)
+ eLinker := icmpPacketElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
+
+ if a != nil {
+ icmpPacketElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+//
+//go:nosplit
+func (l *icmpPacketList) InsertBefore(a, e *icmpPacket) {
+ aLinker := icmpPacketElementMapper{}.linkerFor(a)
+ eLinker := icmpPacketElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
+
+ if b != nil {
+ icmpPacketElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+//
+//go:nosplit
+func (l *icmpPacketList) Remove(e *icmpPacket) {
+ linker := icmpPacketElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
+
+ if prev != nil {
+ icmpPacketElementMapper{}.linkerFor(prev).SetNext(next)
+ } else if l.head == e {
+ l.head = next
+ }
+
+ if next != nil {
+ icmpPacketElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else if l.tail == e {
+ l.tail = prev
+ }
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type icmpPacketEntry struct {
+ next *icmpPacket
+ prev *icmpPacket
+}
+
+// Next returns the entry that follows e in the list.
+//
+//go:nosplit
+func (e *icmpPacketEntry) Next() *icmpPacket {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *icmpPacketEntry) Prev() *icmpPacket {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+//
+//go:nosplit
+func (e *icmpPacketEntry) SetNext(elem *icmpPacket) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *icmpPacketEntry) SetPrev(elem *icmpPacket) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/icmp/icmp_state_autogen.go b/pkg/tcpip/transport/icmp/icmp_state_autogen.go
new file mode 100644
index 000000000..d90b76d9c
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/icmp_state_autogen.go
@@ -0,0 +1,170 @@
+// automatically generated by stateify.
+
+package icmp
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+func (p *icmpPacket) StateTypeName() string {
+ return "pkg/tcpip/transport/icmp.icmpPacket"
+}
+
+func (p *icmpPacket) StateFields() []string {
+ return []string{
+ "icmpPacketEntry",
+ "senderAddress",
+ "data",
+ "receivedAt",
+ }
+}
+
+func (p *icmpPacket) beforeSave() {}
+
+// +checklocksignore
+func (p *icmpPacket) StateSave(stateSinkObject state.Sink) {
+ p.beforeSave()
+ var dataValue buffer.VectorisedView
+ dataValue = p.saveData()
+ stateSinkObject.SaveValue(2, dataValue)
+ var receivedAtValue int64
+ receivedAtValue = p.saveReceivedAt()
+ stateSinkObject.SaveValue(3, receivedAtValue)
+ stateSinkObject.Save(0, &p.icmpPacketEntry)
+ stateSinkObject.Save(1, &p.senderAddress)
+}
+
+func (p *icmpPacket) afterLoad() {}
+
+// +checklocksignore
+func (p *icmpPacket) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &p.icmpPacketEntry)
+ stateSourceObject.Load(1, &p.senderAddress)
+ stateSourceObject.LoadValue(2, new(buffer.VectorisedView), func(y interface{}) { p.loadData(y.(buffer.VectorisedView)) })
+ stateSourceObject.LoadValue(3, new(int64), func(y interface{}) { p.loadReceivedAt(y.(int64)) })
+}
+
+func (e *endpoint) StateTypeName() string {
+ return "pkg/tcpip/transport/icmp.endpoint"
+}
+
+func (e *endpoint) StateFields() []string {
+ return []string{
+ "TransportEndpointInfo",
+ "DefaultSocketOptionsHandler",
+ "waiterQueue",
+ "uniqueID",
+ "rcvReady",
+ "rcvList",
+ "rcvBufSize",
+ "rcvClosed",
+ "shutdownFlags",
+ "state",
+ "ttl",
+ "owner",
+ "ops",
+ "frozen",
+ }
+}
+
+// +checklocksignore
+func (e *endpoint) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.TransportEndpointInfo)
+ stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler)
+ stateSinkObject.Save(2, &e.waiterQueue)
+ stateSinkObject.Save(3, &e.uniqueID)
+ stateSinkObject.Save(4, &e.rcvReady)
+ stateSinkObject.Save(5, &e.rcvList)
+ stateSinkObject.Save(6, &e.rcvBufSize)
+ stateSinkObject.Save(7, &e.rcvClosed)
+ stateSinkObject.Save(8, &e.shutdownFlags)
+ stateSinkObject.Save(9, &e.state)
+ stateSinkObject.Save(10, &e.ttl)
+ stateSinkObject.Save(11, &e.owner)
+ stateSinkObject.Save(12, &e.ops)
+ stateSinkObject.Save(13, &e.frozen)
+}
+
+// +checklocksignore
+func (e *endpoint) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.TransportEndpointInfo)
+ stateSourceObject.Load(1, &e.DefaultSocketOptionsHandler)
+ stateSourceObject.Load(2, &e.waiterQueue)
+ stateSourceObject.Load(3, &e.uniqueID)
+ stateSourceObject.Load(4, &e.rcvReady)
+ stateSourceObject.Load(5, &e.rcvList)
+ stateSourceObject.Load(6, &e.rcvBufSize)
+ stateSourceObject.Load(7, &e.rcvClosed)
+ stateSourceObject.Load(8, &e.shutdownFlags)
+ stateSourceObject.Load(9, &e.state)
+ stateSourceObject.Load(10, &e.ttl)
+ stateSourceObject.Load(11, &e.owner)
+ stateSourceObject.Load(12, &e.ops)
+ stateSourceObject.Load(13, &e.frozen)
+ stateSourceObject.AfterLoad(e.afterLoad)
+}
+
+func (l *icmpPacketList) StateTypeName() string {
+ return "pkg/tcpip/transport/icmp.icmpPacketList"
+}
+
+func (l *icmpPacketList) StateFields() []string {
+ return []string{
+ "head",
+ "tail",
+ }
+}
+
+func (l *icmpPacketList) beforeSave() {}
+
+// +checklocksignore
+func (l *icmpPacketList) StateSave(stateSinkObject state.Sink) {
+ l.beforeSave()
+ stateSinkObject.Save(0, &l.head)
+ stateSinkObject.Save(1, &l.tail)
+}
+
+func (l *icmpPacketList) afterLoad() {}
+
+// +checklocksignore
+func (l *icmpPacketList) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &l.head)
+ stateSourceObject.Load(1, &l.tail)
+}
+
+func (e *icmpPacketEntry) StateTypeName() string {
+ return "pkg/tcpip/transport/icmp.icmpPacketEntry"
+}
+
+func (e *icmpPacketEntry) StateFields() []string {
+ return []string{
+ "next",
+ "prev",
+ }
+}
+
+func (e *icmpPacketEntry) beforeSave() {}
+
+// +checklocksignore
+func (e *icmpPacketEntry) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.next)
+ stateSinkObject.Save(1, &e.prev)
+}
+
+func (e *icmpPacketEntry) afterLoad() {}
+
+// +checklocksignore
+func (e *icmpPacketEntry) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.next)
+ stateSourceObject.Load(1, &e.prev)
+}
+
+func init() {
+ state.Register((*icmpPacket)(nil))
+ state.Register((*endpoint)(nil))
+ state.Register((*icmpPacketList)(nil))
+ state.Register((*icmpPacketEntry)(nil))
+}
diff --git a/pkg/tcpip/transport/icmp/icmp_test.go b/pkg/tcpip/transport/icmp/icmp_test.go
deleted file mode 100644
index 729f50e9a..000000000
--- a/pkg/tcpip/transport/icmp/icmp_test.go
+++ /dev/null
@@ -1,239 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package icmp_test
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-// TODO(https://gvisor.dev/issues/5623): Finish unit testing the icmp package.
-// See the issue for remaining areas of work.
-
-var (
- localV4Addr1 = testutil.MustParse4("10.0.0.1")
- localV4Addr2 = testutil.MustParse4("10.0.0.2")
- remoteV4Addr = testutil.MustParse4("10.0.0.3")
-)
-
-func addNICWithDefaultRoute(t *testing.T, s *stack.Stack, id tcpip.NICID, name string, addrV4 tcpip.Address) *channel.Endpoint {
- t.Helper()
-
- ep := channel.New(1 /* size */, header.IPv4MinimumMTU, "" /* linkAddr */)
- t.Cleanup(ep.Close)
-
- wep := stack.LinkEndpoint(ep)
- if testing.Verbose() {
- wep = sniffer.New(ep)
- }
-
- opts := stack.NICOptions{Name: name}
- if err := s.CreateNICWithOptions(id, wep, opts); err != nil {
- t.Fatalf("s.CreateNIC(%d, _) = %s", id, err)
- }
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: addrV4.WithPrefix(),
- }
- if err := s.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", id, protocolAddr, err)
- }
-
- s.AddRoute(tcpip.Route{
- Destination: header.IPv4EmptySubnet,
- NIC: id,
- })
-
- return ep
-}
-
-func writePayload(buf []byte) {
- for i := range buf {
- buf[i] = byte(i)
- }
-}
-
-func newICMPv4EchoRequest(payloadSize uint32) buffer.View {
- buf := buffer.NewView(header.ICMPv4MinimumSize + int(payloadSize))
- writePayload(buf[header.ICMPv4MinimumSize:])
-
- icmp := header.ICMPv4(buf)
- icmp.SetType(header.ICMPv4Echo)
- // No need to set the checksum; it is reset by the socket before the packet
- // is sent.
-
- return buf
-}
-
-// TestWriteUnboundWithBindToDevice exercises writing to an unbound ICMP socket
-// when SO_BINDTODEVICE is set to the non-default NIC for that subnet.
-//
-// Only IPv4 is tested. The logic to determine which NIC to use is agnostic to
-// the version of IP.
-func TestWriteUnboundWithBindToDevice(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
- HandleLocal: true,
- })
-
- // Add two NICs, both with default routes on the same subnet. The first NIC
- // added will be the default NIC for that subnet.
- defaultEP := addNICWithDefaultRoute(t, s, 1, "default", localV4Addr1)
- alternateEP := addNICWithDefaultRoute(t, s, 2, "alternate", localV4Addr2)
-
- socket, err := s.NewEndpoint(icmp.ProtocolNumber4, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("s.NewEndpoint(%d, %d, _) = %s", icmp.ProtocolNumber4, ipv4.ProtocolNumber, err)
- }
- defer socket.Close()
-
- echoPayloadSize := defaultEP.MTU() - header.IPv4MinimumSize - header.ICMPv4MinimumSize
-
- // Send a packet without SO_BINDTODEVICE. This verifies that the first NIC
- // to be added is the default NIC to send packets when not explicitly bound.
- {
- buf := newICMPv4EchoRequest(echoPayloadSize)
- r := buf.Reader()
- n, err := socket.Write(&r, tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: remoteV4Addr},
- })
- if err != nil {
- t.Fatalf("socket.Write(_, {To:%s}) = %s", remoteV4Addr, err)
- }
- if n != int64(len(buf)) {
- t.Fatalf("got n = %d, want n = %d", n, len(buf))
- }
-
- // Verify the packet was sent out the default NIC.
- p, ok := defaultEP.Read()
- if !ok {
- t.Fatalf("got defaultEP.Read(_) = _, false; want = _, true (packet wasn't written out)")
- }
-
- vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
- b := vv.ToView()
-
- checker.IPv4(t, b, []checker.NetworkChecker{
- checker.SrcAddr(localV4Addr1),
- checker.DstAddr(remoteV4Addr),
- checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4Echo),
- checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]),
- ),
- }...)
-
- // Verify the packet was not sent out the alternate NIC.
- if p, ok := alternateEP.Read(); ok {
- t.Fatalf("got alternateEP.Read(_) = %+v, true; want = _, false", p)
- }
- }
-
- // Send a packet with SO_BINDTODEVICE. This exercises reliance on
- // SO_BINDTODEVICE to route the packet to the alternate NIC.
- {
- // Use SO_BINDTODEVICE to send over the alternate NIC by default.
- socket.SocketOptions().SetBindToDevice(2)
-
- buf := newICMPv4EchoRequest(echoPayloadSize)
- r := buf.Reader()
- n, err := socket.Write(&r, tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: remoteV4Addr},
- })
- if err != nil {
- t.Fatalf("socket.Write(_, {To:%s}) = %s", tcpip.Address(remoteV4Addr), err)
- }
- if n != int64(len(buf)) {
- t.Fatalf("got n = %d, want n = %d", n, len(buf))
- }
-
- // Verify the packet was not sent out the default NIC.
- if p, ok := defaultEP.Read(); ok {
- t.Fatalf("got defaultEP.Read(_) = %+v, true; want = _, false", p)
- }
-
- // Verify the packet was sent out the alternate NIC.
- p, ok := alternateEP.Read()
- if !ok {
- t.Fatalf("got alternateEP.Read(_) = _, false; want = _, true (packet wasn't written out)")
- }
-
- vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
- b := vv.ToView()
-
- checker.IPv4(t, b, []checker.NetworkChecker{
- checker.SrcAddr(localV4Addr2),
- checker.DstAddr(remoteV4Addr),
- checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4Echo),
- checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]),
- ),
- }...)
- }
-
- // Send a packet with SO_BINDTODEVICE cleared. This verifies that clearing
- // the device binding will fallback to using the default NIC to send
- // packets.
- {
- socket.SocketOptions().SetBindToDevice(0)
-
- buf := newICMPv4EchoRequest(echoPayloadSize)
- r := buf.Reader()
- n, err := socket.Write(&r, tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: remoteV4Addr},
- })
- if err != nil {
- t.Fatalf("socket.Write(_, {To:%s}) = %s", tcpip.Address(remoteV4Addr), err)
- }
- if n != int64(len(buf)) {
- t.Fatalf("got n = %d, want n = %d", n, len(buf))
- }
-
- // Verify the packet was sent out the default NIC.
- p, ok := defaultEP.Read()
- if !ok {
- t.Fatalf("got defaultEP.Read(_) = _, false; want = _, true (packet wasn't written out)")
- }
-
- vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
- b := vv.ToView()
-
- checker.IPv4(t, b, []checker.NetworkChecker{
- checker.SrcAddr(localV4Addr1),
- checker.DstAddr(remoteV4Addr),
- checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4Echo),
- checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]),
- ),
- }...)
-
- // Verify the packet was not sent out the alternate NIC.
- if p, ok := alternateEP.Read(); ok {
- t.Fatalf("got alternateEP.Read(_) = %+v, true; want = _, false", p)
- }
- }
-}
diff --git a/pkg/tcpip/transport/internal/network/BUILD b/pkg/tcpip/transport/internal/network/BUILD
deleted file mode 100644
index b1edce39b..000000000
--- a/pkg/tcpip/transport/internal/network/BUILD
+++ /dev/null
@@ -1,45 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "network",
- srcs = [
- "endpoint.go",
- "endpoint_state.go",
- ],
- visibility = [
- "//pkg/tcpip/transport/raw:__pkg__",
- "//pkg/tcpip/transport/udp:__pkg__",
- ],
- deps = [
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/header",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport",
- ],
-)
-
-go_test(
- name = "network_test",
- size = "small",
- srcs = ["endpoint_test.go"],
- deps = [
- ":network",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/faketime",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/loopback",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport",
- "//pkg/tcpip/transport/udp",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/transport/internal/network/endpoint_test.go b/pkg/tcpip/transport/internal/network/endpoint_test.go
deleted file mode 100644
index f263a9ea2..000000000
--- a/pkg/tcpip/transport/internal/network/endpoint_test.go
+++ /dev/null
@@ -1,318 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package network_test
-
-import (
- "fmt"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport"
- "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
-)
-
-var (
- ipv4NICAddr = testutil.MustParse4("1.2.3.4")
- ipv6NICAddr = testutil.MustParse6("a::1")
- ipv4RemoteAddr = testutil.MustParse4("6.7.8.9")
- ipv6RemoteAddr = testutil.MustParse6("b::1")
-)
-
-func TestEndpointStateTransitions(t *testing.T) {
- const nicID = 1
-
- data := buffer.View([]byte{1, 2, 4, 5})
- v4Checker := func(t *testing.T, b buffer.View) {
- checker.IPv4(t, b,
- checker.SrcAddr(ipv4NICAddr),
- checker.DstAddr(ipv4RemoteAddr),
- checker.IPPayload(data),
- )
- }
-
- v6Checker := func(t *testing.T, b buffer.View) {
- checker.IPv6(t, b,
- checker.SrcAddr(ipv6NICAddr),
- checker.DstAddr(ipv6RemoteAddr),
- checker.IPPayload(data),
- )
- }
-
- tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- expectedMaxHeaderLength uint16
- expectedNetProto tcpip.NetworkProtocolNumber
- expectedLocalAddr tcpip.Address
- bindAddr tcpip.Address
- expectedBoundAddr tcpip.Address
- remoteAddr tcpip.Address
- expectedRemoteAddr tcpip.Address
- checker func(*testing.T, buffer.View)
- }{
- {
- name: "IPv4",
- netProto: ipv4.ProtocolNumber,
- expectedMaxHeaderLength: header.IPv4MaximumHeaderSize,
- expectedNetProto: ipv4.ProtocolNumber,
- expectedLocalAddr: ipv4NICAddr,
- bindAddr: header.IPv4AllSystems,
- expectedBoundAddr: header.IPv4AllSystems,
- remoteAddr: ipv4RemoteAddr,
- expectedRemoteAddr: ipv4RemoteAddr,
- checker: v4Checker,
- },
- {
- name: "IPv6",
- netProto: ipv6.ProtocolNumber,
- expectedMaxHeaderLength: header.IPv6FixedHeaderSize,
- expectedNetProto: ipv6.ProtocolNumber,
- expectedLocalAddr: ipv6NICAddr,
- bindAddr: header.IPv6AllNodesMulticastAddress,
- expectedBoundAddr: header.IPv6AllNodesMulticastAddress,
- remoteAddr: ipv6RemoteAddr,
- expectedRemoteAddr: ipv6RemoteAddr,
- checker: v6Checker,
- },
- {
- name: "IPv4-mapped-IPv6",
- netProto: ipv6.ProtocolNumber,
- expectedMaxHeaderLength: header.IPv4MaximumHeaderSize,
- expectedNetProto: ipv4.ProtocolNumber,
- expectedLocalAddr: ipv4NICAddr,
- bindAddr: testutil.MustParse6("::ffff:e000:0001"),
- expectedBoundAddr: header.IPv4AllSystems,
- remoteAddr: testutil.MustParse6("::ffff:0607:0809"),
- expectedRemoteAddr: ipv4RemoteAddr,
- checker: v4Checker,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- Clock: &faketime.NullClock{},
- })
- e := channel.New(1, header.IPv6MinimumMTU, "")
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
-
- ipv4ProtocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: ipv4NICAddr.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, ipv4ProtocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %+v, {}: %s", nicID, ipv4ProtocolAddr, err)
- }
- ipv6ProtocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: ipv6NICAddr.WithPrefix(),
- }
-
- if err := s.AddProtocolAddress(nicID, ipv6ProtocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtocolAddr, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {Destination: ipv4RemoteAddr.WithPrefix().Subnet(), NIC: nicID},
- {Destination: ipv6RemoteAddr.WithPrefix().Subnet(), NIC: nicID},
- })
-
- var ops tcpip.SocketOptions
- var ep network.Endpoint
- ep.Init(s, test.netProto, udp.ProtocolNumber, &ops)
- defer ep.Close()
- if state := ep.State(); state != transport.DatagramEndpointStateInitial {
- t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateInitial)
- }
-
- bindAddr := tcpip.FullAddress{Addr: test.bindAddr}
- if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
- }
- if state := ep.State(); state != transport.DatagramEndpointStateBound {
- t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateBound)
- }
- if diff := cmp.Diff(ep.GetLocalAddress(), tcpip.FullAddress{Addr: test.expectedBoundAddr}); diff != "" {
- t.Errorf("ep.GetLocalAddress() mismatch (-want +got):\n%s", diff)
- }
- if addr, connected := ep.GetRemoteAddress(); connected {
- t.Errorf("got ep.GetRemoteAddress() = (true, %#v), want = (false, _)", addr)
- }
-
- connectAddr := tcpip.FullAddress{Addr: test.remoteAddr}
- if err := ep.Connect(connectAddr); err != nil {
- t.Fatalf("ep.Connect(%#v): %s", connectAddr, err)
- }
- if state := ep.State(); state != transport.DatagramEndpointStateConnected {
- t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateConnected)
- }
- if diff := cmp.Diff(ep.GetLocalAddress(), tcpip.FullAddress{Addr: test.expectedLocalAddr}); diff != "" {
- t.Errorf("ep.GetLocalAddress() mismatch (-want +got):\n%s", diff)
- }
- if addr, connected := ep.GetRemoteAddress(); !connected {
- t.Errorf("got ep.GetRemoteAddress() = (false, _), want = (true, %#v)", connectAddr)
- } else if diff := cmp.Diff(addr, tcpip.FullAddress{Addr: test.expectedRemoteAddr}); diff != "" {
- t.Errorf("remote address mismatch (-want +got):\n%s", diff)
- }
-
- ctx, err := ep.AcquireContextForWrite(tcpip.WriteOptions{})
- if err != nil {
- t.Fatalf("ep.AcquireContexForWrite({}): %s", err)
- }
- defer ctx.Release()
- info := ctx.PacketInfo()
- if diff := cmp.Diff(network.WritePacketInfo{
- NetProto: test.expectedNetProto,
- LocalAddress: test.expectedLocalAddr,
- RemoteAddress: test.expectedRemoteAddr,
- MaxHeaderLength: test.expectedMaxHeaderLength,
- RequiresTXTransportChecksum: true,
- }, info); diff != "" {
- t.Errorf("write packet info mismatch (-want +got):\n%s", diff)
- }
- if err := ctx.WritePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(info.MaxHeaderLength),
- Data: data.ToVectorisedView(),
- }), false /* headerIncluded */); err != nil {
- t.Fatalf("ctx.WritePacket(_, false): %s", err)
- }
- if pkt, ok := e.Read(); !ok {
- t.Fatalf("expected packet to be read from link endpoint")
- } else {
- test.checker(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()))
- }
-
- ep.Close()
- if state := ep.State(); state != transport.DatagramEndpointStateClosed {
- t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateClosed)
- }
- })
- }
-}
-
-func TestBindNICID(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- bindAddr tcpip.Address
- unicast bool
- }{
- {
- name: "IPv4 multicast",
- netProto: ipv4.ProtocolNumber,
- bindAddr: header.IPv4AllSystems,
- unicast: false,
- },
- {
- name: "IPv6 multicast",
- netProto: ipv6.ProtocolNumber,
- bindAddr: header.IPv6AllNodesMulticastAddress,
- unicast: false,
- },
- {
- name: "IPv4 unicast",
- netProto: ipv4.ProtocolNumber,
- bindAddr: ipv4NICAddr,
- unicast: true,
- },
- {
- name: "IPv6 unicast",
- netProto: ipv6.ProtocolNumber,
- bindAddr: ipv6NICAddr,
- unicast: true,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- for _, testBindNICID := range []tcpip.NICID{0, nicID} {
- t.Run(fmt.Sprintf("BindNICID=%d", testBindNICID), func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- Clock: &faketime.NullClock{},
- })
- if err := s.CreateNIC(nicID, loopback.New()); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
-
- ipv4ProtocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: ipv4NICAddr.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, ipv4ProtocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtocolAddr, err)
- }
- ipv6ProtocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: ipv6NICAddr.WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, ipv6ProtocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtocolAddr, err)
- }
-
- var ops tcpip.SocketOptions
- var ep network.Endpoint
- ep.Init(s, test.netProto, udp.ProtocolNumber, &ops)
- defer ep.Close()
- if ep.WasBound() {
- t.Fatal("got ep.WasBound() = true, want = false")
- }
- wantInfo := stack.TransportEndpointInfo{NetProto: test.netProto, TransProto: udp.ProtocolNumber}
- if diff := cmp.Diff(wantInfo, ep.Info()); diff != "" {
- t.Fatalf("ep.Info() mismatch (-want +got):\n%s", diff)
- }
-
- bindAddr := tcpip.FullAddress{Addr: test.bindAddr, NIC: testBindNICID}
- if err := ep.Bind(bindAddr); err != nil {
- t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
- }
- if !ep.WasBound() {
- t.Error("got ep.WasBound() = false, want = true")
- }
- wantInfo.ID = stack.TransportEndpointID{LocalAddress: bindAddr.Addr}
- wantInfo.BindAddr = bindAddr.Addr
- wantInfo.BindNICID = bindAddr.NIC
- if test.unicast {
- wantInfo.RegisterNICID = nicID
- } else {
- wantInfo.RegisterNICID = bindAddr.NIC
- }
- if diff := cmp.Diff(wantInfo, ep.Info()); diff != "" {
- t.Errorf("ep.Info() mismatch (-want +got):\n%s", diff)
- }
- })
- }
- })
- }
-}
diff --git a/pkg/tcpip/transport/internal/network/network_state_autogen.go b/pkg/tcpip/transport/internal/network/network_state_autogen.go
new file mode 100644
index 000000000..1515c8632
--- /dev/null
+++ b/pkg/tcpip/transport/internal/network/network_state_autogen.go
@@ -0,0 +1,110 @@
+// automatically generated by stateify.
+
+package network
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (e *Endpoint) StateTypeName() string {
+ return "pkg/tcpip/transport/internal/network.Endpoint"
+}
+
+func (e *Endpoint) StateFields() []string {
+ return []string{
+ "ops",
+ "netProto",
+ "transProto",
+ "wasBound",
+ "owner",
+ "writeShutdown",
+ "effectiveNetProto",
+ "multicastMemberships",
+ "ttl",
+ "multicastTTL",
+ "multicastAddr",
+ "multicastNICID",
+ "ipv4TOS",
+ "ipv6TClass",
+ "info",
+ "state",
+ }
+}
+
+func (e *Endpoint) beforeSave() {}
+
+// +checklocksignore
+func (e *Endpoint) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.ops)
+ stateSinkObject.Save(1, &e.netProto)
+ stateSinkObject.Save(2, &e.transProto)
+ stateSinkObject.Save(3, &e.wasBound)
+ stateSinkObject.Save(4, &e.owner)
+ stateSinkObject.Save(5, &e.writeShutdown)
+ stateSinkObject.Save(6, &e.effectiveNetProto)
+ stateSinkObject.Save(7, &e.multicastMemberships)
+ stateSinkObject.Save(8, &e.ttl)
+ stateSinkObject.Save(9, &e.multicastTTL)
+ stateSinkObject.Save(10, &e.multicastAddr)
+ stateSinkObject.Save(11, &e.multicastNICID)
+ stateSinkObject.Save(12, &e.ipv4TOS)
+ stateSinkObject.Save(13, &e.ipv6TClass)
+ stateSinkObject.Save(14, &e.info)
+ stateSinkObject.Save(15, &e.state)
+}
+
+func (e *Endpoint) afterLoad() {}
+
+// +checklocksignore
+func (e *Endpoint) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.ops)
+ stateSourceObject.Load(1, &e.netProto)
+ stateSourceObject.Load(2, &e.transProto)
+ stateSourceObject.Load(3, &e.wasBound)
+ stateSourceObject.Load(4, &e.owner)
+ stateSourceObject.Load(5, &e.writeShutdown)
+ stateSourceObject.Load(6, &e.effectiveNetProto)
+ stateSourceObject.Load(7, &e.multicastMemberships)
+ stateSourceObject.Load(8, &e.ttl)
+ stateSourceObject.Load(9, &e.multicastTTL)
+ stateSourceObject.Load(10, &e.multicastAddr)
+ stateSourceObject.Load(11, &e.multicastNICID)
+ stateSourceObject.Load(12, &e.ipv4TOS)
+ stateSourceObject.Load(13, &e.ipv6TClass)
+ stateSourceObject.Load(14, &e.info)
+ stateSourceObject.Load(15, &e.state)
+}
+
+func (m *multicastMembership) StateTypeName() string {
+ return "pkg/tcpip/transport/internal/network.multicastMembership"
+}
+
+func (m *multicastMembership) StateFields() []string {
+ return []string{
+ "nicID",
+ "multicastAddr",
+ }
+}
+
+func (m *multicastMembership) beforeSave() {}
+
+// +checklocksignore
+func (m *multicastMembership) StateSave(stateSinkObject state.Sink) {
+ m.beforeSave()
+ stateSinkObject.Save(0, &m.nicID)
+ stateSinkObject.Save(1, &m.multicastAddr)
+}
+
+func (m *multicastMembership) afterLoad() {}
+
+// +checklocksignore
+func (m *multicastMembership) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &m.nicID)
+ stateSourceObject.Load(1, &m.multicastAddr)
+}
+
+func init() {
+ state.Register((*Endpoint)(nil))
+ state.Register((*multicastMembership)(nil))
+}
diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD
deleted file mode 100644
index b989b1209..000000000
--- a/pkg/tcpip/transport/packet/BUILD
+++ /dev/null
@@ -1,37 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-package(licenses = ["notice"])
-
-go_template_instance(
- name = "packet_list",
- out = "packet_list.go",
- package = "packet",
- prefix = "packet",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*packet",
- "Linker": "*packet",
- },
-)
-
-go_library(
- name = "packet",
- srcs = [
- "endpoint.go",
- "endpoint_state.go",
- "packet_list.go",
- ],
- imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/log",
- "//pkg/sleep",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/stack",
- "//pkg/waiter",
- ],
-)
diff --git a/pkg/tcpip/transport/packet/packet_list.go b/pkg/tcpip/transport/packet/packet_list.go
new file mode 100644
index 000000000..2c983aad0
--- /dev/null
+++ b/pkg/tcpip/transport/packet/packet_list.go
@@ -0,0 +1,221 @@
+package packet
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type packetElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (packetElementMapper) linkerFor(elem *packet) *packet { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type packetList struct {
+ head *packet
+ tail *packet
+}
+
+// Reset resets list l to the empty state.
+func (l *packetList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+//
+//go:nosplit
+func (l *packetList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+//
+//go:nosplit
+func (l *packetList) Front() *packet {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+//
+//go:nosplit
+func (l *packetList) Back() *packet {
+ return l.tail
+}
+
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+//
+//go:nosplit
+func (l *packetList) Len() (count int) {
+ for e := l.Front(); e != nil; e = (packetElementMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
+// PushFront inserts the element e at the front of list l.
+//
+//go:nosplit
+func (l *packetList) PushFront(e *packet) {
+ linker := packetElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
+ if l.head != nil {
+ packetElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+//
+//go:nosplit
+func (l *packetList) PushBack(e *packet) {
+ linker := packetElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
+ if l.tail != nil {
+ packetElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+//
+//go:nosplit
+func (l *packetList) PushBackList(m *packetList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ packetElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ packetElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+//
+//go:nosplit
+func (l *packetList) InsertAfter(b, e *packet) {
+ bLinker := packetElementMapper{}.linkerFor(b)
+ eLinker := packetElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
+
+ if a != nil {
+ packetElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+//
+//go:nosplit
+func (l *packetList) InsertBefore(a, e *packet) {
+ aLinker := packetElementMapper{}.linkerFor(a)
+ eLinker := packetElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
+
+ if b != nil {
+ packetElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+//
+//go:nosplit
+func (l *packetList) Remove(e *packet) {
+ linker := packetElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
+
+ if prev != nil {
+ packetElementMapper{}.linkerFor(prev).SetNext(next)
+ } else if l.head == e {
+ l.head = next
+ }
+
+ if next != nil {
+ packetElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else if l.tail == e {
+ l.tail = prev
+ }
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type packetEntry struct {
+ next *packet
+ prev *packet
+}
+
+// Next returns the entry that follows e in the list.
+//
+//go:nosplit
+func (e *packetEntry) Next() *packet {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *packetEntry) Prev() *packet {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+//
+//go:nosplit
+func (e *packetEntry) SetNext(elem *packet) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *packetEntry) SetPrev(elem *packet) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/packet/packet_state_autogen.go b/pkg/tcpip/transport/packet/packet_state_autogen.go
new file mode 100644
index 000000000..9c6623ffd
--- /dev/null
+++ b/pkg/tcpip/transport/packet/packet_state_autogen.go
@@ -0,0 +1,167 @@
+// automatically generated by stateify.
+
+package packet
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+func (p *packet) StateTypeName() string {
+ return "pkg/tcpip/transport/packet.packet"
+}
+
+func (p *packet) StateFields() []string {
+ return []string{
+ "packetEntry",
+ "data",
+ "receivedAt",
+ "senderAddr",
+ "packetInfo",
+ }
+}
+
+func (p *packet) beforeSave() {}
+
+// +checklocksignore
+func (p *packet) StateSave(stateSinkObject state.Sink) {
+ p.beforeSave()
+ var dataValue buffer.VectorisedView
+ dataValue = p.saveData()
+ stateSinkObject.SaveValue(1, dataValue)
+ var receivedAtValue int64
+ receivedAtValue = p.saveReceivedAt()
+ stateSinkObject.SaveValue(2, receivedAtValue)
+ stateSinkObject.Save(0, &p.packetEntry)
+ stateSinkObject.Save(3, &p.senderAddr)
+ stateSinkObject.Save(4, &p.packetInfo)
+}
+
+func (p *packet) afterLoad() {}
+
+// +checklocksignore
+func (p *packet) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &p.packetEntry)
+ stateSourceObject.Load(3, &p.senderAddr)
+ stateSourceObject.Load(4, &p.packetInfo)
+ stateSourceObject.LoadValue(1, new(buffer.VectorisedView), func(y interface{}) { p.loadData(y.(buffer.VectorisedView)) })
+ stateSourceObject.LoadValue(2, new(int64), func(y interface{}) { p.loadReceivedAt(y.(int64)) })
+}
+
+func (ep *endpoint) StateTypeName() string {
+ return "pkg/tcpip/transport/packet.endpoint"
+}
+
+func (ep *endpoint) StateFields() []string {
+ return []string{
+ "DefaultSocketOptionsHandler",
+ "waiterQueue",
+ "cooked",
+ "ops",
+ "rcvList",
+ "rcvBufSize",
+ "rcvClosed",
+ "rcvDisabled",
+ "closed",
+ "boundNetProto",
+ "boundNIC",
+ "lastError",
+ }
+}
+
+// +checklocksignore
+func (ep *endpoint) StateSave(stateSinkObject state.Sink) {
+ ep.beforeSave()
+ stateSinkObject.Save(0, &ep.DefaultSocketOptionsHandler)
+ stateSinkObject.Save(1, &ep.waiterQueue)
+ stateSinkObject.Save(2, &ep.cooked)
+ stateSinkObject.Save(3, &ep.ops)
+ stateSinkObject.Save(4, &ep.rcvList)
+ stateSinkObject.Save(5, &ep.rcvBufSize)
+ stateSinkObject.Save(6, &ep.rcvClosed)
+ stateSinkObject.Save(7, &ep.rcvDisabled)
+ stateSinkObject.Save(8, &ep.closed)
+ stateSinkObject.Save(9, &ep.boundNetProto)
+ stateSinkObject.Save(10, &ep.boundNIC)
+ stateSinkObject.Save(11, &ep.lastError)
+}
+
+// +checklocksignore
+func (ep *endpoint) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &ep.DefaultSocketOptionsHandler)
+ stateSourceObject.Load(1, &ep.waiterQueue)
+ stateSourceObject.Load(2, &ep.cooked)
+ stateSourceObject.Load(3, &ep.ops)
+ stateSourceObject.Load(4, &ep.rcvList)
+ stateSourceObject.Load(5, &ep.rcvBufSize)
+ stateSourceObject.Load(6, &ep.rcvClosed)
+ stateSourceObject.Load(7, &ep.rcvDisabled)
+ stateSourceObject.Load(8, &ep.closed)
+ stateSourceObject.Load(9, &ep.boundNetProto)
+ stateSourceObject.Load(10, &ep.boundNIC)
+ stateSourceObject.Load(11, &ep.lastError)
+ stateSourceObject.AfterLoad(ep.afterLoad)
+}
+
+func (l *packetList) StateTypeName() string {
+ return "pkg/tcpip/transport/packet.packetList"
+}
+
+func (l *packetList) StateFields() []string {
+ return []string{
+ "head",
+ "tail",
+ }
+}
+
+func (l *packetList) beforeSave() {}
+
+// +checklocksignore
+func (l *packetList) StateSave(stateSinkObject state.Sink) {
+ l.beforeSave()
+ stateSinkObject.Save(0, &l.head)
+ stateSinkObject.Save(1, &l.tail)
+}
+
+func (l *packetList) afterLoad() {}
+
+// +checklocksignore
+func (l *packetList) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &l.head)
+ stateSourceObject.Load(1, &l.tail)
+}
+
+func (e *packetEntry) StateTypeName() string {
+ return "pkg/tcpip/transport/packet.packetEntry"
+}
+
+func (e *packetEntry) StateFields() []string {
+ return []string{
+ "next",
+ "prev",
+ }
+}
+
+func (e *packetEntry) beforeSave() {}
+
+// +checklocksignore
+func (e *packetEntry) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.next)
+ stateSinkObject.Save(1, &e.prev)
+}
+
+func (e *packetEntry) afterLoad() {}
+
+// +checklocksignore
+func (e *packetEntry) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.next)
+ stateSourceObject.Load(1, &e.prev)
+}
+
+func init() {
+ state.Register((*packet)(nil))
+ state.Register((*endpoint)(nil))
+ state.Register((*packetList)(nil))
+ state.Register((*packetEntry)(nil))
+}
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
deleted file mode 100644
index b7e97e218..000000000
--- a/pkg/tcpip/transport/raw/BUILD
+++ /dev/null
@@ -1,41 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-package(licenses = ["notice"])
-
-go_template_instance(
- name = "raw_packet_list",
- out = "raw_packet_list.go",
- package = "raw",
- prefix = "rawPacket",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*rawPacket",
- "Linker": "*rawPacket",
- },
-)
-
-go_library(
- name = "raw",
- srcs = [
- "endpoint.go",
- "endpoint_state.go",
- "protocol.go",
- "raw_packet_list.go",
- ],
- imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/log",
- "//pkg/sleep",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport",
- "//pkg/tcpip/transport/internal/network",
- "//pkg/tcpip/transport/packet",
- "//pkg/waiter",
- ],
-)
diff --git a/pkg/tcpip/transport/raw/raw_packet_list.go b/pkg/tcpip/transport/raw/raw_packet_list.go
new file mode 100644
index 000000000..48804ff1b
--- /dev/null
+++ b/pkg/tcpip/transport/raw/raw_packet_list.go
@@ -0,0 +1,221 @@
+package raw
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type rawPacketElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (rawPacketElementMapper) linkerFor(elem *rawPacket) *rawPacket { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type rawPacketList struct {
+ head *rawPacket
+ tail *rawPacket
+}
+
+// Reset resets list l to the empty state.
+func (l *rawPacketList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+//
+//go:nosplit
+func (l *rawPacketList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+//
+//go:nosplit
+func (l *rawPacketList) Front() *rawPacket {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+//
+//go:nosplit
+func (l *rawPacketList) Back() *rawPacket {
+ return l.tail
+}
+
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+//
+//go:nosplit
+func (l *rawPacketList) Len() (count int) {
+ for e := l.Front(); e != nil; e = (rawPacketElementMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
+// PushFront inserts the element e at the front of list l.
+//
+//go:nosplit
+func (l *rawPacketList) PushFront(e *rawPacket) {
+ linker := rawPacketElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
+ if l.head != nil {
+ rawPacketElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+//
+//go:nosplit
+func (l *rawPacketList) PushBack(e *rawPacket) {
+ linker := rawPacketElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
+ if l.tail != nil {
+ rawPacketElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+//
+//go:nosplit
+func (l *rawPacketList) PushBackList(m *rawPacketList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ rawPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ rawPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+//
+//go:nosplit
+func (l *rawPacketList) InsertAfter(b, e *rawPacket) {
+ bLinker := rawPacketElementMapper{}.linkerFor(b)
+ eLinker := rawPacketElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
+
+ if a != nil {
+ rawPacketElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+//
+//go:nosplit
+func (l *rawPacketList) InsertBefore(a, e *rawPacket) {
+ aLinker := rawPacketElementMapper{}.linkerFor(a)
+ eLinker := rawPacketElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
+
+ if b != nil {
+ rawPacketElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+//
+//go:nosplit
+func (l *rawPacketList) Remove(e *rawPacket) {
+ linker := rawPacketElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
+
+ if prev != nil {
+ rawPacketElementMapper{}.linkerFor(prev).SetNext(next)
+ } else if l.head == e {
+ l.head = next
+ }
+
+ if next != nil {
+ rawPacketElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else if l.tail == e {
+ l.tail = prev
+ }
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type rawPacketEntry struct {
+ next *rawPacket
+ prev *rawPacket
+}
+
+// Next returns the entry that follows e in the list.
+//
+//go:nosplit
+func (e *rawPacketEntry) Next() *rawPacket {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *rawPacketEntry) Prev() *rawPacket {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+//
+//go:nosplit
+func (e *rawPacketEntry) SetNext(elem *rawPacket) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *rawPacketEntry) SetPrev(elem *rawPacket) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/raw/raw_state_autogen.go b/pkg/tcpip/transport/raw/raw_state_autogen.go
new file mode 100644
index 000000000..0de2d2264
--- /dev/null
+++ b/pkg/tcpip/transport/raw/raw_state_autogen.go
@@ -0,0 +1,158 @@
+// automatically generated by stateify.
+
+package raw
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+func (p *rawPacket) StateTypeName() string {
+ return "pkg/tcpip/transport/raw.rawPacket"
+}
+
+func (p *rawPacket) StateFields() []string {
+ return []string{
+ "rawPacketEntry",
+ "data",
+ "receivedAt",
+ "senderAddr",
+ }
+}
+
+func (p *rawPacket) beforeSave() {}
+
+// +checklocksignore
+func (p *rawPacket) StateSave(stateSinkObject state.Sink) {
+ p.beforeSave()
+ var dataValue buffer.VectorisedView
+ dataValue = p.saveData()
+ stateSinkObject.SaveValue(1, dataValue)
+ var receivedAtValue int64
+ receivedAtValue = p.saveReceivedAt()
+ stateSinkObject.SaveValue(2, receivedAtValue)
+ stateSinkObject.Save(0, &p.rawPacketEntry)
+ stateSinkObject.Save(3, &p.senderAddr)
+}
+
+func (p *rawPacket) afterLoad() {}
+
+// +checklocksignore
+func (p *rawPacket) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &p.rawPacketEntry)
+ stateSourceObject.Load(3, &p.senderAddr)
+ stateSourceObject.LoadValue(1, new(buffer.VectorisedView), func(y interface{}) { p.loadData(y.(buffer.VectorisedView)) })
+ stateSourceObject.LoadValue(2, new(int64), func(y interface{}) { p.loadReceivedAt(y.(int64)) })
+}
+
+func (e *endpoint) StateTypeName() string {
+ return "pkg/tcpip/transport/raw.endpoint"
+}
+
+func (e *endpoint) StateFields() []string {
+ return []string{
+ "DefaultSocketOptionsHandler",
+ "transProto",
+ "waiterQueue",
+ "associated",
+ "net",
+ "ops",
+ "rcvList",
+ "rcvBufSize",
+ "rcvClosed",
+ "frozen",
+ }
+}
+
+// +checklocksignore
+func (e *endpoint) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.DefaultSocketOptionsHandler)
+ stateSinkObject.Save(1, &e.transProto)
+ stateSinkObject.Save(2, &e.waiterQueue)
+ stateSinkObject.Save(3, &e.associated)
+ stateSinkObject.Save(4, &e.net)
+ stateSinkObject.Save(5, &e.ops)
+ stateSinkObject.Save(6, &e.rcvList)
+ stateSinkObject.Save(7, &e.rcvBufSize)
+ stateSinkObject.Save(8, &e.rcvClosed)
+ stateSinkObject.Save(9, &e.frozen)
+}
+
+// +checklocksignore
+func (e *endpoint) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.DefaultSocketOptionsHandler)
+ stateSourceObject.Load(1, &e.transProto)
+ stateSourceObject.Load(2, &e.waiterQueue)
+ stateSourceObject.Load(3, &e.associated)
+ stateSourceObject.Load(4, &e.net)
+ stateSourceObject.Load(5, &e.ops)
+ stateSourceObject.Load(6, &e.rcvList)
+ stateSourceObject.Load(7, &e.rcvBufSize)
+ stateSourceObject.Load(8, &e.rcvClosed)
+ stateSourceObject.Load(9, &e.frozen)
+ stateSourceObject.AfterLoad(e.afterLoad)
+}
+
+func (l *rawPacketList) StateTypeName() string {
+ return "pkg/tcpip/transport/raw.rawPacketList"
+}
+
+func (l *rawPacketList) StateFields() []string {
+ return []string{
+ "head",
+ "tail",
+ }
+}
+
+func (l *rawPacketList) beforeSave() {}
+
+// +checklocksignore
+func (l *rawPacketList) StateSave(stateSinkObject state.Sink) {
+ l.beforeSave()
+ stateSinkObject.Save(0, &l.head)
+ stateSinkObject.Save(1, &l.tail)
+}
+
+func (l *rawPacketList) afterLoad() {}
+
+// +checklocksignore
+func (l *rawPacketList) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &l.head)
+ stateSourceObject.Load(1, &l.tail)
+}
+
+func (e *rawPacketEntry) StateTypeName() string {
+ return "pkg/tcpip/transport/raw.rawPacketEntry"
+}
+
+func (e *rawPacketEntry) StateFields() []string {
+ return []string{
+ "next",
+ "prev",
+ }
+}
+
+func (e *rawPacketEntry) beforeSave() {}
+
+// +checklocksignore
+func (e *rawPacketEntry) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.next)
+ stateSinkObject.Save(1, &e.prev)
+}
+
+func (e *rawPacketEntry) afterLoad() {}
+
+// +checklocksignore
+func (e *rawPacketEntry) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.next)
+ stateSourceObject.Load(1, &e.prev)
+}
+
+func init() {
+ state.Register((*rawPacket)(nil))
+ state.Register((*endpoint)(nil))
+ state.Register((*rawPacketList)(nil))
+ state.Register((*rawPacketEntry)(nil))
+}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
deleted file mode 100644
index 5148fe157..000000000
--- a/pkg/tcpip/transport/tcp/BUILD
+++ /dev/null
@@ -1,141 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test", "more_shards")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-package(licenses = ["notice"])
-
-go_template_instance(
- name = "tcp_segment_list",
- out = "tcp_segment_list.go",
- package = "tcp",
- prefix = "segment",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*segment",
- "Linker": "*segment",
- },
-)
-
-go_template_instance(
- name = "tcp_endpoint_list",
- out = "tcp_endpoint_list.go",
- package = "tcp",
- prefix = "endpoint",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*endpoint",
- "Linker": "*endpoint",
- },
-)
-
-go_library(
- name = "tcp",
- srcs = [
- "accept.go",
- "connect.go",
- "connect_unsafe.go",
- "cubic.go",
- "dispatcher.go",
- "endpoint.go",
- "endpoint_state.go",
- "forwarder.go",
- "protocol.go",
- "rack.go",
- "rcv.go",
- "reno.go",
- "reno_recovery.go",
- "sack.go",
- "sack_recovery.go",
- "sack_scoreboard.go",
- "segment.go",
- "segment_heap.go",
- "segment_queue.go",
- "segment_state.go",
- "segment_unsafe.go",
- "snd.go",
- "tcp_endpoint_list.go",
- "tcp_segment_list.go",
- "timer.go",
- ],
- imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/log",
- "//pkg/rand",
- "//pkg/sleep",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/hash/jenkins",
- "//pkg/tcpip/header",
- "//pkg/tcpip/header/parse",
- "//pkg/tcpip/internal/tcp",
- "//pkg/tcpip/ports",
- "//pkg/tcpip/seqnum",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/raw",
- "//pkg/waiter",
- "@com_github_google_btree//:go_default_library",
- ],
-)
-
-go_test(
- name = "tcp_x_test",
- size = "medium",
- srcs = [
- "dual_stack_test.go",
- "sack_scoreboard_test.go",
- "tcp_noracedetector_test.go",
- "tcp_rack_test.go",
- "tcp_sack_test.go",
- "tcp_test.go",
- "tcp_timestamp_test.go",
- ],
- shard_count = more_shards,
- deps = [
- ":tcp",
- "//pkg/rand",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/faketime",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/loopback",
- "//pkg/tcpip/link/sniffer",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/seqnum",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/tcp/testing/context",
- "//pkg/test/testutil",
- "//pkg/waiter",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
-
-go_test(
- name = "rcv_test",
- size = "small",
- srcs = ["rcv_test.go"],
- deps = [
- "//pkg/tcpip/header",
- "//pkg/tcpip/seqnum",
- ],
-)
-
-go_test(
- name = "tcp_test",
- size = "small",
- srcs = [
- "segment_test.go",
- "timer_test.go",
- ],
- library = ":tcp",
- deps = [
- "//pkg/sleep",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/faketime",
- "//pkg/tcpip/stack",
- "@com_github_google_go_cmp//cmp:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
deleted file mode 100644
index 5342aacfd..000000000
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ /dev/null
@@ -1,650 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp_test
-
-import (
- "strings"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-func TestV4MappedConnectOnV6Only(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(true)
-
- // Start connection attempt, it must fail.
- err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort})
- if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" {
- t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
- }
-}
-
-func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) {
- // Start connection attempt.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.WritableEvents)
- defer c.WQ.EventUnregister(&we)
-
- err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort})
- if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
- t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
- }
-
- // Receive SYN packet.
- b := c.GetPacket()
- synCheckers := append(checkers, checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ))
- checker.IPv4(t, b, synCheckers...)
-
- tcp := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcp.SequenceNumber())
-
- iss := seqnum.Value(789)
- c.SendPacket(nil, &context.Headers{
- SrcPort: tcp.DestinationPort(),
- DstPort: tcp.SourcePort(),
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Receive ACK packet.
- ackCheckers := append(checkers, checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+1),
- ))
- checker.IPv4(t, c.GetPacket(), ackCheckers...)
-
- // Wait for connection to be established.
- select {
- case <-ch:
- if err := c.EP.LastError(); err != nil {
- t.Fatalf("Unexpected error when connecting: %v", err)
- }
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for connection")
- }
-}
-
-func TestV4MappedConnect(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Test the connection request.
- testV4Connect(t, c)
-}
-
-func TestV4ConnectWhenBoundToWildcard(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test the connection request.
- testV4Connect(t, c)
-}
-
-func TestV4ConnectWhenBoundToV4MappedWildcard(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind to v4 mapped wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test the connection request.
- testV4Connect(t, c)
-}
-
-func TestV4ConnectWhenBoundToV4Mapped(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind to v4 mapped address.
- if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test the connection request.
- testV4Connect(t, c)
-}
-
-func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) {
- // Start connection attempt to IPv6 address.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.WritableEvents)
- defer c.WQ.EventUnregister(&we)
-
- err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort})
- if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
- t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d)
- }
-
- // Receive SYN packet.
- b := c.GetV6Packet()
- synCheckers := append(checkers, checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ))
- checker.IPv6(t, b, synCheckers...)
-
- tcp := header.TCP(header.IPv6(b).Payload())
- c.IRS = seqnum.Value(tcp.SequenceNumber())
-
- iss := seqnum.Value(789)
- c.SendV6Packet(nil, &context.Headers{
- SrcPort: tcp.DestinationPort(),
- DstPort: tcp.SourcePort(),
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Receive ACK packet.
- ackCheckers := append(checkers, checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+1),
- ))
- checker.IPv6(t, c.GetV6Packet(), ackCheckers...)
-
- // Wait for connection to be established.
- select {
- case <-ch:
- if err := c.EP.LastError(); err != nil {
- t.Fatalf("Unexpected error when connecting: %v", err)
- }
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for connection")
- }
-}
-
-func TestV6Connect(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Test the connection request.
- testV6Connect(t, c)
-}
-
-func TestV6ConnectV6Only(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(true)
-
- // Test the connection request.
- testV6Connect(t, c)
-}
-
-func TestV6ConnectWhenBoundToWildcard(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test the connection request.
- testV6Connect(t, c)
-}
-
-func TestStackV6OnlyConnectWhenBoundToWildcard(t *testing.T) {
- c := context.NewWithOpts(t, context.Options{
- EnableV6: true,
- MTU: defaultMTU,
- })
- defer c.Cleanup()
-
- // Create a v6 endpoint but don't set the v6-only TCP option.
- c.CreateV6Endpoint(false)
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test the connection request.
- testV6Connect(t, c)
-}
-
-func TestV6ConnectWhenBoundToLocalAddress(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind to local address.
- if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV6Addr}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test the connection request.
- testV6Connect(t, c)
-}
-
-func TestV4RefuseOnV6Only(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(true)
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Start listening.
- if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-
- // Send a SYN request.
- irs := seqnum.Value(789)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
-
- // Receive the RST reply.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.TCPAckNum(uint32(irs)+1),
- ),
- )
-}
-
-func TestV6RefuseOnBoundToV4Mapped(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind and listen.
- if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-
- // Send a SYN request.
- irs := seqnum.Value(789)
- c.SendV6Packet(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
-
- // Receive the RST reply.
- checker.IPv6(t, c.GetV6Packet(),
- checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.TCPAckNum(uint32(irs)+1),
- ),
- )
-}
-
-func testV4Accept(t *testing.T, c *context.Context) {
- c.SetGSOEnabled(true)
- defer c.SetGSOEnabled(false)
-
- // Start listening.
- if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-
- // Send a SYN request.
- irs := seqnum.Value(789)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcp := header.TCP(header.IPv4(b).Payload())
- iss := seqnum.Value(tcp.SequenceNumber())
- checker.IPv4(t, b,
- checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.TCPAckNum(uint32(irs)+1),
- ),
- )
-
- // Send ACK.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- nep, _, err := c.EP.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- nep, _, err = c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %v", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Check the peer address.
- addr, err := nep.GetRemoteAddress()
- if err != nil {
- t.Fatalf("GetRemoteAddress failed failed: %v", err)
- }
-
- if addr.Addr != context.TestAddr {
- t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestAddr)
- }
-
- var r strings.Reader
- data := "Don't panic"
- r.Reset(data)
- nep.Write(&r, tcpip.WriteOptions{})
- b = c.GetPacket()
- tcp = header.IPv4(b).Payload()
- if string(tcp.Payload()) != data {
- t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data)
- }
-}
-
-func TestV4AcceptOnV6(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test acceptance.
- testV4Accept(t, c)
-}
-
-func TestV4AcceptOnBoundToV4MappedWildcard(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind to v4 mapped wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test acceptance.
- testV4Accept(t, c)
-}
-
-func TestV4AcceptOnBoundToV4Mapped(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind and listen.
- if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test acceptance.
- testV4Accept(t, c)
-}
-
-func TestV6AcceptOnV6(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind and listen.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-
- // Send a SYN request.
- irs := seqnum.Value(789)
- c.SendV6Packet(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetV6Packet()
- tcp := header.TCP(header.IPv6(b).Payload())
- iss := seqnum.Value(tcp.SequenceNumber())
- checker.IPv6(t, b,
- checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.TCPAckNum(uint32(irs)+1),
- ),
- )
-
- // Send ACK.
- c.SendV6Packet(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
- var addr tcpip.FullAddress
- _, _, err := c.EP.Accept(&addr)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- _, _, err = c.EP.Accept(&addr)
- if err != nil {
- t.Fatalf("Accept failed: %v", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- if addr.Addr != context.TestV6Addr {
- t.Errorf("Unexpected remote address: got %s, want %s", addr.Addr, context.TestV6Addr)
- }
-}
-
-func TestV4AcceptOnV4(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test acceptance.
- testV4Accept(t, c)
-}
-
-func testV4ListenClose(t *testing.T, c *context.Context) {
- opt := tcpip.TCPAlwaysUseSynCookies(true)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
-
- const n = 32
-
- // Start listening.
- if err := c.EP.Listen(n); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-
- irs := seqnum.Value(789)
- for i := uint16(0); i < n; i++ {
- // Send a SYN request.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort + i,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
- }
-
- // Each of these ACKs will cause a syn-cookie based connection to be
- // accepted and delivered to the listening endpoint.
- for i := 0; i < n; i++ {
- b := c.GetPacket()
- tcp := header.TCP(header.IPv4(b).Payload())
- iss := seqnum.Value(tcp.SequenceNumber())
- // Send ACK.
- c.SendPacket(nil, &context.Headers{
- SrcPort: tcp.DestinationPort(),
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
- }
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
- nep, _, err := c.EP.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- nep, _, err = c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %v", err)
- }
-
- case <-time.After(10 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
- nep.Close()
- c.EP.Close()
-}
-
-func TestV4ListenCloseOnV4(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test acceptance.
- testV4ListenClose(t, c)
-}
diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go
deleted file mode 100644
index 8a026ec46..000000000
--- a/pkg/tcpip/transport/tcp/rcv_test.go
+++ /dev/null
@@ -1,74 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package rcv_test
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
-)
-
-func TestAcceptable(t *testing.T) {
- for _, tt := range []struct {
- segSeq seqnum.Value
- segLen seqnum.Size
- rcvNxt, rcvAcc seqnum.Value
- want bool
- }{
- // The segment is smaller than the window.
- {105, 2, 100, 104, false},
- {105, 2, 101, 105, true},
- {105, 2, 102, 106, true},
- {105, 2, 103, 107, true},
- {105, 2, 104, 108, true},
- {105, 2, 105, 109, true},
- {105, 2, 106, 110, true},
- {105, 2, 107, 111, false},
-
- // The segment is larger than the window.
- {105, 4, 103, 105, true},
- {105, 4, 104, 106, true},
- {105, 4, 105, 107, true},
- {105, 4, 106, 108, true},
- {105, 4, 107, 109, true},
- {105, 4, 108, 110, true},
- {105, 4, 109, 111, false},
- {105, 4, 110, 112, false},
-
- // The segment has no width.
- {105, 0, 100, 102, false},
- {105, 0, 101, 103, false},
- {105, 0, 102, 104, false},
- {105, 0, 103, 105, true},
- {105, 0, 104, 106, true},
- {105, 0, 105, 107, true},
- {105, 0, 106, 108, false},
- {105, 0, 107, 109, false},
-
- // The receive window has no width.
- {105, 2, 103, 103, false},
- {105, 2, 104, 104, false},
- {105, 2, 105, 105, false},
- {105, 2, 106, 106, false},
- {105, 2, 107, 107, false},
- {105, 2, 108, 108, false},
- {105, 2, 109, 109, false},
- } {
- if got := header.Acceptable(tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc); got != tt.want {
- t.Errorf("header.Acceptable(%d, %d, %d, %d) = %t, want %t", tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc, got, tt.want)
- }
- }
-}
diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard_test.go b/pkg/tcpip/transport/tcp/sack_scoreboard_test.go
deleted file mode 100644
index b4e5ba0df..000000000
--- a/pkg/tcpip/transport/tcp/sack_scoreboard_test.go
+++ /dev/null
@@ -1,249 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp_test
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
-)
-
-const smss = 1500
-
-func initScoreboard(blocks []header.SACKBlock, iss seqnum.Value) *tcp.SACKScoreboard {
- s := tcp.NewSACKScoreboard(smss, iss)
- for _, blk := range blocks {
- s.Insert(blk)
- }
- return s
-}
-
-func TestSACKScoreboardIsSACKED(t *testing.T) {
- type blockTest struct {
- block header.SACKBlock
- sacked bool
- }
- testCases := []struct {
- comment string
- scoreboardBlocks []header.SACKBlock
- blockTests []blockTest
- iss seqnum.Value
- }{
- {
- "Test holes and unsacked SACK blocks in SACKed ranges and insertion of overlapping SACK blocks",
- []header.SACKBlock{{10, 20}, {10, 30}, {30, 40}, {41, 50}, {5, 10}, {1, 50}, {111, 120}, {101, 110}, {52, 120}},
- []blockTest{
- {header.SACKBlock{15, 21}, true},
- {header.SACKBlock{200, 201}, false},
- {header.SACKBlock{50, 51}, false},
- {header.SACKBlock{53, 120}, true},
- },
- 0,
- },
- {
- "Test disjoint SACKBlocks",
- []header.SACKBlock{{2288624809, 2288810057}, {2288811477, 2288838565}},
- []blockTest{
- {header.SACKBlock{2288624809, 2288810057}, true},
- {header.SACKBlock{2288811477, 2288838565}, true},
- {header.SACKBlock{2288810057, 2288811477}, false},
- },
- 2288624809,
- },
- {
- "Test sequence number wrap around",
- []header.SACKBlock{{4294254144, 225652}, {5340409, 5350509}},
- []blockTest{
- {header.SACKBlock{4294254144, 4294254145}, true},
- {header.SACKBlock{4294254143, 4294254144}, false},
- {header.SACKBlock{4294254144, 1}, true},
- {header.SACKBlock{225652, 5350509}, false},
- {header.SACKBlock{5340409, 5350509}, true},
- {header.SACKBlock{5350509, 5350609}, false},
- },
- 4294254144,
- },
- {
- "Test disjoint SACKBlocks out of order",
- []header.SACKBlock{{827450276, 827454536}, {827426028, 827428868}},
- []blockTest{
- {header.SACKBlock{827426028, 827428867}, true},
- {header.SACKBlock{827450168, 827450275}, false},
- },
- 827426000,
- },
- }
- for _, tc := range testCases {
- sb := initScoreboard(tc.scoreboardBlocks, tc.iss)
- for _, blkTest := range tc.blockTests {
- if want, got := blkTest.sacked, sb.IsSACKED(blkTest.block); got != want {
- t.Errorf("%s: s.IsSACKED(%v) = %v, want %v", tc.comment, blkTest.block, got, want)
- }
- }
- }
-}
-
-func TestSACKScoreboardIsRangeLost(t *testing.T) {
- s := tcp.NewSACKScoreboard(10, 0)
- s.Insert(header.SACKBlock{1, 25})
- s.Insert(header.SACKBlock{25, 50})
- s.Insert(header.SACKBlock{51, 100})
- s.Insert(header.SACKBlock{111, 120})
- s.Insert(header.SACKBlock{101, 110})
- s.Insert(header.SACKBlock{121, 141})
- s.Insert(header.SACKBlock{145, 146})
- s.Insert(header.SACKBlock{147, 148})
- s.Insert(header.SACKBlock{149, 150})
- s.Insert(header.SACKBlock{153, 154})
- s.Insert(header.SACKBlock{155, 156})
- testCases := []struct {
- block header.SACKBlock
- lost bool
- }{
- // Block not covered by SACK block and has more than
- // nDupAckThreshold discontiguous SACK blocks after it as well
- // as (nDupAckThreshold -1) * 10 (smss) bytes that have been
- // SACKED above the sequence number covered by this block.
- {block: header.SACKBlock{0, 1}, lost: true},
-
- // These blocks have all been SACKed and should not be
- // considered lost.
- {block: header.SACKBlock{1, 2}, lost: false},
- {block: header.SACKBlock{25, 26}, lost: false},
- {block: header.SACKBlock{1, 45}, lost: false},
-
- // Same as the first case above.
- {block: header.SACKBlock{50, 51}, lost: true},
-
- // This block has been SACKed and should not be considered lost.
- {block: header.SACKBlock{119, 120}, lost: false},
-
- // This one should return true because there are >
- // (nDupAckThreshold - 1) * 10 (smss) bytes that have been
- // sacked above this sequence number.
- {block: header.SACKBlock{120, 121}, lost: true},
-
- // This block has been SACKed and should not be considered lost.
- {block: header.SACKBlock{125, 126}, lost: false},
-
- // This block has not been SACKed and there are nDupAckThreshold
- // number of SACKed blocks after it.
- {block: header.SACKBlock{141, 145}, lost: true},
-
- // This block has not been SACKed and there are less than
- // nDupAckThreshold SACKed sequences after it.
- {block: header.SACKBlock{151, 152}, lost: false},
- }
- for _, tc := range testCases {
- if want, got := tc.lost, s.IsRangeLost(tc.block); got != want {
- t.Errorf("s.IsRangeLost(%v) = %v, want %v", tc.block, got, want)
- }
- }
-}
-
-func TestSACKScoreboardIsLost(t *testing.T) {
- s := tcp.NewSACKScoreboard(10, 0)
- s.Insert(header.SACKBlock{1, 25})
- s.Insert(header.SACKBlock{25, 50})
- s.Insert(header.SACKBlock{51, 100})
- s.Insert(header.SACKBlock{111, 120})
- s.Insert(header.SACKBlock{101, 110})
- s.Insert(header.SACKBlock{121, 141})
- s.Insert(header.SACKBlock{121, 141})
- s.Insert(header.SACKBlock{145, 146})
- s.Insert(header.SACKBlock{147, 148})
- s.Insert(header.SACKBlock{149, 150})
- s.Insert(header.SACKBlock{153, 154})
- s.Insert(header.SACKBlock{155, 156})
- testCases := []struct {
- seq seqnum.Value
- lost bool
- }{
- // Sequence number not covered by SACK block and has more than
- // nDupAckThreshold discontiguous SACK blocks after it as well
- // as (nDupAckThreshold -1) * 10 (smss) bytes that have been
- // SACKED above the sequence number.
- {seq: 0, lost: true},
-
- // These sequence numbers have all been SACKed and should not be
- // considered lost.
- {seq: 1, lost: false},
- {seq: 25, lost: false},
- {seq: 45, lost: false},
-
- // Same as first case above.
- {seq: 50, lost: true},
-
- // This block has been SACKed and should not be considered lost.
- {seq: 119, lost: false},
-
- // This one should return true because there are >
- // (nDupAckThreshold - 1) * 10 (smss) bytes that have been
- // sacked above this sequence number.
- {seq: 120, lost: true},
-
- // This sequence number has been SACKed and should not be
- // considered lost.
- {seq: 125, lost: false},
-
- // This sequence number has not been SACKed and there are
- // nDupAckThreshold number of SACKed blocks after it.
- {seq: 141, lost: true},
-
- // This sequence number has not been SACKed and there are less
- // than nDupAckThreshold SACKed sequences after it.
- {seq: 151, lost: false},
- }
- for _, tc := range testCases {
- if want, got := tc.lost, s.IsLost(tc.seq); got != want {
- t.Errorf("s.IsLost(%v) = %v, want %v", tc.seq, got, want)
- }
- }
-}
-
-func TestSACKScoreboardDelete(t *testing.T) {
- blocks := []header.SACKBlock{{4294254144, 225652}, {5340409, 5350509}}
- s := initScoreboard(blocks, 4294254143)
- s.Delete(5340408)
- if s.Empty() {
- t.Fatalf("s.Empty() = true, want false")
- }
- if got, want := s.Sacked(), blocks[1].Start.Size(blocks[1].End); got != want {
- t.Fatalf("incorrect sacked bytes in scoreboard got: %v, want: %v", got, want)
- }
- s.Delete(5340410)
- if s.Empty() {
- t.Fatal("s.Empty() = true, want false")
- }
- newSB := header.SACKBlock{5340410, 5350509}
- if !s.IsSACKED(newSB) {
- t.Fatalf("s.IsSACKED(%v) = false, want true, scoreboard: %v", newSB, s)
- }
- s.Delete(5350509)
- lastOctet := header.SACKBlock{5350508, 5350509}
- if s.IsSACKED(lastOctet) {
- t.Fatalf("s.IsSACKED(%v) = false, want true", lastOctet)
- }
-
- s.Delete(5350510)
- if !s.Empty() {
- t.Fatal("s.Empty() = false, want true")
- }
- if got, want := s.Sacked(), seqnum.Size(0); got != want {
- t.Fatalf("incorrect sacked bytes in scoreboard got: %v, want: %v", got, want)
- }
-}
diff --git a/pkg/tcpip/transport/tcp/segment_test.go b/pkg/tcpip/transport/tcp/segment_test.go
deleted file mode 100644
index 2d5fdda19..000000000
--- a/pkg/tcpip/transport/tcp/segment_test.go
+++ /dev/null
@@ -1,69 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp
-
-import (
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-type segmentSizeWants struct {
- DataSize int
- SegMemSize int
-}
-
-func checkSegmentSize(t *testing.T, name string, seg *segment, want segmentSizeWants) {
- t.Helper()
- got := segmentSizeWants{
- DataSize: seg.data.Size(),
- SegMemSize: seg.segMemSize(),
- }
- if diff := cmp.Diff(want, got); diff != "" {
- t.Errorf("%s differs (-want +got):\n%s", name, diff)
- }
-}
-
-func TestSegmentMerge(t *testing.T) {
- var clock faketime.NullClock
- id := stack.TransportEndpointID{}
- seg1 := newOutgoingSegment(id, &clock, buffer.NewView(10))
- defer seg1.decRef()
- seg2 := newOutgoingSegment(id, &clock, buffer.NewView(20))
- defer seg2.decRef()
-
- checkSegmentSize(t, "seg1", seg1, segmentSizeWants{
- DataSize: 10,
- SegMemSize: SegSize + 10,
- })
- checkSegmentSize(t, "seg2", seg2, segmentSizeWants{
- DataSize: 20,
- SegMemSize: SegSize + 20,
- })
-
- seg1.merge(seg2)
-
- checkSegmentSize(t, "seg1", seg1, segmentSizeWants{
- DataSize: 30,
- SegMemSize: SegSize + 30,
- })
- checkSegmentSize(t, "seg2", seg2, segmentSizeWants{
- DataSize: 0,
- SegMemSize: SegSize,
- })
-}
diff --git a/pkg/tcpip/transport/tcp/tcp_endpoint_list.go b/pkg/tcpip/transport/tcp/tcp_endpoint_list.go
new file mode 100644
index 000000000..a7dc5df81
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_endpoint_list.go
@@ -0,0 +1,221 @@
+package tcp
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type endpointElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (endpointElementMapper) linkerFor(elem *endpoint) *endpoint { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type endpointList struct {
+ head *endpoint
+ tail *endpoint
+}
+
+// Reset resets list l to the empty state.
+func (l *endpointList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+//
+//go:nosplit
+func (l *endpointList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+//
+//go:nosplit
+func (l *endpointList) Front() *endpoint {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+//
+//go:nosplit
+func (l *endpointList) Back() *endpoint {
+ return l.tail
+}
+
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+//
+//go:nosplit
+func (l *endpointList) Len() (count int) {
+ for e := l.Front(); e != nil; e = (endpointElementMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
+// PushFront inserts the element e at the front of list l.
+//
+//go:nosplit
+func (l *endpointList) PushFront(e *endpoint) {
+ linker := endpointElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
+ if l.head != nil {
+ endpointElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+//
+//go:nosplit
+func (l *endpointList) PushBack(e *endpoint) {
+ linker := endpointElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
+ if l.tail != nil {
+ endpointElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+//
+//go:nosplit
+func (l *endpointList) PushBackList(m *endpointList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ endpointElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ endpointElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+//
+//go:nosplit
+func (l *endpointList) InsertAfter(b, e *endpoint) {
+ bLinker := endpointElementMapper{}.linkerFor(b)
+ eLinker := endpointElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
+
+ if a != nil {
+ endpointElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+//
+//go:nosplit
+func (l *endpointList) InsertBefore(a, e *endpoint) {
+ aLinker := endpointElementMapper{}.linkerFor(a)
+ eLinker := endpointElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
+
+ if b != nil {
+ endpointElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+//
+//go:nosplit
+func (l *endpointList) Remove(e *endpoint) {
+ linker := endpointElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
+
+ if prev != nil {
+ endpointElementMapper{}.linkerFor(prev).SetNext(next)
+ } else if l.head == e {
+ l.head = next
+ }
+
+ if next != nil {
+ endpointElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else if l.tail == e {
+ l.tail = prev
+ }
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type endpointEntry struct {
+ next *endpoint
+ prev *endpoint
+}
+
+// Next returns the entry that follows e in the list.
+//
+//go:nosplit
+func (e *endpointEntry) Next() *endpoint {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *endpointEntry) Prev() *endpoint {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+//
+//go:nosplit
+func (e *endpointEntry) SetNext(elem *endpoint) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *endpointEntry) SetPrev(elem *endpoint) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
deleted file mode 100644
index 84fb1c416..000000000
--- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
+++ /dev/null
@@ -1,559 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-// These tests are flaky when run under the go race detector due to some
-// iterations taking long enough that the retransmit timer can kick in causing
-// the congestion window measurements to fail due to extra packets etc.
-//
-//go:build !race
-// +build !race
-
-package tcp_test
-
-import (
- "bytes"
- "fmt"
- "math"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
- "gvisor.dev/gvisor/pkg/test/testutil"
-)
-
-func TestFastRecovery(t *testing.T) {
- maxPayload := 32
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- const iterations = 3
- data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1)))
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write all the data in one shot. Packets will only be written at the
- // MTU size though.
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Do slow start for a few iterations.
- expected := tcp.InitialCwnd
- bytesRead := 0
- for i := 0; i < iterations; i++ {
- expected = tcp.InitialCwnd << uint(i)
- if i > 0 {
- // Acknowledge all the data received so far if not on
- // first iteration.
- c.SendAck(790, bytesRead)
- }
-
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
- }
-
- // Send 3 duplicate acks. This should force an immediate retransmit of
- // the pending packet and put the sender into fast recovery.
- rtxOffset := bytesRead - maxPayload*expected
- for i := 0; i < 3; i++ {
- c.SendAck(790, rtxOffset)
- }
-
- // Receive the retransmitted packet.
- c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
-
- // Wait before checking metrics.
- metricPollFn := func() error {
- if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want)
- }
- if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want)
- }
-
- if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.FastRecovery.Value = %d, want = %d", got, want)
- }
- return nil
- }
-
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-
- // Now send 7 mode duplicate acks. Each of these should cause a window
- // inflation by 1 and cause the sender to send an extra packet.
- for i := 0; i < 7; i++ {
- c.SendAck(790, rtxOffset)
- }
-
- recover := bytesRead
-
- // Ensure no new packets arrive.
- c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.",
- 50*time.Millisecond)
-
- // Acknowledge half of the pending data.
- rtxOffset = bytesRead - expected*maxPayload/2
- c.SendAck(790, rtxOffset)
-
- // Receive the retransmit due to partial ack.
- c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
-
- // Wait before checking metrics.
- metricPollFn = func() error {
- if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want {
- return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want)
- }
- if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want {
- return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want)
- }
- return nil
- }
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-
- // Receive the 10 extra packets that should have been released due to
- // the congestion window inflation in recovery.
- for i := 0; i < 10; i++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // A partial ACK during recovery should reduce congestion window by the
- // number acked. Since we had "expected" packets outstanding before sending
- // partial ack and we acked expected/2 , the cwnd and outstanding should
- // be expected/2 + 10 (7 dupAcks + 3 for the original 3 dupacks that triggered
- // fast recovery). Which means the sender should not send any more packets
- // till we ack this one.
- c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.",
- 50*time.Millisecond)
-
- // Acknowledge all pending data to recover point.
- c.SendAck(790, recover)
-
- // At this point, the cwnd should reset to expected/2 and there are 10
- // packets outstanding.
- //
- // NOTE: Technically netstack is incorrect in that we adjust the cwnd on
- // the same segment that takes us out of recovery. But because of that
- // the actual cwnd at exit of recovery will be expected/2 + 1 as we
- // acked a cwnd worth of packets which will increase the cwnd further by
- // 1 in congestion avoidance.
- //
- // Now in the first iteration since there are 10 packets outstanding.
- // We would expect to get expected/2 +1 - 10 packets. But subsequent
- // iterations will send us expected/2 + 1 + 1 (per iteration).
- expected = expected/2 + 1 - 10
- for i := 0; i < iterations; i++ {
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd.", expected), 50*time.Millisecond)
-
- // Acknowledge all the data received so far.
- c.SendAck(790, bytesRead)
-
- // In cogestion avoidance, the packets trains increase by 1 in
- // each iteration.
- if i == 0 {
- // After the first iteration we expect to get the full
- // congestion window worth of packets in every
- // iteration.
- expected += 10
- }
- expected++
- }
-}
-
-func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
- maxPayload := 32
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- const iterations = 3
- data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1)))
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write all the data in one shot. Packets will only be written at the
- // MTU size though.
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- expected := tcp.InitialCwnd
- bytesRead := 0
- for i := 0; i < iterations; i++ {
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
-
- // Acknowledge all the data received so far.
- c.SendAck(790, bytesRead)
-
- // Double the number of expected packets for the next iteration.
- expected *= 2
- }
-}
-
-func TestCongestionAvoidance(t *testing.T) {
- maxPayload := 32
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- const iterations = 3
- data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1)))
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write all the data in one shot. Packets will only be written at the
- // MTU size though.
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Do slow start for a few iterations.
- expected := tcp.InitialCwnd
- bytesRead := 0
- for i := 0; i < iterations; i++ {
- expected = tcp.InitialCwnd << uint(i)
- if i > 0 {
- // Acknowledge all the data received so far if not on
- // first iteration.
- c.SendAck(790, bytesRead)
- }
-
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd (slow start phase).", 50*time.Millisecond)
- }
-
- // Don't acknowledge the first packet of the last packet train. Let's
- // wait for them to time out, which will trigger a restart of slow
- // start, and initialization of ssthresh to cwnd/2.
- rtxOffset := bytesRead - maxPayload*expected
- c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
-
- // Acknowledge all the data received so far.
- c.SendAck(790, bytesRead)
-
- // This part is tricky: when the timeout happened, we had "expected"
- // packets pending, cwnd reset to 1, and ssthresh set to expected/2.
- // By acknowledging "expected" packets, the slow-start part will
- // increase cwnd to expected/2 (which "consumes" expected/2-1 of the
- // acknowledgements), then the congestion avoidance part will consume
- // an extra expected/2 acks to take cwnd to expected/2 + 1. One ack
- // remains in the "ack count" (which will cause cwnd to be incremented
- // once it reaches cwnd acks).
- //
- // So we're straight into congestion avoidance with cwnd set to
- // expected/2 + 1.
- //
- // Check that packets trains of cwnd packets are sent, and that cwnd is
- // incremented by 1 after we acknowledge each packet.
- expected = expected/2 + 1
- for i := 0; i < iterations; i++ {
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd (congestion avoidance phase).", 50*time.Millisecond)
-
- // Acknowledge all the data received so far.
- c.SendAck(790, bytesRead)
-
- // In cogestion avoidance, the packets trains increase by 1 in
- // each iteration.
- expected++
- }
-}
-
-// cubicCwnd returns an estimate of a cubic window given the
-// originalCwnd, wMax, last congestion event time and sRTT.
-func cubicCwnd(origCwnd int, wMax int, congEventTime time.Time, sRTT time.Duration) int {
- cwnd := float64(origCwnd)
- // We wait 50ms between each iteration so sRTT as computed by cubic
- // should be close to 50ms.
- elapsed := (time.Since(congEventTime) + sRTT).Seconds()
- k := math.Cbrt(float64(wMax) * 0.3 / 0.7)
- wtRTT := 0.4*math.Pow(elapsed-k, 3) + float64(wMax)
- cwnd += (wtRTT - cwnd) / cwnd
- return int(cwnd)
-}
-
-func TestCubicCongestionAvoidance(t *testing.T) {
- maxPayload := 32
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- enableCUBIC(t, c)
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- const iterations = 3
- data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1)))
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write all the data in one shot. Packets will only be written at the
- // MTU size though.
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Do slow start for a few iterations.
- expected := tcp.InitialCwnd
- bytesRead := 0
- for i := 0; i < iterations; i++ {
- expected = tcp.InitialCwnd << uint(i)
- if i > 0 {
- // Acknowledge all the data received so far if not on
- // first iteration.
- c.SendAck(790, bytesRead)
- }
-
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd (during slow-start phase).", 50*time.Millisecond)
- }
-
- // Don't acknowledge the first packet of the last packet train. Let's
- // wait for them to time out, which will trigger a restart of slow
- // start, and initialization of ssthresh to cwnd * 0.7.
- rtxOffset := bytesRead - maxPayload*expected
- c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
-
- // Acknowledge all pending data.
- c.SendAck(790, bytesRead)
-
- // Store away the time we sent the ACK and assuming a 200ms RTO
- // we estimate that the sender will have an RTO 200ms from now
- // and go back into slow start.
- packetDropTime := time.Now().Add(200 * time.Millisecond)
-
- // This part is tricky: when the timeout happened, we had "expected"
- // packets pending, cwnd reset to 1, and ssthresh set to expected * 0.7.
- // By acknowledging "expected" packets, the slow-start part will
- // increase cwnd to expected/2 essentially putting the connection
- // straight into congestion avoidance.
- wMax := expected
- // Lower expected as per cubic spec after a congestion event.
- expected = int(float64(expected) * 0.7)
- cwnd := expected
- for i := 0; i < iterations; i++ {
- // Cubic grows window independent of ACKs. Cubic Window growth
- // is a function of time elapsed since last congestion event.
- // As a result the congestion window does not grow
- // deterministically in response to ACKs.
- //
- // We need to roughly estimate what the cwnd of the sender is
- // based on when we sent the dupacks.
- cwnd := cubicCwnd(cwnd, wMax, packetDropTime, 50*time.Millisecond)
-
- packetsExpected := cwnd
- for j := 0; j < packetsExpected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
- t.Logf("expected packets received, next trying to receive any extra packets that may come")
-
- // If our estimate was correct there should be no more pending packets.
- // We attempt to read a packet a few times with a short sleep in between
- // to ensure that we don't see the sender send any unexpected packets.
- unexpectedPackets := 0
- for {
- gotPacket := c.ReceiveNonBlockingAndCheckPacket(data, bytesRead, maxPayload)
- if !gotPacket {
- break
- }
- bytesRead += maxPayload
- unexpectedPackets++
- time.Sleep(1 * time.Millisecond)
- }
- if unexpectedPackets != 0 {
- t.Fatalf("received %d unexpected packets for iteration %d", unexpectedPackets, i)
- }
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd(congestion avoidance)", 5*time.Millisecond)
-
- // Acknowledge all the data received so far.
- c.SendAck(790, bytesRead)
- }
-}
-
-func TestRetransmit(t *testing.T) {
- maxPayload := 32
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- const iterations = 3
- data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1)))
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write all the data in two shots. Packets will only be written at the
- // MTU size though.
- var r bytes.Reader
- r.Reset(data[:len(data)/2])
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
- r.Reset(data[len(data)/2:])
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Do slow start for a few iterations.
- expected := tcp.InitialCwnd
- bytesRead := 0
- for i := 0; i < iterations; i++ {
- expected = tcp.InitialCwnd << uint(i)
- if i > 0 {
- // Acknowledge all the data received so far if not on
- // first iteration.
- c.SendAck(790, bytesRead)
- }
-
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
- }
-
- // Wait for a timeout and retransmit.
- rtxOffset := bytesRead - maxPayload*expected
- c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
-
- metricPollFn := func() error {
- if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.Timeouts.Value = %d, want = %d", got, want)
- }
-
- if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want)
- }
-
- if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want {
- return fmt.Errorf("got EP SendErrors.Timeouts.Value = %d, want = %d", got, want)
- }
-
- if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want {
- return fmt.Errorf("got EP stats SendErrors.Retransmits.Value = %d, want = %d", got, want)
- }
-
- if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.SlowStartRetransmits.Value = %d, want = %d", got, want)
- }
-
- return nil
- }
-
- // Poll when checking metrics.
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-
- // Acknowledge half of the pending data.
- rtxOffset = bytesRead - expected*maxPayload/2
- c.SendAck(790, rtxOffset)
-
- // Receive the remaining data, making sure that acknowledged data is not
- // retransmitted.
- for offset := rtxOffset; offset < len(data); offset += maxPayload {
- c.ReceiveAndCheckPacket(data, offset, maxPayload)
- c.SendAck(790, offset+maxPayload)
- }
-
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
-}
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
deleted file mode 100644
index c35db7c95..000000000
--- a/pkg/tcpip/transport/tcp/tcp_rack_test.go
+++ /dev/null
@@ -1,1101 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp_test
-
-import (
- "bytes"
- "fmt"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
- "gvisor.dev/gvisor/pkg/test/testutil"
-)
-
-const (
- maxPayload = 10
- tsOptionSize = 12
- maxTCPOptionSize = 40
- mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload
-)
-
-func setStackTCPRecovery(t *testing.T, c *context.Context, recovery int) {
- t.Helper()
- opt := tcpip.TCPRecovery(recovery)
- if err := c.Stack().SetTransportProtocolOption(header.TCPProtocolNumber, &opt); err != nil {
- t.Fatalf("c.s.SetTransportProtocolOption(%d, &%v(%v)): %s", header.TCPProtocolNumber, opt, opt, err)
- }
-}
-
-// TestRACKUpdate tests the RACK related fields are updated when an ACK is
-// received on a SACK enabled connection.
-func TestRACKUpdate(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- var xmitTime tcpip.MonotonicTime
- probeDone := make(chan struct{})
- c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
- // Validate that the endpoint Sender.RACKState is what we expect.
- if state.Sender.RACKState.XmitTime.Before(xmitTime) {
- t.Fatalf("RACK transmit time failed to update when an ACK is received")
- }
-
- gotSeq := state.Sender.RACKState.EndSequence
- wantSeq := state.Sender.SndNxt
- if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) {
- t.Fatalf("RACK sequence number failed to update, got: %v, but want: %v", gotSeq, wantSeq)
- }
-
- if state.Sender.RACKState.RTT == 0 {
- t.Fatalf("RACK RTT failed to update when an ACK is received, got RACKState.RTT == 0 want != 0")
- }
- close(probeDone)
- })
- setStackSACKPermitted(t, c, true)
- createConnectedWithSACKAndTS(c)
-
- data := make([]byte, maxPayload)
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write the data.
- xmitTime = c.Stack().Clock().NowMonotonic()
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- bytesRead := 0
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
- bytesRead += maxPayload
- c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead)
-
- // Wait for the probe function to finish processing the ACK before the
- // test completes.
- <-probeDone
-}
-
-// TestRACKDetectReorder tests that RACK detects packet reordering.
-func TestRACKDetectReorder(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- t.Skipf("Skipping this test as reorder detection does not consider DSACK.")
-
- var n int
- const ackNumToVerify = 2
- probeDone := make(chan struct{})
- c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
- gotSeq := state.Sender.RACKState.FACK
- wantSeq := state.Sender.SndNxt
- // FACK should be updated to the highest ending sequence number of the
- // segment acknowledged most recently.
- if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) {
- t.Fatalf("RACK FACK failed to update, got: %v, but want: %v", gotSeq, wantSeq)
- }
-
- n++
- if n < ackNumToVerify {
- if state.Sender.RACKState.Reord {
- t.Fatalf("RACK reorder detected when there is no reordering")
- }
- return
- }
-
- if state.Sender.RACKState.Reord == false {
- t.Fatalf("RACK reorder detection failed")
- }
- close(probeDone)
- })
- setStackSACKPermitted(t, c, true)
- createConnectedWithSACKAndTS(c)
- data := make([]byte, ackNumToVerify*maxPayload)
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write the data.
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- bytesRead := 0
- for i := 0; i < ackNumToVerify; i++ {
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
- bytesRead += maxPayload
- }
-
- start := c.IRS.Add(maxPayload + 1)
- end := start.Add(maxPayload)
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}})
- c.SendAck(seq, bytesRead)
-
- // Wait for the probe function to finish processing the ACK before the
- // test completes.
- <-probeDone
-}
-
-func sendAndReceiveWithSACK(t *testing.T, c *context.Context, numPackets int, enableRACK bool) []byte {
- setStackSACKPermitted(t, c, true)
- if !enableRACK {
- setStackTCPRecovery(t, c, 0)
- }
- // The delay should be below initial RTO (1s) otherwise retransimission
- // will start. Choose a relatively large value so that estimated RTT
- // keeps high even after a few rounds of undelayed RTT samples.
- c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true}, 800*time.Millisecond /* delay */)
-
- data := make([]byte, numPackets*maxPayload)
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write the data.
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- bytesRead := 0
- for i := 0; i < numPackets; i++ {
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
- bytesRead += maxPayload
- }
-
- return data
-}
-
-const (
- validDSACKDetected = 1
- failedToDetectDSACK = 2
- invalidDSACKDetected = 3
-)
-
-func addDSACKSeenCheckerProbe(t *testing.T, c *context.Context, numACK int, probeDone chan int) {
- var n int
- c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
- // Validate that RACK detects DSACK.
- n++
- if n < numACK {
- if state.Sender.RACKState.DSACKSeen {
- probeDone <- invalidDSACKDetected
- }
- return
- }
-
- if !state.Sender.RACKState.DSACKSeen {
- probeDone <- failedToDetectDSACK
- return
- }
- probeDone <- validDSACKDetected
- })
-}
-
-// TestRACKTLPRecovery tests that RACK sends a tail loss probe (TLP) in the
-// case of a tail loss. This simulates a situation where the TLP is able to
-// insinuate the SACK holes and sender is able to retransmit the rest.
-func TestRACKTLPRecovery(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- // Send 8 packets.
- numPackets := 8
- data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // Packets [6-8] are lost. Send cumulative ACK for [1-5].
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- bytesRead := 5 * maxPayload
- c.SendAck(seq, bytesRead)
-
- // PTO should fire and send #8 packet as a TLP.
- c.ReceiveAndCheckPacketWithOptions(data, 7*maxPayload, maxPayload, tsOptionSize)
- var info tcpip.TCPInfoOption
- if err := c.EP.GetSockOpt(&info); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
- }
-
- // Send the SACK after RTT because RACK RFC states that if the ACK for a
- // retransmission arrives before the smoothed RTT then the sender should not
- // update RACK state as it could be a spurious inference.
- time.Sleep(info.RTT)
-
- // Okay, let the sender know we got #8 using a SACK block.
- eighthPStart := c.IRS.Add(1 + seqnum.Size(7*maxPayload))
- eighthPEnd := eighthPStart.Add(maxPayload)
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{eighthPStart, eighthPEnd}})
-
- // The sender should be entering RACK based loss-recovery and sending #6 and
- // #7 one after another.
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
- bytesRead += maxPayload
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
- bytesRead += 2 * maxPayload
- c.SendAck(seq, bytesRead)
-
- metricPollFn := func() error {
- tcpStats := c.Stack().Stats().TCP
- stats := []struct {
- stat *tcpip.StatCounter
- name string
- want uint64
- }{
- // One fast retransmit after the SACK.
- {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1},
- // Recovery should be SACK recovery.
- {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
- // Packets 6, 7 and 8 were retransmitted.
- {tcpStats.Retransmits, "stats.TCP.Retransmits", 3},
- // TLP recovery should have been detected.
- {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 1},
- // No RTOs should have occurred.
- {tcpStats.Timeouts, "stats.TCP.Timeouts", 0},
- }
- for _, s := range stats {
- if got, want := s.stat.Value(), s.want; got != want {
- return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
- }
- }
- return nil
- }
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-}
-
-// TestRACKTLPFallbackRTO tests that RACK sends a tail loss probe (TLP) in the
-// case of a tail loss. This simulates a situation where either the TLP or its
-// ACK is lost. The sender should retransmit when RTO fires.
-func TestRACKTLPFallbackRTO(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- // Send 8 packets.
- numPackets := 8
- data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // Packets [6-8] are lost. Send cumulative ACK for [1-5].
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- bytesRead := 5 * maxPayload
- c.SendAck(seq, bytesRead)
-
- // PTO should fire and send #8 packet as a TLP.
- c.ReceiveAndCheckPacketWithOptions(data, 7*maxPayload, maxPayload, tsOptionSize)
-
- // Either the TLP or the ACK the receiver sent with SACK blocks was lost.
-
- // Confirm that RTO fires and retransmits packet #6.
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
-
- metricPollFn := func() error {
- tcpStats := c.Stack().Stats().TCP
- stats := []struct {
- stat *tcpip.StatCounter
- name string
- want uint64
- }{
- // No fast retransmits happened.
- {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 0},
- // No SACK recovery happened.
- {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 0},
- // TLP was unsuccessful.
- {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0},
- // RTO should have fired.
- {tcpStats.Timeouts, "stats.TCP.Timeouts", 1},
- }
- for _, s := range stats {
- if got, want := s.stat.Value(), s.want; got != want {
- return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
- }
- }
- return nil
- }
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-}
-
-// TestNoTLPRecoveryOnDSACK tests the scenario where the sender speculates a
-// tail loss and sends a TLP. Everything is received and acked. The probe
-// segment is DSACKed. No fast recovery should be triggered in this case.
-func TestNoTLPRecoveryOnDSACK(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- // Send 8 packets.
- numPackets := 8
- data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // Packets [1-5] are received first. [6-8] took a detour and will take a
- // while to arrive. Ack [1-5].
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- bytesRead := 5 * maxPayload
- c.SendAck(seq, bytesRead)
-
- // The tail loss probe (#8 packet) is received.
- c.ReceiveAndCheckPacketWithOptions(data, 7*maxPayload, maxPayload, tsOptionSize)
-
- // Now that all 8 packets are received + duplicate 8th packet, send ack.
- bytesRead += 3 * maxPayload
- eighthPStart := c.IRS.Add(1 + seqnum.Size(7*maxPayload))
- eighthPEnd := eighthPStart.Add(maxPayload)
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{eighthPStart, eighthPEnd}})
-
- // Wait for RTO and make sure that nothing else is received.
- var info tcpip.TCPInfoOption
- if err := c.EP.GetSockOpt(&info); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
- }
- if p := c.GetPacketWithTimeout(info.RTO); p != nil {
- t.Errorf("received an unexpected packet: %v", p)
- }
-
- metricPollFn := func() error {
- tcpStats := c.Stack().Stats().TCP
- stats := []struct {
- stat *tcpip.StatCounter
- name string
- want uint64
- }{
- // Make sure no recovery was entered.
- {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 0},
- {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 0},
- {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0},
- // RTO should not have fired.
- {tcpStats.Timeouts, "stats.TCP.Timeouts", 0},
- // Only #8 was retransmitted.
- {tcpStats.Retransmits, "stats.TCP.Retransmits", 1},
- }
- for _, s := range stats {
- if got, want := s.stat.Value(), s.want; got != want {
- return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
- }
- }
- return nil
- }
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-}
-
-// TestNoTLPOnSACK tests the scenario where there is not exactly a tail loss
-// due to the presence of multiple SACK holes. In such a scenario, TLP should
-// not be sent.
-func TestNoTLPOnSACK(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- // Send 8 packets.
- numPackets := 8
- data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // Packets [1-5] and #7 were received. #6 and #8 were dropped.
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- bytesRead := 5 * maxPayload
- seventhStart := c.IRS.Add(1 + seqnum.Size(6*maxPayload))
- seventhEnd := seventhStart.Add(maxPayload)
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{seventhStart, seventhEnd}})
-
- // The sender should retransmit #6. If the sender sends a TLP, then #8 will
- // received and fail this test.
- c.ReceiveAndCheckPacketWithOptions(data, 5*maxPayload, maxPayload, tsOptionSize)
-
- metricPollFn := func() error {
- tcpStats := c.Stack().Stats().TCP
- stats := []struct {
- stat *tcpip.StatCounter
- name string
- want uint64
- }{
- // #6 was retransmitted due to SACK recovery.
- {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1},
- {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
- {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0},
- // RTO should not have fired.
- {tcpStats.Timeouts, "stats.TCP.Timeouts", 0},
- // Only #6 was retransmitted.
- {tcpStats.Retransmits, "stats.TCP.Retransmits", 1},
- }
- for _, s := range stats {
- if got, want := s.stat.Value(), s.want; got != want {
- return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
- }
- }
- return nil
- }
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-}
-
-// TestRACKOnePacketTailLoss tests the trivial case of a tail loss of only one
-// packet. The probe should itself repairs the loss instead of having to go
-// into any recovery.
-func TestRACKOnePacketTailLoss(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- // Send 3 packets.
- numPackets := 3
- data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // Packets [1-2] are received. #3 is lost.
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- bytesRead := 2 * maxPayload
- c.SendAck(seq, bytesRead)
-
- // PTO should fire and send #3 packet as a TLP.
- c.ReceiveAndCheckPacketWithOptions(data, 2*maxPayload, maxPayload, tsOptionSize)
- bytesRead += maxPayload
- c.SendAck(seq, bytesRead)
-
- metricPollFn := func() error {
- tcpStats := c.Stack().Stats().TCP
- stats := []struct {
- stat *tcpip.StatCounter
- name string
- want uint64
- }{
- // #3 was retransmitted as TLP.
- {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 0},
- {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
- {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0},
- // RTO should not have fired.
- {tcpStats.Timeouts, "stats.TCP.Timeouts", 0},
- // Only #3 was retransmitted.
- {tcpStats.Retransmits, "stats.TCP.Retransmits", 1},
- }
- for _, s := range stats {
- if got, want := s.stat.Value(), s.want; got != want {
- return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
- }
- }
- return nil
- }
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-}
-
-// TestRACKDetectDSACK tests that RACK detects DSACK with duplicate segments.
-// See: https://tools.ietf.org/html/rfc2883#section-4.1.1.
-func TestRACKDetectDSACK(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- probeDone := make(chan int)
- const ackNumToVerify = 2
- addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
-
- numPackets := 8
- data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // Cumulative ACK for [1-5] packets and SACK #8 packet (to prevent TLP).
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- bytesRead := 5 * maxPayload
- eighthPStart := c.IRS.Add(1 + seqnum.Size(7*maxPayload))
- eighthPEnd := eighthPStart.Add(maxPayload)
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{eighthPStart, eighthPEnd}})
-
- // Expect retransmission of #6 packet after RTO expires.
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
-
- // Send DSACK block for #6 packet indicating both
- // initial and retransmitted packet are received and
- // packets [1-8] are received.
- start := c.IRS.Add(1 + seqnum.Size(bytesRead))
- end := start.Add(maxPayload)
- bytesRead += 3 * maxPayload
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
-
- // Wait for the probe function to finish processing the
- // ACK before the test completes.
- err := <-probeDone
- switch err {
- case failedToDetectDSACK:
- t.Fatalf("RACK DSACK detection failed")
- case invalidDSACKDetected:
- t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
- }
-
- metricPollFn := func() error {
- tcpStats := c.Stack().Stats().TCP
- stats := []struct {
- stat *tcpip.StatCounter
- name string
- want uint64
- }{
- // Check DSACK was received for one segment.
- {tcpStats.SegmentsAckedWithDSACK, "stats.TCP.SegmentsAckedWithDSACK", 1},
- }
- for _, s := range stats {
- if got, want := s.stat.Value(), s.want; got != want {
- return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
- }
- }
- return nil
- }
-
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-}
-
-// TestRACKDetectDSACKWithOutOfOrder tests that RACK detects DSACK with out of
-// order segments.
-// See: https://tools.ietf.org/html/rfc2883#section-4.1.2.
-func TestRACKDetectDSACKWithOutOfOrder(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- probeDone := make(chan int)
- const ackNumToVerify = 2
- addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
-
- numPackets := 10
- data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // Cumulative ACK for [1-5] packets and SACK for #7 packet (to prevent TLP).
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- bytesRead := 5 * maxPayload
- seventhPStart := c.IRS.Add(1 + seqnum.Size(6*maxPayload))
- seventhPEnd := seventhPStart.Add(maxPayload)
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{seventhPStart, seventhPEnd}})
-
- // Expect retransmission of #6 packet.
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
-
- // Send DSACK block for #6 packet indicating both
- // initial and retransmitted packet are received and
- // packets [1-7] are received.
- start := c.IRS.Add(1 + seqnum.Size(bytesRead))
- end := start.Add(maxPayload)
- bytesRead += 2 * maxPayload
- // Send DSACK block for #6 along with SACK for out of
- // order #9 packet.
- start1 := c.IRS.Add(1 + seqnum.Size(bytesRead) + maxPayload)
- end1 := start1.Add(maxPayload)
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}, {start1, end1}})
-
- // Wait for the probe function to finish processing the
- // ACK before the test completes.
- err := <-probeDone
- switch err {
- case failedToDetectDSACK:
- t.Fatalf("RACK DSACK detection failed")
- case invalidDSACKDetected:
- t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
- }
-}
-
-// TestRACKDetectDSACKWithOutOfOrderDup tests that DSACK is detected on a
-// duplicate of out of order packet.
-// See: https://tools.ietf.org/html/rfc2883#section-4.1.3
-func TestRACKDetectDSACKWithOutOfOrderDup(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- probeDone := make(chan int)
- const ackNumToVerify = 4
- addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
-
- numPackets := 10
- sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // ACK [1-5] packets.
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- bytesRead := 5 * maxPayload
- c.SendAck(seq, bytesRead)
-
- // Send SACK indicating #6 packet is missing and received #7 packet.
- offset := seqnum.Size(bytesRead + maxPayload)
- start := c.IRS.Add(1 + offset)
- end := start.Add(maxPayload)
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
-
- // Send SACK with #6 packet is missing and received [7-8] packets.
- end = start.Add(2 * maxPayload)
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
-
- // Consider #8 packet is duplicated on the network and send DSACK.
- dsackStart := c.IRS.Add(1 + offset + maxPayload)
- dsackEnd := dsackStart.Add(maxPayload)
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}})
-
- // Wait for the probe function to finish processing the ACK before the
- // test completes.
- err := <-probeDone
- switch err {
- case failedToDetectDSACK:
- t.Fatalf("RACK DSACK detection failed")
- case invalidDSACKDetected:
- t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
- }
-}
-
-// TestRACKDetectDSACKSingleDup tests DSACK for a single duplicate subsegment.
-// See: https://tools.ietf.org/html/rfc2883#section-4.2.1.
-func TestRACKDetectDSACKSingleDup(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- probeDone := make(chan int)
- const ackNumToVerify = 4
- addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
-
- numPackets := 4
- data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // Send ACK for #1 packet.
- bytesRead := maxPayload
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendAck(seq, bytesRead)
-
- // Missing [2-3] packets and received #4 packet.
- seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- start := c.IRS.Add(1 + seqnum.Size(3*maxPayload))
- end := start.Add(seqnum.Size(maxPayload))
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
-
- // Expect retransmission of #2 packet.
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
-
- // ACK for retransmitted #2 packet.
- bytesRead += maxPayload
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
-
- // Simulate receving delayed subsegment of #2 packet and delayed #3 packet by
- // sending DSACK block for the subsegment.
- dsackStart := c.IRS.Add(1 + seqnum.Size(bytesRead))
- dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2))
- c.SendAckWithSACK(seq, numPackets*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}})
-
- // Wait for the probe function to finish processing the ACK before the
- // test completes.
- err := <-probeDone
- switch err {
- case failedToDetectDSACK:
- t.Fatalf("RACK DSACK detection failed")
- case invalidDSACKDetected:
- t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
- }
-
- metricPollFn := func() error {
- tcpStats := c.Stack().Stats().TCP
- stats := []struct {
- stat *tcpip.StatCounter
- name string
- want uint64
- }{
- // Check DSACK was received for a subsegment.
- {tcpStats.SegmentsAckedWithDSACK, "stats.TCP.SegmentsAckedWithDSACK", 1},
- }
- for _, s := range stats {
- if got, want := s.stat.Value(), s.want; got != want {
- return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
- }
- }
- return nil
- }
-
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-}
-
-// TestRACKDetectDSACKDupWithCumulativeACK tests DSACK for two non-contiguous
-// duplicate subsegments covered by the cumulative acknowledgement.
-// See: https://tools.ietf.org/html/rfc2883#section-4.2.2.
-func TestRACKDetectDSACKDupWithCumulativeACK(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- probeDone := make(chan int)
- const ackNumToVerify = 5
- addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
-
- numPackets := 6
- data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // Send ACK for #1 packet.
- bytesRead := maxPayload
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendAck(seq, bytesRead)
-
- // Missing [2-5] packets and received #6 packet.
- seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- start := c.IRS.Add(1 + seqnum.Size(5*maxPayload))
- end := start.Add(seqnum.Size(maxPayload))
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
-
- // Expect retransmission of #2 packet.
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
-
- // Received delayed #2 packet.
- bytesRead += maxPayload
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
-
- // Received delayed #4 packet.
- start1 := c.IRS.Add(1 + seqnum.Size(3*maxPayload))
- end1 := start1.Add(seqnum.Size(maxPayload))
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}})
-
- // Simulate receiving retransmitted subsegment for #2 packet and delayed #3
- // packet by sending DSACK block for #2 packet.
- dsackStart := c.IRS.Add(1 + seqnum.Size(maxPayload))
- dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2))
- c.SendAckWithSACK(seq, 4*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}})
-
- // Wait for the probe function to finish processing the ACK before the
- // test completes.
- err := <-probeDone
- switch err {
- case failedToDetectDSACK:
- t.Fatalf("RACK DSACK detection failed")
- case invalidDSACKDetected:
- t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
- }
-}
-
-// TestRACKDetectDSACKDup tests two non-contiguous duplicate subsegments not
-// covered by the cumulative acknowledgement.
-// See: https://tools.ietf.org/html/rfc2883#section-4.2.3.
-func TestRACKDetectDSACKDup(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- probeDone := make(chan int)
- const ackNumToVerify = 5
- addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone)
-
- numPackets := 7
- data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // Send ACK for #1 packet.
- bytesRead := maxPayload
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendAck(seq, bytesRead)
-
- // Missing [2-6] packets and SACK #7 packet.
- seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- start := c.IRS.Add(1 + seqnum.Size(6*maxPayload))
- end := start.Add(seqnum.Size(maxPayload))
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
-
- // Received delayed #3 packet.
- start1 := c.IRS.Add(1 + seqnum.Size(2*maxPayload))
- end1 := start1.Add(seqnum.Size(maxPayload))
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}})
-
- // Expect retransmission of #2 packet.
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
-
- // Consider #2 packet has been dropped and SACK #4 packet.
- start2 := c.IRS.Add(1 + seqnum.Size(3*maxPayload))
- end2 := start2.Add(seqnum.Size(maxPayload))
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start2, end2}, {start1, end1}, {start, end}})
-
- // Simulate receiving retransmitted subsegment for #3 packet and delayed #5
- // packet by sending DSACK block for the subsegment.
- dsackStart := c.IRS.Add(1 + seqnum.Size(2*maxPayload))
- dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2))
- end1 = end1.Add(seqnum.Size(2 * maxPayload))
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{dsackStart, dsackEnd}, {start1, end1}})
-
- // Wait for the probe function to finish processing the ACK before the
- // test completes.
- err := <-probeDone
- switch err {
- case failedToDetectDSACK:
- t.Fatalf("RACK DSACK detection failed")
- case invalidDSACKDetected:
- t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
- }
-}
-
-// TestRACKWithInvalidDSACKBlock tests that DSACK is not detected when DSACK
-// is not the first SACK block.
-func TestRACKWithInvalidDSACKBlock(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- probeDone := make(chan struct{})
- const ackNumToVerify = 2
- var n int
- c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
- // Validate that RACK does not detect DSACK when DSACK block is
- // not the first SACK block.
- n++
- t.Helper()
- if state.Sender.RACKState.DSACKSeen {
- t.Fatalf("RACK DSACK detected when there is no duplicate SACK")
- }
-
- if n == ackNumToVerify {
- close(probeDone)
- }
- })
-
- numPackets := 10
- data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // Cumulative ACK for [1-5] packets and SACK for #7 packet (to prevent TLP).
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- bytesRead := 5 * maxPayload
- seventhPStart := c.IRS.Add(1 + seqnum.Size(6*maxPayload))
- seventhPEnd := seventhPStart.Add(maxPayload)
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{seventhPStart, seventhPEnd}})
-
- // Expect retransmission of #6 packet.
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
-
- // Send DSACK block for #6 packet indicating both
- // initial and retransmitted packet are received and
- // packets [1-7] are received.
- start := c.IRS.Add(1 + seqnum.Size(bytesRead))
- end := start.Add(maxPayload)
- bytesRead += 2 * maxPayload
-
- // Send DSACK block as second block. The first block is a SACK for #9 packet.
- start1 := c.IRS.Add(1 + seqnum.Size(bytesRead) + maxPayload)
- end1 := start1.Add(maxPayload)
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}})
-
- // Wait for the probe function to finish processing the
- // ACK before the test completes.
- <-probeDone
-}
-
-func addReorderWindowCheckerProbe(c *context.Context, numACK int, probeDone chan error) {
- var n int
- c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
- // Validate that RACK detects DSACK.
- n++
- if n < numACK {
- return
- }
-
- if state.Sender.RACKState.ReoWnd == 0 || state.Sender.RACKState.ReoWnd > state.Sender.RTTState.SRTT {
- probeDone <- fmt.Errorf("got RACKState.ReoWnd: %d, expected it to be greater than 0 and less than %d", state.Sender.RACKState.ReoWnd, state.Sender.RTTState.SRTT)
- return
- }
-
- if state.Sender.RACKState.ReoWndIncr != 1 {
- probeDone <- fmt.Errorf("got RACKState.ReoWndIncr: %v, want: 1", state.Sender.RACKState.ReoWndIncr)
- return
- }
-
- if state.Sender.RACKState.ReoWndPersist > 0 {
- probeDone <- fmt.Errorf("got RACKState.ReoWndPersist: %v, want: greater than 0", state.Sender.RACKState.ReoWndPersist)
- return
- }
- probeDone <- nil
- })
-}
-
-func TestRACKCheckReorderWindow(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- probeDone := make(chan error)
- const ackNumToVerify = 3
- addReorderWindowCheckerProbe(c, ackNumToVerify, probeDone)
-
- const numPackets = 7
- sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // Send ACK for #1 packet.
- bytesRead := maxPayload
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendAck(seq, bytesRead)
-
- // Missing [2-6] packets and SACK #7 packet.
- start := c.IRS.Add(1 + seqnum.Size(6*maxPayload))
- end := start.Add(seqnum.Size(maxPayload))
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
-
- // Received delayed packets [2-6] which indicates there is reordering
- // in the connection.
- bytesRead += 6 * maxPayload
- c.SendAck(seq, bytesRead)
-
- // Wait for the probe function to finish processing the ACK before the
- // test completes.
- if err := <-probeDone; err != nil {
- t.Fatalf("unexpected values for RACK variables: %v", err)
- }
-}
-
-func TestRACKWithDuplicateACK(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- const numPackets = 4
- data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
-
- // Send three duplicate ACKs to trigger fast recovery. The first
- // segment is considered as lost and will be retransmitted after
- // receiving the duplicate ACKs.
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- start := c.IRS.Add(1 + seqnum.Size(maxPayload))
- end := start.Add(seqnum.Size(maxPayload))
- for i := 0; i < 3; i++ {
- c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}})
- end = end.Add(seqnum.Size(maxPayload))
- }
-
- // Receive the retransmitted packet.
- c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize)
-
- metricPollFn := func() error {
- tcpStats := c.Stack().Stats().TCP
- stats := []struct {
- stat *tcpip.StatCounter
- name string
- want uint64
- }{
- {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1},
- {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
- {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0},
- }
- for _, s := range stats {
- if got, want := s.stat.Value(), s.want; got != want {
- return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
- }
- }
- return nil
- }
-
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-}
-
-// TestRACKUpdateSackedOut tests the sacked out field is updated when a SACK
-// is received.
-func TestRACKUpdateSackedOut(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- probeDone := make(chan struct{})
- ackNum := 0
- c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
- // Validate that the endpoint Sender.SackedOut is what we expect.
- if state.Sender.SackedOut != 2 && ackNum == 0 {
- t.Fatalf("SackedOut got updated to wrong value got: %v want: 2", state.Sender.SackedOut)
- }
-
- if state.Sender.SackedOut != 0 && ackNum == 1 {
- t.Fatalf("SackedOut got updated to wrong value got: %v want: 0", state.Sender.SackedOut)
- }
- if ackNum > 0 {
- close(probeDone)
- }
- ackNum++
- })
-
- sendAndReceiveWithSACK(t, c, 8, true /* enableRACK */)
-
- // ACK for [3-5] packets.
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- start := c.IRS.Add(seqnum.Size(1 + 3*maxPayload))
- bytesRead := 2 * maxPayload
- end := start.Add(seqnum.Size(bytesRead))
- c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
-
- bytesRead += 3 * maxPayload
- c.SendAck(seq, bytesRead)
-
- // Wait for the probe function to finish processing the ACK before the
- // test completes.
- <-probeDone
-}
-
-// TestRACKWithWindowFull tests that RACK honors the receive window size.
-func TestRACKWithWindowFull(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- setStackSACKPermitted(t, c, true)
- createConnectedWithSACKAndTS(c)
-
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- const numPkts = 10
- data := make([]byte, numPkts*maxPayload)
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write the data.
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- bytesRead := 0
- for i := 0; i < numPkts; i++ {
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
- bytesRead += maxPayload
- if i == 0 {
- // Send ACK for the first packet to establish RTT.
- c.SendAck(seq, maxPayload)
- }
- }
-
- // SACK for #10 packet.
- start := c.IRS.Add(seqnum.Size(1 + (numPkts-1)*maxPayload))
- end := start.Add(seqnum.Size(maxPayload))
- c.SendAckWithSACK(seq, 2*maxPayload, []header.SACKBlock{{start, end}})
-
- var info tcpip.TCPInfoOption
- if err := c.EP.GetSockOpt(&info); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
- }
- // Wait for RTT to trigger recovery.
- time.Sleep(info.RTT)
-
- // Expect retransmission of #2 packet.
- c.ReceiveAndCheckPacketWithOptions(data, 2*maxPayload, maxPayload, tsOptionSize)
-
- // Send ACK for #2 packet.
- c.SendAck(seq, 3*maxPayload)
-
- // Expect retransmission of #3 packet.
- c.ReceiveAndCheckPacketWithOptions(data, 3*maxPayload, maxPayload, tsOptionSize)
-
- // Send ACK with zero window size.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: seq,
- AckNum: c.IRS.Add(1 + 4*maxPayload),
- RcvWnd: 0,
- })
-
- // No packet should be received as the receive window size is zero.
- c.CheckNoPacket("unexpected packet received after userTimeout has expired")
-}
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
deleted file mode 100644
index 6255355bb..000000000
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ /dev/null
@@ -1,704 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp_test
-
-import (
- "bytes"
- "fmt"
- "log"
- "reflect"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
- "gvisor.dev/gvisor/pkg/test/testutil"
-)
-
-// createConnectedWithSACKPermittedOption creates and connects c.ep with the
-// SACKPermitted option enabled if the stack in the context has the SACK support
-// enabled.
-func createConnectedWithSACKPermittedOption(c *context.Context) *context.RawEndpoint {
- return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()})
-}
-
-// createConnectedWithSACKAndTS creates and connects c.ep with the SACK & TS
-// option enabled if the stack in the context has SACK and TS enabled.
-func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint {
- return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true})
-}
-
-func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) {
- t.Helper()
- opt := tcpip.TCPSACKEnabled(enable)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("c.s.SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
-}
-
-// TestSackPermittedConnect establishes a connection with the SACK option
-// enabled.
-func TestSackPermittedConnect(t *testing.T) {
- for _, sackEnabled := range []bool{false, true} {
- t.Run(fmt.Sprintf("stack.sackEnabled: %v", sackEnabled), func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- setStackSACKPermitted(t, c, sackEnabled)
- setStackTCPRecovery(t, c, 0)
- rep := createConnectedWithSACKPermittedOption(c)
- data := []byte{1, 2, 3}
-
- rep.SendPacket(data, nil)
- savedSeqNum := rep.NextSeqNum
- rep.VerifyACKNoSACK()
-
- // Make an out of order packet and send it.
- rep.NextSeqNum += 3
- sackBlocks := []header.SACKBlock{
- {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))},
- }
- rep.SendPacket(data, nil)
-
- // Restore the saved sequence number so that the
- // VerifyXXX calls use the right sequence number for
- // checking ACK numbers.
- rep.NextSeqNum = savedSeqNum
- if sackEnabled {
- rep.VerifyACKHasSACK(sackBlocks)
- } else {
- rep.VerifyACKNoSACK()
- }
-
- // Send the missing segment.
- rep.SendPacket(data, nil)
- // The ACK should contain the cumulative ACK for all 9
- // bytes sent and no SACK blocks.
- rep.NextSeqNum += 3
- // Check that no SACK block is returned in the ACK.
- rep.VerifyACKNoSACK()
- })
- }
-}
-
-// TestSackDisabledConnect establishes a connection with the SACK option
-// disabled and verifies that no SACKs are sent for out of order segments.
-func TestSackDisabledConnect(t *testing.T) {
- for _, sackEnabled := range []bool{false, true} {
- t.Run(fmt.Sprintf("sackEnabled: %v", sackEnabled), func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- setStackSACKPermitted(t, c, sackEnabled)
- setStackTCPRecovery(t, c, 0)
-
- rep := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{})
-
- data := []byte{1, 2, 3}
-
- rep.SendPacket(data, nil)
- savedSeqNum := rep.NextSeqNum
- rep.VerifyACKNoSACK()
-
- // Make an out of order packet and send it.
- rep.NextSeqNum += 3
- rep.SendPacket(data, nil)
-
- // The ACK should contain the older sequence number and
- // no SACK blocks.
- rep.NextSeqNum = savedSeqNum
- rep.VerifyACKNoSACK()
-
- // Send the missing segment.
- rep.SendPacket(data, nil)
- // The ACK should contain the cumulative ACK for all 9
- // bytes sent and no SACK blocks.
- rep.NextSeqNum += 3
- // Check that no SACK block is returned in the ACK.
- rep.VerifyACKNoSACK()
- })
- }
-}
-
-// TestSackPermittedAccept accepts and establishes a connection with the
-// SACKPermitted option enabled if the connection request specifies the
-// SACKPermitted option. In case of SYN cookies SACK should be disabled as we
-// don't encode the SACK information in the cookie.
-func TestSackPermittedAccept(t *testing.T) {
- type testCase struct {
- cookieEnabled bool
- sackPermitted bool
- wndScale int
- wndSize uint16
- }
-
- testCases := []testCase{
- // When cookie is used window scaling is disabled.
- {true, false, -1, 0xffff}, // When cookie is used window scaling is disabled.
- {false, true, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
- }
-
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) {
- for _, sackEnabled := range []bool{false, true} {
- t.Run(fmt.Sprintf("test stack.sackEnabled: %v", sackEnabled), func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- if tc.cookieEnabled {
- opt := tcpip.TCPAlwaysUseSynCookies(true)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
- }
- setStackSACKPermitted(t, c, sackEnabled)
- setStackTCPRecovery(t, c, 0)
-
- rep := c.AcceptWithOptionsNoDelay(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted})
- // Now verify no SACK blocks are
- // received when sack is disabled.
- data := []byte{1, 2, 3}
- rep.SendPacket(data, nil)
- rep.VerifyACKNoSACK()
-
- savedSeqNum := rep.NextSeqNum
-
- // Make an out of order packet and send
- // it.
- rep.NextSeqNum += 3
- sackBlocks := []header.SACKBlock{
- {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))},
- }
- rep.SendPacket(data, nil)
-
- // The ACK should contain the older
- // sequence number.
- rep.NextSeqNum = savedSeqNum
- if sackEnabled && tc.sackPermitted {
- rep.VerifyACKHasSACK(sackBlocks)
- } else {
- rep.VerifyACKNoSACK()
- }
-
- // Send the missing segment.
- rep.SendPacket(data, nil)
- // The ACK should contain the cumulative
- // ACK for all 9 bytes sent and no SACK
- // blocks.
- rep.NextSeqNum += 3
- // Check that no SACK block is returned
- // in the ACK.
- rep.VerifyACKNoSACK()
- })
- }
- })
- }
-}
-
-// TestSackDisabledAccept accepts and establishes a connection with
-// the SACKPermitted option disabled and verifies that no SACKs are
-// sent for out of order packets.
-func TestSackDisabledAccept(t *testing.T) {
- type testCase struct {
- cookieEnabled bool
- wndScale int
- wndSize uint16
- }
-
- testCases := []testCase{
- // When cookie is used window scaling is disabled.
- {true, -1, 0xffff}, // When cookie is used window scaling is disabled.
- {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
- }
-
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) {
- for _, sackEnabled := range []bool{false, true} {
- t.Run(fmt.Sprintf("test: sackEnabled: %v", sackEnabled), func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- if tc.cookieEnabled {
- opt := tcpip.TCPAlwaysUseSynCookies(true)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
- }
-
- setStackSACKPermitted(t, c, sackEnabled)
- setStackTCPRecovery(t, c, 0)
-
- rep := c.AcceptWithOptionsNoDelay(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
-
- // Now verify no SACK blocks are
- // received when sack is disabled.
- data := []byte{1, 2, 3}
- rep.SendPacket(data, nil)
- rep.VerifyACKNoSACK()
- savedSeqNum := rep.NextSeqNum
-
- // Make an out of order packet and send
- // it.
- rep.NextSeqNum += 3
- rep.SendPacket(data, nil)
-
- // The ACK should contain the older
- // sequence number and no SACK blocks.
- rep.NextSeqNum = savedSeqNum
- rep.VerifyACKNoSACK()
-
- // Send the missing segment.
- rep.SendPacket(data, nil)
- // The ACK should contain the cumulative
- // ACK for all 9 bytes sent and no SACK
- // blocks.
- rep.NextSeqNum += 3
- // Check that no SACK block is returned
- // in the ACK.
- rep.VerifyACKNoSACK()
- })
- }
- })
- }
-}
-
-func TestUpdateSACKBlocks(t *testing.T) {
- testCases := []struct {
- segStart seqnum.Value
- segEnd seqnum.Value
- rcvNxt seqnum.Value
- sackBlocks []header.SACKBlock
- updated []header.SACKBlock
- }{
- // Trivial cases where current SACK block list is empty and we
- // have an out of order delivery.
- {10, 11, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 11}}},
- {10, 12, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 12}}},
- {10, 20, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 20}}},
-
- // Cases where current SACK block list is not empty and we have
- // an out of order delivery. Tests that the updated SACK block
- // list has the first block as the one that contains the new
- // SACK block representing the segment that was just delivered.
- {10, 11, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 11}, {12, 20}}},
- {24, 30, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{24, 30}, {12, 20}}},
- {24, 30, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}}},
-
- // Ensure that we only retain header.MaxSACKBlocks and drop the
- // oldest one if adding a new block exceeds
- // header.MaxSACKBlocks.
- {24, 30, 9,
- []header.SACKBlock{{12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}, {72, 80}},
- []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}},
-
- // Cases where segment extends an existing SACK block.
- {10, 12, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 20}}},
- {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}},
- {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}},
- {15, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 22}}},
- {15, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 25}}},
- {11, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{11, 25}}},
- {10, 12, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 20}, {32, 40}}},
- {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}},
- {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}},
- {15, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 22}, {32, 40}}},
- {15, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 25}, {32, 40}}},
- {11, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{11, 25}, {32, 40}}},
-
- // Cases where segment contains rcvNxt.
- {10, 20, 15, []header.SACKBlock{{20, 30}, {40, 50}}, []header.SACKBlock{{40, 50}}},
- }
-
- for _, tc := range testCases {
- var sack tcp.SACKInfo
- copy(sack.Blocks[:], tc.sackBlocks)
- sack.NumBlocks = len(tc.sackBlocks)
- tcp.UpdateSACKBlocks(&sack, tc.segStart, tc.segEnd, tc.rcvNxt)
- if got, want := sack.Blocks[:sack.NumBlocks], tc.updated; !reflect.DeepEqual(got, want) {
- t.Errorf("UpdateSACKBlocks(%v, %v, %v, %v), got: %v, want: %v", tc.sackBlocks, tc.segStart, tc.segEnd, tc.rcvNxt, got, want)
- }
-
- }
-}
-
-func TestTrimSackBlockList(t *testing.T) {
- testCases := []struct {
- rcvNxt seqnum.Value
- sackBlocks []header.SACKBlock
- trimmed []header.SACKBlock
- }{
- // Simple cases where we trim whole entries.
- {2, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}},
- {21, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{22, 30}, {32, 40}}},
- {31, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{32, 40}}},
- {40, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}},
- // Cases where we need to update a block.
- {12, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{12, 20}, {22, 30}, {32, 40}}},
- {23, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{23, 30}, {32, 40}}},
- {33, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{33, 40}}},
- {41, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}},
- }
- for _, tc := range testCases {
- var sack tcp.SACKInfo
- copy(sack.Blocks[:], tc.sackBlocks)
- sack.NumBlocks = len(tc.sackBlocks)
- tcp.TrimSACKBlockList(&sack, tc.rcvNxt)
- if got, want := sack.Blocks[:sack.NumBlocks], tc.trimmed; !reflect.DeepEqual(got, want) {
- t.Errorf("TrimSackBlockList(%v, %v), got: %v, want: %v", tc.sackBlocks, tc.rcvNxt, got, want)
- }
- }
-}
-
-func TestSACKRecovery(t *testing.T) {
- const maxPayload = 10
- // See: tcp.makeOptions for why tsOptionSize is set to 12 here.
- const tsOptionSize = 12
- // Enabling SACK means the payload size is reduced to account
- // for the extra space required for the TCP options.
- //
- // We increase the MTU by 40 bytes to account for SACK and Timestamp
- // options.
- const maxTCPOptionSize = 40
-
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload))
- defer c.Cleanup()
-
- c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) {
- // We use log.Printf instead of t.Logf here because this probe
- // can fire even when the test function has finished. This is
- // because closing the endpoint in cleanup() does not mean the
- // actual worker loop terminates immediately as it still has to
- // do a full TCP shutdown. But this test can finish running
- // before the shutdown is done. Using t.Logf in such a case
- // causes the test to panic due to logging after test finished.
- log.Printf("state: %+v\n", s)
- })
- setStackSACKPermitted(t, c, true)
- setStackTCPRecovery(t, c, 0)
- createConnectedWithSACKAndTS(c)
-
- const iterations = 3
- data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1)))
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write all the data in one shot. Packets will only be written at the
- // MTU size though.
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Do slow start for a few iterations.
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- expected := tcp.InitialCwnd
- bytesRead := 0
- for i := 0; i < iterations; i++ {
- expected = tcp.InitialCwnd << uint(i)
- if i > 0 {
- // Acknowledge all the data received so far if not on
- // first iteration.
- c.SendAck(seq, bytesRead)
- }
-
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
- }
-
- // Send 3 duplicate acks. This should force an immediate retransmit of
- // the pending packet and put the sender into fast recovery.
- rtxOffset := bytesRead - maxPayload*expected
- start := c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1)
- end := start.Add(10)
- for i := 0; i < 3; i++ {
- c.SendAckWithSACK(seq, rtxOffset, []header.SACKBlock{{start, end}})
- end = end.Add(10)
- }
-
- // Receive the retransmitted packet.
- c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize)
-
- metricPollFn := func() error {
- tcpStats := c.Stack().Stats().TCP
- stats := []struct {
- stat *tcpip.StatCounter
- name string
- want uint64
- }{
- {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1},
- {tcpStats.Retransmits, "stats.TCP.Retransmits", 1},
- {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
- {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0},
- }
- for _, s := range stats {
- if got, want := s.stat.Value(), s.want; got != want {
- return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
- }
- }
- return nil
- }
-
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-
- // Now send 7 mode duplicate ACKs. In SACK TCP dupAcks do not cause
- // window inflation and sending of packets is completely handled by the
- // SACK Recovery algorithm. We should see no packets being released, as
- // the cwnd at this point after entering recovery should be half of the
- // outstanding number of packets in flight.
- for i := 0; i < 7; i++ {
- c.SendAckWithSACK(seq, rtxOffset, []header.SACKBlock{{start, end}})
- end = end.Add(10)
- }
-
- recover := bytesRead
-
- // Ensure no new packets arrive.
- c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.",
- 50*time.Millisecond)
-
- // Acknowledge half of the pending data. This along with the 10 sacked
- // segments above should reduce the outstanding below the current
- // congestion window allowing the sender to transmit data.
- rtxOffset = bytesRead - expected*maxPayload/2
-
- // Now send a partial ACK w/ a SACK block that indicates that the next 3
- // segments are lost and we have received 6 segments after the lost
- // segments. This should cause the sender to immediately transmit all 3
- // segments in response to this ACK unlike in FastRecovery where only 1
- // segment is retransmitted per ACK.
- start = c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1)
- end = start.Add(60)
- c.SendAckWithSACK(seq, rtxOffset, []header.SACKBlock{{start, end}})
-
- // At this point, we acked expected/2 packets and we SACKED 6 packets and
- // 3 segments were considered lost due to the SACK block we sent.
- //
- // So total packets outstanding can be calculated as follows after 7
- // iterations of slow start -> 10/20/40/80/160/320/640. So expected
- // should be 640 at start, then we went to recover at which point the
- // cwnd should be set to 320 + 3 (for the 3 dupAcks which have left the
- // network).
- // Outstanding at this point after acking half the window
- // (320 packets) will be:
- // outstanding = 640-320-6(due to SACK block)-3 = 311
- //
- // The last 3 is due to the fact that the first 3 packets after
- // rtxOffset will be considered lost due to the SACK blocks sent.
- // Receive the retransmit due to partial ack.
-
- c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize)
- // Receive the 2 extra packets that should have been retransmitted as
- // those should be considered lost and immediately retransmitted based
- // on the SACK information in the previous ACK sent above.
- for i := 0; i < 2; i++ {
- c.ReceiveAndCheckPacketWithOptions(data, rtxOffset+maxPayload*(i+1), maxPayload, tsOptionSize)
- }
-
- // Now we should get 9 more new unsent packets as the cwnd is 323 and
- // outstanding is 311.
- for i := 0; i < 9; i++ {
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
- bytesRead += maxPayload
- }
-
- metricPollFn = func() error {
- // In SACK recovery only the first segment is fast retransmitted when
- // entering recovery.
- if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
- return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want)
- }
-
- if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want {
- return fmt.Errorf("got EP stats SendErrors.FastRetransmit = %d, want = %d", got, want)
- }
-
- if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want {
- return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want)
- }
-
- if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want {
- return fmt.Errorf("got EP stats Stats.SendErrors.Retransmits = %d, want = %d", got, want)
- }
- return nil
- }
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-
- c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond)
-
- // Acknowledge all pending data to recover point.
- c.SendAck(seq, recover)
-
- // At this point, the cwnd should reset to expected/2 and there are 9
- // packets outstanding.
- //
- // Now in the first iteration since there are 9 packets outstanding.
- // We would expect to get expected/2 - 9 packets. But subsequent
- // iterations will send us expected/2 + 1 (per iteration).
- expected = expected/2 - 9
- for i := 0; i < iterations; i++ {
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
- bytesRead += maxPayload
- }
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd and iteration: %d.", expected, i), 50*time.Millisecond)
-
- // Acknowledge all the data received so far.
- c.SendAck(seq, bytesRead)
-
- // In cogestion avoidance, the packets trains increase by 1 in
- // each iteration.
- if i == 0 {
- // After the first iteration we expect to get the full
- // congestion window worth of packets in every
- // iteration.
- expected += 9
- }
- expected++
- }
-}
-
-// TestRecoveryEntry tests the following two properties of entering recovery:
-// - Fast SACK recovery is entered when SND.UNA is considered lost by the SACK
-// scoreboard but dupack count is still below threshold.
-// - Only enter recovery when at least one more byte of data beyond the highest
-// byte that was outstanding when fast retransmit was last entered is acked.
-func TestRecoveryEntry(t *testing.T) {
- c := context.New(t, uint32(mtu))
- defer c.Cleanup()
-
- numPackets := 5
- data := sendAndReceiveWithSACK(t, c, numPackets, false /* enableRACK */)
-
- // Ack #1 packet.
- seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendAck(seq, maxPayload)
-
- // Now SACK #3, #4 and #5 packets. This will simulate a situation where
- // SND.UNA should be considered lost and the sender should enter fast recovery
- // (even though dupack count is still below threshold).
- p3Start := c.IRS.Add(1 + seqnum.Size(2*maxPayload))
- p3End := p3Start.Add(maxPayload)
- p4Start := p3End
- p4End := p4Start.Add(maxPayload)
- p5Start := p4End
- p5End := p5Start.Add(maxPayload)
- c.SendAckWithSACK(seq, maxPayload, []header.SACKBlock{{p3Start, p3End}, {p4Start, p4End}, {p5Start, p5End}})
-
- // Expect #2 to be retransmitted.
- c.ReceiveAndCheckPacketWithOptions(data, maxPayload, maxPayload, tsOptionSize)
-
- metricPollFn := func() error {
- tcpStats := c.Stack().Stats().TCP
- stats := []struct {
- stat *tcpip.StatCounter
- name string
- want uint64
- }{
- // SACK recovery must have happened.
- {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1},
- {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
- // #2 was retransmitted.
- {tcpStats.Retransmits, "stats.TCP.Retransmits", 1},
- // No RTOs should have fired yet.
- {tcpStats.Timeouts, "stats.TCP.Timeouts", 0},
- }
- for _, s := range stats {
- if got, want := s.stat.Value(), s.want; got != want {
- return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
- }
- }
- return nil
- }
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-
- // Send 4 more packets.
- var r bytes.Reader
- data = append(data, data...)
- r.Reset(data[5*maxPayload : 9*maxPayload])
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- var sackBlocks []header.SACKBlock
- bytesRead := numPackets * maxPayload
- for i := 0; i < 4; i++ {
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
- if i > 0 {
- pStart := c.IRS.Add(1 + seqnum.Size(bytesRead))
- sackBlocks = append(sackBlocks, header.SACKBlock{pStart, pStart.Add(maxPayload)})
- c.SendAckWithSACK(seq, 5*maxPayload, sackBlocks)
- }
- bytesRead += maxPayload
- }
-
- // #6 should be retransmitted after RTO. The sender should NOT enter fast
- // recovery because the highest byte that was outstanding when fast recovery
- // was last entered is #5 packet's end. And the sender requires at least one
- // more byte beyond that (#6 packet start) to be acked to enter recovery.
- c.ReceiveAndCheckPacketWithOptions(data, 5*maxPayload, maxPayload, tsOptionSize)
- c.SendAck(seq, 9*maxPayload)
-
- metricPollFn = func() error {
- tcpStats := c.Stack().Stats().TCP
- stats := []struct {
- stat *tcpip.StatCounter
- name string
- want uint64
- }{
- // Only 1 SACK recovery must have happened.
- {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1},
- {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
- // #2 and #6 were retransmitted.
- {tcpStats.Retransmits, "stats.TCP.Retransmits", 2},
- // RTO should have fired once.
- {tcpStats.Timeouts, "stats.TCP.Timeouts", 1},
- }
- for _, s := range stats {
- if got, want := s.stat.Value(), s.want; got != want {
- return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
- }
- }
- return nil
- }
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-}
diff --git a/pkg/tcpip/transport/tcp/tcp_segment_list.go b/pkg/tcpip/transport/tcp/tcp_segment_list.go
new file mode 100644
index 000000000..a14cff27e
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_segment_list.go
@@ -0,0 +1,221 @@
+package tcp
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type segmentElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (segmentElementMapper) linkerFor(elem *segment) *segment { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type segmentList struct {
+ head *segment
+ tail *segment
+}
+
+// Reset resets list l to the empty state.
+func (l *segmentList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+//
+//go:nosplit
+func (l *segmentList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+//
+//go:nosplit
+func (l *segmentList) Front() *segment {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+//
+//go:nosplit
+func (l *segmentList) Back() *segment {
+ return l.tail
+}
+
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+//
+//go:nosplit
+func (l *segmentList) Len() (count int) {
+ for e := l.Front(); e != nil; e = (segmentElementMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
+// PushFront inserts the element e at the front of list l.
+//
+//go:nosplit
+func (l *segmentList) PushFront(e *segment) {
+ linker := segmentElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
+ if l.head != nil {
+ segmentElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+//
+//go:nosplit
+func (l *segmentList) PushBack(e *segment) {
+ linker := segmentElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
+ if l.tail != nil {
+ segmentElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+//
+//go:nosplit
+func (l *segmentList) PushBackList(m *segmentList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ segmentElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ segmentElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+//
+//go:nosplit
+func (l *segmentList) InsertAfter(b, e *segment) {
+ bLinker := segmentElementMapper{}.linkerFor(b)
+ eLinker := segmentElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
+
+ if a != nil {
+ segmentElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+//
+//go:nosplit
+func (l *segmentList) InsertBefore(a, e *segment) {
+ aLinker := segmentElementMapper{}.linkerFor(a)
+ eLinker := segmentElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
+
+ if b != nil {
+ segmentElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+//
+//go:nosplit
+func (l *segmentList) Remove(e *segment) {
+ linker := segmentElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
+
+ if prev != nil {
+ segmentElementMapper{}.linkerFor(prev).SetNext(next)
+ } else if l.head == e {
+ l.head = next
+ }
+
+ if next != nil {
+ segmentElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else if l.tail == e {
+ l.tail = prev
+ }
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type segmentEntry struct {
+ next *segment
+ prev *segment
+}
+
+// Next returns the entry that follows e in the list.
+//
+//go:nosplit
+func (e *segmentEntry) Next() *segment {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *segmentEntry) Prev() *segment {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+//
+//go:nosplit
+func (e *segmentEntry) SetNext(elem *segment) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *segmentEntry) SetPrev(elem *segment) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
new file mode 100644
index 000000000..13061d2b1
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
@@ -0,0 +1,901 @@
+// automatically generated by stateify.
+
+package tcp
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+func (c *cubicState) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.cubicState"
+}
+
+func (c *cubicState) StateFields() []string {
+ return []string{
+ "TCPCubicState",
+ "numCongestionEvents",
+ "s",
+ }
+}
+
+func (c *cubicState) beforeSave() {}
+
+// +checklocksignore
+func (c *cubicState) StateSave(stateSinkObject state.Sink) {
+ c.beforeSave()
+ stateSinkObject.Save(0, &c.TCPCubicState)
+ stateSinkObject.Save(1, &c.numCongestionEvents)
+ stateSinkObject.Save(2, &c.s)
+}
+
+func (c *cubicState) afterLoad() {}
+
+// +checklocksignore
+func (c *cubicState) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &c.TCPCubicState)
+ stateSourceObject.Load(1, &c.numCongestionEvents)
+ stateSourceObject.Load(2, &c.s)
+}
+
+func (s *SACKInfo) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.SACKInfo"
+}
+
+func (s *SACKInfo) StateFields() []string {
+ return []string{
+ "Blocks",
+ "NumBlocks",
+ }
+}
+
+func (s *SACKInfo) beforeSave() {}
+
+// +checklocksignore
+func (s *SACKInfo) StateSave(stateSinkObject state.Sink) {
+ s.beforeSave()
+ stateSinkObject.Save(0, &s.Blocks)
+ stateSinkObject.Save(1, &s.NumBlocks)
+}
+
+func (s *SACKInfo) afterLoad() {}
+
+// +checklocksignore
+func (s *SACKInfo) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &s.Blocks)
+ stateSourceObject.Load(1, &s.NumBlocks)
+}
+
+func (s *sndQueueInfo) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.sndQueueInfo"
+}
+
+func (s *sndQueueInfo) StateFields() []string {
+ return []string{
+ "TCPSndBufState",
+ }
+}
+
+func (s *sndQueueInfo) beforeSave() {}
+
+// +checklocksignore
+func (s *sndQueueInfo) StateSave(stateSinkObject state.Sink) {
+ s.beforeSave()
+ stateSinkObject.Save(0, &s.TCPSndBufState)
+}
+
+func (s *sndQueueInfo) afterLoad() {}
+
+// +checklocksignore
+func (s *sndQueueInfo) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &s.TCPSndBufState)
+}
+
+func (r *rcvQueueInfo) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.rcvQueueInfo"
+}
+
+func (r *rcvQueueInfo) StateFields() []string {
+ return []string{
+ "TCPRcvBufState",
+ "rcvQueue",
+ }
+}
+
+func (r *rcvQueueInfo) beforeSave() {}
+
+// +checklocksignore
+func (r *rcvQueueInfo) StateSave(stateSinkObject state.Sink) {
+ r.beforeSave()
+ stateSinkObject.Save(0, &r.TCPRcvBufState)
+ stateSinkObject.Save(1, &r.rcvQueue)
+}
+
+func (r *rcvQueueInfo) afterLoad() {}
+
+// +checklocksignore
+func (r *rcvQueueInfo) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &r.TCPRcvBufState)
+ stateSourceObject.LoadWait(1, &r.rcvQueue)
+}
+
+func (a *accepted) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.accepted"
+}
+
+func (a *accepted) StateFields() []string {
+ return []string{
+ "endpoints",
+ "cap",
+ }
+}
+
+func (a *accepted) beforeSave() {}
+
+// +checklocksignore
+func (a *accepted) StateSave(stateSinkObject state.Sink) {
+ a.beforeSave()
+ var endpointsValue []*endpoint
+ endpointsValue = a.saveEndpoints()
+ stateSinkObject.SaveValue(0, endpointsValue)
+ stateSinkObject.Save(1, &a.cap)
+}
+
+func (a *accepted) afterLoad() {}
+
+// +checklocksignore
+func (a *accepted) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(1, &a.cap)
+ stateSourceObject.LoadValue(0, new([]*endpoint), func(y interface{}) { a.loadEndpoints(y.([]*endpoint)) })
+}
+
+func (e *endpoint) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.endpoint"
+}
+
+func (e *endpoint) StateFields() []string {
+ return []string{
+ "TCPEndpointStateInner",
+ "TransportEndpointInfo",
+ "DefaultSocketOptionsHandler",
+ "waiterQueue",
+ "uniqueID",
+ "hardError",
+ "lastError",
+ "rcvQueueInfo",
+ "rcvMemUsed",
+ "ownedByUser",
+ "state",
+ "boundNICID",
+ "ttl",
+ "isConnectNotified",
+ "portFlags",
+ "boundBindToDevice",
+ "boundPortFlags",
+ "boundDest",
+ "effectiveNetProtos",
+ "workerRunning",
+ "workerCleanup",
+ "recentTSTime",
+ "shutdownFlags",
+ "tcpRecovery",
+ "sack",
+ "delay",
+ "scoreboard",
+ "segmentQueue",
+ "synRcvdCount",
+ "userMSS",
+ "maxSynRetries",
+ "windowClamp",
+ "sndQueueInfo",
+ "cc",
+ "keepalive",
+ "userTimeout",
+ "deferAccept",
+ "accepted",
+ "rcv",
+ "snd",
+ "connectingAddress",
+ "amss",
+ "sendTOS",
+ "gso",
+ "tcpLingerTimeout",
+ "closed",
+ "txHash",
+ "owner",
+ "ops",
+ "lastOutOfWindowAckTime",
+ }
+}
+
+// +checklocksignore
+func (e *endpoint) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ var stateValue EndpointState
+ stateValue = e.saveState()
+ stateSinkObject.SaveValue(10, stateValue)
+ stateSinkObject.Save(0, &e.TCPEndpointStateInner)
+ stateSinkObject.Save(1, &e.TransportEndpointInfo)
+ stateSinkObject.Save(2, &e.DefaultSocketOptionsHandler)
+ stateSinkObject.Save(3, &e.waiterQueue)
+ stateSinkObject.Save(4, &e.uniqueID)
+ stateSinkObject.Save(5, &e.hardError)
+ stateSinkObject.Save(6, &e.lastError)
+ stateSinkObject.Save(7, &e.rcvQueueInfo)
+ stateSinkObject.Save(8, &e.rcvMemUsed)
+ stateSinkObject.Save(9, &e.ownedByUser)
+ stateSinkObject.Save(11, &e.boundNICID)
+ stateSinkObject.Save(12, &e.ttl)
+ stateSinkObject.Save(13, &e.isConnectNotified)
+ stateSinkObject.Save(14, &e.portFlags)
+ stateSinkObject.Save(15, &e.boundBindToDevice)
+ stateSinkObject.Save(16, &e.boundPortFlags)
+ stateSinkObject.Save(17, &e.boundDest)
+ stateSinkObject.Save(18, &e.effectiveNetProtos)
+ stateSinkObject.Save(19, &e.workerRunning)
+ stateSinkObject.Save(20, &e.workerCleanup)
+ stateSinkObject.Save(21, &e.recentTSTime)
+ stateSinkObject.Save(22, &e.shutdownFlags)
+ stateSinkObject.Save(23, &e.tcpRecovery)
+ stateSinkObject.Save(24, &e.sack)
+ stateSinkObject.Save(25, &e.delay)
+ stateSinkObject.Save(26, &e.scoreboard)
+ stateSinkObject.Save(27, &e.segmentQueue)
+ stateSinkObject.Save(28, &e.synRcvdCount)
+ stateSinkObject.Save(29, &e.userMSS)
+ stateSinkObject.Save(30, &e.maxSynRetries)
+ stateSinkObject.Save(31, &e.windowClamp)
+ stateSinkObject.Save(32, &e.sndQueueInfo)
+ stateSinkObject.Save(33, &e.cc)
+ stateSinkObject.Save(34, &e.keepalive)
+ stateSinkObject.Save(35, &e.userTimeout)
+ stateSinkObject.Save(36, &e.deferAccept)
+ stateSinkObject.Save(37, &e.accepted)
+ stateSinkObject.Save(38, &e.rcv)
+ stateSinkObject.Save(39, &e.snd)
+ stateSinkObject.Save(40, &e.connectingAddress)
+ stateSinkObject.Save(41, &e.amss)
+ stateSinkObject.Save(42, &e.sendTOS)
+ stateSinkObject.Save(43, &e.gso)
+ stateSinkObject.Save(44, &e.tcpLingerTimeout)
+ stateSinkObject.Save(45, &e.closed)
+ stateSinkObject.Save(46, &e.txHash)
+ stateSinkObject.Save(47, &e.owner)
+ stateSinkObject.Save(48, &e.ops)
+ stateSinkObject.Save(49, &e.lastOutOfWindowAckTime)
+}
+
+// +checklocksignore
+func (e *endpoint) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.TCPEndpointStateInner)
+ stateSourceObject.Load(1, &e.TransportEndpointInfo)
+ stateSourceObject.Load(2, &e.DefaultSocketOptionsHandler)
+ stateSourceObject.LoadWait(3, &e.waiterQueue)
+ stateSourceObject.Load(4, &e.uniqueID)
+ stateSourceObject.Load(5, &e.hardError)
+ stateSourceObject.Load(6, &e.lastError)
+ stateSourceObject.Load(7, &e.rcvQueueInfo)
+ stateSourceObject.Load(8, &e.rcvMemUsed)
+ stateSourceObject.Load(9, &e.ownedByUser)
+ stateSourceObject.Load(11, &e.boundNICID)
+ stateSourceObject.Load(12, &e.ttl)
+ stateSourceObject.Load(13, &e.isConnectNotified)
+ stateSourceObject.Load(14, &e.portFlags)
+ stateSourceObject.Load(15, &e.boundBindToDevice)
+ stateSourceObject.Load(16, &e.boundPortFlags)
+ stateSourceObject.Load(17, &e.boundDest)
+ stateSourceObject.Load(18, &e.effectiveNetProtos)
+ stateSourceObject.Load(19, &e.workerRunning)
+ stateSourceObject.Load(20, &e.workerCleanup)
+ stateSourceObject.Load(21, &e.recentTSTime)
+ stateSourceObject.Load(22, &e.shutdownFlags)
+ stateSourceObject.Load(23, &e.tcpRecovery)
+ stateSourceObject.Load(24, &e.sack)
+ stateSourceObject.Load(25, &e.delay)
+ stateSourceObject.Load(26, &e.scoreboard)
+ stateSourceObject.LoadWait(27, &e.segmentQueue)
+ stateSourceObject.Load(28, &e.synRcvdCount)
+ stateSourceObject.Load(29, &e.userMSS)
+ stateSourceObject.Load(30, &e.maxSynRetries)
+ stateSourceObject.Load(31, &e.windowClamp)
+ stateSourceObject.Load(32, &e.sndQueueInfo)
+ stateSourceObject.Load(33, &e.cc)
+ stateSourceObject.Load(34, &e.keepalive)
+ stateSourceObject.Load(35, &e.userTimeout)
+ stateSourceObject.Load(36, &e.deferAccept)
+ stateSourceObject.Load(37, &e.accepted)
+ stateSourceObject.LoadWait(38, &e.rcv)
+ stateSourceObject.LoadWait(39, &e.snd)
+ stateSourceObject.Load(40, &e.connectingAddress)
+ stateSourceObject.Load(41, &e.amss)
+ stateSourceObject.Load(42, &e.sendTOS)
+ stateSourceObject.Load(43, &e.gso)
+ stateSourceObject.Load(44, &e.tcpLingerTimeout)
+ stateSourceObject.Load(45, &e.closed)
+ stateSourceObject.Load(46, &e.txHash)
+ stateSourceObject.Load(47, &e.owner)
+ stateSourceObject.Load(48, &e.ops)
+ stateSourceObject.Load(49, &e.lastOutOfWindowAckTime)
+ stateSourceObject.LoadValue(10, new(EndpointState), func(y interface{}) { e.loadState(y.(EndpointState)) })
+ stateSourceObject.AfterLoad(e.afterLoad)
+}
+
+func (k *keepalive) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.keepalive"
+}
+
+func (k *keepalive) StateFields() []string {
+ return []string{
+ "idle",
+ "interval",
+ "count",
+ "unacked",
+ }
+}
+
+func (k *keepalive) beforeSave() {}
+
+// +checklocksignore
+func (k *keepalive) StateSave(stateSinkObject state.Sink) {
+ k.beforeSave()
+ stateSinkObject.Save(0, &k.idle)
+ stateSinkObject.Save(1, &k.interval)
+ stateSinkObject.Save(2, &k.count)
+ stateSinkObject.Save(3, &k.unacked)
+}
+
+func (k *keepalive) afterLoad() {}
+
+// +checklocksignore
+func (k *keepalive) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &k.idle)
+ stateSourceObject.Load(1, &k.interval)
+ stateSourceObject.Load(2, &k.count)
+ stateSourceObject.Load(3, &k.unacked)
+}
+
+func (rc *rackControl) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.rackControl"
+}
+
+func (rc *rackControl) StateFields() []string {
+ return []string{
+ "TCPRACKState",
+ "exitedRecovery",
+ "minRTT",
+ "tlpRxtOut",
+ "tlpHighRxt",
+ "snd",
+ }
+}
+
+func (rc *rackControl) beforeSave() {}
+
+// +checklocksignore
+func (rc *rackControl) StateSave(stateSinkObject state.Sink) {
+ rc.beforeSave()
+ stateSinkObject.Save(0, &rc.TCPRACKState)
+ stateSinkObject.Save(1, &rc.exitedRecovery)
+ stateSinkObject.Save(2, &rc.minRTT)
+ stateSinkObject.Save(3, &rc.tlpRxtOut)
+ stateSinkObject.Save(4, &rc.tlpHighRxt)
+ stateSinkObject.Save(5, &rc.snd)
+}
+
+func (rc *rackControl) afterLoad() {}
+
+// +checklocksignore
+func (rc *rackControl) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &rc.TCPRACKState)
+ stateSourceObject.Load(1, &rc.exitedRecovery)
+ stateSourceObject.Load(2, &rc.minRTT)
+ stateSourceObject.Load(3, &rc.tlpRxtOut)
+ stateSourceObject.Load(4, &rc.tlpHighRxt)
+ stateSourceObject.Load(5, &rc.snd)
+}
+
+func (r *receiver) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.receiver"
+}
+
+func (r *receiver) StateFields() []string {
+ return []string{
+ "TCPReceiverState",
+ "ep",
+ "rcvWnd",
+ "rcvWUP",
+ "prevBufUsed",
+ "closed",
+ "pendingRcvdSegments",
+ "lastRcvdAckTime",
+ }
+}
+
+func (r *receiver) beforeSave() {}
+
+// +checklocksignore
+func (r *receiver) StateSave(stateSinkObject state.Sink) {
+ r.beforeSave()
+ stateSinkObject.Save(0, &r.TCPReceiverState)
+ stateSinkObject.Save(1, &r.ep)
+ stateSinkObject.Save(2, &r.rcvWnd)
+ stateSinkObject.Save(3, &r.rcvWUP)
+ stateSinkObject.Save(4, &r.prevBufUsed)
+ stateSinkObject.Save(5, &r.closed)
+ stateSinkObject.Save(6, &r.pendingRcvdSegments)
+ stateSinkObject.Save(7, &r.lastRcvdAckTime)
+}
+
+func (r *receiver) afterLoad() {}
+
+// +checklocksignore
+func (r *receiver) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &r.TCPReceiverState)
+ stateSourceObject.Load(1, &r.ep)
+ stateSourceObject.Load(2, &r.rcvWnd)
+ stateSourceObject.Load(3, &r.rcvWUP)
+ stateSourceObject.Load(4, &r.prevBufUsed)
+ stateSourceObject.Load(5, &r.closed)
+ stateSourceObject.Load(6, &r.pendingRcvdSegments)
+ stateSourceObject.Load(7, &r.lastRcvdAckTime)
+}
+
+func (r *renoState) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.renoState"
+}
+
+func (r *renoState) StateFields() []string {
+ return []string{
+ "s",
+ }
+}
+
+func (r *renoState) beforeSave() {}
+
+// +checklocksignore
+func (r *renoState) StateSave(stateSinkObject state.Sink) {
+ r.beforeSave()
+ stateSinkObject.Save(0, &r.s)
+}
+
+func (r *renoState) afterLoad() {}
+
+// +checklocksignore
+func (r *renoState) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &r.s)
+}
+
+func (rr *renoRecovery) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.renoRecovery"
+}
+
+func (rr *renoRecovery) StateFields() []string {
+ return []string{
+ "s",
+ }
+}
+
+func (rr *renoRecovery) beforeSave() {}
+
+// +checklocksignore
+func (rr *renoRecovery) StateSave(stateSinkObject state.Sink) {
+ rr.beforeSave()
+ stateSinkObject.Save(0, &rr.s)
+}
+
+func (rr *renoRecovery) afterLoad() {}
+
+// +checklocksignore
+func (rr *renoRecovery) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &rr.s)
+}
+
+func (sr *sackRecovery) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.sackRecovery"
+}
+
+func (sr *sackRecovery) StateFields() []string {
+ return []string{
+ "s",
+ }
+}
+
+func (sr *sackRecovery) beforeSave() {}
+
+// +checklocksignore
+func (sr *sackRecovery) StateSave(stateSinkObject state.Sink) {
+ sr.beforeSave()
+ stateSinkObject.Save(0, &sr.s)
+}
+
+func (sr *sackRecovery) afterLoad() {}
+
+// +checklocksignore
+func (sr *sackRecovery) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &sr.s)
+}
+
+func (s *SACKScoreboard) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.SACKScoreboard"
+}
+
+func (s *SACKScoreboard) StateFields() []string {
+ return []string{
+ "smss",
+ "maxSACKED",
+ }
+}
+
+func (s *SACKScoreboard) beforeSave() {}
+
+// +checklocksignore
+func (s *SACKScoreboard) StateSave(stateSinkObject state.Sink) {
+ s.beforeSave()
+ stateSinkObject.Save(0, &s.smss)
+ stateSinkObject.Save(1, &s.maxSACKED)
+}
+
+func (s *SACKScoreboard) afterLoad() {}
+
+// +checklocksignore
+func (s *SACKScoreboard) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &s.smss)
+ stateSourceObject.Load(1, &s.maxSACKED)
+}
+
+func (s *segment) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.segment"
+}
+
+func (s *segment) StateFields() []string {
+ return []string{
+ "segmentEntry",
+ "refCnt",
+ "ep",
+ "qFlags",
+ "srcAddr",
+ "dstAddr",
+ "netProto",
+ "nicID",
+ "data",
+ "hdr",
+ "sequenceNumber",
+ "ackNumber",
+ "flags",
+ "window",
+ "csum",
+ "csumValid",
+ "parsedOptions",
+ "options",
+ "hasNewSACKInfo",
+ "rcvdTime",
+ "xmitTime",
+ "xmitCount",
+ "acked",
+ "dataMemSize",
+ "lost",
+ }
+}
+
+func (s *segment) beforeSave() {}
+
+// +checklocksignore
+func (s *segment) StateSave(stateSinkObject state.Sink) {
+ s.beforeSave()
+ var dataValue buffer.VectorisedView
+ dataValue = s.saveData()
+ stateSinkObject.SaveValue(8, dataValue)
+ var optionsValue []byte
+ optionsValue = s.saveOptions()
+ stateSinkObject.SaveValue(17, optionsValue)
+ stateSinkObject.Save(0, &s.segmentEntry)
+ stateSinkObject.Save(1, &s.refCnt)
+ stateSinkObject.Save(2, &s.ep)
+ stateSinkObject.Save(3, &s.qFlags)
+ stateSinkObject.Save(4, &s.srcAddr)
+ stateSinkObject.Save(5, &s.dstAddr)
+ stateSinkObject.Save(6, &s.netProto)
+ stateSinkObject.Save(7, &s.nicID)
+ stateSinkObject.Save(9, &s.hdr)
+ stateSinkObject.Save(10, &s.sequenceNumber)
+ stateSinkObject.Save(11, &s.ackNumber)
+ stateSinkObject.Save(12, &s.flags)
+ stateSinkObject.Save(13, &s.window)
+ stateSinkObject.Save(14, &s.csum)
+ stateSinkObject.Save(15, &s.csumValid)
+ stateSinkObject.Save(16, &s.parsedOptions)
+ stateSinkObject.Save(18, &s.hasNewSACKInfo)
+ stateSinkObject.Save(19, &s.rcvdTime)
+ stateSinkObject.Save(20, &s.xmitTime)
+ stateSinkObject.Save(21, &s.xmitCount)
+ stateSinkObject.Save(22, &s.acked)
+ stateSinkObject.Save(23, &s.dataMemSize)
+ stateSinkObject.Save(24, &s.lost)
+}
+
+func (s *segment) afterLoad() {}
+
+// +checklocksignore
+func (s *segment) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &s.segmentEntry)
+ stateSourceObject.Load(1, &s.refCnt)
+ stateSourceObject.Load(2, &s.ep)
+ stateSourceObject.Load(3, &s.qFlags)
+ stateSourceObject.Load(4, &s.srcAddr)
+ stateSourceObject.Load(5, &s.dstAddr)
+ stateSourceObject.Load(6, &s.netProto)
+ stateSourceObject.Load(7, &s.nicID)
+ stateSourceObject.Load(9, &s.hdr)
+ stateSourceObject.Load(10, &s.sequenceNumber)
+ stateSourceObject.Load(11, &s.ackNumber)
+ stateSourceObject.Load(12, &s.flags)
+ stateSourceObject.Load(13, &s.window)
+ stateSourceObject.Load(14, &s.csum)
+ stateSourceObject.Load(15, &s.csumValid)
+ stateSourceObject.Load(16, &s.parsedOptions)
+ stateSourceObject.Load(18, &s.hasNewSACKInfo)
+ stateSourceObject.Load(19, &s.rcvdTime)
+ stateSourceObject.Load(20, &s.xmitTime)
+ stateSourceObject.Load(21, &s.xmitCount)
+ stateSourceObject.Load(22, &s.acked)
+ stateSourceObject.Load(23, &s.dataMemSize)
+ stateSourceObject.Load(24, &s.lost)
+ stateSourceObject.LoadValue(8, new(buffer.VectorisedView), func(y interface{}) { s.loadData(y.(buffer.VectorisedView)) })
+ stateSourceObject.LoadValue(17, new([]byte), func(y interface{}) { s.loadOptions(y.([]byte)) })
+}
+
+func (q *segmentQueue) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.segmentQueue"
+}
+
+func (q *segmentQueue) StateFields() []string {
+ return []string{
+ "list",
+ "ep",
+ "frozen",
+ }
+}
+
+func (q *segmentQueue) beforeSave() {}
+
+// +checklocksignore
+func (q *segmentQueue) StateSave(stateSinkObject state.Sink) {
+ q.beforeSave()
+ stateSinkObject.Save(0, &q.list)
+ stateSinkObject.Save(1, &q.ep)
+ stateSinkObject.Save(2, &q.frozen)
+}
+
+func (q *segmentQueue) afterLoad() {}
+
+// +checklocksignore
+func (q *segmentQueue) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.LoadWait(0, &q.list)
+ stateSourceObject.Load(1, &q.ep)
+ stateSourceObject.Load(2, &q.frozen)
+}
+
+func (s *sender) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.sender"
+}
+
+func (s *sender) StateFields() []string {
+ return []string{
+ "TCPSenderState",
+ "ep",
+ "lr",
+ "firstRetransmittedSegXmitTime",
+ "writeNext",
+ "writeList",
+ "rtt",
+ "minRTO",
+ "maxRTO",
+ "maxRetries",
+ "gso",
+ "state",
+ "cc",
+ "rc",
+ }
+}
+
+func (s *sender) beforeSave() {}
+
+// +checklocksignore
+func (s *sender) StateSave(stateSinkObject state.Sink) {
+ s.beforeSave()
+ stateSinkObject.Save(0, &s.TCPSenderState)
+ stateSinkObject.Save(1, &s.ep)
+ stateSinkObject.Save(2, &s.lr)
+ stateSinkObject.Save(3, &s.firstRetransmittedSegXmitTime)
+ stateSinkObject.Save(4, &s.writeNext)
+ stateSinkObject.Save(5, &s.writeList)
+ stateSinkObject.Save(6, &s.rtt)
+ stateSinkObject.Save(7, &s.minRTO)
+ stateSinkObject.Save(8, &s.maxRTO)
+ stateSinkObject.Save(9, &s.maxRetries)
+ stateSinkObject.Save(10, &s.gso)
+ stateSinkObject.Save(11, &s.state)
+ stateSinkObject.Save(12, &s.cc)
+ stateSinkObject.Save(13, &s.rc)
+}
+
+func (s *sender) afterLoad() {}
+
+// +checklocksignore
+func (s *sender) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &s.TCPSenderState)
+ stateSourceObject.Load(1, &s.ep)
+ stateSourceObject.Load(2, &s.lr)
+ stateSourceObject.Load(3, &s.firstRetransmittedSegXmitTime)
+ stateSourceObject.Load(4, &s.writeNext)
+ stateSourceObject.Load(5, &s.writeList)
+ stateSourceObject.Load(6, &s.rtt)
+ stateSourceObject.Load(7, &s.minRTO)
+ stateSourceObject.Load(8, &s.maxRTO)
+ stateSourceObject.Load(9, &s.maxRetries)
+ stateSourceObject.Load(10, &s.gso)
+ stateSourceObject.Load(11, &s.state)
+ stateSourceObject.Load(12, &s.cc)
+ stateSourceObject.Load(13, &s.rc)
+}
+
+func (r *rtt) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.rtt"
+}
+
+func (r *rtt) StateFields() []string {
+ return []string{
+ "TCPRTTState",
+ }
+}
+
+func (r *rtt) beforeSave() {}
+
+// +checklocksignore
+func (r *rtt) StateSave(stateSinkObject state.Sink) {
+ r.beforeSave()
+ stateSinkObject.Save(0, &r.TCPRTTState)
+}
+
+func (r *rtt) afterLoad() {}
+
+// +checklocksignore
+func (r *rtt) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &r.TCPRTTState)
+}
+
+func (l *endpointList) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.endpointList"
+}
+
+func (l *endpointList) StateFields() []string {
+ return []string{
+ "head",
+ "tail",
+ }
+}
+
+func (l *endpointList) beforeSave() {}
+
+// +checklocksignore
+func (l *endpointList) StateSave(stateSinkObject state.Sink) {
+ l.beforeSave()
+ stateSinkObject.Save(0, &l.head)
+ stateSinkObject.Save(1, &l.tail)
+}
+
+func (l *endpointList) afterLoad() {}
+
+// +checklocksignore
+func (l *endpointList) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &l.head)
+ stateSourceObject.Load(1, &l.tail)
+}
+
+func (e *endpointEntry) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.endpointEntry"
+}
+
+func (e *endpointEntry) StateFields() []string {
+ return []string{
+ "next",
+ "prev",
+ }
+}
+
+func (e *endpointEntry) beforeSave() {}
+
+// +checklocksignore
+func (e *endpointEntry) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.next)
+ stateSinkObject.Save(1, &e.prev)
+}
+
+func (e *endpointEntry) afterLoad() {}
+
+// +checklocksignore
+func (e *endpointEntry) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.next)
+ stateSourceObject.Load(1, &e.prev)
+}
+
+func (l *segmentList) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.segmentList"
+}
+
+func (l *segmentList) StateFields() []string {
+ return []string{
+ "head",
+ "tail",
+ }
+}
+
+func (l *segmentList) beforeSave() {}
+
+// +checklocksignore
+func (l *segmentList) StateSave(stateSinkObject state.Sink) {
+ l.beforeSave()
+ stateSinkObject.Save(0, &l.head)
+ stateSinkObject.Save(1, &l.tail)
+}
+
+func (l *segmentList) afterLoad() {}
+
+// +checklocksignore
+func (l *segmentList) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &l.head)
+ stateSourceObject.Load(1, &l.tail)
+}
+
+func (e *segmentEntry) StateTypeName() string {
+ return "pkg/tcpip/transport/tcp.segmentEntry"
+}
+
+func (e *segmentEntry) StateFields() []string {
+ return []string{
+ "next",
+ "prev",
+ }
+}
+
+func (e *segmentEntry) beforeSave() {}
+
+// +checklocksignore
+func (e *segmentEntry) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.next)
+ stateSinkObject.Save(1, &e.prev)
+}
+
+func (e *segmentEntry) afterLoad() {}
+
+// +checklocksignore
+func (e *segmentEntry) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.next)
+ stateSourceObject.Load(1, &e.prev)
+}
+
+func init() {
+ state.Register((*cubicState)(nil))
+ state.Register((*SACKInfo)(nil))
+ state.Register((*sndQueueInfo)(nil))
+ state.Register((*rcvQueueInfo)(nil))
+ state.Register((*accepted)(nil))
+ state.Register((*endpoint)(nil))
+ state.Register((*keepalive)(nil))
+ state.Register((*rackControl)(nil))
+ state.Register((*receiver)(nil))
+ state.Register((*renoState)(nil))
+ state.Register((*renoRecovery)(nil))
+ state.Register((*sackRecovery)(nil))
+ state.Register((*SACKScoreboard)(nil))
+ state.Register((*segment)(nil))
+ state.Register((*segmentQueue)(nil))
+ state.Register((*sender)(nil))
+ state.Register((*rtt)(nil))
+ state.Register((*endpointList)(nil))
+ state.Register((*endpointEntry)(nil))
+ state.Register((*segmentList)(nil))
+ state.Register((*segmentEntry)(nil))
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
deleted file mode 100644
index 6f1ee3816..000000000
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ /dev/null
@@ -1,8602 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp_test
-
-import (
- "bytes"
- "fmt"
- "io/ioutil"
- "math"
- "strings"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/rand"
- "gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- tcpiptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
- "gvisor.dev/gvisor/pkg/test/testutil"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-// endpointTester provides helper functions to test a tcpip.Endpoint.
-type endpointTester struct {
- ep tcpip.Endpoint
-}
-
-// CheckReadError issues a read to the endpoint and checking for an error.
-func (e *endpointTester) CheckReadError(t *testing.T, want tcpip.Error) {
- t.Helper()
- res, got := e.ep.Read(ioutil.Discard, tcpip.ReadOptions{})
- if got != want {
- t.Fatalf("ep.Read = %s, want %s", got, want)
- }
- if diff := cmp.Diff(tcpip.ReadResult{}, res); diff != "" {
- t.Errorf("ep.Read: unexpected non-zero result (-want +got):\n%s", diff)
- }
-}
-
-// CheckRead issues a read to the endpoint and checking for a success, returning
-// the data read.
-func (e *endpointTester) CheckRead(t *testing.T) []byte {
- t.Helper()
- var buf bytes.Buffer
- res, err := e.ep.Read(&buf, tcpip.ReadOptions{})
- if err != nil {
- t.Fatalf("ep.Read = _, %s; want _, nil", err)
- }
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: buf.Len(),
- Total: buf.Len(),
- }, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
- t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
- }
- return buf.Bytes()
-}
-
-// CheckReadFull reads from the endpoint for exactly count bytes.
-func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-chan struct{}, timeout time.Duration) []byte {
- t.Helper()
- var buf bytes.Buffer
- w := tcpip.LimitedWriter{
- W: &buf,
- N: int64(count),
- }
- for w.N != 0 {
- _, err := e.ep.Read(&w, tcpip.ReadOptions{})
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for receive to be notified.
- select {
- case <-notifyRead:
- case <-time.After(timeout):
- t.Fatalf("Timed out waiting for data to arrive")
- }
- continue
- } else if err != nil {
- t.Fatalf("ep.Read = _, %s; want _, nil", err)
- }
- }
- return buf.Bytes()
-}
-
-const (
- // defaultMTU is the MTU, in bytes, used throughout the tests, except
- // where another value is explicitly used. It is chosen to match the MTU
- // of loopback interfaces on linux systems.
- defaultMTU = 65535
-
- // defaultIPv4MSS is the MSS sent by the network stack in SYN/SYN-ACK for an
- // IPv4 endpoint when the MTU is set to defaultMTU in the test.
- defaultIPv4MSS = defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize
-)
-
-func TestGiveUpConnect(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- var wq waiter.Queue
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- // Register for notification, then start connection attempt.
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- wq.EventRegister(&waitEntry, waiter.EventHUp)
- defer wq.EventUnregister(&waitEntry)
-
- {
- err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
- t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d)
- }
- }
-
- // Close the connection, wait for completion.
- ep.Close()
-
- // Wait for ep to become writable.
- <-notifyCh
-
- // Call Connect again to retreive the handshake failure status
- // and stats updates.
- {
- err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if d := cmp.Diff(&tcpip.ErrAborted{}, err); d != "" {
- t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d)
- }
- }
-
- if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 {
- t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = 1", got)
- }
-
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
- }
-}
-
-// Test for ICMP error handling without completing handshake.
-func TestConnectICMPError(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- var wq waiter.Queue
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- wq.EventRegister(&waitEntry, waiter.EventHUp)
- defer wq.EventUnregister(&waitEntry)
-
- {
- err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
- t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d)
- }
- }
-
- syn := c.GetPacket()
- checker.IPv4(t, syn, checker.TCP(checker.TCPFlags(header.TCPFlagSyn)))
-
- wep := ep.(interface {
- StopWork()
- ResumeWork()
- LastErrorLocked() tcpip.Error
- })
-
- // Stop the protocol loop, ensure that the ICMP error is processed and
- // the last ICMP error is read before the loop is resumed. This sanity
- // tests the handshake completion logic on ICMP errors.
- wep.StopWork()
-
- c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable, nil, syn, defaultMTU)
-
- for {
- if err := wep.LastErrorLocked(); err != nil {
- if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" {
- t.Errorf("ep.LastErrorLocked() mismatch (-want +got):\n%s", d)
- }
- break
- }
- time.Sleep(time.Millisecond)
- }
-
- wep.ResumeWork()
-
- <-notifyCh
-
- // The stack would have unregistered the endpoint because of the ICMP error.
- // Expect a RST for any subsequent packets sent to the endpoint.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(context.TestInitialSequenceNumber) + 1,
- AckNum: c.IRS + 1,
- })
-
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+1)),
- checker.TCPAckNum(0),
- checker.TCPFlags(header.TCPFlagRst)))
-}
-
-func TestConnectIncrementActiveConnection(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- stats := c.Stack().Stats()
- want := stats.TCP.ActiveConnectionOpenings.Value() + 1
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
- if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want {
- t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %d, want = %d", got, want)
- }
-}
-
-func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- stats := c.Stack().Stats()
- want := stats.TCP.FailedConnectionAttempts.Value()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
- if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want)
- }
- if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got EP stats.FailedConnectionAttempts = %d, want = %d", got, want)
- }
-}
-
-func TestActiveFailedConnectionAttemptIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- stats := c.Stack().Stats()
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- c.EP = ep
- want := stats.TCP.FailedConnectionAttempts.Value() + 1
-
- {
- err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort})
- if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" {
- t.Errorf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
- }
- }
-
- if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want)
- }
- if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got EP stats FailedConnectionAttempts = %d, want = %d", got, want)
- }
-}
-
-func TestCloseWithoutConnect(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- c.EP.Close()
-
- if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
- }
-}
-
-func TestTCPSegmentsSentIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- stats := c.Stack().Stats()
- // SYN and ACK
- want := stats.TCP.SegmentsSent.Value() + 2
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- if got := stats.TCP.SegmentsSent.Value(); got != want {
- t.Errorf("got stats.TCP.SegmentsSent.Value() = %d, want = %d", got, want)
- }
- if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want {
- t.Errorf("got EP stats SegmentsSent.Value() = %d, want = %d", got, want)
- }
-}
-
-func TestTCPResetsSentIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- stats := c.Stack().Stats()
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- want := stats.TCP.SegmentsSent.Value() + 1
-
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send a SYN request.
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- ackHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- // If the AckNum is not the increment of the last sequence number, a RST
- // segment is sent back in response.
- AckNum: c.IRS + 2,
- }
-
- // Send ACK.
- c.SendPacket(nil, ackHeaders)
-
- c.GetPacket()
-
- metricPollFn := func() error {
- if got := stats.TCP.ResetsSent.Value(); got != want {
- return fmt.Errorf("got stats.TCP.ResetsSent.Value() = %d, want = %d", got, want)
- }
- return nil
- }
- if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
- t.Error(err)
- }
-}
-
-// TestTCPResetsSentNoICMP confirms that we don't get an ICMP
-// DstUnreachable packet when we try send a packet which is not part
-// of an active session.
-func TestTCPResetsSentNoICMP(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- stats := c.Stack().Stats()
-
- // Send a SYN request for a closed port. This should elicit an RST
- // but NOT an ICMPv4 DstUnreachable packet.
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- })
-
- // Receive whatever comes back.
- b := c.GetPacket()
- ipHdr := header.IPv4(b)
- if got, want := ipHdr.Protocol(), uint8(header.TCPProtocolNumber); got != want {
- t.Errorf("unexpected protocol, got = %d, want = %d", got, want)
- }
-
- // Read outgoing ICMP stats and check no ICMP DstUnreachable was recorded.
- sent := stats.ICMP.V4.PacketsSent
- if got, want := sent.DstUnreachable.Value(), uint64(0); got != want {
- t.Errorf("got ICMP DstUnreachable.Value() = %d, want = %d", got, want)
- }
-}
-
-// TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates
-// a RST if an ACK is received on the listening socket for which there is no
-// active handshake in progress and we are not using SYN cookies.
-func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set TCPLingerTimeout to 5 seconds so that sockets are marked closed
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send a SYN request.
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- ackHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 1,
- }
-
- // Send ACK.
- c.SendPacket(nil, ackHeaders)
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Lower stackwide TIME_WAIT timeout so that the reservations
- // are released instantly on Close.
- tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &tcpTW); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, tcpTW, tcpTW, err)
- }
-
- c.EP.Close()
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+1)),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
- finHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss + 1,
- AckNum: c.IRS + 2,
- }
-
- c.SendPacket(nil, finHeaders)
-
- // Get the ACK to the FIN we just sent.
- c.GetPacket()
-
- // Since an active close was done we need to wait for a little more than
- // tcpLingerTimeout for the port reservations to be released and the
- // socket to move to a CLOSED state.
- time.Sleep(20 * time.Millisecond)
-
- // Now resend the same ACK, this ACK should generate a RST as there
- // should be no endpoint in SYN-RCVD state and we are not using
- // syn-cookies yet. The reason we send the same ACK is we need a valid
- // cookie(IRS) generated by the netstack without which the ACK will be
- // rejected.
- c.SendPacket(nil, ackHeaders)
-
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+1)),
- checker.TCPAckNum(0),
- checker.TCPFlags(header.TCPFlagRst)))
-}
-
-func TestTCPResetsReceivedIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- stats := c.Stack().Stats()
- want := stats.TCP.ResetsReceived.Value() + 1
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- rcvWnd := seqnum.Size(30000)
- c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
-
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- SeqNum: iss.Add(1),
- AckNum: c.IRS.Add(1),
- RcvWnd: rcvWnd,
- Flags: header.TCPFlagRst,
- })
-
- if got := stats.TCP.ResetsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want)
- }
-}
-
-func TestTCPResetsDoNotGenerateResets(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- stats := c.Stack().Stats()
- want := stats.TCP.ResetsReceived.Value() + 1
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- rcvWnd := seqnum.Size(30000)
- c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
-
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- SeqNum: iss.Add(1),
- AckNum: c.IRS.Add(1),
- RcvWnd: rcvWnd,
- Flags: header.TCPFlagRst,
- })
-
- if got := stats.TCP.ResetsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want)
- }
- c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond)
-}
-
-func TestActiveHandshake(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-}
-
-func TestNonBlockingClose(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
- ep := c.EP
- c.EP = nil
-
- // Close the endpoint and measure how long it takes.
- t0 := time.Now()
- ep.Close()
- if diff := time.Now().Sub(t0); diff > 3*time.Second {
- t.Fatalf("Took too long to close: %s", diff)
- }
-}
-
-func TestConnectResetAfterClose(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set TCPLinger to 3 seconds so that sockets are marked closed
- // after 3 second in FIN_WAIT2 state.
- tcpLingerTimeout := 3 * time.Second
- opt := tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
- ep := c.EP
- c.EP = nil
-
- // Close the endpoint, make sure we get a FIN segment, then acknowledge
- // to complete closure of sender, but don't send our own FIN.
- ep.Close()
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- // Wait for the ep to give up waiting for a FIN.
- time.Sleep(tcpLingerTimeout + 1*time.Second)
-
- // Now send an ACK and it should trigger a RST as the endpoint should
- // not exist anymore.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- for {
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- if tcpHdr.Flags() == header.TCPFlagAck|header.TCPFlagFin {
- // This is a retransmit of the FIN, ignore it.
- continue
- }
-
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- // RST is always generated with sndNxt which if the FIN
- // has been sent will be 1 higher than the sequence number
- // of the FIN itself.
- checker.TCPSeqNum(uint32(c.IRS)+2),
- checker.TCPAckNum(0),
- checker.TCPFlags(header.TCPFlagRst),
- ),
- )
- break
- }
-}
-
-// TestCurrentConnectedIncrement tests increment of the current
-// established and connected counters.
-func TestCurrentConnectedIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed
- // after 1 second in TIME_WAIT state.
- tcpTimeWaitTimeout := 1 * time.Second
- opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
- ep := c.EP
- c.EP = nil
-
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 1 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 1", got)
- }
- gotConnected := c.Stack().Stats().TCP.CurrentConnected.Value()
- if gotConnected != 1 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 1", gotConnected)
- }
-
- ep.Close()
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
- }
- if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != gotConnected {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = %d", got, gotConnected)
- }
-
- // Ack and send FIN as well.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss,
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- // Check that the stack acks the FIN.
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+2),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-
- // Wait for a little more than the TIME-WAIT duration for the socket to
- // transition to CLOSED state.
- time.Sleep(1200 * time.Millisecond)
-
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
- }
- if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
- }
-}
-
-// TestClosingWithEnqueuedSegments tests handling of still enqueued segments
-// when the endpoint transitions to StateClose. The in-flight segments would be
-// re-enqueued to a any listening endpoint.
-func TestClosingWithEnqueuedSegments(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
- ep := c.EP
- c.EP = nil
-
- if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want {
- t.Errorf("unexpected endpoint state: want %d, got %d", want, got)
- }
-
- // Send a FIN for ESTABLISHED --> CLOSED-WAIT
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagFin | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Get the ACK for the FIN we sent.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-
- // Give the stack a few ms to transition the endpoint out of ESTABLISHED
- // state.
- time.Sleep(10 * time.Millisecond)
-
- if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want {
- t.Errorf("unexpected endpoint state: want %d, got %d", want, got)
- }
-
- // Close the application endpoint for CLOSE_WAIT --> LAST_ACK
- ep.Close()
-
- // Get the FIN
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
-
- if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
-
- // Pause the endpoint`s protocolMainLoop.
- ep.(interface{ StopWork() }).StopWork()
-
- // Enqueue last ACK followed by an ACK matching the endpoint
- //
- // Send Last ACK for LAST_ACK --> CLOSED
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss.Add(1),
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- // Send a packet with ACK set, this would generate RST when
- // not using SYN cookies as in this test.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss.Add(2),
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- // Unpause endpoint`s protocolMainLoop.
- ep.(interface{ ResumeWork() }).ResumeWork()
-
- // Wait for the protocolMainLoop to resume and update state.
- time.Sleep(10 * time.Millisecond)
-
- // Expect the endpoint to be closed.
- if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
-
- if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = 1", got)
- }
-
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
- }
-
- // Check if the endpoint was moved to CLOSED and netstack a reset in
- // response to the ACK packet that we sent after last-ACK.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+2),
- checker.TCPAckNum(0),
- checker.TCPFlags(header.TCPFlagRst),
- ),
- )
-}
-
-func TestSimpleReceive(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- ept := endpointTester{c.EP}
-
- data := []byte{1, 2, 3}
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
- // Receive data.
- v := ept.CheckRead(t)
- if !bytes.Equal(data, v) {
- t.Fatalf("got data = %v, want = %v", v, data)
- }
-
- // Check that ACK is received.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+uint32(len(data))),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-// TestUserSuppliedMSSOnConnect tests that the user supplied MSS is used when
-// creating a new active TCP socket. It should be present in the sent TCP
-// SYN segment.
-func TestUserSuppliedMSSOnConnect(t *testing.T) {
- const mtu = 5000
-
- ips := []struct {
- name string
- createEP func(*context.Context)
- connectAddr tcpip.Address
- checker func(*testing.T, *context.Context, uint16, int)
- maxMSS uint16
- }{
- {
- name: "IPv4",
- createEP: func(c *context.Context) {
- c.Create(-1)
- },
- connectAddr: context.TestAddr,
- checker: func(t *testing.T, c *context.Context, mss uint16, ws int) {
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws})))
- },
- maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize,
- },
- {
- name: "IPv6",
- createEP: func(c *context.Context) {
- c.CreateV6Endpoint(true)
- },
- connectAddr: context.TestV6Addr,
- checker: func(t *testing.T, c *context.Context, mss uint16, ws int) {
- checker.IPv6(t, c.GetV6Packet(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws})))
- },
- maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize,
- },
- }
-
- for _, ip := range ips {
- t.Run(ip.name, func(t *testing.T) {
- tests := []struct {
- name string
- setMSS uint16
- expMSS uint16
- }{
- {
- name: "EqualToMaxMSS",
- setMSS: ip.maxMSS,
- expMSS: ip.maxMSS,
- },
- {
- name: "LessThanMaxMSS",
- setMSS: ip.maxMSS - 1,
- expMSS: ip.maxMSS - 1,
- },
- {
- name: "GreaterThanMaxMSS",
- setMSS: ip.maxMSS + 1,
- expMSS: ip.maxMSS,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := context.New(t, mtu)
- defer c.Cleanup()
-
- ip.createEP(c)
-
- // Set the MSS socket option.
- if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil {
- t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err)
- }
-
- // Get expected window size.
- rcvBufSize := c.EP.SocketOptions().GetReceiveBufferSize()
- ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
-
- connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort}
- {
- err := c.EP.Connect(connectAddr)
- if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
- t.Fatalf("Connect(%+v) mismatch (-want +got):\n%s", connectAddr, d)
- }
- }
-
- // Receive SYN packet with our user supplied MSS.
- ip.checker(t, c, test.expMSS, ws)
- })
- }
- })
- }
-}
-
-// TestUserSuppliedMSSOnListenAccept tests that the user supplied MSS is used
-// when completing the handshake for a new TCP connection from a TCP
-// listening socket. It should be present in the sent TCP SYN-ACK segment.
-func TestUserSuppliedMSSOnListenAccept(t *testing.T) {
- const mtu = 5000
-
- ips := []struct {
- name string
- createEP func(*context.Context)
- sendPkt func(*context.Context, *context.Headers)
- checker func(*testing.T, *context.Context, uint16, uint16)
- maxMSS uint16
- }{
- {
- name: "IPv4",
- createEP: func(c *context.Context) {
- c.Create(-1)
- },
- sendPkt: func(c *context.Context, h *context.Headers) {
- c.SendPacket(nil, h)
- },
- checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) {
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.DstPort(srcPort),
- checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1})))
- },
- maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize,
- },
- {
- name: "IPv6",
- createEP: func(c *context.Context) {
- c.CreateV6Endpoint(false)
- },
- sendPkt: func(c *context.Context, h *context.Headers) {
- c.SendV6Packet(nil, h)
- },
- checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) {
- checker.IPv6(t, c.GetV6Packet(), checker.TCP(
- checker.DstPort(srcPort),
- checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1})))
- },
- maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize,
- },
- }
-
- for _, ip := range ips {
- t.Run(ip.name, func(t *testing.T) {
- tests := []struct {
- name string
- setMSS uint16
- expMSS uint16
- }{
- {
- name: "EqualToMaxMSS",
- setMSS: ip.maxMSS,
- expMSS: ip.maxMSS,
- },
- {
- name: "LessThanMaxMSS",
- setMSS: ip.maxMSS - 1,
- expMSS: ip.maxMSS - 1,
- },
- {
- name: "GreaterThanMaxMSS",
- setMSS: ip.maxMSS + 1,
- expMSS: ip.maxMSS,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := context.New(t, mtu)
- defer c.Cleanup()
-
- ip.createEP(c)
-
- if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil {
- t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err)
- }
-
- bindAddr := tcpip.FullAddress{Port: context.StackPort}
- if err := c.EP.Bind(bindAddr); err != nil {
- t.Fatalf("Bind(%+v): %s:", bindAddr, err)
- }
-
- backlog := 5
- // Keep the number of client requests twice to the backlog
- // such that half of the connections do not use syncookies
- // and the other half does.
- clientConnects := backlog * 2
-
- if err := c.EP.Listen(backlog); err != nil {
- t.Fatalf("Listen(%d): %s:", backlog, err)
- }
-
- for i := 0; i < clientConnects; i++ {
- // Send a SYN requests.
- iss := seqnum.Value(i)
- srcPort := context.TestPort + uint16(i)
- ip.sendPkt(c, &context.Headers{
- SrcPort: srcPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- })
-
- // Receive the SYN-ACK reply.
- ip.checker(t, c, srcPort, test.expMSS)
- }
- })
- }
- })
- }
-}
-func TestSendRstOnListenerRxSynAckV4(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.Create(-1)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(10); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- SeqNum: 100,
- AckNum: 200,
- })
-
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst),
- checker.TCPSeqNum(200)))
-}
-
-func TestSendRstOnListenerRxSynAckV6(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(true)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(10); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- c.SendV6Packet(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- SeqNum: 100,
- AckNum: 200,
- })
-
- checker.IPv6(t, c.GetV6Packet(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst),
- checker.TCPSeqNum(200)))
-}
-
-// TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete,
-// peers can send data and expect a response within a reasonable ammount of time
-// without calling Accept on the listening endpoint first.
-//
-// This test uses IPv4.
-func TestTCPAckBeforeAcceptV4(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.Create(-1)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(10); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
-
- // Send data before accepting the connection.
- c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- })
-
- // Receive ACK for the data we sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPSeqNum(uint32(iss+1)),
- checker.TCPAckNum(uint32(irs+5))))
-}
-
-// TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete,
-// peers can send data and expect a response within a reasonable ammount of time
-// without calling Accept on the listening endpoint first.
-//
-// This test uses IPv6.
-func TestTCPAckBeforeAcceptV6(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(true)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(10); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- irs, iss := executeV6Handshake(t, c, context.TestPort, false /* synCookiesInUse */)
-
- // Send data before accepting the connection.
- c.SendV6Packet([]byte{1, 2, 3, 4}, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- })
-
- // Receive ACK for the data we sent.
- checker.IPv6(t, c.GetV6Packet(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPSeqNum(uint32(iss+1)),
- checker.TCPAckNum(uint32(irs+5))))
-}
-
-func TestSendRstOnListenerRxAckV4(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.Create(-1 /* epRcvBuf */)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(10 /* backlog */); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagFin | header.TCPFlagAck,
- SeqNum: 100,
- AckNum: 200,
- })
-
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst),
- checker.TCPSeqNum(200)))
-}
-
-func TestSendRstOnListenerRxAckV6(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(true /* v6Only */)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(10 /* backlog */); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- c.SendV6Packet(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagFin | header.TCPFlagAck,
- SeqNum: 100,
- AckNum: 200,
- })
-
- checker.IPv6(t, c.GetV6Packet(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst),
- checker.TCPSeqNum(200)))
-}
-
-// TestListenShutdown tests for the listening endpoint replying with RST
-// on read shutdown.
-func TestListenShutdown(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.Create(-1 /* epRcvBuf */)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(1 /* backlog */); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
- t.Fatal("Shutdown failed:", err)
- }
-
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: 100,
- AckNum: 200,
- })
-
- // Expect the listening endpoint to reset the connection.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
- ))
-}
-
-var _ waiter.EntryCallback = (callback)(nil)
-
-type callback func(*waiter.Entry, waiter.EventMask)
-
-func (cb callback) Callback(entry *waiter.Entry, mask waiter.EventMask) {
- cb(entry, mask)
-}
-
-func TestListenerReadinessOnEvent(t *testing.T) {
- s := stack.New(stack.Options{
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- })
- {
- ep := loopback.New()
- if testing.Verbose() {
- ep = sniffer.New(ep)
- }
- const id = 1
- if err := s.CreateNIC(id, ep); err != nil {
- t.Fatalf("CreateNIC(%d, %T): %s", id, ep, err)
- }
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.Address(context.StackAddr).WithPrefix(),
- }
- if err := s.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", id, protocolAddr, err)
- }
- s.SetRouteTable([]tcpip.Route{
- {Destination: header.IPv4EmptySubnet, NIC: id},
- })
- }
-
- var wq waiter.Queue
- ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err)
- }
- defer ep.Close()
-
- if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr}); err != nil {
- t.Fatalf("Bind(%s): %s", context.StackAddr, err)
- }
- const backlog = 1
- if err := ep.Listen(backlog); err != nil {
- t.Fatalf("Listen(%d): %s", backlog, err)
- }
-
- address, err := ep.GetLocalAddress()
- if err != nil {
- t.Fatalf("GetLocalAddress(): %s", err)
- }
-
- conn, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err)
- }
- defer conn.Close()
-
- events := make(chan waiter.EventMask)
- // Scope `entry` to allow a binding of the same name below.
- {
- entry := waiter.Entry{Callback: callback(func(_ *waiter.Entry, mask waiter.EventMask) {
- events <- ep.Readiness(mask)
- })}
- wq.EventRegister(&entry, waiter.EventIn)
- defer wq.EventUnregister(&entry)
- }
-
- entry, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&entry, waiter.EventOut)
- defer wq.EventUnregister(&entry)
-
- switch err := conn.Connect(address).(type) {
- case *tcpip.ErrConnectStarted:
- default:
- t.Fatalf("Connect(%#v): %v", address, err)
- }
-
- // Read at least one event.
- got := <-events
- for {
- select {
- case event := <-events:
- got |= event
- continue
- case <-ch:
- if want := waiter.ReadableEvents; got != want {
- t.Errorf("observed events = %b, want %b", got, want)
- }
- }
- break
- }
-}
-
-// TestListenCloseWhileConnect tests for the listening endpoint to
-// drain the accept-queue when closed. This should reset all of the
-// pending connections that are waiting to be accepted.
-func TestListenCloseWhileConnect(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.Create(-1 /* epRcvBuf */)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(1 /* backlog */); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&waitEntry, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&waitEntry)
-
- executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
- // Wait for the new endpoint created because of handshake to be delivered
- // to the listening endpoint's accept queue.
- <-notifyCh
-
- // Close the listening endpoint.
- c.EP.Close()
-
- // Expect the listening endpoint to reset the connection.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
- ))
-}
-
-func TestTOSV4(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- c.EP = ep
-
- const tos = 0xC0
- if err := c.EP.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil {
- t.Errorf("SetSockOptInt(IPv4TOSOption, %d) failed: %s", tos, err)
- }
-
- v, err := c.EP.GetSockOptInt(tcpip.IPv4TOSOption)
- if err != nil {
- t.Errorf("GetSockoptInt(IPv4TOSOption) failed: %s", err)
- }
-
- if v != tos {
- t.Errorf("got GetSockOptInt(IPv4TOSOption) = %d, want = %d", v, tos)
- }
-
- testV4Connect(t, c, checker.TOS(tos, 0))
-
- data := []byte{1, 2, 3}
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Check that data is received.
- b := c.GetPacket()
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)), // Acknum is initial sequence number + 1
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- checker.TOS(tos, 0),
- )
-
- if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
- t.Errorf("got data = %x, want = %x", p, data)
- }
-}
-
-func TestTrafficClassV6(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- const tos = 0xC0
- if err := c.EP.SetSockOptInt(tcpip.IPv6TrafficClassOption, tos); err != nil {
- t.Errorf("SetSockOpInt(IPv6TrafficClassOption, %d) failed: %s", tos, err)
- }
-
- v, err := c.EP.GetSockOptInt(tcpip.IPv6TrafficClassOption)
- if err != nil {
- t.Fatalf("GetSockoptInt(IPv6TrafficClassOption) failed: %s", err)
- }
-
- if v != tos {
- t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = %d, want = %d", v, tos)
- }
-
- // Test the connection request.
- testV6Connect(t, c, checker.TOS(tos, 0))
-
- data := []byte{1, 2, 3}
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Check that data is received.
- b := c.GetV6Packet()
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv6(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- checker.TOS(tos, 0),
- )
-
- if p := b[header.IPv6MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
- t.Errorf("got data = %x, want = %x", p, data)
- }
-}
-
-func TestConnectBindToDevice(t *testing.T) {
- for _, test := range []struct {
- name string
- device tcpip.NICID
- want tcp.EndpointState
- }{
- {"RightDevice", 1, tcp.StateEstablished},
- {"WrongDevice", 2, tcp.StateSynSent},
- {"AnyDevice", 0, tcp.StateEstablished},
- } {
- t.Run(test.name, func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.Create(-1)
- if err := c.EP.SocketOptions().SetBindToDevice(int32(test.device)); err != nil {
- t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", test.device, test.device, err)
- }
- // Start connection attempt.
- waitEntry, _ := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&waitEntry, waiter.WritableEvents)
- defer c.WQ.EventUnregister(&waitEntry)
-
- err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
- t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
- }
-
- // Receive SYN packet.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ),
- )
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
- t.Fatalf("unexpected endpoint state: want %s, got %s", want, got)
- }
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- rcvWnd := seqnum.Size(30000)
- c.SendPacket(nil, &context.Headers{
- SrcPort: tcpHdr.DestinationPort(),
- DstPort: tcpHdr.SourcePort(),
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: rcvWnd,
- TCPOpts: nil,
- })
-
- c.GetPacket()
- if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want {
- t.Fatalf("unexpected endpoint state: want %s, got %s", want, got)
- }
- })
- }
-}
-
-func TestShutdownConnectingSocket(t *testing.T) {
- for _, test := range []struct {
- name string
- shutdownMode tcpip.ShutdownFlags
- }{
- {"ShutdownRead", tcpip.ShutdownRead},
- {"ShutdownWrite", tcpip.ShutdownWrite},
- {"ShutdownReadWrite", tcpip.ShutdownRead | tcpip.ShutdownWrite},
- } {
- t.Run(test.name, func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create an endpoint, don't handshake because we want to interfere with
- // the handshake process.
- c.Create(-1)
-
- waitEntry, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
- defer c.WQ.EventUnregister(&waitEntry)
-
- // Start connection attempt.
- addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
- if d := cmp.Diff(&tcpip.ErrConnectStarted{}, c.EP.Connect(addr)); d != "" {
- t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d)
- }
-
- // Check the SYN packet.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ),
- )
-
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
- t.Fatalf("got State() = %s, want %s", got, want)
- }
-
- if err := c.EP.Shutdown(test.shutdownMode); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- // The endpoint internal state is updated immediately.
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
- t.Fatalf("got State() = %s, want %s", got, want)
- }
-
- select {
- case <-ch:
- default:
- t.Fatal("endpoint was not notified")
- }
-
- ept := endpointTester{c.EP}
- ept.CheckReadError(t, &tcpip.ErrConnectionReset{})
-
- // If the endpoint is not properly shutdown, it'll re-attempt to connect
- // by sending another ACK packet.
- c.CheckNoPacketTimeout("got an unexpected packet", tcp.InitialRTO+(500*time.Millisecond))
- })
- }
-}
-
-func TestSynSent(t *testing.T) {
- for _, test := range []struct {
- name string
- reset bool
- }{
- {"RstOnSynSent", true},
- {"CloseOnSynSent", false},
- } {
- t.Run(test.name, func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create an endpoint, don't handshake because we want to interfere with the
- // handshake process.
- c.Create(-1)
-
- // Start connection attempt.
- waitEntry, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
- defer c.WQ.EventUnregister(&waitEntry)
-
- addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
- err := c.EP.Connect(addr)
- if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
- t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d)
- }
-
- // Receive SYN packet.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ),
- )
-
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
- t.Fatalf("got State() = %s, want %s", got, want)
- }
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- if test.reset {
- // Send a packet with a proper ACK and a RST flag to cause the socket
- // to error and close out.
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- rcvWnd := seqnum.Size(30000)
- c.SendPacket(nil, &context.Headers{
- SrcPort: tcpHdr.DestinationPort(),
- DstPort: tcpHdr.SourcePort(),
- Flags: header.TCPFlagRst | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: rcvWnd,
- TCPOpts: nil,
- })
- } else {
- c.EP.Close()
- }
-
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(3 * time.Second):
- t.Fatal("timed out waiting for packet to arrive")
- }
-
- ept := endpointTester{c.EP}
- if test.reset {
- ept.CheckReadError(t, &tcpip.ErrConnectionRefused{})
- } else {
- ept.CheckReadError(t, &tcpip.ErrAborted{})
- }
-
- if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
- }
-
- // Due to the RST the endpoint should be in an error state.
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
- t.Fatalf("got State() = %s, want %s", got, want)
- }
- })
- }
-}
-
-func TestOutOfOrderReceive(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- ept := endpointTester{c.EP}
- ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
-
- // Send second half of data first, with seqnum 3 ahead of expected.
- data := []byte{1, 2, 3, 4, 5, 6}
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendPacket(data[3:], &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss.Add(3),
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Check that we get an ACK specifying which seqnum is expected.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-
- // Wait 200ms and check that no data has been received.
- time.Sleep(200 * time.Millisecond)
- ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
-
- // Send the first 3 bytes now.
- c.SendPacket(data[:3], &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Receive data.
- read := ept.CheckReadFull(t, 6, ch, 5*time.Second)
-
- // Check that we received the data in proper order.
- if !bytes.Equal(data, read) {
- t.Fatalf("got data = %v, want = %v", read, data)
- }
-
- // Check that the whole data is acknowledged.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+uint32(len(data))),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestOutOfOrderFlood(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- rcvBufSz := math.MaxUint16
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBufSz)
-
- ept := endpointTester{c.EP}
- ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
-
- // Send 100 packets before the actual one that is expected.
- data := []byte{1, 2, 3, 4, 5, 6}
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- for i := 0; i < 100; i++ {
- c.SendPacket(data[3:], &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss.Add(6),
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- }
-
- // Send packet with seqnum as initial + 3. It must be discarded because the
- // out-of-order buffer was filled by the previous packets.
- c.SendPacket(data[3:], &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss.Add(3),
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-
- // Now send the expected packet with initial sequence number.
- c.SendPacket(data[:3], &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Check that only packet with initial sequence number is acknowledged.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+3),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestRstOnCloseWithUnreadData(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- ept := endpointTester{c.EP}
- ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
-
- data := []byte{1, 2, 3}
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(3 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
- // Check that ACK is received, this happens regardless of the read.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+uint32(len(data))),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-
- // Now that we know we have unread data, let's just close the connection
- // and verify that netstack sends an RST rather than a FIN.
- c.EP.Close()
-
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
- // We shouldn't consume a sequence number on RST.
- checker.TCPSeqNum(uint32(c.IRS)+1),
- ))
- // The RST puts the endpoint into an error state.
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
-
- // This final ACK should be ignored because an ACK on a reset doesn't mean
- // anything.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss.Add(seqnum.Size(len(data))),
- AckNum: c.IRS.Add(seqnum.Size(2)),
- RcvWnd: 30000,
- })
-}
-
-func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- ept := endpointTester{c.EP}
- ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
-
- data := []byte{1, 2, 3}
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(3 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
- // Check that ACK is received, this happens regardless of the read.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+uint32(len(data))),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-
- // Cause a FIN to be generated.
- if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- // Make sure we get the FIN but DON't ACK IT.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- ))
-
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
-
- // Cause a RST to be generated by closing the read end now since we have
- // unread data.
- if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- // Make sure we get the RST
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
- // RST is always generated with sndNxt which if the FIN
- // has been sent will be 1 higher than the sequence
- // number of the FIN itself.
- checker.TCPSeqNum(uint32(c.IRS)+2),
- ))
- // The RST puts the endpoint into an error state.
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
-
- // The ACK to the FIN should now be rejected since the connection has been
- // closed by a RST.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss.Add(seqnum.Size(len(data))),
- AckNum: c.IRS.Add(seqnum.Size(2)),
- RcvWnd: 30000,
- })
-}
-
-func TestShutdownRead(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- ept := endpointTester{c.EP}
- ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
-
- if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- ept.CheckReadError(t, &tcpip.ErrClosedForReceive{})
- var want uint64 = 1
- if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want {
- t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want)
- }
-}
-
-func TestFullWindowReceive(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- const rcvBufSz = 10
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBufSz)
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- ept := endpointTester{c.EP}
- ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
-
- // Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies
- // the provided buffer value by tcp.SegOverheadFactor to calculate the actual
- // receive buffer size.
- data := make([]byte, tcp.SegOverheadFactor*rcvBufSz)
- for i := range data {
- data[i] = byte(i % 255)
- }
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(5 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
- // Check that data is acknowledged, and window goes to zero.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+uint32(len(data))),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPWindow(0),
- ),
- )
-
- // Receive data and check it.
- v := ept.CheckRead(t)
- if !bytes.Equal(data, v) {
- t.Fatalf("got data = %v, want = %v", v, data)
- }
-
- var want uint64 = 1
- if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want {
- t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %d want %d", got, want)
- }
-
- // Check that we get an ACK for the newly non-zero window.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+uint32(len(data))),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPWindow(10),
- ),
- )
-}
-
-func TestSmallReceiveBufferReadiness(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
- })
-
- ep := loopback.New()
- if testing.Verbose() {
- ep = sniffer.New(ep)
- }
-
- const nicID = 1
- nicOpts := stack.NICOptions{Name: "nic1"}
- if err := s.CreateNICWithOptions(nicID, ep, nicOpts); err != nil {
- t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %s", nicOpts, err)
- }
-
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address("\x7f\x00\x00\x01"),
- PrefixLen: 8,
- },
- }
- if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}) failed: %s", nicID, protocolAddr, err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x7f\x00\x00\x00", "\xff\x00\x00\x00")
- if err != nil {
- t.Fatalf("tcpip.NewSubnet failed: %s", err)
- }
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: subnet,
- NIC: nicID,
- },
- })
- }
-
- listenerEntry, listenerCh := waiter.NewChannelEntry(nil)
- var listenerWQ waiter.Queue
- listener, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- defer listener.Close()
- listenerWQ.EventRegister(&listenerEntry, waiter.ReadableEvents)
- defer listenerWQ.EventUnregister(&listenerEntry)
-
- if err := listener.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
- if err := listener.Listen(1); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- localAddress, err := listener.GetLocalAddress()
- if err != nil {
- t.Fatalf("GetLocalAddress failed: %s", err)
- }
-
- for i := 8; i > 0; i /= 2 {
- size := int64(i << 10)
- t.Run(fmt.Sprintf("size=%d", size), func(t *testing.T) {
- var clientWQ waiter.Queue
- client, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- defer client.Close()
- switch err := client.Connect(localAddress).(type) {
- case nil:
- t.Fatal("Connect returned nil error")
- case *tcpip.ErrConnectStarted:
- default:
- t.Fatalf("Connect failed: %s", err)
- }
-
- <-listenerCh
- server, serverWQ, err := listener.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
- defer server.Close()
-
- client.SocketOptions().SetReceiveBufferSize(size, true)
- // Send buffer size doesn't seem to affect this test.
- // server.SocketOptions().SetSendBufferSize(size, true)
-
- clientEntry, clientCh := waiter.NewChannelEntry(nil)
- clientWQ.EventRegister(&clientEntry, waiter.ReadableEvents)
- defer clientWQ.EventUnregister(&clientEntry)
-
- serverEntry, serverCh := waiter.NewChannelEntry(nil)
- serverWQ.EventRegister(&serverEntry, waiter.WritableEvents)
- defer serverWQ.EventUnregister(&serverEntry)
-
- var total int64
- for {
- var b [64 << 10]byte
- var r bytes.Reader
- r.Reset(b[:])
- switch n, err := server.Write(&r, tcpip.WriteOptions{}); err.(type) {
- case nil:
- t.Logf("wrote %d bytes", n)
- total += n
- continue
- case *tcpip.ErrWouldBlock:
- select {
- case <-serverCh:
- continue
- case <-time.After(100 * time.Millisecond):
- // Well and truly full.
- t.Logf("send and receive queues are full")
- }
- default:
- t.Fatalf("Write failed: %s", err)
- }
- break
- }
- t.Logf("wrote %d bytes in total", total)
-
- var wg sync.WaitGroup
- defer wg.Wait()
-
- wg.Add(2)
- go func() {
- defer wg.Done()
-
- var b [64 << 10]byte
- var r bytes.Reader
- r.Reset(b[:])
- if err := func() error {
- var total int64
- defer t.Logf("wrote %d bytes in total", total)
- for r.Len() != 0 {
- switch n, err := server.Write(&r, tcpip.WriteOptions{}); err.(type) {
- case nil:
- t.Logf("wrote %d bytes", n)
- total += n
- case *tcpip.ErrWouldBlock:
- for {
- t.Logf("waiting on server")
- select {
- case <-serverCh:
- case <-time.After(time.Second):
- if readiness := server.Readiness(waiter.WritableEvents); readiness != 0 {
- t.Logf("server.Readiness(%b) = %b but channel not signaled", waiter.WritableEvents, readiness)
- }
- continue
- }
- break
- }
- default:
- return fmt.Errorf("server.Write failed: %s", err)
- }
- }
- if err := server.Shutdown(tcpip.ShutdownWrite); err != nil {
- return fmt.Errorf("server.Shutdown failed: %s", err)
- }
- t.Logf("server end shutdown done")
- return nil
- }(); err != nil {
- t.Error(err)
- }
- }()
-
- go func() {
- defer wg.Done()
-
- if err := func() error {
- total := 0
- defer t.Logf("read %d bytes in total", total)
- for {
- switch res, err := client.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) {
- case nil:
- t.Logf("read %d bytes", res.Count)
- total += res.Count
- t.Logf("read total %d bytes till now", total)
- case *tcpip.ErrClosedForReceive:
- return nil
- case *tcpip.ErrWouldBlock:
- for {
- t.Logf("waiting on client")
- select {
- case <-clientCh:
- case <-time.After(time.Second):
- if readiness := client.Readiness(waiter.ReadableEvents); readiness != 0 {
- return fmt.Errorf("client.Readiness(%b) = %b but channel not signaled", waiter.ReadableEvents, readiness)
- }
- continue
- }
- break
- }
- default:
- return fmt.Errorf("client.Write failed: %s", err)
- }
- }
- }(); err != nil {
- t.Error(err)
- }
- }()
- })
- }
-}
-
-// Test the stack receive window advertisement on receiving segments smaller than
-// segment overhead. It tests for the right edge of the window to not grow when
-// the endpoint is not being read from.
-func TestSmallSegReceiveWindowAdvertisement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- opt := tcpip.TCPReceiveBufferSizeRangeOption{
- Min: 1,
- Default: tcp.DefaultReceiveBufferSize,
- Max: tcp.DefaultReceiveBufferSize << tcp.FindWndScale(seqnum.Size(tcp.DefaultReceiveBufferSize)),
- }
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
- }
-
- c.AcceptWithOptionsNoDelay(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS})
-
- // Bump up the receive buffer size such that, when the receive window grows,
- // the scaled window exceeds maxUint16.
- c.EP.SocketOptions().SetReceiveBufferSize(int64(opt.Max)*2, true /* notify */)
-
- // Keep the payload size < segment overhead and such that it is a multiple
- // of the window scaled value. This enables the test to perform equality
- // checks on the incoming receive window.
- payloadSize := 1 << c.RcvdWindowScale
- if payloadSize >= tcp.SegSize {
- t.Fatalf("payload size of %d is not less than the segment overhead of %d", payloadSize, tcp.SegSize)
- }
- payload := generateRandomPayload(t, payloadSize)
- payloadLen := seqnum.Size(len(payload))
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
-
- // Send payload to the endpoint and return the advertised receive window
- // from the endpoint.
- getIncomingRcvWnd := func() uint32 {
- c.SendPacket(payload, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- Flags: header.TCPFlagAck,
- RcvWnd: 30000,
- })
- iss = iss.Add(payloadLen)
-
- pkt := c.GetPacket()
- return uint32(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.RcvdWindowScale
- }
-
- // Read the advertised receive window with the ACK for payload.
- rcvWnd := getIncomingRcvWnd()
-
- // Check if the subsequent ACK to our send has not grown the right edge of
- // the window.
- if got, want := getIncomingRcvWnd(), rcvWnd-uint32(len(payload)); got != want {
- t.Fatalf("got incomingRcvwnd %d want %d", got, want)
- }
-
- // Read the data so that the subsequent ACK from the endpoint
- // grows the right edge of the window.
- var buf bytes.Buffer
- if _, err := c.EP.Read(&buf, tcpip.ReadOptions{}); err != nil {
- t.Fatalf("c.EP.Read: %s", err)
- }
-
- // Check if we have received max uint16 as our advertised
- // scaled window now after a read above.
- maxRcv := uint32(math.MaxUint16 << c.RcvdWindowScale)
- if got, want := getIncomingRcvWnd(), maxRcv; got != want {
- t.Fatalf("got incomingRcvwnd %d want %d", got, want)
- }
-
- // Check if the subsequent ACK to our send has not grown the right edge of
- // the window.
- if got, want := getIncomingRcvWnd(), maxRcv-uint32(len(payload)); got != want {
- t.Fatalf("got incomingRcvwnd %d want %d", got, want)
- }
-}
-
-func TestNoWindowShrinking(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Start off with a certain receive buffer then cut it in half and verify that
- // the right edge of the window does not shrink.
- // NOTE: Netstack doubles the value specified here.
- rcvBufSize := 65536
- // Enable window scaling with a scale of zero from our end.
- c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, rcvBufSize, []byte{
- header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
- })
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- ept := endpointTester{c.EP}
- ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
-
- // Send a 1 byte payload so that we can record the current receive window.
- // Send a payload of half the size of rcvBufSize.
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- payload := []byte{1}
- c.SendPacket(payload, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(5 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
- // Read the 1 byte payload we just sent.
- if got, want := payload, ept.CheckRead(t); !bytes.Equal(got, want) {
- t.Fatalf("got data: %v, want: %v", got, want)
- }
-
- // Verify that the ACK does not shrink the window.
- pkt := c.GetPacket()
- iss = iss.Add(1)
- checker.IPv4(t, pkt,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- // Stash the initial window.
- initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale
- initialLastAcceptableSeq := iss.Add(seqnum.Size(initialWnd))
- // Now shrink the receive buffer to half its original size.
- c.EP.SocketOptions().SetReceiveBufferSize(int64(rcvBufSize), true /* notify */)
-
- data := generateRandomPayload(t, rcvBufSize)
- // Send a payload of half the size of rcvBufSize.
- c.SendPacket(data[:rcvBufSize/2], &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
- iss = iss.Add(seqnum.Size(rcvBufSize / 2))
-
- // Verify that the ACK does not shrink the window.
- pkt = c.GetPacket()
- checker.IPv4(t, pkt,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- newWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale
- newLastAcceptableSeq := iss.Add(seqnum.Size(newWnd))
- if newLastAcceptableSeq.LessThan(initialLastAcceptableSeq) {
- t.Fatalf("receive window shrunk unexpectedly got: %d, want >= %d", newLastAcceptableSeq, initialLastAcceptableSeq)
- }
-
- // Send another payload of half the size of rcvBufSize. This should fill up the
- // socket receive buffer and we should see a zero window.
- c.SendPacket(data[rcvBufSize/2:], &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
- iss = iss.Add(seqnum.Size(rcvBufSize / 2))
-
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPWindow(0),
- ),
- )
-
- // Receive data and check it.
- read := ept.CheckReadFull(t, len(data), ch, 5*time.Second)
- if !bytes.Equal(data, read) {
- t.Fatalf("got data = %v, want = %v", read, data)
- }
-
- // Check that we get an ACK for the newly non-zero window, which is the new
- // receive buffer size we set after the connection was established.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPWindow(uint16(rcvBufSize/2)>>c.RcvdWindowScale),
- ),
- )
-}
-
-func TestSimpleSend(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- data := []byte{1, 2, 3}
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Check that data is received.
- b := c.GetPacket()
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
- t.Fatalf("got data = %v, want = %v", p, data)
- }
-
- // Acknowledge the data.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1 + seqnum.Size(len(data))),
- RcvWnd: 30000,
- })
-}
-
-func TestZeroWindowSend(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 0 /* rcvWnd */, -1 /* epRcvBuf */)
-
- data := []byte{1, 2, 3}
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Check if we got a zero-window probe.
- b := c.GetPacket()
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, b,
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- // Open up the window. Data should be received now.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Check that data is received.
- b = c.GetPacket()
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
- t.Fatalf("got data = %v, want = %v", p, data)
- }
-
- // Acknowledge the data.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1 + seqnum.Size(len(data))),
- RcvWnd: 30000,
- })
-}
-
-func TestScaledWindowConnect(t *testing.T) {
- // This test ensures that window scaling is used when the peer
- // does advertise it and connection is established with Connect().
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set the window size greater than the maximum non-scaled window.
- c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, 65535*3, []byte{
- header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
- })
-
- data := []byte{1, 2, 3}
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Check that data is received, and that advertised window is 0x5fff,
- // that is, that it is scaled.
- b := c.GetPacket()
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPWindow(0x5fff),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-}
-
-func TestNonScaledWindowConnect(t *testing.T) {
- // This test ensures that window scaling is not used when the peer
- // doesn't advertise it and connection is established with Connect().
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set the window size greater than the maximum non-scaled window.
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, 65535*3)
-
- data := []byte{1, 2, 3}
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Check that data is received, and that advertised window is 0xffff,
- // that is, that it's not scaled.
- b := c.GetPacket()
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPWindow(0xffff),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-}
-
-func TestScaledWindowAccept(t *testing.T) {
- // This test ensures that window scaling is used when the peer
- // does advertise it and connection is established with Accept().
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create EP and start listening.
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- defer ep.Close()
-
- // Set the window size greater than the maximum non-scaled window.
- ep.SocketOptions().SetReceiveBufferSize(65535*6, true /* notify */)
-
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Do 3-way handshake.
- // wndScale expected is 3 as 65535 * 3 * 2 < 65535 * 2^3 but > 65535 *2 *2
- c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS}, 0 /* delay */)
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- data := []byte{1, 2, 3}
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Check that data is received, and that advertised window is 0x5fff,
- // that is, that it is scaled.
- b := c.GetPacket()
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPWindow(0x5fff),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-}
-
-func TestNonScaledWindowAccept(t *testing.T) {
- // This test ensures that window scaling is not used when the peer
- // doesn't advertise it and connection is established with Accept().
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create EP and start listening.
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- defer ep.Close()
-
- // Set the window size greater than the maximum non-scaled window.
- ep.SocketOptions().SetReceiveBufferSize(65535*6, true /* notify */)
-
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Do 3-way handshake w/ window scaling disabled. The SYN-ACK to the SYN
- // should not carry the window scaling option.
- c.PassiveConnect(100, -1, header.TCPSynOptions{MSS: defaultIPv4MSS})
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- data := []byte{1, 2, 3}
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Check that data is received, and that advertised window is 0xffff,
- // that is, that it's not scaled.
- b := c.GetPacket()
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPWindow(0xffff),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-}
-
-func TestZeroScaledWindowReceive(t *testing.T) {
- // This test ensures that the endpoint sends a non-zero window size
- // advertisement when the scaled window transitions from 0 to non-zero,
- // but the actual window (not scaled) hasn't gotten to zero.
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set the buffer size such that a window scale of 5 will be used.
- const bufSz = 65535 * 10
- const ws = uint32(5)
- c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, bufSz, []byte{
- header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
- })
-
- // Write chunks of 50000 bytes.
- remain := 0
- sent := 0
- data := make([]byte, 50000)
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- // Keep writing till the window drops below len(data).
- for {
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss.Add(seqnum.Size(sent)),
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
- sent += len(data)
- pkt := c.GetPacket()
- checker.IPv4(t, pkt,
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+uint32(sent)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- // Don't reduce window to zero here.
- if wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()); wnd<<ws < len(data) {
- remain = wnd << ws
- break
- }
- }
-
- // Make the window non-zero, but the scaled window zero.
- for remain >= 16 {
- data = data[:remain-15]
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss.Add(seqnum.Size(sent)),
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
- sent += len(data)
- pkt := c.GetPacket()
- checker.IPv4(t, pkt,
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+uint32(sent)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- // Since the receive buffer is split between window advertisement and
- // application data buffer the window does not always reflect the space
- // available and actual space available can be a bit more than what is
- // advertised in the window.
- wnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize())
- if wnd == 0 {
- break
- }
- remain = wnd << ws
- }
-
- // Read at least 2MSS of data. An ack should be sent in response to that.
- // Since buffer space is now split in half between window and application
- // data we need to read more than 1 MSS(65536) of data for a non-zero window
- // update to be sent. For 1MSS worth of window to be available we need to
- // read at least 128KB. Since our segments above were 50KB each it means
- // we need to read at 3 packets.
- w := tcpip.LimitedWriter{
- W: ioutil.Discard,
- N: defaultMTU * 2,
- }
- for w.N != 0 {
- res, err := c.EP.Read(&w, tcpip.ReadOptions{})
- t.Logf("err=%v res=%#v", err, res)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
- }
- }
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+uint32(sent)),
- checker.TCPWindowGreaterThanEq(uint16(defaultMTU>>ws)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestSegmentMerging(t *testing.T) {
- tests := []struct {
- name string
- stop func(tcpip.Endpoint)
- resume func(tcpip.Endpoint)
- }{
- {
- "stop work",
- func(ep tcpip.Endpoint) {
- ep.(interface{ StopWork() }).StopWork()
- },
- func(ep tcpip.Endpoint) {
- ep.(interface{ ResumeWork() }).ResumeWork()
- },
- },
- {
- "cork",
- func(ep tcpip.Endpoint) {
- ep.SocketOptions().SetCorkOption(true)
- },
- func(ep tcpip.Endpoint) {
- ep.SocketOptions().SetCorkOption(false)
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- // Send tcp.InitialCwnd number of segments to fill up
- // InitialWindow but don't ACK. That should prevent
- // anymore packets from going out.
- var r bytes.Reader
- for i := 0; i < tcp.InitialCwnd; i++ {
- r.Reset([]byte{0})
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %s", i+1, err)
- }
- }
-
- // Now send the segments that should get merged as the congestion
- // window is full and we won't be able to send any more packets.
- var allData []byte
- for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
- allData = append(allData, data...)
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %s", i+1, err)
- }
- }
-
- // Check that we get tcp.InitialCwnd packets.
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- for i := 0; i < tcp.InitialCwnd; i++ {
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.PayloadLen(header.TCPMinimumSize+1),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
- }
-
- // Acknowledge the data.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1 + 10), // 10 for the 10 bytes of payload.
- RcvWnd: 30000,
- })
-
- // Check that data is received.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.PayloadLen(len(allData)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+11),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, allData) {
- t.Fatalf("got data = %v, want = %v", got, allData)
- }
-
- // Acknowledge the data.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(11 + seqnum.Size(len(allData))),
- RcvWnd: 30000,
- })
- })
- }
-}
-
-func TestDelay(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- c.EP.SocketOptions().SetDelayOption(true)
-
- var allData []byte
- for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
- allData = append(allData, data...)
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %s", i+1, err)
- }
- }
-
- seq := c.IRS.Add(1)
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- for _, want := range [][]byte{allData[:1], allData[1:]} {
- // Check that data is received.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.PayloadLen(len(want)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(seq)),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, want) {
- t.Fatalf("got data = %v, want = %v", got, want)
- }
-
- seq = seq.Add(seqnum.Size(len(want)))
- // Acknowledge the data.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: seq,
- RcvWnd: 30000,
- })
- }
-}
-
-func TestUndelay(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- c.EP.SocketOptions().SetDelayOption(true)
-
- allData := [][]byte{{0}, {1, 2, 3}}
- for i, data := range allData {
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %s", i+1, err)
- }
- }
-
- seq := c.IRS.Add(1)
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- // Check that data is received.
- first := c.GetPacket()
- checker.IPv4(t, first,
- checker.PayloadLen(len(allData[0])+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(seq)),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- if got, want := first[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[0]; !bytes.Equal(got, want) {
- t.Fatalf("got first packet's data = %v, want = %v", got, want)
- }
-
- seq = seq.Add(seqnum.Size(len(allData[0])))
-
- // Check that we don't get the second packet yet.
- c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond)
-
- c.EP.SocketOptions().SetDelayOption(false)
-
- // Check that data is received.
- second := c.GetPacket()
- checker.IPv4(t, second,
- checker.PayloadLen(len(allData[1])+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(seq)),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- if got, want := second[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[1]; !bytes.Equal(got, want) {
- t.Fatalf("got second packet's data = %v, want = %v", got, want)
- }
-
- seq = seq.Add(seqnum.Size(len(allData[1])))
-
- // Acknowledge the data.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: seq,
- RcvWnd: 30000,
- })
-}
-
-func TestMSSNotDelayed(t *testing.T) {
- tests := []struct {
- name string
- fn func(tcpip.Endpoint)
- }{
- {"no-op", func(tcpip.Endpoint) {}},
- {"delay", func(ep tcpip.Endpoint) { ep.SocketOptions().SetDelayOption(true) }},
- {"cork", func(ep tcpip.Endpoint) { ep.SocketOptions().SetCorkOption(true) }},
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- const maxPayload = 100
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{
- header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
- })
-
- test.fn(c.EP)
-
- allData := [][]byte{{0}, make([]byte, maxPayload), make([]byte, maxPayload)}
- for i, data := range allData {
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %s", i+1, err)
- }
- }
-
- seq := c.IRS.Add(1)
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- for i, data := range allData {
- // Check that data is received.
- packet := c.GetPacket()
- checker.IPv4(t, packet,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(seq)),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- if got, want := packet[header.IPv4MinimumSize+header.TCPMinimumSize:], data; !bytes.Equal(got, want) {
- t.Fatalf("got packet #%d's data = %v, want = %v", i+1, got, want)
- }
-
- seq = seq.Add(seqnum.Size(len(data)))
- }
-
- // Acknowledge the data.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: seq,
- RcvWnd: 30000,
- })
- })
- }
-}
-
-func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
- payloadMultiplier := 10
- dataLen := payloadMultiplier * maxPayload
- data := make([]byte, dataLen)
- for i := range data {
- data[i] = byte(i)
- }
-
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Check that data is received in chunks.
- bytesReceived := 0
- numPackets := 0
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- for bytesReceived != dataLen {
- b := c.GetPacket()
- numPackets++
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- payloadLen := len(tcpHdr.Payload())
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1+uint32(bytesReceived)),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- pdata := data[bytesReceived : bytesReceived+payloadLen]
- if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) {
- t.Fatalf("got data = %v, want = %v", p, pdata)
- }
- bytesReceived += payloadLen
- var options []byte
- if c.TimeStampEnabled {
- // If timestamp option is enabled, echo back the timestamp and increment
- // the TSEcr value included in the packet and send that back as the TSVal.
- parsedOpts := tcpHdr.ParsedOptions()
- tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
- header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:])
- options = tsOpt[:]
- }
- // Acknowledge the data.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)),
- RcvWnd: 30000,
- TCPOpts: options,
- })
- }
- if numPackets == 1 {
- t.Fatalf("expected write to be broken up into multiple packets, but got 1 packet")
- }
-}
-
-func TestSendGreaterThanMTU(t *testing.T) {
- const maxPayload = 100
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
- testBrokenUpWrite(t, c, maxPayload)
-}
-
-func TestSetTTL(t *testing.T) {
- for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} {
- t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) {
- c := context.New(t, 65535)
- defer c.Cleanup()
-
- var err tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- if err := c.EP.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil {
- t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err)
- }
-
- {
- err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
- t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
- }
- }
-
- // Receive SYN packet.
- b := c.GetPacket()
-
- checker.IPv4(t, b, checker.TTL(wantTTL))
- })
- }
-}
-
-func TestActiveSendMSSLessThanMTU(t *testing.T) {
- const maxPayload = 100
- c := context.New(t, 65535)
- defer c.Cleanup()
-
- c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{
- header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
- })
- testBrokenUpWrite(t, c, maxPayload)
-}
-
-func TestPassiveSendMSSLessThanMTU(t *testing.T) {
- const maxPayload = 100
- const mtu = 1200
- c := context.New(t, mtu)
- defer c.Cleanup()
-
- // Create EP and start listening.
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- defer ep.Close()
-
- // Set the buffer size to a deterministic size so that we can check the
- // window scaling option.
- const rcvBufferSize = 0x20000
- ep.SocketOptions().SetReceiveBufferSize(rcvBufferSize*2, true /* notify */)
-
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Do 3-way handshake.
- c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Check that data gets properly segmented.
- testBrokenUpWrite(t, c, maxPayload)
-}
-
-func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) {
- const maxPayload = 536
- const mtu = 2000
- c := context.New(t, mtu)
- defer c.Cleanup()
-
- opt := tcpip.TCPAlwaysUseSynCookies(true)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
-
- // Create EP and start listening.
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- defer ep.Close()
-
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Do 3-way handshake.
- c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Check that data gets properly segmented.
- testBrokenUpWrite(t, c, maxPayload)
-}
-
-func TestForwarderSendMSSLessThanMTU(t *testing.T) {
- const maxPayload = 100
- const mtu = 1200
- c := context.New(t, mtu)
- defer c.Cleanup()
-
- s := c.Stack()
- ch := make(chan tcpip.Error, 1)
- f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) {
- var err tcpip.Error
- c.EP, err = r.CreateEndpoint(&c.WQ)
- ch <- err
- })
- s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket)
-
- // Do 3-way handshake.
- c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
-
- // Wait for connection to be available.
- select {
- case err := <-ch:
- if err != nil {
- t.Fatalf("Error creating endpoint: %s", err)
- }
- case <-time.After(2 * time.Second):
- t.Fatalf("Timed out waiting for connection")
- }
-
- // Check that data gets properly segmented.
- testBrokenUpWrite(t, c, maxPayload)
-}
-
-func TestSynOptionsOnActiveConnect(t *testing.T) {
- const mtu = 1400
- c := context.New(t, mtu)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- // Set the buffer size to a deterministic size so that we can check the
- // window scaling option.
- const rcvBufferSize = 0x20000
- const wndScale = 3
- c.EP.SocketOptions().SetReceiveBufferSize(rcvBufferSize*2, true /* notify */)
-
- // Start connection attempt.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.WritableEvents)
- defer c.WQ.EventUnregister(&we)
-
- {
- err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
- t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
- }
- }
-
- // Receive SYN packet.
- b := c.GetPacket()
- mss := uint16(mtu - header.IPv4MinimumSize - header.TCPMinimumSize)
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}),
- ),
- )
-
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- // Wait for retransmit.
- time.Sleep(1 * time.Second)
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- checker.SrcPort(tcpHdr.SourcePort()),
- checker.TCPSeqNum(tcpHdr.SequenceNumber()),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}),
- ),
- )
-
- // Send SYN-ACK.
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: tcpHdr.DestinationPort(),
- DstPort: tcpHdr.SourcePort(),
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Receive ACK packet.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+1),
- ),
- )
-
- // Wait for connection to be established.
- select {
- case <-ch:
- if err := c.EP.LastError(); err != nil {
- t.Fatalf("Connect failed: %s", err)
- }
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for connection")
- }
-}
-
-func TestCloseListener(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create listener.
- var wq waiter.Queue
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- if err := ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Close the listener and measure how long it takes.
- t0 := time.Now()
- ep.Close()
- if diff := time.Now().Sub(t0); diff > 3*time.Second {
- t.Fatalf("Took too long to close: %s", diff)
- }
-}
-
-func TestReceiveOnResetConnection(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- // Send RST segment.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagRst,
- SeqNum: iss,
- RcvWnd: 30000,
- })
-
- // Try to read.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
-loop:
- for {
- switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) {
- case *tcpip.ErrWouldBlock:
- <-ch
- // Expect the state to be StateError and subsequent Reads to fail with HardError.
- _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
- if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" {
- t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d)
- }
- break loop
- case *tcpip.ErrConnectionReset:
- break loop
- default:
- t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, &tcpip.ErrConnectionReset{})
- }
- }
-
- if tcp.EndpointState(c.EP.State()) != tcp.StateError {
- t.Fatalf("got EP state is not StateError")
- }
-
- checkValid := func() []error {
- var errors []error
- if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
- errors = append(errors, fmt.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got))
- }
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- errors = append(errors, fmt.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got))
- }
- if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
- errors = append(errors, fmt.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got))
- }
- return errors
- }
-
- start := time.Now()
- for time.Since(start) < time.Minute && len(checkValid()) > 0 {
- time.Sleep(50 * time.Millisecond)
- }
- for _, err := range checkValid() {
- t.Error(err)
- }
-}
-
-func TestSendOnResetConnection(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- // Send RST segment.
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagRst,
- SeqNum: iss,
- RcvWnd: 30000,
- })
-
- // Wait for the RST to be received.
- time.Sleep(1 * time.Second)
-
- // Try to write.
- var r bytes.Reader
- r.Reset(make([]byte, 10))
- _, err := c.EP.Write(&r, tcpip.WriteOptions{})
- if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" {
- t.Fatalf("c.EP.Write(...) mismatch (-want +got):\n%s", d)
- }
-}
-
-// TestMaxRetransmitsTimeout tests if the connection is timed out after
-// a segment has been retransmitted MaxRetries times.
-func TestMaxRetransmitsTimeout(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- const numRetries = 2
- opt := tcpip.TCPMaxRetriesOption(numRetries)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
-
- // Wait for the connection to timeout after MaxRetries retransmits.
- initRTO := time.Second
- minRTOOpt := tcpip.TCPMinRTOOption(initRTO)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
- }
- c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
-
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
- defer c.WQ.EventUnregister(&waitEntry)
-
- var r bytes.Reader
- r.Reset(make([]byte, 1))
- _, err := c.EP.Write(&r, tcpip.WriteOptions{})
- if err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Expect first transmit and MaxRetries retransmits.
- for i := 0; i < numRetries+1; i++ {
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh),
- ),
- )
- }
- select {
- case <-notifyCh:
- case <-time.After((2 << numRetries) * initRTO):
- t.Fatalf("connection still alive after maximum retransmits.\n")
- }
-
- // Send an ACK and expect a RST as the connection would have been closed.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- })
-
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst),
- ),
- )
-
- if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got)
- }
- if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
- }
-}
-
-// TestMaxRTO tests if the retransmit interval caps to MaxRTO.
-func TestMaxRTO(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- rto := 1 * time.Second
- minRTOOpt := tcpip.TCPMinRTOOption(rto / 2)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
- }
- maxRTOOpt := tcpip.TCPMaxRTOOption(rto)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &maxRTOOpt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, maxRTOOpt, maxRTOOpt, err)
- }
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
-
- var r bytes.Reader
- r.Reset(make([]byte, 1))
- _, err := c.EP.Write(&r, tcpip.WriteOptions{})
- if err != nil {
- t.Fatalf("Write failed: %s", err)
- }
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
- const numRetransmits = 2
- for i := 0; i < numRetransmits; i++ {
- start := time.Now()
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
- if elapsed := time.Since(start); elapsed.Round(time.Second).Seconds() != rto.Seconds() {
- t.Errorf("Retransmit interval not capped to MaxRTO(%s). %s", rto, elapsed)
- }
- }
-}
-
-// TestZeroSizedWriteRetransmit tests that a zero sized write should not
-// result in a panic on an RTO as no segment should have been queued for
-// a zero sized write.
-func TestZeroSizedWriteRetransmit(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
-
- var r bytes.Reader
- _, err := c.EP.Write(&r, tcpip.WriteOptions{})
- if err != nil {
- t.Fatalf("Write failed: %s", err)
- }
- // Now do a non-zero sized write to trigger actual sending of data.
- r.Reset(make([]byte, 1))
- _, err = c.EP.Write(&r, tcpip.WriteOptions{})
- if err != nil {
- t.Fatalf("Write failed: %s", err)
- }
- // Do not ACK the packet and expect an original transmit and a
- // retransmit. This should not cause a panic.
- for i := 0; i < 2; i++ {
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
- }
-}
-
-// TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is
-// unique on retransmits.
-func TestRetransmitIPv4IDUniqueness(t *testing.T) {
- for _, tc := range []struct {
- name string
- size int
- }{
- {"1Byte", 1},
- {"512Bytes", 512},
- } {
- t.Run(tc.name, func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- minRTOOpt := tcpip.TCPMinRTOOption(time.Second)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
- }
- c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
-
- // Disabling PMTU discovery causes all packets sent from this socket to
- // have DF=0. This needs to be done because the IPv4 ID uniqueness
- // applies only to non-atomic IPv4 datagrams as defined in RFC 6864
- // Section 4, and datagrams with DF=0 are non-atomic.
- if err := c.EP.SetSockOptInt(tcpip.MTUDiscoverOption, tcpip.PMTUDiscoveryDont); err != nil {
- t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err)
- }
-
- var r bytes.Reader
- r.Reset(make([]byte, tc.size))
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
- pkt := c.GetPacket()
- checker.IPv4(t, pkt,
- checker.FragmentFlags(0),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
- idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): {}}
- // Expect two retransmitted packets, and that all packets received have
- // unique IPv4 ID values.
- for i := 0; i <= 2; i++ {
- pkt := c.GetPacket()
- checker.IPv4(t, pkt,
- checker.FragmentFlags(0),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
- id := header.IPv4(pkt).ID()
- if _, exists := idSet[id]; exists {
- t.Fatalf("duplicate IPv4 ID=%d found in retransmitted packet", id)
- }
- idSet[id] = struct{}{}
- }
- })
- }
-}
-
-func TestFinImmediately(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- // Shutdown immediately, check that we get a FIN.
- if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
-
- // Ack and send FIN as well.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss,
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- // Check that the stack acks the FIN.
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+2),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestFinRetransmit(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- // Shutdown immediately, check that we get a FIN.
- if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
-
- // Don't acknowledge yet. We should get a retransmit of the FIN.
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
-
- // Ack and send FIN as well.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss,
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- // Check that the stack acks the FIN.
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+2),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestFinWithNoPendingData(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- // Write something out, and have it acknowledged.
- view := make([]byte, 10)
- var r bytes.Reader
- r.Reset(view)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- next := uint32(c.IRS) + 1
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
- next += uint32(len(view))
-
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- // Shutdown, check that we get a FIN.
- if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
- next++
-
- // Ack and send FIN as well.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- // Check that the stack acks the FIN.
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestFinWithPendingDataCwndFull(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- // Write enough segments to fill the congestion window before ACK'ing
- // any of them.
- view := make([]byte, 10)
- var r bytes.Reader
- for i := tcp.InitialCwnd; i > 0; i-- {
- r.Reset(view)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
- }
-
- next := uint32(c.IRS) + 1
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- for i := tcp.InitialCwnd; i > 0; i-- {
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
- next += uint32(len(view))
- }
-
- // Shutdown the connection, check that the FIN segment isn't sent
- // because the congestion window doesn't allow it. Wait until a
- // retransmit is received.
- if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- // Send the ACK that will allow the FIN to be sent as well.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
- next++
-
- // Send a FIN that acknowledges everything. Get an ACK back.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestFinWithPendingData(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- // Write something out, and acknowledge it to get cwnd to 2.
- view := make([]byte, 10)
- var r bytes.Reader
- r.Reset(view)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- next := uint32(c.IRS) + 1
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
- next += uint32(len(view))
-
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- // Write new data, but don't acknowledge it.
- r.Reset(view)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
- next += uint32(len(view))
-
- // Shutdown the connection, check that we do get a FIN.
- if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
- next++
-
- // Send a FIN that acknowledges everything. Get an ACK back.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestFinWithPartialAck(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- // Write something out, and acknowledge it to get cwnd to 2. Also send
- // FIN from the test side.
- view := make([]byte, 10)
- var r bytes.Reader
- r.Reset(view)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- next := uint32(c.IRS) + 1
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
- next += uint32(len(view))
-
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- // Check that we get an ACK for the fin.
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- // Write new data, but don't acknowledge it.
- r.Reset(view)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
- next += uint32(len(view))
-
- // Shutdown the connection, check that we do get a FIN.
- if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
- next++
-
- // Send an ACK for the data, but not for the FIN yet.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss.Add(1),
- AckNum: seqnum.Value(next - 1),
- RcvWnd: 30000,
- })
-
- // Check that we don't get a retransmit of the FIN.
- c.CheckNoPacketTimeout("FIN retransmitted when data was ack'd", 100*time.Millisecond)
-
- // Ack the FIN.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss.Add(1),
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-}
-
-func TestUpdateListenBacklog(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create listener.
- var wq waiter.Queue
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- if err := ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Update the backlog with another Listen() on the same endpoint.
- if err := ep.Listen(20); err != nil {
- t.Fatalf("Listen failed to update backlog: %s", err)
- }
-
- ep.Close()
-}
-
-func scaledSendWindow(t *testing.T, scale uint8) {
- // This test ensures that the endpoint is using the right scaling by
- // sending a buffer that is larger than the window size, and ensuring
- // that the endpoint doesn't send more than allowed.
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize
- c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 0, -1 /* epRcvBuf */, []byte{
- header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
- header.TCPOptionWS, 3, scale, header.TCPOptionNOP,
- })
-
- // Open up the window with a scaled value.
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 1,
- })
-
- // Send some data. Check that it's capped by the window size.
- view := make([]byte, 65535)
- var r bytes.Reader
- r.Reset(view)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Check that only data that fits in the scaled window is sent.
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen((1<<scale)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- // Reset the connection to free resources.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagRst,
- SeqNum: iss,
- })
-}
-
-func TestScaledSendWindow(t *testing.T) {
- for scale := uint8(0); scale <= 14; scale++ {
- scaledSendWindow(t, scale)
- }
-}
-
-func TestReceivedValidSegmentCountIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
- stats := c.Stack().Stats()
- want := stats.TCP.ValidSegmentsReceived.Value() + 1
-
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- if got := stats.TCP.ValidSegmentsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %d, want = %d", got, want)
- }
- if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want {
- t.Errorf("got EP stats Stats.SegmentsReceived = %d, want = %d", got, want)
- }
- // Ensure there were no errors during handshake. If these stats have
- // incremented, then the connection should not have been established.
- if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 {
- t.Errorf("got EP stats Stats.SendErrors.NoRoute = %d, want = %d", got, 0)
- }
-}
-
-func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
- stats := c.Stack().Stats()
- want := stats.TCP.InvalidSegmentsReceived.Value() + 1
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- vv := c.BuildSegment(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
- tcpbuf := vv.ToView()[header.IPv4MinimumSize:]
- tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4
-
- c.SendSegment(vv)
-
- if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %d, want = %d", got, want)
- }
- if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
- }
-}
-
-func TestReceivedIncorrectChecksumIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
- stats := c.Stack().Stats()
- want := stats.TCP.ChecksumErrors.Value() + 1
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
- tcpbuf := vv.ToView()[header.IPv4MinimumSize:]
- // Overwrite a byte in the payload which should cause checksum
- // verification to fail.
- tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4
-
- c.SendSegment(vv)
-
- if got := stats.TCP.ChecksumErrors.Value(); got != want {
- t.Errorf("got stats.TCP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ChecksumErrors.Value(); got != want {
- t.Errorf("got EP stats Stats.ReceiveErrors.ChecksumErrors = %d, want = %d", got, want)
- }
-}
-
-func TestReceivedSegmentQueuing(t *testing.T) {
- // This test sends 200 segments containing a few bytes each to an
- // endpoint and checks that they're all received and acknowledged by
- // the endpoint, that is, that none of the segments are dropped by
- // internal queues.
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- // Send 200 segments.
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- data := []byte{1, 2, 3}
- for i := 0; i < 200; i++ {
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss.Add(seqnum.Size(i * len(data))),
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
- }
-
- // Receive ACKs for all segments.
- last := iss.Add(seqnum.Size(200 * len(data)))
- for {
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- ack := seqnum.Value(tcpHdr.AckNumber())
- if ack == last {
- break
- }
-
- if last.LessThan(ack) {
- t.Fatalf("Acknowledge (%v) beyond the expected (%v)", ack, last)
- }
- }
-}
-
-func TestReadAfterClosedState(t *testing.T) {
- // This test ensures that calling Read() or Peek() after the endpoint
- // has transitioned to closedState still works if there is pending
- // data. To transition to stateClosed without calling Close(), we must
- // shutdown the send path and the peer must send its own FIN.
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed
- // after 1 second in TIME_WAIT state.
- tcpTimeWaitTimeout := 1 * time.Second
- opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- ept := endpointTester{c.EP}
- ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
-
- // Shutdown immediately for write, check that we get a FIN.
- if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
-
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
-
- // Send some data and acknowledge the FIN.
- data := []byte{1, 2, 3}
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss,
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- // Check that ACK is received.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+2),
- checker.TCPAckNum(uint32(iss)+uint32(len(data))+1),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-
- // Give the stack the chance to transition to closed state from
- // TIME_WAIT.
- time.Sleep(tcpTimeWaitTimeout * 2)
-
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
-
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
- // Check that peek works.
- var peekBuf bytes.Buffer
- res, err := c.EP.Read(&peekBuf, tcpip.ReadOptions{Peek: true})
- if err != nil {
- t.Fatalf("Peek failed: %s", err)
- }
-
- if got, want := res.Count, len(data); got != want {
- t.Fatalf("res.Count = %d, want %d", got, want)
- }
- if !bytes.Equal(data, peekBuf.Bytes()) {
- t.Fatalf("got data = %v, want = %v", peekBuf.Bytes(), data)
- }
-
- // Receive data.
- v := ept.CheckRead(t)
- if !bytes.Equal(data, v) {
- t.Fatalf("got data = %v, want = %v", v, data)
- }
-
- // Now that we drained the queue, check that functions fail with the
- // right error code.
- ept.CheckReadError(t, &tcpip.ErrClosedForReceive{})
- var buf bytes.Buffer
- {
- _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true})
- if d := cmp.Diff(&tcpip.ErrClosedForReceive{}, err); d != "" {
- t.Fatalf("c.EP.Read(_, {Peek: true}) mismatch (-want +got):\n%s", d)
- }
- }
-}
-
-func TestReusePort(t *testing.T) {
- // This test ensures that ports are immediately available for reuse
- // after Close on the endpoints using them returns.
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // First case, just an endpoint that was bound.
- var err tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %s", err)
- }
- c.EP.SocketOptions().SetReuseAddress(true)
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- c.EP.Close()
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %s", err)
- }
- c.EP.SocketOptions().SetReuseAddress(true)
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
- c.EP.Close()
-
- // Second case, an endpoint that was bound and is connecting..
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %s", err)
- }
- c.EP.SocketOptions().SetReuseAddress(true)
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
- {
- err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
- t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d)
- }
- }
- c.EP.Close()
-
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %s", err)
- }
- c.EP.SocketOptions().SetReuseAddress(true)
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
- c.EP.Close()
-
- // Third case, an endpoint that was bound and is listening.
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %s", err)
- }
- c.EP.SocketOptions().SetReuseAddress(true)
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
- if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
- c.EP.Close()
-
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %s", err)
- }
- c.EP.SocketOptions().SetReuseAddress(true)
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
- if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-}
-
-func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
- t.Helper()
-
- s := ep.SocketOptions().GetReceiveBufferSize()
- if int(s) != v {
- t.Fatalf("got receive buffer size = %d, want = %d", s, v)
- }
-}
-
-func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
- t.Helper()
-
- if s := ep.SocketOptions().GetSendBufferSize(); int(s) != v {
- t.Fatalf("got send buffer size = %d, want = %d", s, v)
- }
-}
-
-func TestDefaultBufferSizes(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
- })
-
- // Check the default values.
- ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %s", err)
- }
- defer func() {
- if ep != nil {
- ep.Close()
- }
- }()
-
- checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize)
- checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
-
- // Change the default send buffer size.
- {
- opt := tcpip.TCPSendBufferSizeRangeOption{
- Min: 1,
- Default: tcp.DefaultSendBufferSize * 2,
- Max: tcp.DefaultSendBufferSize * 20,
- }
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
- }
- }
-
- ep.Close()
- ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %s", err)
- }
-
- checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2)
- checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
-
- // Change the default receive buffer size.
- {
- opt := tcpip.TCPReceiveBufferSizeRangeOption{
- Min: 1,
- Default: tcp.DefaultReceiveBufferSize * 3,
- Max: tcp.DefaultReceiveBufferSize * 30,
- }
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
- }
- }
-
- ep.Close()
- ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %s", err)
- }
-
- checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2)
- checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*3)
-}
-
-func TestBindToDeviceOption(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}})
-
- ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %s", err)
- }
- defer ep.Close()
-
- if err := s.CreateNIC(321, loopback.New()); err != nil {
- t.Errorf("CreateNIC failed: %s", err)
- }
-
- // nicIDPtr is used instead of taking the address of NICID literals, which is
- // a compiler error.
- nicIDPtr := func(s tcpip.NICID) *tcpip.NICID {
- return &s
- }
-
- testActions := []struct {
- name string
- setBindToDevice *tcpip.NICID
- setBindToDeviceError tcpip.Error
- getBindToDevice int32
- }{
- {"GetDefaultValue", nil, nil, 0},
- {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0},
- {"BindToExistent", nicIDPtr(321), nil, 321},
- {"UnbindToDevice", nicIDPtr(0), nil, 0},
- }
- for _, testAction := range testActions {
- t.Run(testAction.name, func(t *testing.T) {
- if testAction.setBindToDevice != nil {
- bindToDevice := int32(*testAction.setBindToDevice)
- if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
- t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr)
- }
- }
- bindToDevice := ep.SocketOptions().GetBindToDevice()
- if bindToDevice != testAction.getBindToDevice {
- t.Errorf("got bindToDevice = %d, want %d", bindToDevice, testAction.getBindToDevice)
- }
- })
- }
-}
-
-func makeStack() (*stack.Stack, tcpip.Error) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{
- ipv4.NewProtocol,
- ipv6.NewProtocol,
- },
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
- })
-
- id := loopback.New()
- if testing.Verbose() {
- id = sniffer.New(id)
- }
-
- if err := s.CreateNIC(1, id); err != nil {
- return nil, err
- }
-
- for _, ct := range []struct {
- number tcpip.NetworkProtocolNumber
- addrWithPrefix tcpip.AddressWithPrefix
- }{
- {ipv4.ProtocolNumber, context.StackAddrWithPrefix},
- {ipv6.ProtocolNumber, context.StackV6AddrWithPrefix},
- } {
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: ct.number,
- AddressWithPrefix: ct.addrWithPrefix,
- }
- if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
- return nil, err
- }
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: 1,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: 1,
- },
- })
-
- return s, nil
-}
-
-func TestSelfConnect(t *testing.T) {
- // This test ensures that intentional self-connects work. In particular,
- // it checks that if an endpoint binds to say 127.0.0.1:1000 then
- // connects to 127.0.0.1:1000, then it will be connected to itself, and
- // is able to send and receive data through the same endpoint.
- s, err := makeStack()
- if err != nil {
- t.Fatal(err)
- }
-
- var wq waiter.Queue
- ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- defer ep.Close()
-
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- // Register for notification, then start connection attempt.
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- wq.EventRegister(&waitEntry, waiter.WritableEvents)
- defer wq.EventUnregister(&waitEntry)
-
- {
- err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort})
- if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
- t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d)
- }
- }
-
- <-notifyCh
- if err := ep.LastError(); err != nil {
- t.Fatalf("Connect failed: %s", err)
- }
-
- // Write something.
- data := []byte{1, 2, 3}
- var r bytes.Reader
- r.Reset(data)
- if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Read back what was written.
- wq.EventUnregister(&waitEntry)
- wq.EventRegister(&waitEntry, waiter.ReadableEvents)
- ept := endpointTester{ep}
- rd := ept.CheckReadFull(t, len(data), notifyCh, 5*time.Second)
-
- if !bytes.Equal(data, rd) {
- t.Fatalf("got data = %v, want = %v", rd, data)
- }
-}
-
-func TestConnectAvoidsBoundPorts(t *testing.T) {
- addressTypes := func(t *testing.T, network string) []string {
- switch network {
- case "ipv4":
- return []string{"v4"}
- case "ipv6":
- return []string{"v6"}
- case "dual":
- return []string{"v6", "mapped"}
- default:
- t.Fatalf("unknown network: '%s'", network)
- }
-
- panic("unreachable")
- }
-
- address := func(t *testing.T, addressType string, isAny bool) tcpip.Address {
- switch addressType {
- case "v4":
- if isAny {
- return ""
- }
- return context.StackAddr
- case "v6":
- if isAny {
- return ""
- }
- return context.StackV6Addr
- case "mapped":
- if isAny {
- return context.V4MappedWildcardAddr
- }
- return context.StackV4MappedAddr
- default:
- t.Fatalf("unknown address type: '%s'", addressType)
- }
-
- panic("unreachable")
- }
- // This test ensures that Endpoint.Connect doesn't select already-bound ports.
- networks := []string{"ipv4", "ipv6", "dual"}
- for _, exhaustedNetwork := range networks {
- t.Run(fmt.Sprintf("exhaustedNetwork=%s", exhaustedNetwork), func(t *testing.T) {
- for _, exhaustedAddressType := range addressTypes(t, exhaustedNetwork) {
- t.Run(fmt.Sprintf("exhaustedAddressType=%s", exhaustedAddressType), func(t *testing.T) {
- for _, isAny := range []bool{false, true} {
- t.Run(fmt.Sprintf("isAny=%t", isAny), func(t *testing.T) {
- for _, candidateNetwork := range networks {
- t.Run(fmt.Sprintf("candidateNetwork=%s", candidateNetwork), func(t *testing.T) {
- for _, candidateAddressType := range addressTypes(t, candidateNetwork) {
- t.Run(fmt.Sprintf("candidateAddressType=%s", candidateAddressType), func(t *testing.T) {
- s, err := makeStack()
- if err != nil {
- t.Fatal(err)
- }
-
- var wq waiter.Queue
- var eps []tcpip.Endpoint
- defer func() {
- for _, ep := range eps {
- ep.Close()
- }
- }()
- makeEP := func(network string) tcpip.Endpoint {
- var networkProtocolNumber tcpip.NetworkProtocolNumber
- switch network {
- case "ipv4":
- networkProtocolNumber = ipv4.ProtocolNumber
- case "ipv6", "dual":
- networkProtocolNumber = ipv6.ProtocolNumber
- default:
- t.Fatalf("unknown network: '%s'", network)
- }
- ep, err := s.NewEndpoint(tcp.ProtocolNumber, networkProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- eps = append(eps, ep)
- switch network {
- case "ipv4":
- case "ipv6":
- ep.SocketOptions().SetV6Only(true)
- case "dual":
- ep.SocketOptions().SetV6Only(false)
- default:
- t.Fatalf("unknown network: '%s'", network)
- }
- return ep
- }
-
- var v4reserved, v6reserved bool
- switch exhaustedAddressType {
- case "v4", "mapped":
- v4reserved = true
- case "v6":
- v6reserved = true
- // Dual stack sockets bound to v6 any reserve on v4 as
- // well.
- if isAny {
- switch exhaustedNetwork {
- case "ipv6":
- case "dual":
- v4reserved = true
- default:
- t.Fatalf("unknown address type: '%s'", exhaustedNetwork)
- }
- }
- default:
- t.Fatalf("unknown address type: '%s'", exhaustedAddressType)
- }
- var collides bool
- switch candidateAddressType {
- case "v4", "mapped":
- collides = v4reserved
- case "v6":
- collides = v6reserved
- default:
- t.Fatalf("unknown address type: '%s'", candidateAddressType)
- }
-
- const (
- start = 16000
- end = 16050
- )
- if err := s.SetPortRange(start, end); err != nil {
- t.Fatalf("got s.SetPortRange(%d, %d) = %s, want = nil", start, end, err)
- }
- for i := start; i <= end; i++ {
- if err := makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil {
- t.Fatalf("Bind(%d) failed: %s", i, err)
- }
- }
- var want tcpip.Error = &tcpip.ErrConnectStarted{}
- if collides {
- want = &tcpip.ErrNoPortAvailable{}
- }
- if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want {
- t.Fatalf("got ep.Connect(..) = %s, want = %s", err, want)
- }
- })
- }
- })
- }
- })
- }
- })
- }
- })
- }
-}
-
-func TestPathMTUDiscovery(t *testing.T) {
- // This test verifies the stack retransmits packets after it receives an
- // ICMP packet indicating that the path MTU has been exceeded.
- c := context.New(t, 1500)
- defer c.Cleanup()
-
- // Create new connection with MSS of 1460.
- const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize
- c.CreateConnectedWithRawOptions(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */, []byte{
- header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
- })
-
- // Send 3200 bytes of data.
- const writeSize = 3200
- data := make([]byte, writeSize)
- for i := range data {
- data[i] = byte(i)
- }
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte {
- var ret []byte
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- for i, size := range sizes {
- p := c.GetPacket()
- if i == which {
- ret = p
- }
- checker.IPv4(t, p,
- checker.PayloadLen(size+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(seqNum),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
- seqNum += uint32(size)
- }
- return ret
- }
-
- // Receive three packets.
- sizes := []int{maxPayload, maxPayload, writeSize - 2*maxPayload}
- first := receivePackets(c, sizes, 0, uint32(c.IRS)+1)
-
- // Send "packet too big" messages back to netstack.
- const newMTU = 1200
- const newMaxPayload = newMTU - header.IPv4MinimumSize - header.TCPMinimumSize
- mtu := []byte{0, 0, newMTU / 256, newMTU % 256}
- c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, mtu, first, newMTU)
-
- // See retransmitted packets. None exceeding the new max.
- sizes = []int{newMaxPayload, maxPayload - newMaxPayload, newMaxPayload, maxPayload - newMaxPayload, writeSize - 2*maxPayload}
- receivePackets(c, sizes, -1, uint32(c.IRS)+1)
-}
-
-func TestTCPEndpointProbe(t *testing.T) {
- c := context.New(t, 1500)
- defer c.Cleanup()
-
- invoked := make(chan struct{})
- c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
- // Validate that the endpoint ID is what we expect.
- //
- // We don't do an extensive validation of every field but a
- // basic sanity test.
- if got, want := state.ID.LocalAddress, tcpip.Address(context.StackAddr); got != want {
- t.Fatalf("got LocalAddress: %q, want: %q", got, want)
- }
- if got, want := state.ID.LocalPort, c.Port; got != want {
- t.Fatalf("got LocalPort: %d, want: %d", got, want)
- }
- if got, want := state.ID.RemoteAddress, tcpip.Address(context.TestAddr); got != want {
- t.Fatalf("got RemoteAddress: %q, want: %q", got, want)
- }
- if got, want := state.ID.RemotePort, uint16(context.TestPort); got != want {
- t.Fatalf("got RemotePort: %d, want: %d", got, want)
- }
-
- invoked <- struct{}{}
- })
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- data := []byte{1, 2, 3}
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- select {
- case <-invoked:
- case <-time.After(100 * time.Millisecond):
- t.Fatalf("TCP Probe function was not called")
- }
-}
-
-func TestStackSetCongestionControl(t *testing.T) {
- testCases := []struct {
- cc tcpip.CongestionControlOption
- err tcpip.Error
- }{
- {"reno", nil},
- {"cubic", nil},
- {"blahblah", &tcpip.ErrNoSuchFile{}},
- }
-
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("SetTransportProtocolOption(.., %v)", tc.cc), func(t *testing.T) {
- c := context.New(t, 1500)
- defer c.Cleanup()
-
- s := c.Stack()
-
- var oldCC tcpip.CongestionControlOption
- if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = %s", tcp.ProtocolNumber, &oldCC, err)
- }
-
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &tc.cc); err != tc.err {
- t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = %s, want = %s", tcp.ProtocolNumber, tc.cc, tc.cc, err, tc.err)
- }
-
- var cc tcpip.CongestionControlOption
- if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err)
- }
-
- got, want := cc, oldCC
- // If SetTransportProtocolOption is expected to succeed
- // then the returned value for congestion control should
- // match the one specified in the
- // SetTransportProtocolOption call above, else it should
- // be what it was before the call to
- // SetTransportProtocolOption.
- if tc.err == nil {
- want = tc.cc
- }
- if got != want {
- t.Fatalf("got congestion control: %v, want: %v", got, want)
- }
- })
- }
-}
-
-func TestStackAvailableCongestionControl(t *testing.T) {
- c := context.New(t, 1500)
- defer c.Cleanup()
-
- s := c.Stack()
-
- // Query permitted congestion control algorithms.
- var aCC tcpip.TCPAvailableCongestionControlOption
- if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err)
- }
- if got, want := aCC, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want {
- t.Fatalf("got tcpip.TCPAvailableCongestionControlOption: %v, want: %v", got, want)
- }
-}
-
-func TestStackSetAvailableCongestionControl(t *testing.T) {
- c := context.New(t, 1500)
- defer c.Cleanup()
-
- s := c.Stack()
-
- // Setting AvailableCongestionControlOption should fail.
- aCC := tcpip.TCPAvailableCongestionControlOption("xyz")
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil {
- t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = nil, want non-nil", tcp.ProtocolNumber, aCC, aCC)
- }
-
- // Verify that we still get the expected list of congestion control options.
- var cc tcpip.TCPAvailableCongestionControlOption
- if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
- t.Fatalf("s.TransportProtocolOptio(%d, &%T(%s)): %s", tcp.ProtocolNumber, cc, cc, err)
- }
- if got, want := cc, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want {
- t.Fatalf("got tcpip.TCPAvailableCongestionControlOption = %s, want = %s", got, want)
- }
-}
-
-func TestEndpointSetCongestionControl(t *testing.T) {
- testCases := []struct {
- cc tcpip.CongestionControlOption
- err tcpip.Error
- }{
- {"reno", nil},
- {"cubic", nil},
- {"blahblah", &tcpip.ErrNoSuchFile{}},
- }
-
- for _, connected := range []bool{false, true} {
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("SetSockOpt(.., %v) w/ connected = %v", tc.cc, connected), func(t *testing.T) {
- c := context.New(t, 1500)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- var oldCC tcpip.CongestionControlOption
- if err := c.EP.GetSockOpt(&oldCC); err != nil {
- t.Fatalf("c.EP.GetSockOpt(&%T) = %s", oldCC, err)
- }
-
- if connected {
- c.Connect(context.TestInitialSequenceNumber, 32768 /* rcvWnd */, nil)
- }
-
- if err := c.EP.SetSockOpt(&tc.cc); err != tc.err {
- t.Fatalf("got c.EP.SetSockOpt(&%#v) = %s, want %s", tc.cc, err, tc.err)
- }
-
- var cc tcpip.CongestionControlOption
- if err := c.EP.GetSockOpt(&cc); err != nil {
- t.Fatalf("c.EP.GetSockOpt(&%T): %s", cc, err)
- }
-
- got, want := cc, oldCC
- // If SetSockOpt is expected to succeed then the
- // returned value for congestion control should match
- // the one specified in the SetSockOpt above, else it
- // should be what it was before the call to SetSockOpt.
- if tc.err == nil {
- want = tc.cc
- }
- if got != want {
- t.Fatalf("got congestion control = %+v, want = %+v", got, want)
- }
- })
- }
- }
-}
-
-func enableCUBIC(t *testing.T, c *context.Context) {
- t.Helper()
- opt := tcpip.CongestionControlOption("cubic")
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)) %s", tcp.ProtocolNumber, opt, opt, err)
- }
-}
-
-func TestKeepalive(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- const keepAliveIdle = 100 * time.Millisecond
- const keepAliveInterval = 3 * time.Second
- keepAliveIdleOpt := tcpip.KeepaliveIdleOption(keepAliveIdle)
- if err := c.EP.SetSockOpt(&keepAliveIdleOpt); err != nil {
- t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOpt, keepAliveIdle, err)
- }
- keepAliveIntervalOpt := tcpip.KeepaliveIntervalOption(keepAliveInterval)
- if err := c.EP.SetSockOpt(&keepAliveIntervalOpt); err != nil {
- t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOpt, keepAliveInterval, err)
- }
- c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5)
- if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5); err != nil {
- t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5): %s", err)
- }
- c.EP.SocketOptions().SetKeepAlive(true)
-
- // 5 unacked keepalives are sent. ACK each one, and check that the
- // connection stays alive after 5.
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- for i := 0; i < 10; i++ {
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-
- // Acknowledge the keepalive.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS,
- RcvWnd: 30000,
- })
- }
-
- // Check that the connection is still alive.
- ept := endpointTester{c.EP}
- ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
-
- // Send some data and wait before ACKing it. Keepalives should be disabled
- // during this period.
- view := make([]byte, 3)
- var r bytes.Reader
- r.Reset(view)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- next := uint32(c.IRS) + 1
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- // Wait for the packet to be retransmitted. Verify that no keepalives
- // were sent.
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh),
- ),
- )
- c.CheckNoPacket("Keepalive packet received while unACKed data is pending")
-
- next += uint32(len(view))
-
- // Send ACK. Keepalives should start sending again.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- // Now receive 5 keepalives, but don't ACK them. The connection
- // should be reset after 5.
- for i := 0; i < 5; i++ {
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next-1),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- }
-
- // Sleep for a litte over the KeepAlive interval to make sure
- // the timer has time to fire after the last ACK and close the
- // close the socket.
- time.Sleep(keepAliveInterval + keepAliveInterval/2)
-
- // The connection should be terminated after 5 unacked keepalives.
- // Send an ACK to trigger a RST from the stack as the endpoint should
- // be dead.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst)),
- )
-
- if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got)
- }
-
- ept.CheckReadError(t, &tcpip.ErrTimeout{})
-
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
- }
- if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
- }
-}
-
-func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
- t.Helper()
- // Send a SYN request.
- irs = seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: srcPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcp := header.TCP(header.IPv4(b).Payload())
- iss = seqnum.Value(tcp.SequenceNumber())
- tcpCheckers := []checker.TransportChecker{
- checker.SrcPort(context.StackPort),
- checker.DstPort(srcPort),
- checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.TCPAckNum(uint32(irs) + 1),
- }
-
- if synCookieInUse {
- // When cookies are in use window scaling is disabled.
- tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{
- WS: -1,
- MSS: c.MSSWithoutOptions(),
- }))
- }
-
- checker.IPv4(t, b, checker.TCP(tcpCheckers...))
-
- // Send ACK.
- c.SendPacket(nil, &context.Headers{
- SrcPort: srcPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
- return irs, iss
-}
-
-func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
- t.Helper()
- // Send a SYN request.
- irs = seqnum.Value(context.TestInitialSequenceNumber)
- c.SendV6Packet(nil, &context.Headers{
- SrcPort: srcPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetV6Packet()
- tcp := header.TCP(header.IPv6(b).Payload())
- iss = seqnum.Value(tcp.SequenceNumber())
- tcpCheckers := []checker.TransportChecker{
- checker.SrcPort(context.StackPort),
- checker.DstPort(srcPort),
- checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.TCPAckNum(uint32(irs) + 1),
- }
-
- if synCookieInUse {
- // When cookies are in use window scaling is disabled.
- tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{
- WS: -1,
- MSS: c.MSSWithoutOptionsV6(),
- }))
- }
-
- checker.IPv6(t, b, checker.TCP(tcpCheckers...))
-
- // Send ACK.
- c.SendV6Packet(nil, &context.Headers{
- SrcPort: srcPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
- return irs, iss
-}
-
-// TestListenBacklogFull tests that netstack does not complete handshakes if the
-// listen backlog for the endpoint is full.
-func TestListenBacklogFull(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- // Test acceptance.
- // Start listening.
- listenBacklog := 10
- if err := c.EP.Listen(listenBacklog); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- lastPortOffset := uint16(0)
- for ; int(lastPortOffset) < listenBacklog; lastPortOffset++ {
- executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */)
- }
-
- time.Sleep(50 * time.Millisecond)
-
- // Now execute send one more SYN. The stack should not respond as the backlog
- // is full at this point.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort + lastPortOffset,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: seqnum.Value(context.TestInitialSequenceNumber),
- RcvWnd: 30000,
- })
- c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
-
- // Try to accept the connections in the backlog.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- for i := 0; i < listenBacklog; i++ {
- _, _, err = c.EP.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- _, _, err = c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
- }
-
- // Now verify that there are no more connections that can be accepted.
- _, _, err = c.EP.Accept(nil)
- if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- select {
- case <-ch:
- t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP)
- case <-time.After(1 * time.Second):
- }
- }
-
- // Now a new handshake must succeed.
- executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */)
-
- newEP, _, err := c.EP.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- newEP, _, err = c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Now verify that the TCP socket is usable and in a connected state.
- data := "Don't panic"
- var r strings.Reader
- r.Reset(data)
- newEP.Write(&r, tcpip.WriteOptions{})
- b := c.GetPacket()
- tcp := header.TCP(header.IPv4(b).Payload())
- if string(tcp.Payload()) != data {
- t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data)
- }
-}
-
-// TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a
-// non unicast IPv4 address are not accepted.
-func TestListenNoAcceptNonUnicastV4(t *testing.T) {
- multicastAddr := tcpiptestutil.MustParse4("224.0.1.2")
- otherMulticastAddr := tcpiptestutil.MustParse4("224.0.1.3")
- subnet := context.StackAddrWithPrefix.Subnet()
- subnetBroadcastAddr := subnet.Broadcast()
-
- tests := []struct {
- name string
- srcAddr tcpip.Address
- dstAddr tcpip.Address
- }{
- {
- name: "SourceUnspecified",
- srcAddr: header.IPv4Any,
- dstAddr: context.StackAddr,
- },
- {
- name: "SourceBroadcast",
- srcAddr: header.IPv4Broadcast,
- dstAddr: context.StackAddr,
- },
- {
- name: "SourceOurMulticast",
- srcAddr: multicastAddr,
- dstAddr: context.StackAddr,
- },
- {
- name: "SourceOtherMulticast",
- srcAddr: otherMulticastAddr,
- dstAddr: context.StackAddr,
- },
- {
- name: "DestUnspecified",
- srcAddr: context.TestAddr,
- dstAddr: header.IPv4Any,
- },
- {
- name: "DestBroadcast",
- srcAddr: context.TestAddr,
- dstAddr: header.IPv4Broadcast,
- },
- {
- name: "DestOurMulticast",
- srcAddr: context.TestAddr,
- dstAddr: multicastAddr,
- },
- {
- name: "DestOtherMulticast",
- srcAddr: context.TestAddr,
- dstAddr: otherMulticastAddr,
- },
- {
- name: "SrcSubnetBroadcast",
- srcAddr: subnetBroadcastAddr,
- dstAddr: context.StackAddr,
- },
- {
- name: "DestSubnetBroadcast",
- srcAddr: context.TestAddr,
- dstAddr: subnetBroadcastAddr,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.Create(-1)
-
- if err := c.Stack().JoinGroup(header.IPv4ProtocolNumber, 1, multicastAddr); err != nil {
- t.Fatalf("JoinGroup failed: %s", err)
- }
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := c.EP.Listen(1); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- irs := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacketWithAddrs(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- }, test.srcAddr, test.dstAddr)
- c.CheckNoPacket("Should not have received a response")
-
- // Handle normal packet.
- c.SendPacketWithAddrs(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- }, context.TestAddr, context.StackAddr)
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.TCPAckNum(uint32(irs)+1)))
- })
- }
-}
-
-// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a
-// non unicast IPv6 address are not accepted.
-func TestListenNoAcceptNonUnicastV6(t *testing.T) {
- multicastAddr := tcpiptestutil.MustParse6("ff0e::101")
- otherMulticastAddr := tcpiptestutil.MustParse6("ff0e::102")
-
- tests := []struct {
- name string
- srcAddr tcpip.Address
- dstAddr tcpip.Address
- }{
- {
- "SourceUnspecified",
- header.IPv6Any,
- context.StackV6Addr,
- },
- {
- "SourceAllNodes",
- header.IPv6AllNodesMulticastAddress,
- context.StackV6Addr,
- },
- {
- "SourceOurMulticast",
- multicastAddr,
- context.StackV6Addr,
- },
- {
- "SourceOtherMulticast",
- otherMulticastAddr,
- context.StackV6Addr,
- },
- {
- "DestUnspecified",
- context.TestV6Addr,
- header.IPv6Any,
- },
- {
- "DestAllNodes",
- context.TestV6Addr,
- header.IPv6AllNodesMulticastAddress,
- },
- {
- "DestOurMulticast",
- context.TestV6Addr,
- multicastAddr,
- },
- {
- "DestOtherMulticast",
- context.TestV6Addr,
- otherMulticastAddr,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(true)
-
- if err := c.Stack().JoinGroup(header.IPv6ProtocolNumber, 1, multicastAddr); err != nil {
- t.Fatalf("JoinGroup failed: %s", err)
- }
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := c.EP.Listen(1); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- irs := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendV6PacketWithAddrs(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- }, test.srcAddr, test.dstAddr)
- c.CheckNoPacket("Should not have received a response")
-
- // Handle normal packet.
- c.SendV6PacketWithAddrs(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- }, context.TestV6Addr, context.StackV6Addr)
- checker.IPv6(t, c.GetV6Packet(),
- checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.TCPAckNum(uint32(irs)+1)))
- })
- }
-}
-
-func TestListenSynRcvdQueueFull(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- // Test acceptance.
- if err := c.EP.Listen(1); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send two SYN's the first one should get a SYN-ACK, the
- // second one should not get any response and is dropped as
- // the accept queue is full.
- irs := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcp := header.TCP(header.IPv4(b).Payload())
- iss := seqnum.Value(tcp.SequenceNumber())
- tcpCheckers := []checker.TransportChecker{
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.TCPAckNum(uint32(irs) + 1),
- }
- checker.IPv4(t, b, checker.TCP(tcpCheckers...))
-
- // Now complete the previous connection.
- // Send ACK.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
-
- // Verify if that is delivered to the accept queue.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
- <-ch
-
- // Now execute send one more SYN. The stack should not respond as the backlog
- // is full at this point.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort + 1,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: seqnum.Value(889),
- RcvWnd: 30000,
- })
- c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
-
- // Try to accept the connections in the backlog.
- newEP, _, err := c.EP.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- newEP, _, err = c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Now verify that the TCP socket is usable and in a connected state.
- data := "Don't panic"
- var r strings.Reader
- r.Reset(data)
- newEP.Write(&r, tcpip.WriteOptions{})
- pkt := c.GetPacket()
- tcp = header.IPv4(pkt).Payload()
- if string(tcp.Payload()) != data {
- t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data)
- }
-}
-
-func TestListenBacklogFullSynCookieInUse(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- // Test for SynCookies usage after filling up the backlog.
- if err := c.EP.Listen(1); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- executeHandshake(t, c, context.TestPort, false)
-
- // Wait for this to be delivered to the accept queue.
- time.Sleep(50 * time.Millisecond)
-
- // Send a SYN request.
- irs := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- // pick a different src port for new SYN.
- SrcPort: context.TestPort + 1,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
- // The Syn should be dropped as the endpoint's backlog is full.
- c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
-
- // Verify that there is only one acceptable connection at this point.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- _, _, err = c.EP.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- _, _, err = c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Now verify that there are no more connections that can be accepted.
- _, _, err = c.EP.Accept(nil)
- if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- select {
- case <-ch:
- t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP)
- case <-time.After(1 * time.Second):
- }
- }
-}
-
-func TestSYNRetransmit(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- // Start listening.
- if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send the same SYN packet multiple times. We should still get a valid SYN-ACK
- // reply.
- irs := seqnum.Value(context.TestInitialSequenceNumber)
- for i := 0; i < 5; i++ {
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
- }
-
- // Receive the SYN-ACK reply.
- tcpCheckers := []checker.TransportChecker{
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.TCPAckNum(uint32(irs) + 1),
- }
- checker.IPv4(t, c.GetPacket(), checker.TCP(tcpCheckers...))
-}
-
-func TestSynRcvdBadSeqNumber(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- // Start listening.
- if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send a SYN to get a SYN-ACK. This should put the ep into SYN-RCVD state
- irs := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- iss := seqnum.Value(tcpHdr.SequenceNumber())
- tcpCheckers := []checker.TransportChecker{
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.TCPAckNum(uint32(irs) + 1),
- }
- checker.IPv4(t, b, checker.TCP(tcpCheckers...))
-
- // Now send a packet with an out-of-window sequence number
- largeSeqnum := irs + seqnum.Value(tcpHdr.WindowSize()) + 1
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: largeSeqnum,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
-
- // Should receive an ACK with the expected SEQ number
- b = c.GetPacket()
- tcpCheckers = []checker.TransportChecker{
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPAckNum(uint32(irs) + 1),
- checker.TCPSeqNum(uint32(iss + 1)),
- }
- checker.IPv4(t, b, checker.TCP(tcpCheckers...))
-
- // Now that the socket replied appropriately with the ACK,
- // complete the connection to test that the large SEQ num
- // did not change the state from SYN-RCVD.
-
- // Get setup to be notified about connection establishment.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- // Send ACK to move to ESTABLISHED state.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
-
- <-ch
- newEP, _, err := c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- // Now verify that the TCP socket is usable and in a connected state.
- data := "Don't panic"
- var r strings.Reader
- r.Reset(data)
- if _, err := newEP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- pkt := c.GetPacket()
- tcpHdr = header.IPv4(pkt).Payload()
- if string(tcpHdr.Payload()) != data {
- t.Fatalf("unexpected data: got %s, want %s", string(tcpHdr.Payload()), data)
- }
-}
-
-func TestPassiveConnectionAttemptIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- c.EP = ep
- if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
- if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
- if err := c.EP.Listen(1); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
-
- stats := c.Stack().Stats()
- want := stats.TCP.PassiveConnectionOpenings.Value() + 1
-
- srcPort := uint16(context.TestPort)
- executeHandshake(t, c, srcPort+1, false)
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- // Verify that there is only one acceptable connection at this point.
- _, _, err = c.EP.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- _, _, err = c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want {
- t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %d, want = %d", got, want)
- }
-}
-
-func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- stats := c.Stack().Stats()
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- c.EP = ep
- if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
- if err := c.EP.Listen(1); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- srcPort := uint16(context.TestPort)
- // Now attempt a handshakes it will fill up the accept backlog.
- executeHandshake(t, c, srcPort, false)
-
- // Give time for the final ACK to be processed as otherwise the next handshake could
- // get accepted before the previous one based on goroutine scheduling.
- time.Sleep(50 * time.Millisecond)
-
- want := stats.TCP.ListenOverflowSynDrop.Value() + 1
-
- // Now we will send one more SYN and this one should get dropped
- // Send a SYN request.
- c.SendPacket(nil, &context.Headers{
- SrcPort: srcPort + 2,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: seqnum.Value(context.TestInitialSequenceNumber),
- RcvWnd: 30000,
- })
-
- checkValid := func() []error {
- var errors []error
- if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want {
- errors = append(errors, fmt.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want))
- }
- if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want {
- errors = append(errors, fmt.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want))
- }
- return errors
- }
-
- start := time.Now()
- for time.Since(start) < time.Minute && len(checkValid()) > 0 {
- time.Sleep(50 * time.Millisecond)
- }
- for _, err := range checkValid() {
- t.Error(err)
- }
- if t.Failed() {
- t.FailNow()
- }
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- // Now check that there is one acceptable connections.
- _, _, err = c.EP.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- <-ch
- _, _, err = c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
- }
-}
-
-func TestListenDropIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- stats := c.Stack().Stats()
- c.Create(-1 /*epRcvBuf*/)
-
- if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
- if err := c.EP.Listen(1 /*backlog*/); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- initialDropped := stats.DroppedPackets.Value()
-
- // Send RST, FIN segments, that are expected to be dropped by the listener.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagRst,
- })
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagFin,
- })
-
- // To ensure that the RST, FIN sent earlier are indeed received and ignored
- // by the listener, send a SYN and wait for the SYN to be ACKd.
- irs := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- })
- checker.IPv4(t, c.GetPacket(), checker.TCP(checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.TCPAckNum(uint32(irs)+1),
- ))
-
- if got, want := stats.DroppedPackets.Value(), initialDropped+2; got != want {
- t.Fatalf("got stats.DroppedPackets.Value() = %d, want = %d", got, want)
- }
-}
-
-func TestEndpointBindListenAcceptState(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
- if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
-
- ept := endpointTester{ep}
- ept.CheckReadError(t, &tcpip.ErrNotConnected{})
- if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 {
- t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
- if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
-
- c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS}, 0 /* delay */)
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
-
- aep, _, err := ep.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- aep, _, err = ep.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
- if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
- {
- err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
- if d := cmp.Diff(&tcpip.ErrAlreadyConnected{}, err); d != "" {
- t.Errorf("Connect(...) mismatch (-want +got):\n%s", d)
- }
- }
- // Listening endpoint remains in listen state.
- if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
-
- ep.Close()
- // Give worker goroutines time to receive the close notification.
- time.Sleep(1 * time.Second)
- if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
- // Accepted endpoint remains open when the listen endpoint is closed.
- if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
- t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
- }
-
-}
-
-// This test verifies that the auto tuning does not grow the receive buffer if
-// the application is not reading the data actively.
-func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
- const mtu = 1500
- const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
-
- c := context.New(t, mtu)
- defer c.Cleanup()
-
- stk := c.Stack()
- // Set lower limits for auto-tuning tests. This is required because the
- // test stops the worker which can cause packets to be dropped because
- // the segment queue holding unprocessed packets is limited to 500.
- const receiveBufferSize = 80 << 10 // 80KB.
- const maxReceiveBufferSize = receiveBufferSize * 10
- {
- opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
- }
- }
-
- // Enable auto-tuning.
- {
- opt := tcpip.TCPModerateReceiveBufferOption(true)
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
- }
- // Change the expected window scale to match the value needed for the
- // maximum buffer size defined above.
- c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
-
- rawEP := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, WS: 4})
-
- // NOTE: The timestamp values in the sent packets are meaningless to the
- // peer so we just increment the timestamp value by 1 every batch as we
- // are not really using them for anything. Send a single byte to verify
- // the advertised window.
- tsVal := rawEP.TSVal + 1
-
- // Introduce a 25ms latency by delaying the first byte.
- latency := 25 * time.Millisecond
- time.Sleep(latency)
- // Send an initial payload with atleast segment overhead size. The receive
- // window would not grow for smaller segments.
- rawEP.SendPacketWithTS(make([]byte, tcp.SegSize), tsVal)
-
- pkt := rawEP.VerifyAndReturnACKWithTS(tsVal)
- rcvWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize()
-
- time.Sleep(25 * time.Millisecond)
-
- // Allocate a large enough payload for the test.
- payloadSize := receiveBufferSize * 2
- b := make([]byte, payloadSize)
-
- worker := (c.EP).(interface {
- StopWork()
- ResumeWork()
- })
- tsVal++
-
- // Stop the worker goroutine.
- worker.StopWork()
- start := 0
- end := payloadSize / 2
- packetsSent := 0
- for ; start < end; start += mss {
- packetEnd := start + mss
- if start+mss > end {
- packetEnd = end
- }
- rawEP.SendPacketWithTS(b[start:packetEnd], tsVal)
- packetsSent++
- }
-
- // Resume the worker so that it only sees the packets once all of them
- // are waiting to be read.
- worker.ResumeWork()
-
- // Since we sent almost the full receive buffer worth of data (some may have
- // been dropped due to segment overheads), we should get a zero window back.
- pkt = c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(pkt).Payload())
- gotRcvWnd := tcpHdr.WindowSize()
- wantAckNum := tcpHdr.AckNumber()
- if got, want := int(gotRcvWnd), 0; got != want {
- t.Fatalf("got rcvWnd: %d, want: %d", got, want)
- }
-
- time.Sleep(25 * time.Millisecond)
- // Verify that sending more data when receiveBuffer is exhausted.
- rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
-
- // Now read all the data from the endpoint and verify that advertised
- // window increases to the full available buffer size.
- for {
- _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- break
- }
- }
-
- // Verify that we receive a non-zero window update ACK. When running
- // under thread santizer this test can end up sending more than 1
- // ack, 1 for the non-zero window
- p := c.GetPacket()
- checker.IPv4(t, p, checker.TCP(
- checker.TCPAckNum(wantAckNum),
- func(t *testing.T, h header.Transport) {
- tcp, ok := h.(header.TCP)
- if !ok {
- return
- }
- // We use 10% here as the error margin upwards as the initial window we
- // got was afer 1 segment was already in the receive buffer queue.
- tolerance := 1.1
- if w := tcp.WindowSize(); w == 0 || w > uint16(float64(rcvWnd)*tolerance) {
- t.Errorf("expected a non-zero window: got %d, want <= %d", w, uint16(float64(rcvWnd)*tolerance))
- }
- },
- ))
-}
-
-// This test verifies that the advertised window is auto-tuned up as the
-// application is reading the data that is being received.
-func TestReceiveBufferAutoTuning(t *testing.T) {
- const mtu = 1500
- const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
-
- c := context.New(t, mtu)
- defer c.Cleanup()
-
- // Enable Auto-tuning.
- stk := c.Stack()
- // Disable out of window rate limiting for this test by setting it to 0 as we
- // use out of window ACKs to measure the advertised window.
- var tcpInvalidRateLimit stack.TCPInvalidRateLimitOption
- if err := stk.SetOption(tcpInvalidRateLimit); err != nil {
- t.Fatalf("e.stack.SetOption(%#v) = %s", tcpInvalidRateLimit, err)
- }
-
- const receiveBufferSize = 80 << 10 // 80KB.
- const maxReceiveBufferSize = receiveBufferSize * 10
- {
- opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
- }
- }
-
- // Enable auto-tuning.
- {
- opt := tcpip.TCPModerateReceiveBufferOption(true)
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
- }
- // Change the expected window scale to match the value needed for the
- // maximum buffer size used by stack.
- c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
-
- rawEP := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, WS: 4})
- tsVal := rawEP.TSVal
- rawEP.NextSeqNum--
- rawEP.SendPacketWithTS(nil, tsVal)
- rawEP.NextSeqNum++
- pkt := rawEP.VerifyAndReturnACKWithTS(tsVal)
- curRcvWnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale
- scaleRcvWnd := func(rcvWnd int) uint16 {
- return uint16(rcvWnd >> c.WindowScale)
- }
- // Allocate a large array to send to the endpoint.
- b := make([]byte, receiveBufferSize*48)
-
- // In every iteration we will send double the number of bytes sent in
- // the previous iteration and read the same from the app. The received
- // window should grow by at least 2x of bytes read by the app in every
- // RTT.
- offset := 0
- payloadSize := receiveBufferSize / 8
- worker := (c.EP).(interface {
- StopWork()
- ResumeWork()
- })
- latency := 1 * time.Millisecond
- for i := 0; i < 5; i++ {
- tsVal++
-
- // Stop the worker goroutine.
- worker.StopWork()
- start := offset
- end := offset + payloadSize
- totalSent := 0
- packetsSent := 0
- for ; start < end; start += mss {
- rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
- totalSent += mss
- packetsSent++
- }
-
- // Resume it so that it only sees the packets once all of them
- // are waiting to be read.
- worker.ResumeWork()
-
- // Give 1ms for the worker to process the packets.
- time.Sleep(1 * time.Millisecond)
-
- lastACK := c.GetPacket()
- // Discard any intermediate ACKs and only check the last ACK we get in a
- // short time period of few ms.
- for {
- time.Sleep(1 * time.Millisecond)
- pkt := c.GetPacketNonBlocking()
- if pkt == nil {
- break
- }
- lastACK = pkt
- }
- if got, want := int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()), int(scaleRcvWnd(curRcvWnd)); got > want {
- t.Fatalf("advertised window got: %d, want <= %d", got, want)
- }
-
- // Now read all the data from the endpoint and invoke the
- // moderation API to allow for receive buffer auto-tuning
- // to happen before we measure the new window.
- totalCopied := 0
- for {
- res, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- break
- }
- totalCopied += res.Count
- }
-
- // Invoke the moderation API. This is required for auto-tuning
- // to happen. This method is normally expected to be invoked
- // from a higher layer than tcpip.Endpoint. So we simulate
- // copying to userspace by invoking it explicitly here.
- c.EP.ModerateRecvBuf(totalCopied)
-
- // Now send a keep-alive packet to trigger an ACK so that we can
- // measure the new window.
- rawEP.NextSeqNum--
- rawEP.SendPacketWithTS(nil, tsVal)
- rawEP.NextSeqNum++
-
- if i == 0 {
- // In the first iteration the receiver based RTT is not
- // yet known as a result the moderation code should not
- // increase the advertised window.
- rawEP.VerifyACKRcvWnd(scaleRcvWnd(curRcvWnd))
- } else {
- // Read loop above could generate an ACK if the window had dropped to
- // zero and then read had opened it up.
- lastACK := c.GetPacket()
- // Discard any intermediate ACKs and only check the last ACK we get in a
- // short time period of few ms.
- for {
- time.Sleep(1 * time.Millisecond)
- pkt := c.GetPacketNonBlocking()
- if pkt == nil {
- break
- }
- lastACK = pkt
- }
- curRcvWnd = int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()) << c.WindowScale
- // If thew new current window is close maxReceiveBufferSize then terminate
- // the loop. This can happen before all iterations are done due to timing
- // differences when running the test.
- if int(float64(curRcvWnd)*1.1) > maxReceiveBufferSize/2 {
- break
- }
- // Increase the latency after first two iterations to
- // establish a low RTT value in the receiver since it
- // only tracks the lowest value. This ensures that when
- // ModerateRcvBuf is called the elapsed time is always >
- // rtt. Without this the test is flaky due to delays due
- // to scheduling/wakeup etc.
- latency += 50 * time.Millisecond
- }
- time.Sleep(latency)
- offset += payloadSize
- payloadSize *= 2
- }
- // Check that at the end of our iterations the receive window grew close to the maximum
- // permissible size of maxReceiveBufferSize/2
- if got, want := int(float64(curRcvWnd)*1.1), maxReceiveBufferSize/2; got < want {
- t.Fatalf("unexpected rcvWnd got: %d, want > %d", got, want)
- }
-
-}
-
-func TestDelayEnabled(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- checkDelayOption(t, c, false, false) // Delay is disabled by default.
-
- for _, delayEnabled := range []bool{false, true} {
- t.Run(fmt.Sprintf("delayEnabled=%t", delayEnabled), func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- opt := tcpip.TCPDelayEnabled(delayEnabled)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, delayEnabled, err)
- }
- checkDelayOption(t, c, opt, delayEnabled)
- })
- }
-}
-
-func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.TCPDelayEnabled, wantDelayOption bool) {
- t.Helper()
-
- var gotDelayEnabled tcpip.TCPDelayEnabled
- if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil {
- t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err)
- }
- if gotDelayEnabled != wantDelayEnabled {
- t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled)
- }
-
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue))
- if err != nil {
- t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %s", err)
- }
- gotDelayOption := ep.SocketOptions().GetDelayOption()
- if gotDelayOption != wantDelayOption {
- t.Errorf("ep.GetSockOptBool(tcpip.DelayOption) got: %t, want: %t", gotDelayOption, wantDelayOption)
- }
-}
-
-func TestTCPLingerTimeout(t *testing.T) {
- c := context.New(t, 1500 /* mtu */)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- testCases := []struct {
- name string
- tcpLingerTimeout time.Duration
- want time.Duration
- }{
- {"NegativeLingerTimeout", -123123, -1},
- // Zero is treated same as the stack's default TCP_LINGER2 timeout.
- {"ZeroLingerTimeout", 0, tcp.DefaultTCPLingerTimeout},
- {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second},
- // Values > stack's TCPLingerTimeout are capped to the stack's
- // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds)
- {"AboveMaxLingerTimeout", tcp.MaxTCPLingerTimeout + 5*time.Second, tcp.MaxTCPLingerTimeout},
- }
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- v := tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)
- if err := c.EP.SetSockOpt(&v); err != nil {
- t.Fatalf("SetSockOpt(&%T(%s)) = %s", v, tc.tcpLingerTimeout, err)
- }
-
- v = 0
- if err := c.EP.GetSockOpt(&v); err != nil {
- t.Fatalf("GetSockOpt(&%T) = %s", v, err)
- }
- if got, want := time.Duration(v), tc.want; got != want {
- t.Fatalf("got linger timeout = %s, want = %s", got, want)
- }
- })
- }
-}
-
-func TestTCPTimeWaitRSTIgnored(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send a SYN request.
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- ackHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 1,
- }
-
- // Send ACK.
- c.SendPacket(nil, ackHeaders)
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- c.EP.Close()
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+1)),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
-
- finHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss + 1,
- AckNum: c.IRS + 2,
- }
-
- c.SendPacket(nil, finHeaders)
-
- // Get the ACK to the FIN we just sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+2)),
- checker.TCPAckNum(uint32(iss)+2),
- checker.TCPFlags(header.TCPFlagAck)))
-
- // Now send a RST and this should be ignored and not
- // generate an ACK.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagRst,
- SeqNum: iss + 1,
- AckNum: c.IRS + 2,
- })
-
- c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second)
-
- // Out of order ACK should generate an immediate ACK in
- // TIME_WAIT.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 3,
- })
-
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+2)),
- checker.TCPAckNum(uint32(iss)+2),
- checker.TCPFlags(header.TCPFlagAck)))
-}
-
-func TestTCPTimeWaitOutOfOrder(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send a SYN request.
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- ackHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 1,
- }
-
- // Send ACK.
- c.SendPacket(nil, ackHeaders)
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- c.EP.Close()
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+1)),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
-
- finHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss + 1,
- AckNum: c.IRS + 2,
- }
-
- c.SendPacket(nil, finHeaders)
-
- // Get the ACK to the FIN we just sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+2)),
- checker.TCPAckNum(uint32(iss)+2),
- checker.TCPFlags(header.TCPFlagAck)))
-
- // Out of order ACK should generate an immediate ACK in
- // TIME_WAIT.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 3,
- })
-
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+2)),
- checker.TCPAckNum(uint32(iss)+2),
- checker.TCPFlags(header.TCPFlagAck)))
-}
-
-func TestTCPTimeWaitNewSyn(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send a SYN request.
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- ackHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 1,
- }
-
- // Send ACK.
- c.SendPacket(nil, ackHeaders)
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- c.EP.Close()
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+1)),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
-
- finHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss + 1,
- AckNum: c.IRS + 2,
- }
-
- c.SendPacket(nil, finHeaders)
-
- // Get the ACK to the FIN we just sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+2)),
- checker.TCPAckNum(uint32(iss)+2),
- checker.TCPFlags(header.TCPFlagAck)))
-
- // Send a SYN request w/ sequence number lower than
- // the highest sequence number sent. We just reuse
- // the same number.
- iss = seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- RcvWnd: 30000,
- })
-
- c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second)
-
- // drain any older notifications from the notification channel before attempting
- // 2nd connection.
- select {
- case <-ch:
- default:
- }
-
- // Send a SYN request w/ sequence number higher than
- // the highest sequence number sent.
- iss = iss.Add(3)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b = c.GetPacket()
- tcpHdr = header.IPv4(b).Payload()
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- ackHeaders = &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 1,
- }
-
- // Send ACK.
- c.SendPacket(nil, ackHeaders)
-
- // Try to accept the connection.
- c.EP, _, err = ep.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-}
-
-func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
- // after 5 seconds in TIME_WAIT state.
- tcpTimeWaitTimeout := 5 * time.Second
- opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err)
- }
-
- want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1
-
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send a SYN request.
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- ackHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 1,
- }
-
- // Send ACK.
- c.SendPacket(nil, ackHeaders)
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- c.EP.Close()
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+1)),
- checker.TCPAckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
-
- finHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss + 1,
- AckNum: c.IRS + 2,
- }
-
- c.SendPacket(nil, finHeaders)
-
- // Get the ACK to the FIN we just sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+2)),
- checker.TCPAckNum(uint32(iss)+2),
- checker.TCPFlags(header.TCPFlagAck)))
-
- time.Sleep(2 * time.Second)
-
- // Now send a duplicate FIN. This should cause the TIME_WAIT to extend
- // by another 5 seconds and also send us a duplicate ACK as it should
- // indicate that the final ACK was potentially lost.
- c.SendPacket(nil, finHeaders)
-
- // Get the ACK to the FIN we just sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+2)),
- checker.TCPAckNum(uint32(iss)+2),
- checker.TCPFlags(header.TCPFlagAck)))
-
- // Sleep for 4 seconds so at this point we are 1 second past the
- // original tcpLingerTimeout of 5 seconds.
- time.Sleep(4 * time.Second)
-
- // Send an ACK and it should not generate any packet as the socket
- // should still be in TIME_WAIT for another another 5 seconds due
- // to the duplicate FIN we sent earlier.
- *ackHeaders = *finHeaders
- ackHeaders.SeqNum = ackHeaders.SeqNum + 1
- ackHeaders.Flags = header.TCPFlagAck
- c.SendPacket(nil, ackHeaders)
-
- c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second)
- // Now sleep for another 2 seconds so that we are past the
- // extended TIME_WAIT of 7 seconds (2 + 5).
- time.Sleep(2 * time.Second)
-
- // Resend the same ACK.
- c.SendPacket(nil, ackHeaders)
-
- // Receive the RST that should be generated as there is no valid
- // endpoint.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(ackHeaders.AckNum)),
- checker.TCPAckNum(0),
- checker.TCPFlags(header.TCPFlagRst)))
-
- if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = %d", got, want)
- }
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
- }
-}
-
-func TestTCPCloseWithData(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
- // after 5 seconds in TIME_WAIT state.
- tcpTimeWaitTimeout := 5 * time.Second
- opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err)
- }
-
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send a SYN request.
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- ackHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 1,
- RcvWnd: 30000,
- }
-
- // Send ACK.
- c.SendPacket(nil, ackHeaders)
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept(nil)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Now trigger a passive close by sending a FIN.
- finHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss + 1,
- AckNum: c.IRS + 2,
- RcvWnd: 30000,
- }
-
- c.SendPacket(nil, finHeaders)
-
- // Get the ACK to the FIN we just sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+1)),
- checker.TCPAckNum(uint32(iss)+2),
- checker.TCPFlags(header.TCPFlagAck)))
-
- // Now write a few bytes and then close the endpoint.
- data := []byte{1, 2, 3}
-
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Check that data is received.
- b = c.GetPacket()
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+2), // Acknum is initial sequence number + 1
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
- t.Errorf("got data = %x, want = %x", p, data)
- }
-
- c.EP.Close()
- // Check the FIN.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+1)+uint32(len(data))),
- checker.TCPAckNum(uint32(iss+2)),
- checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
-
- // First send a partial ACK.
- ackHeaders = &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 2,
- AckNum: c.IRS + 1 + seqnum.Value(len(data)-1),
- RcvWnd: 30000,
- }
- c.SendPacket(nil, ackHeaders)
-
- // Now send a full ACK.
- ackHeaders = &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 2,
- AckNum: c.IRS + 1 + seqnum.Value(len(data)),
- RcvWnd: 30000,
- }
- c.SendPacket(nil, ackHeaders)
-
- // Now ACK the FIN.
- ackHeaders.AckNum++
- c.SendPacket(nil, ackHeaders)
-
- // Now send an ACK and we should get a RST back as the endpoint should
- // be in CLOSED state.
- ackHeaders = &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 2,
- AckNum: c.IRS + 1 + seqnum.Value(len(data)),
- RcvWnd: 30000,
- }
- c.SendPacket(nil, ackHeaders)
-
- // Check the RST.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(ackHeaders.AckNum)),
- checker.TCPAckNum(0),
- checker.TCPFlags(header.TCPFlagRst)))
-}
-
-func TestTCPUserTimeout(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- initRTO := 1 * time.Second
- minRTOOpt := tcpip.TCPMinRTOOption(initRTO)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
- }
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
- defer c.WQ.EventUnregister(&waitEntry)
-
- origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value()
-
- // Ensure that on the next retransmit timer fire, the user timeout has
- // expired.
- userTimeout := initRTO / 2
- v := tcpip.TCPUserTimeoutOption(userTimeout)
- if err := c.EP.SetSockOpt(&v); err != nil {
- t.Fatalf("c.EP.SetSockOpt(&%T(%s): %s", v, userTimeout, err)
- }
-
- // Send some data and wait before ACKing it.
- view := make([]byte, 3)
- var r bytes.Reader
- r.Reset(view)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- next := uint32(c.IRS) + 1
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- // Wait for the retransmit timer to be fired and the user timeout to cause
- // close of the connection.
- select {
- case <-notifyCh:
- case <-time.After(2 * initRTO):
- t.Fatalf("connection still alive after %s, should have been closed after %s", 2*initRTO, userTimeout)
- }
-
- // No packet should be received as the connection should be silently
- // closed due to timeout.
- c.CheckNoPacket("unexpected packet received after userTimeout has expired")
-
- next += uint32(len(view))
-
- // The connection should be terminated after userTimeout has expired.
- // Send an ACK to trigger a RST from the stack as the endpoint should
- // be dead.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(next),
- checker.TCPAckNum(uint32(0)),
- checker.TCPFlags(header.TCPFlagRst),
- ),
- )
-
- ept := endpointTester{c.EP}
- ept.CheckReadError(t, &tcpip.ErrTimeout{})
-
- if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
- }
- if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
- }
-}
-
-func TestKeepaliveWithUserTimeout(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value()
-
- const keepAliveIdle = 100 * time.Millisecond
- const keepAliveInterval = 3 * time.Second
- keepAliveIdleOption := tcpip.KeepaliveIdleOption(keepAliveIdle)
- if err := c.EP.SetSockOpt(&keepAliveIdleOption); err != nil {
- t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOption, keepAliveIdle, err)
- }
- keepAliveIntervalOption := tcpip.KeepaliveIntervalOption(keepAliveInterval)
- if err := c.EP.SetSockOpt(&keepAliveIntervalOption); err != nil {
- t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOption, keepAliveInterval, err)
- }
- if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10); err != nil {
- t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10): %s", err)
- }
- c.EP.SocketOptions().SetKeepAlive(true)
-
- // Set userTimeout to be the duration to be 1 keepalive
- // probes. Which means that after the first probe is sent
- // the second one should cause the connection to be
- // closed due to userTimeout being hit.
- userTimeout := tcpip.TCPUserTimeoutOption(keepAliveInterval)
- if err := c.EP.SetSockOpt(&userTimeout); err != nil {
- t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", userTimeout, keepAliveInterval, err)
- }
-
- // Check that the connection is still alive.
- ept := endpointTester{c.EP}
- ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
-
- // Now receive 1 keepalives, but don't ACK it.
- b := c.GetPacket()
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)),
- checker.TCPAckNum(uint32(iss)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-
- // Sleep for a litte over the KeepAlive interval to make sure
- // the timer has time to fire after the last ACK and close the
- // close the socket.
- time.Sleep(keepAliveInterval + keepAliveInterval/2)
-
- // The connection should be closed with a timeout.
- // Send an ACK to trigger a RST from the stack as the endpoint should
- // be dead.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS + 1,
- RcvWnd: 30000,
- })
-
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS+1)),
- checker.TCPAckNum(uint32(0)),
- checker.TCPFlags(header.TCPFlagRst),
- ),
- )
-
- ept.CheckReadError(t, &tcpip.ErrTimeout{})
- if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
- }
- if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
- }
-}
-
-func TestIncreaseWindowOnRead(t *testing.T) {
- // This test ensures that the endpoint sends an ack,
- // after read() when the window grows by more than 1 MSS.
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- const rcvBuf = 65535 * 10
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBuf)
-
- // Write chunks of ~30000 bytes. It's important that two
- // payloads make it equal or longer than MSS.
- remain := rcvBuf * 2
- sent := 0
- data := make([]byte, defaultMTU/2)
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- for remain > len(data) {
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss.Add(seqnum.Size(sent)),
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
- sent += len(data)
- remain -= len(data)
- pkt := c.GetPacket()
- checker.IPv4(t, pkt,
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+uint32(sent)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- // Break once the window drops below defaultMTU/2
- if wnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize(); wnd < defaultMTU/2 {
- break
- }
- }
-
- // We now have < 1 MSS in the buffer space. Read at least > 2 MSS
- // worth of data as receive buffer space
- w := tcpip.LimitedWriter{
- W: ioutil.Discard,
- // defaultMTU is a good enough estimate for the MSS used for this
- // connection.
- N: defaultMTU * 2,
- }
- for w.N != 0 {
- _, err := c.EP.Read(&w, tcpip.ReadOptions{})
- if err != nil {
- t.Fatalf("Read failed: %s", err)
- }
- }
-
- // After reading > MSS worth of data, we surely crossed MSS. See the ack:
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+uint32(sent)),
- checker.TCPWindow(uint16(0xffff)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestIncreaseWindowOnBufferResize(t *testing.T) {
- // This test ensures that the endpoint sends an ack,
- // after available recv buffer grows to more than 1 MSS.
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- const rcvBuf = 65535 * 10
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, rcvBuf)
-
- // Write chunks of ~30000 bytes. It's important that two
- // payloads make it equal or longer than MSS.
- remain := rcvBuf
- sent := 0
- data := make([]byte, defaultMTU/2)
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- for remain > len(data) {
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss.Add(seqnum.Size(sent)),
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
- sent += len(data)
- remain -= len(data)
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+uint32(sent)),
- checker.TCPWindowLessThanEq(0xffff),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- }
-
- // Increasing the buffer from should generate an ACK,
- // since window grew from small value to larger equal MSS
- c.EP.SocketOptions().SetReceiveBufferSize(rcvBuf*4, true /* notify */)
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+uint32(sent)),
- checker.TCPWindow(uint16(0xffff)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestTCPDeferAccept(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.Create(-1)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(10); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- const tcpDeferAccept = 1 * time.Second
- tcpDeferAcceptOption := tcpip.TCPDeferAcceptOption(tcpDeferAccept)
- if err := c.EP.SetSockOpt(&tcpDeferAcceptOption); err != nil {
- t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", tcpDeferAcceptOption, tcpDeferAccept, err)
- }
-
- irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
-
- _, _, err := c.EP.Accept(nil)
- if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" {
- t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d)
- }
-
- // Send data. This should result in an acceptable endpoint.
- c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- })
-
- // Receive ACK for the data we sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPSeqNum(uint32(iss+1)),
- checker.TCPAckNum(uint32(irs+5))))
-
- // Give a bit of time for the socket to be delivered to the accept queue.
- time.Sleep(50 * time.Millisecond)
- aep, _, err := c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err)
- }
-
- aep.Close()
- // Closing aep without reading the data should trigger a RST.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.TCPSeqNum(uint32(iss+1)),
- checker.TCPAckNum(uint32(irs+5))))
-}
-
-func TestTCPDeferAcceptTimeout(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.Create(-1)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(10); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- const tcpDeferAccept = 1 * time.Second
- tcpDeferAcceptOpt := tcpip.TCPDeferAcceptOption(tcpDeferAccept)
- if err := c.EP.SetSockOpt(&tcpDeferAcceptOpt); err != nil {
- t.Fatalf("c.EP.SetSockOpt(&%T(%s)) failed: %s", tcpDeferAcceptOpt, tcpDeferAccept, err)
- }
-
- irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
-
- _, _, err := c.EP.Accept(nil)
- if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" {
- t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d)
- }
-
- // Sleep for a little of the tcpDeferAccept timeout.
- time.Sleep(tcpDeferAccept + 100*time.Millisecond)
-
- // On timeout expiry we should get a SYN-ACK retransmission.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.TCPAckNum(uint32(irs)+1)))
-
- // Send data. This should result in an acceptable endpoint.
- c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- })
-
- // Receive ACK for the data we sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPSeqNum(uint32(iss+1)),
- checker.TCPAckNum(uint32(irs+5))))
-
- // Give sometime for the endpoint to be delivered to the accept queue.
- time.Sleep(50 * time.Millisecond)
- aep, _, err := c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err)
- }
-
- aep.Close()
- // Closing aep without reading the data should trigger a RST.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.TCPSeqNum(uint32(iss+1)),
- checker.TCPAckNum(uint32(irs+5))))
-}
-
-func TestResetDuringClose(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRecvBuf */)
- // Send some data to make sure there is some unread
- // data to trigger a reset on c.Close.
- irs := c.IRS
- iss := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
- c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss,
- AckNum: irs.Add(1),
- RcvWnd: 30000,
- })
-
- // Receive ACK for the data we sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPSeqNum(uint32(irs.Add(1))),
- checker.TCPAckNum(uint32(iss)+4)))
-
- // Close in a separate goroutine so that we can trigger
- // a race with the RST we send below. This should not
- // panic due to the route being released depeding on
- // whether Close() sends an active RST or the RST sent
- // below is processed by the worker first.
- var wg sync.WaitGroup
-
- wg.Add(1)
- go func() {
- defer wg.Done()
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- SeqNum: iss.Add(4),
- AckNum: c.IRS.Add(5),
- RcvWnd: 30000,
- Flags: header.TCPFlagRst,
- })
- }()
-
- wg.Add(1)
- go func() {
- defer wg.Done()
- c.EP.Close()
- }()
-
- wg.Wait()
-}
-
-func TestStackTimeWaitReuse(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- s := c.Stack()
- var twReuse tcpip.TCPTimeWaitReuseOption
- if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &twReuse, err)
- }
- if got, want := twReuse, tcpip.TCPTimeWaitReuseLoopbackOnly; got != want {
- t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want)
- }
-}
-
-func TestSetStackTimeWaitReuse(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- s := c.Stack()
- testCases := []struct {
- v int
- err tcpip.Error
- }{
- {int(tcpip.TCPTimeWaitReuseDisabled), nil},
- {int(tcpip.TCPTimeWaitReuseGlobal), nil},
- {int(tcpip.TCPTimeWaitReuseLoopbackOnly), nil},
- {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, &tcpip.ErrInvalidOptionValue{}},
- {int(tcpip.TCPTimeWaitReuseDisabled) - 1, &tcpip.ErrInvalidOptionValue{}},
- }
-
- for _, tc := range testCases {
- opt := tcpip.TCPTimeWaitReuseOption(tc.v)
- err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt)
- if got, want := err, tc.err; got != want {
- t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)) = %s, want = %s", tcp.ProtocolNumber, tc.v, tc.v, err, tc.err)
- }
- if tc.err != nil {
- continue
- }
-
- var twReuse tcpip.TCPTimeWaitReuseOption
- if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want nil", tcp.ProtocolNumber, &twReuse, err)
- }
-
- if got, want := twReuse, tcpip.TCPTimeWaitReuseOption(tc.v); got != want {
- t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want)
- }
- }
-}
-
-func TestHandshakeRTT(t *testing.T) {
- type testCase struct {
- connect bool
- tsEnabled bool
- useCookie bool
- retrans bool
- delay time.Duration
- wantRTT time.Duration
- }
- var testCases []testCase
- for _, connect := range []bool{false, true} {
- for _, tsEnabled := range []bool{false, true} {
- for _, useCookie := range []bool{false, true} {
- for _, retrans := range []bool{false, true} {
- if connect && useCookie {
- continue
- }
- delay := 800 * time.Millisecond
- if retrans {
- delay = 1200 * time.Millisecond
- }
- wantRTT := delay
- // If syncookie is enabled, sample RTT only when TS option is enabled.
- if !retrans && useCookie && !tsEnabled {
- wantRTT = 0
- }
- // If retransmitted, sample RTT only when TS option is enabled.
- if retrans && !tsEnabled {
- wantRTT = 0
- }
- testCases = append(testCases, testCase{connect, tsEnabled, useCookie, retrans, delay, wantRTT})
- }
- }
- }
- }
- for _, tt := range testCases {
- tt := tt
- t.Run(fmt.Sprintf("connect=%t,TS=%t,cookie=%t,retrans=%t)", tt.connect, tt.tsEnabled, tt.useCookie, tt.retrans), func(t *testing.T) {
- t.Parallel()
- c := context.New(t, defaultMTU)
- if tt.useCookie {
- opt := tcpip.TCPAlwaysUseSynCookies(true)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
- }
- synOpts := header.TCPSynOptions{}
- if tt.tsEnabled {
- synOpts.TS = true
- synOpts.TSVal = 42
- }
- if tt.connect {
- c.CreateConnectedWithOptions(synOpts, tt.delay)
- } else {
- synOpts.MSS = defaultIPv4MSS
- synOpts.WS = -1
- c.AcceptWithOptions(-1, synOpts, tt.delay)
- }
- var info tcpip.TCPInfoOption
- if err := c.EP.GetSockOpt(&info); err != nil {
- t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err)
- }
- if got := info.RTT.Round(tt.wantRTT); got != tt.wantRTT {
- t.Fatalf("got info.RTT=%s, expect %s", got, tt.wantRTT)
- }
- if info.RTTVar != 0 && tt.wantRTT == 0 {
- t.Fatalf("got info.RTTVar=%s, expect 0", info.RTTVar)
- }
- if info.RTTVar == 0 && tt.wantRTT != 0 {
- t.Fatalf("got info.RTTVar=0, expect non zero")
- }
- })
- }
-}
-
-func TestSetRTO(t *testing.T) {
- c := context.New(t, defaultMTU)
- minRTO, maxRTO := tcpRTOMinMax(t, c)
- for _, tt := range []struct {
- name string
- RTO time.Duration
- minRTO time.Duration
- maxRTO time.Duration
- err tcpip.Error
- }{
- {
- name: "invalid minRTO",
- minRTO: maxRTO + time.Second,
- err: &tcpip.ErrInvalidOptionValue{},
- },
- {
- name: "invalid maxRTO",
- maxRTO: minRTO - time.Millisecond,
- err: &tcpip.ErrInvalidOptionValue{},
- },
- {
- name: "valid minRTO",
- minRTO: maxRTO - time.Second,
- },
- {
- name: "valid maxRTO",
- maxRTO: minRTO + time.Millisecond,
- },
- } {
- t.Run(tt.name, func(t *testing.T) {
- c := context.New(t, defaultMTU)
- var opt tcpip.SettableTransportProtocolOption
- if tt.minRTO > 0 {
- min := tcpip.TCPMinRTOOption(tt.minRTO)
- opt = &min
- }
- if tt.maxRTO > 0 {
- max := tcpip.TCPMaxRTOOption(tt.maxRTO)
- opt = &max
- }
- err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt)
- if got, want := err, tt.err; got != want {
- t.Fatalf("c.Stack().SetTransportProtocolOption(TCP, &%T(%v)) = %v, want = %v", opt, opt, got, want)
- }
- if tt.err == nil {
- minRTO, maxRTO := tcpRTOMinMax(t, c)
- if tt.minRTO > 0 && tt.minRTO != minRTO {
- t.Fatalf("got minRTO = %s, want %s", minRTO, tt.minRTO)
- }
- if tt.maxRTO > 0 && tt.maxRTO != maxRTO {
- t.Fatalf("got maxRTO = %s, want %s", maxRTO, tt.maxRTO)
- }
- }
- })
- }
-}
-
-func tcpRTOMinMax(t *testing.T, c *context.Context) (time.Duration, time.Duration) {
- t.Helper()
- var minOpt tcpip.TCPMinRTOOption
- var maxOpt tcpip.TCPMaxRTOOption
- if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &minOpt); err != nil {
- t.Fatalf("c.Stack().TransportProtocolOption(TCP, %T): %s", minOpt, err)
- }
- if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &maxOpt); err != nil {
- t.Fatalf("c.Stack().TransportProtocolOption(TCP, %T): %s", maxOpt, err)
- }
- return time.Duration(minOpt), time.Duration(maxOpt)
-}
-
-// generateRandomPayload generates a random byte slice of the specified length
-// causing a fatal test failure if it is unable to do so.
-func generateRandomPayload(t *testing.T, n int) []byte {
- t.Helper()
- buf := make([]byte, n)
- if _, err := rand.Read(buf); err != nil {
- t.Fatalf("rand.Read(buf) failed: %s", err)
- }
- return buf
-}
-
-func TestSendBufferTuning(t *testing.T) {
- const maxPayload = 536
- const mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload
- const packetOverheadFactor = 2
-
- testCases := []struct {
- name string
- autoTuningDisabled bool
- }{
- {"autoTuningDisabled", true},
- {"autoTuningEnabled", false},
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- c := context.New(t, mtu)
- defer c.Cleanup()
-
- // Set the stack option for send buffer size.
- const defaultSndBufSz = maxPayload * tcp.InitialCwnd
- const maxSndBufSz = defaultSndBufSz * 10
- {
- opt := tcpip.TCPSendBufferSizeRangeOption{Min: 1, Default: defaultSndBufSz, Max: maxSndBufSz}
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
- }
- }
-
- c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
-
- oldSz := c.EP.SocketOptions().GetSendBufferSize()
- if oldSz != defaultSndBufSz {
- t.Fatalf("Wrong send buffer size got %d want %d", oldSz, defaultSndBufSz)
- }
-
- if tc.autoTuningDisabled {
- c.EP.SocketOptions().SetSendBufferSize(defaultSndBufSz, true /* notify */)
- }
-
- data := make([]byte, maxPayload)
- for i := range data {
- data[i] = byte(i)
- }
-
- w, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&w, waiter.WritableEvents)
- defer c.WQ.EventUnregister(&w)
-
- bytesRead := 0
- for {
- // Packets will be sent till the send buffer
- // size is reached.
- var r bytes.Reader
- r.Reset(data[bytesRead : bytesRead+maxPayload])
- _, err := c.EP.Write(&r, tcpip.WriteOptions{})
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- break
- }
-
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, 0)
- bytesRead += maxPayload
- data = append(data, data...)
- }
-
- // Send an ACK and wait for connection to become writable again.
- c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead)
- select {
- case <-ch:
- if err := c.EP.LastError(); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for connection")
- }
-
- outSz := int64(defaultSndBufSz)
- if !tc.autoTuningDisabled {
- // Calculate the new auto tuned send buffer.
- var info tcpip.TCPInfoOption
- if err := c.EP.GetSockOpt(&info); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
- }
- outSz = int64(info.SndCwnd) * packetOverheadFactor * maxPayload
- }
-
- if newSz := c.EP.SocketOptions().GetSendBufferSize(); newSz != outSz {
- t.Fatalf("Wrong send buffer size, got %d want %d", newSz, outSz)
- }
- })
- }
-}
-
-func TestTimestampSynCookies(t *testing.T) {
- clock := faketime.NewManualClock()
- tsNow := func() uint32 {
- return uint32(clock.NowMonotonic().Sub(tcpip.MonotonicTime{}).Milliseconds())
- }
- // Advance the clock so that NowMonotonic is non-zero.
- clock.Advance(time.Second)
- c := context.NewWithOpts(t, context.Options{
- EnableV4: true,
- EnableV6: true,
- MTU: defaultMTU,
- Clock: clock,
- })
- defer c.Cleanup()
- opt := tcpip.TCPAlwaysUseSynCookies(true)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- defer ep.Close()
-
- tcpOpts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
- header.EncodeTSOption(42, 0, tcpOpts[2:])
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
- iss := seqnum.Value(context.TestInitialSequenceNumber)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- RcvWnd: seqnum.Size(512),
- SeqNum: iss,
- TCPOpts: tcpOpts[:],
- })
- // Get the TSVal of SYN-ACK.
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
- initialTSVal := tcpHdr.ParsedOptions().TSVal
- // derive the tsOffset.
- tsOffset := initialTSVal - tsNow()
-
- header.EncodeTSOption(420, initialTSVal, tcpOpts[2:])
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- RcvWnd: seqnum.Size(512),
- SeqNum: iss + 1,
- AckNum: c.IRS + 1,
- TCPOpts: tcpOpts[:],
- })
- c.EP, _, err = ep.Accept(nil)
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- } else if err != nil {
- t.Fatalf("failed to accept: %s", err)
- }
-
- // Advance the clock again so that we expect the next TSVal to change.
- clock.Advance(time.Second)
- data := []byte{1, 2, 3}
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // The endpoint should have a correct TSOffset so that the received TSVal
- // should match our expectation.
- if got, want := header.TCP(header.IPv4(c.GetPacket()).Payload()).ParsedOptions().TSVal, tsNow()+tsOffset; got != want {
- t.Fatalf("got TSVal = %d, want %d", got, want)
- }
-}
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
deleted file mode 100644
index 65925daa5..000000000
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ /dev/null
@@ -1,311 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp_test
-
-import (
- "bytes"
- "math/rand"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-// createConnectedWithTimestampOption creates and connects c.ep with the
-// timestamp option enabled.
-func createConnectedWithTimestampOption(c *context.Context) *context.RawEndpoint {
- return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, TSVal: 1})
-}
-
-// TestTimeStampEnabledConnect tests that netstack sends the timestamp option on
-// an active connect and sets the TS Echo Reply fields correctly when the
-// SYN-ACK also indicates support for the TS option and provides a TSVal.
-func TestTimeStampEnabledConnect(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- rep := createConnectedWithTimestampOption(c)
-
- // Register for read and validate that we have data to read.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- // The following tests ensure that TS option once enabled behaves
- // correctly as described in
- // https://tools.ietf.org/html/rfc7323#section-4.3.
- //
- // We are not testing delayed ACKs here, but we do test out of order
- // packet delivery and filling the sequence number hole created due to
- // the out of order packet.
- //
- // The test also verifies that the sequence numbers and timestamps are
- // as expected.
- data := []byte{1, 2, 3}
-
- // First we increment tsVal by a small amount.
- tsVal := rep.TSVal + 100
- rep.SendPacketWithTS(data, tsVal)
- rep.VerifyACKWithTS(tsVal)
-
- // Next we send an out of order packet.
- rep.NextSeqNum += 3
- tsVal += 200
- rep.SendPacketWithTS(data, tsVal)
-
- // The ACK should contain the original sequenceNumber and an older TS.
- rep.NextSeqNum -= 6
- rep.VerifyACKWithTS(tsVal - 200)
-
- // Next we fill the hole and the returned ACK should contain the
- // cumulative sequence number acking all data sent till now and have the
- // latest timestamp sent below in its TSEcr field.
- tsVal -= 100
- rep.SendPacketWithTS(data, tsVal)
- rep.NextSeqNum += 3
- rep.VerifyACKWithTS(tsVal)
-
- // Increment tsVal by a large value that doesn't result in a wrap around.
- tsVal += 0x7fffffff
- rep.SendPacketWithTS(data, tsVal)
- rep.VerifyACKWithTS(tsVal)
-
- // Increment tsVal again by a large value which should cause the
- // timestamp value to wrap around. The returned ACK should contain the
- // wrapped around timestamp in its tsEcr field and not the tsVal from
- // the previous packet sent above.
- tsVal += 0x7fffffff
- rep.SendPacketWithTS(data, tsVal)
- rep.VerifyACKWithTS(tsVal)
-
- select {
- case <-ch:
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
- // There should be 5 views to read and each of them should
- // contain the same data.
- for i := 0; i < 5; i++ {
- buf := make([]byte, len(data))
- w := tcpip.SliceWriter(buf)
- result, err := c.EP.Read(&w, tcpip.ReadOptions{})
- if err != nil {
- t.Fatalf("Unexpected error from Read: %v", err)
- }
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: len(buf),
- Total: len(buf),
- }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
- t.Errorf("Read: unexpected result (-want +got):\n%s", diff)
- }
- if got, want := buf, data; bytes.Compare(got, want) != 0 {
- t.Fatalf("Data is different: got: %v, want: %v", got, want)
- }
- }
-}
-
-// TestTimeStampDisabledConnect tests that netstack sends timestamp option on an
-// active connect but if the SYN-ACK doesn't specify the TS option then
-// timestamp option is not enabled and future packets do not contain a
-// timestamp.
-func TestTimeStampDisabledConnect(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{})
-}
-
-func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- if cookieEnabled {
- opt := tcpip.TCPAlwaysUseSynCookies(true)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
- }
-
- t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
- tsVal := rand.Uint32()
- c.AcceptWithOptionsNoDelay(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal})
-
- // Now send some data and validate that timestamp is echoed correctly in the ACK.
- data := []byte{1, 2, 3}
-
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Unexpected error from Write: %s", err)
- }
-
- // Check that data is received and that the timestamp option TSEcr field
- // matches the expected value.
- b := c.GetPacket()
- checker.IPv4(t, b,
- // Add 12 bytes for the timestamp option + 2 NOPs to align at 4
- // byte boundary.
- checker.PayloadLen(len(data)+header.TCPMinimumSize+12),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
- checker.TCPWindow(wndSize),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- checker.TCPTimestampChecker(true, 0, tsVal+1),
- ),
- )
-}
-
-// TestTimeStampEnabledAccept tests that if the SYN on a passive connect
-// specifies the Timestamp option then the Timestamp option is sent on a SYN-ACK
-// and echoes the tsVal field of the original SYN in the tcEcr field of the
-// SYN-ACK. We cover the cases where SYN cookies are enabled/disabled and verify
-// that Timestamp option is enabled in both cases if requested in the original
-// SYN.
-func TestTimeStampEnabledAccept(t *testing.T) {
- testCases := []struct {
- cookieEnabled bool
- wndScale int
- wndSize uint16
- }{
- {true, -1, 0xffff}, // When cookie is used window scaling is disabled.
- // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be 1/2 of that.
- {false, 5, 0x4000},
- }
- for _, tc := range testCases {
- timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize)
- }
-}
-
-func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- if cookieEnabled {
- opt := tcpip.TCPAlwaysUseSynCookies(true)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
- }
-
- t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
- c.AcceptWithOptionsNoDelay(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
-
- // Now send some data with the accepted connection endpoint and validate
- // that no timestamp option is sent in the TCP segment.
- data := []byte{1, 2, 3}
-
- var r bytes.Reader
- r.Reset(data)
- if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Unexpected error from Write: %s", err)
- }
-
- // Check that data is received and that the timestamp option is disabled
- // when SYN cookies are enabled/disabled.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(790),
- checker.TCPWindow(wndSize),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- checker.TCPTimestampChecker(false, 0, 0),
- ),
- )
-}
-
-// TestTimeStampDisabledAccept tests that Timestamp option is not used when the
-// peer doesn't advertise it and connection is established with Accept().
-func TestTimeStampDisabledAccept(t *testing.T) {
- testCases := []struct {
- cookieEnabled bool
- wndScale int
- wndSize uint16
- }{
- {true, -1, 0xffff}, // When cookie is used window scaling is disabled.
- // DefaultReceiveBufferSize is 1MB >> 5. Advertised window will be half of
- // that.
- {false, 5, 0x4000},
- }
- for _, tc := range testCases {
- timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize)
- }
-}
-
-func TestSendGreaterThanMTUWithOptions(t *testing.T) {
- const maxPayload = 100
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- createConnectedWithTimestampOption(c)
- testBrokenUpWrite(t, c, maxPayload)
-}
-
-func TestSegmentNotDroppedWhenTimestampMissing(t *testing.T) {
- const maxPayload = 100
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- rep := createConnectedWithTimestampOption(c)
-
- // Register for read.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- droppedPacketsStat := c.Stack().Stats().DroppedPackets
- droppedPackets := droppedPacketsStat.Value()
- data := []byte{1, 2, 3}
- // Send a packet with no TCP options/timestamp.
- rep.SendPacket(data, nil)
-
- select {
- case <-ch:
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
- // Assert that DroppedPackets was not incremented.
- if got, want := droppedPacketsStat.Value(), droppedPackets; got != want {
- t.Fatalf("incorrect number of dropped packets, got: %v, want: %v", got, want)
- }
-
- // Issue a read and we should data.
- var buf bytes.Buffer
- result, err := c.EP.Read(&buf, tcpip.ReadOptions{})
- if err != nil {
- t.Fatalf("Unexpected error from Read: %v", err)
- }
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: buf.Len(),
- Total: buf.Len(),
- }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
- t.Errorf("Read: unexpected result (-want +got):\n%s", diff)
- }
- if got, want := buf.Bytes(), data; bytes.Compare(got, want) != 0 {
- t.Fatalf("Data is different: got: %v, want: %v", got, want)
- }
-}
diff --git a/pkg/tcpip/transport/tcp/tcp_unsafe_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_unsafe_state_autogen.go
new file mode 100644
index 000000000..4cb82fcc9
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_unsafe_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package tcp
diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD
deleted file mode 100644
index ce6a2c31d..000000000
--- a/pkg/tcpip/transport/tcp/testing/context/BUILD
+++ /dev/null
@@ -1,26 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "context",
- testonly = 1,
- srcs = ["context.go"],
- visibility = [
- "//visibility:public",
- ],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/sniffer",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/seqnum",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/tcp",
- "//pkg/waiter",
- ],
-)
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
deleted file mode 100644
index 88bb99354..000000000
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ /dev/null
@@ -1,1268 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package context provides a test context for use in tcp tests. It also
-// provides helper methods to assert/check certain behaviours.
-package context
-
-import (
- "bytes"
- "context"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- // StackAddr is the IPv4 address assigned to the stack.
- StackAddr = "\x0a\x00\x00\x01"
-
- // StackPort is used as the listening port in tests for passive
- // connects.
- StackPort = 1234
-
- // TestAddr is the source address for packets sent to the stack via the
- // link layer endpoint.
- TestAddr = "\x0a\x00\x00\x02"
-
- // TestPort is the TCP port used for packets sent to the stack
- // via the link layer endpoint.
- TestPort = 4096
-
- // StackV6Addr is the IPv6 address assigned to the stack.
- StackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
-
- // TestV6Addr is the source address for packets sent to the stack via
- // the link layer endpoint.
- TestV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
-
- // StackV4MappedAddr is StackAddr as a mapped v6 address.
- StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr
-
- // TestV4MappedAddr is TestAddr as a mapped v6 address.
- TestV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + TestAddr
-
- // V4MappedWildcardAddr is the mapped v6 representation of 0.0.0.0.
- V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
-
- // TestInitialSequenceNumber is the initial sequence number sent in packets that
- // are sent in response to a SYN or in the initial SYN sent to the stack.
- TestInitialSequenceNumber = 789
-)
-
-// StackAddrWithPrefix is StackAddr with its associated prefix length.
-var StackAddrWithPrefix = tcpip.AddressWithPrefix{
- Address: StackAddr,
- PrefixLen: 24,
-}
-
-// StackV6AddrWithPrefix is StackV6Addr with its associated prefix length.
-var StackV6AddrWithPrefix = tcpip.AddressWithPrefix{
- Address: StackV6Addr,
- PrefixLen: header.IIDOffsetInIPv6Address * 8,
-}
-
-// Headers is used to represent the TCP header fields when building a
-// new packet.
-type Headers struct {
- // SrcPort holds the src port value to be used in the packet.
- SrcPort uint16
-
- // DstPort holds the destination port value to be used in the packet.
- DstPort uint16
-
- // SeqNum is the value of the sequence number field in the TCP header.
- SeqNum seqnum.Value
-
- // AckNum represents the acknowledgement number field in the TCP header.
- AckNum seqnum.Value
-
- // Flags are the TCP flags in the TCP header.
- Flags header.TCPFlags
-
- // RcvWnd is the window to be advertised in the ReceiveWindow field of
- // the TCP header.
- RcvWnd seqnum.Size
-
- // TCPOpts holds the options to be sent in the option field of the TCP
- // header.
- TCPOpts []byte
-}
-
-// Options contains options for creating a new test context.
-type Options struct {
- // EnableV4 indicates whether IPv4 should be enabled.
- EnableV4 bool
-
- // EnableV6 indicates whether IPv4 should be enabled.
- EnableV6 bool
-
- // MTU indicates the maximum transmission unit on the link layer.
- MTU uint32
-
- // Clock that is used by Stack.
- Clock tcpip.Clock
-}
-
-// Context provides an initialized Network stack and a link layer endpoint
-// for use in TCP tests.
-type Context struct {
- t *testing.T
- linkEP *channel.Endpoint
- s *stack.Stack
-
- // IRS holds the initial sequence number in the SYN sent by endpoint in
- // case of an active connect or the sequence number sent by the endpoint
- // in the SYN-ACK sent in response to a SYN when listening in passive
- // mode.
- IRS seqnum.Value
-
- // Port holds the port bound by EP below in case of an active connect or
- // the listening port number in case of a passive connect.
- Port uint16
-
- // EP is the test endpoint in the stack owned by this context. This endpoint
- // is used in various tests to either initiate an active connect or is used
- // as a passive listening endpoint to accept inbound connections.
- EP tcpip.Endpoint
-
- // Wq is the wait queue associated with EP and is used to block for events
- // on EP.
- WQ waiter.Queue
-
- // TimeStampEnabled is true if ep is connected with the timestamp option
- // enabled.
- TimeStampEnabled bool
-
- // WindowScale is the expected window scale in SYN packets sent by
- // the stack.
- WindowScale uint8
-
- // RcvdWindowScale is the actual window scale sent by the stack in
- // SYN/SYN-ACK.
- RcvdWindowScale uint8
-}
-
-// New allocates and initializes a test context containing a new
-// stack and a link-layer endpoint.
-func New(t *testing.T, mtu uint32) *Context {
- return NewWithOpts(t, Options{
- EnableV4: true,
- EnableV6: true,
- MTU: mtu,
- })
-}
-
-// NewWithOpts allocates and initializes a test context containing a new
-// stack and a link-layer endpoint with specific options.
-func NewWithOpts(t *testing.T, opts Options) *Context {
- if opts.MTU == 0 {
- panic("MTU must be greater than 0")
- }
-
- stackOpts := stack.Options{
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
- Clock: opts.Clock,
- }
- if opts.EnableV4 {
- stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv4.NewProtocol)
- }
- if opts.EnableV6 {
- stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv6.NewProtocol)
- }
- s := stack.New(stackOpts)
-
- const sendBufferSize = 1 << 20 // 1 MiB
- const recvBufferSize = 1 << 20 // 1 MiB
- // Allow minimum send/receive buffer sizes to be 1 during tests.
- sendBufOpt := tcpip.TCPSendBufferSizeRangeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sendBufOpt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, sendBufOpt, err)
- }
-
- rcvBufOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvBufOpt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, rcvBufOpt, err)
- }
-
- // Increase minimum RTO in tests to avoid test flakes due to early
- // retransmit in case the test executors are overloaded and cause timers
- // to fire earlier than expected.
- minRTOOpt := tcpip.TCPMinRTOOption(3 * time.Second)
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
- t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
- }
-
- // Some of the congestion control tests send up to 640 packets, we so
- // set the channel size to 1000.
- ep := channel.New(1000, opts.MTU, "")
- wep := stack.LinkEndpoint(ep)
- if testing.Verbose() {
- wep = sniffer.New(ep)
- }
- nicOpts := stack.NICOptions{Name: "nic1"}
- if err := s.CreateNICWithOptions(1, wep, nicOpts); err != nil {
- t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
- }
- wep2 := stack.LinkEndpoint(channel.New(1000, opts.MTU, ""))
- if testing.Verbose() {
- wep2 = sniffer.New(channel.New(1000, opts.MTU, ""))
- }
- opts2 := stack.NICOptions{Name: "nic2"}
- if err := s.CreateNICWithOptions(2, wep2, opts2); err != nil {
- t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err)
- }
-
- var routeTable []tcpip.Route
-
- if opts.EnableV4 {
- v4ProtocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: StackAddrWithPrefix,
- }
- if err := s.AddProtocolAddress(1, v4ProtocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", v4ProtocolAddr, err)
- }
- routeTable = append(routeTable, tcpip.Route{
- Destination: header.IPv4EmptySubnet,
- NIC: 1,
- })
- }
-
- if opts.EnableV6 {
- v6ProtocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: StackV6AddrWithPrefix,
- }
- if err := s.AddProtocolAddress(1, v6ProtocolAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", v6ProtocolAddr, err)
- }
- routeTable = append(routeTable, tcpip.Route{
- Destination: header.IPv6EmptySubnet,
- NIC: 1,
- })
- }
-
- s.SetRouteTable(routeTable)
-
- return &Context{
- t: t,
- s: s,
- linkEP: ep,
- WindowScale: uint8(tcp.FindWndScale(recvBufferSize)),
- }
-}
-
-// Cleanup closes the context endpoint if required.
-func (c *Context) Cleanup() {
- if c.EP != nil {
- c.EP.Close()
- }
- c.Stack().Close()
-}
-
-// Stack returns a reference to the stack in the Context.
-func (c *Context) Stack() *stack.Stack {
- return c.s
-}
-
-// CheckNoPacketTimeout verifies that no packet is received during the time
-// specified by wait.
-func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) {
- c.t.Helper()
-
- ctx, cancel := context.WithTimeout(context.Background(), wait)
- defer cancel()
- if _, ok := c.linkEP.ReadContext(ctx); ok {
- c.t.Fatal(errMsg)
- }
-}
-
-// CheckNoPacket verifies that no packet is received for 1 second.
-func (c *Context) CheckNoPacket(errMsg string) {
- c.CheckNoPacketTimeout(errMsg, 1*time.Second)
-}
-
-// GetPacketWithTimeout reads a packet from the link layer endpoint and verifies
-// that it is an IPv4 packet with the expected source and destination
-// addresses. If no packet is received in the specified timeout it will return
-// nil.
-func (c *Context) GetPacketWithTimeout(timeout time.Duration) []byte {
- c.t.Helper()
-
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- p, ok := c.linkEP.ReadContext(ctx)
- if !ok {
- return nil
- }
-
- if p.Proto != ipv4.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
- }
-
- // Just check that the stack set the transport protocol number for outbound
- // TCP messages.
- // TODO(gvisor.dev/issues/3810): Remove when protocol numbers are part
- // of the headerinfo.
- if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber {
- c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber)
- }
-
- vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
- b := vv.ToView()
-
- if p.Pkt.GSOOptions.Type != stack.GSONone && p.Pkt.GSOOptions.L3HdrLen != header.IPv4MinimumSize {
- c.t.Errorf("got L3HdrLen = %d, want = %d", p.Pkt.GSOOptions.L3HdrLen, header.IPv4MinimumSize)
- }
-
- checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
- return b
-}
-
-// GetPacket reads a packet from the link layer endpoint and verifies
-// that it is an IPv4 packet with the expected source and destination
-// addresses.
-func (c *Context) GetPacket() []byte {
- c.t.Helper()
-
- p := c.GetPacketWithTimeout(5 * time.Second)
- if p == nil {
- c.t.Fatalf("Packet wasn't written out")
- return nil
- }
-
- return p
-}
-
-// GetPacketNonBlocking reads a packet from the link layer endpoint
-// and verifies that it is an IPv4 packet with the expected source
-// and destination address. If no packet is available it will return
-// nil immediately.
-func (c *Context) GetPacketNonBlocking() []byte {
- c.t.Helper()
-
- p, ok := c.linkEP.Read()
- if !ok {
- return nil
- }
-
- if p.Proto != ipv4.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
- }
-
- // Just check that the stack set the transport protocol number for outbound
- // TCP messages.
- // TODO(gvisor.dev/issues/3810): Remove when protocol numbers are part
- // of the headerinfo.
- if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber {
- c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber)
- }
-
- vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
- b := vv.ToView()
-
- checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
- return b
-}
-
-// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint.
-func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code header.ICMPv4Code, p1, p2 []byte, maxTotalSize int) {
- // Allocate a buffer data and headers.
- buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p2))
- if len(buf) > maxTotalSize {
- buf = buf[:maxTotalSize]
- }
-
- ip := header.IPv4(buf)
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(len(buf)),
- TTL: 65,
- Protocol: uint8(header.ICMPv4ProtocolNumber),
- SrcAddr: TestAddr,
- DstAddr: StackAddr,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- icmp := header.ICMPv4(buf[header.IPv4MinimumSize:])
- icmp.SetType(typ)
- icmp.SetCode(code)
- const icmpv4VariableHeaderOffset = 4
- copy(icmp[icmpv4VariableHeaderOffset:], p1)
- copy(icmp[header.ICMPv4PayloadOffset:], p2)
- icmp.SetChecksum(0)
- checksum := ^header.Checksum(icmp, 0 /* initial */)
- icmp.SetChecksum(checksum)
-
- // Inject packet.
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- })
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
-}
-
-// BuildSegment builds a TCP segment based on the given Headers and payload.
-func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView {
- return c.BuildSegmentWithAddrs(payload, h, TestAddr, StackAddr)
-}
-
-// BuildSegmentWithAddrs builds a TCP segment based on the given Headers,
-// payload and source and destination IPv4 addresses.
-func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) buffer.VectorisedView {
- // Allocate a buffer for data and headers.
- buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload))
- copy(buf[len(buf)-len(payload):], payload)
- copy(buf[len(buf)-len(payload)-len(h.TCPOpts):], h.TCPOpts)
-
- // Initialize the IP header.
- ip := header.IPv4(buf)
- ip.Encode(&header.IPv4Fields{
- TotalLength: uint16(len(buf)),
- TTL: 65,
- Protocol: uint8(tcp.ProtocolNumber),
- SrcAddr: src,
- DstAddr: dst,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- // Initialize the TCP header.
- t := header.TCP(buf[header.IPv4MinimumSize:])
- t.Encode(&header.TCPFields{
- SrcPort: h.SrcPort,
- DstPort: h.DstPort,
- SeqNum: uint32(h.SeqNum),
- AckNum: uint32(h.AckNum),
- DataOffset: uint8(header.TCPMinimumSize + len(h.TCPOpts)),
- Flags: h.Flags,
- WindowSize: uint16(h.RcvWnd),
- })
-
- // Calculate the TCP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t)))
-
- // Calculate the TCP checksum and set it.
- xsum = header.Checksum(payload, xsum)
- t.SetChecksum(^t.CalculateChecksum(xsum))
-
- // Inject packet.
- return buf.ToVectorisedView()
-}
-
-// SendSegment sends a TCP segment that has already been built and written to a
-// buffer.VectorisedView.
-func (c *Context) SendSegment(s buffer.VectorisedView) {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: s,
- })
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
-}
-
-// SendPacket builds and sends a TCP segment(with the provided payload & TCP
-// headers) in an IPv4 packet via the link layer endpoint.
-func (c *Context) SendPacket(payload []byte, h *Headers) {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: c.BuildSegment(payload, h),
- })
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
-}
-
-// SendPacketWithAddrs builds and sends a TCP segment(with the provided payload
-// & TCPheaders) in an IPv4 packet via the link layer endpoint using the
-// provided source and destination IPv4 addresses.
-func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: c.BuildSegmentWithAddrs(payload, h, src, dst),
- })
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
-}
-
-// SendAck sends an ACK packet.
-func (c *Context) SendAck(seq seqnum.Value, bytesReceived int) {
- c.SendAckWithSACK(seq, bytesReceived, nil)
-}
-
-// SendAckWithSACK sends an ACK packet which includes the sackBlocks specified.
-func (c *Context) SendAckWithSACK(seq seqnum.Value, bytesReceived int, sackBlocks []header.SACKBlock) {
- options := make([]byte, 40)
- offset := 0
- if len(sackBlocks) > 0 {
- offset += header.EncodeNOP(options[offset:])
- offset += header.EncodeNOP(options[offset:])
- offset += header.EncodeSACKBlocks(sackBlocks, options[offset:])
- }
-
- c.SendPacket(nil, &Headers{
- SrcPort: TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: seq,
- AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)),
- RcvWnd: 30000,
- TCPOpts: options[:offset],
- })
-}
-
-// ReceiveAndCheckPacket reads a packet from the link layer endpoint and
-// verifies that the packet packet payload of packet matches the slice
-// of data indicated by offset & size.
-func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) {
- c.t.Helper()
-
- c.ReceiveAndCheckPacketWithOptions(data, offset, size, 0)
-}
-
-// ReceiveAndCheckPacketWithOptions reads a packet from the link layer endpoint
-// and verifies that the packet packet payload of packet matches the slice of
-// data indicated by offset & size and skips optlen bytes in addition to the IP
-// TCP headers when comparing the data.
-func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, optlen int) {
- c.t.Helper()
-
- b := c.GetPacket()
- checker.IPv4(c.t, b,
- checker.PayloadLen(size+header.TCPMinimumSize+optlen),
- checker.TCP(
- checker.DstPort(TestPort),
- checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
- checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- pdata := data[offset:][:size]
- if p := b[header.IPv4MinimumSize+header.TCPMinimumSize+optlen:]; bytes.Compare(pdata, p) != 0 {
- c.t.Fatalf("Data is different: expected %v, got %v", pdata, p)
- }
-}
-
-// ReceiveNonBlockingAndCheckPacket reads a packet from the link layer endpoint
-// and verifies that the packet packet payload of packet matches the slice of
-// data indicated by offset & size. It returns true if a packet was received and
-// processed.
-func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int) bool {
- c.t.Helper()
-
- b := c.GetPacketNonBlocking()
- if b == nil {
- return false
- }
- checker.IPv4(c.t, b,
- checker.PayloadLen(size+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(TestPort),
- checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
- checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
- ),
- )
-
- pdata := data[offset:][:size]
- if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(pdata, p) != 0 {
- c.t.Fatalf("Data is different: expected %v, got %v", pdata, p)
- }
- return true
-}
-
-// CreateV6Endpoint creates and initializes c.ep as a IPv6 Endpoint. If v6Only
-// is true then it sets the IP_V6ONLY option on the socket to make it a IPv6
-// only endpoint instead of a default dual stack socket.
-func (c *Context) CreateV6Endpoint(v6only bool) {
- var err tcpip.Error
- c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- c.EP.SocketOptions().SetV6Only(v6only)
-}
-
-// GetV6Packet reads a single packet from the link layer endpoint of the context
-// and asserts that it is an IPv6 Packet with the expected src/dest addresses.
-func (c *Context) GetV6Packet() []byte {
- c.t.Helper()
-
- ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
- defer cancel()
- p, ok := c.linkEP.ReadContext(ctx)
- if !ok {
- c.t.Fatalf("Packet wasn't written out")
- return nil
- }
-
- if p.Proto != ipv6.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
- }
- vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
- b := vv.ToView()
-
- checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr))
- return b
-}
-
-// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of
-// the context.
-func (c *Context) SendV6Packet(payload []byte, h *Headers) {
- c.SendV6PacketWithAddrs(payload, h, TestV6Addr, StackV6Addr)
-}
-
-// SendV6PacketWithAddrs builds and sends an IPv6 Packet via the link layer
-// endpoint of the context using the provided source and destination IPv6
-// addresses.
-func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
- // Allocate a buffer for data and headers.
- buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload))
- copy(buf[len(buf)-len(payload):], payload)
-
- // Initialize the IP header.
- ip := header.IPv6(buf)
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(header.TCPMinimumSize + len(payload)),
- TransportProtocol: tcp.ProtocolNumber,
- HopLimit: 65,
- SrcAddr: src,
- DstAddr: dst,
- })
-
- // Initialize the TCP header.
- t := header.TCP(buf[header.IPv6MinimumSize:])
- t.Encode(&header.TCPFields{
- SrcPort: h.SrcPort,
- DstPort: h.DstPort,
- SeqNum: uint32(h.SeqNum),
- AckNum: uint32(h.AckNum),
- DataOffset: header.TCPMinimumSize,
- Flags: h.Flags,
- WindowSize: uint16(h.RcvWnd),
- })
-
- // Calculate the TCP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t)))
-
- // Calculate the TCP checksum and set it.
- xsum = header.Checksum(payload, xsum)
- t.SetChecksum(^t.CalculateChecksum(xsum))
-
- // Inject packet.
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- })
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, pkt)
-}
-
-// CreateConnected creates a connected TCP endpoint.
-func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int) {
- c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil)
-}
-
-// Connect performs the 3-way handshake for c.EP with the provided Initial
-// Sequence Number (iss) and receive window(rcvWnd) and any options if
-// specified.
-//
-// It also sets the receive buffer for the endpoint to the specified
-// value in epRcvBuf.
-//
-// PreCondition: c.EP must already be created.
-func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) {
- c.t.Helper()
-
- // Start connection attempt.
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&waitEntry, waiter.WritableEvents)
- defer c.WQ.EventUnregister(&waitEntry)
-
- err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort})
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- c.t.Fatalf("Unexpected return value from Connect: %v", err)
- }
-
- // Receive SYN packet.
- b := c.GetPacket()
- checker.IPv4(c.t, b,
- checker.TCP(
- checker.DstPort(TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ),
- )
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
- c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- synOpts := header.ParseSynOptions(tcpHdr.Options(), false /* isAck */)
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- c.SendPacket(nil, &Headers{
- SrcPort: tcpHdr.DestinationPort(),
- DstPort: tcpHdr.SourcePort(),
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: rcvWnd,
- TCPOpts: options,
- })
-
- // Receive ACK packet.
- checker.IPv4(c.t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPSeqNum(uint32(c.IRS)+1),
- checker.TCPAckNum(uint32(iss)+1),
- ),
- )
-
- // Wait for connection to be established.
- select {
- case <-notifyCh:
- if err := c.EP.LastError(); err != nil {
- c.t.Fatalf("Unexpected error when connecting: %v", err)
- }
- case <-time.After(1 * time.Second):
- c.t.Fatalf("Timed out waiting for connection")
- }
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want {
- c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- c.RcvdWindowScale = uint8(synOpts.WS)
- c.Port = tcpHdr.SourcePort()
-}
-
-// Create creates a TCP endpoint.
-func (c *Context) Create(epRcvBuf int) {
- // Create TCP endpoint.
- var err tcpip.Error
- c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if epRcvBuf != -1 {
- c.EP.SocketOptions().SetReceiveBufferSize(int64(epRcvBuf)*2, true /* notify */)
- }
-}
-
-// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
-// the specified option bytes as the Option field in the initial SYN packet.
-//
-// It also sets the receive buffer for the endpoint to the specified
-// value in epRcvBuf.
-func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) {
- c.Create(epRcvBuf)
- c.Connect(iss, rcvWnd, options)
-}
-
-// RawEndpoint is just a small wrapper around a TCP endpoint's state to make
-// sending data and ACK packets easy while being able to manipulate the sequence
-// numbers and timestamp values as needed.
-type RawEndpoint struct {
- C *Context
- SrcPort uint16
- DstPort uint16
- Flags header.TCPFlags
- NextSeqNum seqnum.Value
- AckNum seqnum.Value
- WndSize seqnum.Size
- RecentTS uint32 // Stores the latest timestamp to echo back.
- TSVal uint32 // TSVal stores the last timestamp sent by this endpoint.
-
- // SackPermitted is true if SACKPermitted option was negotiated for this endpoint.
- SACKPermitted bool
-}
-
-// SendPacketWithTS embeds the provided tsVal in the Timestamp option
-// for the packet to be sent out.
-func (r *RawEndpoint) SendPacketWithTS(payload []byte, tsVal uint32) {
- r.TSVal = tsVal
- tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
- header.EncodeTSOption(r.TSVal, r.RecentTS, tsOpt[2:])
- r.SendPacket(payload, tsOpt[:])
-}
-
-// SendPacket is a small wrapper function to build and send packets.
-func (r *RawEndpoint) SendPacket(payload []byte, opts []byte) {
- packetHeaders := &Headers{
- SrcPort: r.SrcPort,
- DstPort: r.DstPort,
- Flags: r.Flags,
- SeqNum: r.NextSeqNum,
- AckNum: r.AckNum,
- RcvWnd: r.WndSize,
- TCPOpts: opts,
- }
- r.C.SendPacket(payload, packetHeaders)
- r.NextSeqNum = r.NextSeqNum.Add(seqnum.Size(len(payload)))
-}
-
-// VerifyAndReturnACKWithTS verifies that the tsEcr field int he ACK matches
-// the provided tsVal as well as returns the original packet.
-func (r *RawEndpoint) VerifyAndReturnACKWithTS(tsVal uint32) []byte {
- r.C.t.Helper()
- // Read ACK and verify that tsEcr of ACK packet is [1,2,3,4]
- ackPacket := r.C.GetPacket()
- checker.IPv4(r.C.t, ackPacket,
- checker.TCP(
- checker.DstPort(r.SrcPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPSeqNum(uint32(r.AckNum)),
- checker.TCPAckNum(uint32(r.NextSeqNum)),
- checker.TCPTimestampChecker(true, 0, tsVal),
- ),
- )
- // Store the parsed TSVal from the ack as recentTS.
- tcpSeg := header.TCP(header.IPv4(ackPacket).Payload())
- opts := tcpSeg.ParsedOptions()
- r.RecentTS = opts.TSVal
- return ackPacket
-}
-
-// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided
-// tsVal.
-func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) {
- r.C.t.Helper()
- _ = r.VerifyAndReturnACKWithTS(tsVal)
-}
-
-// VerifyACKRcvWnd verifies that the window advertised by the incoming ACK
-// matches the provided rcvWnd.
-func (r *RawEndpoint) VerifyACKRcvWnd(rcvWnd uint16) {
- r.C.t.Helper()
- ackPacket := r.C.GetPacket()
- checker.IPv4(r.C.t, ackPacket,
- checker.TCP(
- checker.DstPort(r.SrcPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPSeqNum(uint32(r.AckNum)),
- checker.TCPAckNum(uint32(r.NextSeqNum)),
- checker.TCPWindow(rcvWnd),
- ),
- )
-}
-
-// VerifyACKNoSACK verifies that the ACK does not contain a SACK block.
-func (r *RawEndpoint) VerifyACKNoSACK() {
- r.VerifyACKHasSACK(nil)
-}
-
-// VerifyACKHasSACK verifies that the ACK contains the specified SACKBlocks.
-func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) {
- // Read ACK and verify that the TCP options in the segment do
- // not contain a SACK block.
- ackPacket := r.C.GetPacket()
- checker.IPv4(r.C.t, ackPacket,
- checker.TCP(
- checker.DstPort(r.SrcPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPSeqNum(uint32(r.AckNum)),
- checker.TCPAckNum(uint32(r.NextSeqNum)),
- checker.TCPSACKBlockChecker(sackBlocks),
- ),
- )
-}
-
-// CreateConnectedWithOptionsNoDelay just calls CreateConnectedWithOptions
-// without delay.
-func (c *Context) CreateConnectedWithOptionsNoDelay(wantOptions header.TCPSynOptions) *RawEndpoint {
- return c.CreateConnectedWithOptions(wantOptions, 0 /* delay */)
-}
-
-// CreateConnectedWithOptions creates and connects c.ep with the specified TCP
-// options enabled and returns a RawEndpoint which represents the other end of
-// the connection. It delays before a SYNACK is sent. This makes c.EP have a
-// higher RTT estimate so that spurious TLPs aren't sent in tests, which helps
-// reduce flakiness.
-//
-// It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK
-// does not carry an option that was not requested.
-func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint {
- var err tcpip.Error
- c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err)
- }
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want {
- c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- // Start connection attempt.
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&waitEntry, waiter.WritableEvents)
- defer c.WQ.EventUnregister(&waitEntry)
-
- testFullAddr := tcpip.FullAddress{Addr: TestAddr, Port: TestPort}
- err = c.EP.Connect(testFullAddr)
- if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
- c.t.Fatalf("c.ep.Connect(%v) = %v", testFullAddr, err)
- }
- // Receive SYN packet.
- b := c.GetPacket()
- // Validate that the syn has the timestamp option and a valid
- // TS value.
- mss := uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize)
-
- synChecker := checker.TCP(
- checker.DstPort(TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- checker.TCPSynOptions(header.TCPSynOptions{
- MSS: mss,
- TS: true,
- WS: int(c.WindowScale),
- SACKPermitted: c.SACKEnabled(),
- }),
- )
- checker.IPv4(c.t, b, synChecker)
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
- c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- tcpSeg := header.TCP(header.IPv4(b).Payload())
- synOptions := header.ParseSynOptions(tcpSeg.Options(), false)
-
- // Build options w/ tsVal to be sent in the SYN-ACK.
- synAckOptions := make([]byte, header.TCPOptionsMaximumSize)
- offset := 0
- if wantOptions.WS != -1 {
- offset += header.EncodeWSOption(wantOptions.WS, synAckOptions[offset:])
- }
- if wantOptions.TS {
- offset += header.EncodeTSOption(wantOptions.TSVal, synOptions.TSVal, synAckOptions[offset:])
- }
- if wantOptions.SACKPermitted {
- offset += header.EncodeSACKPermittedOption(synAckOptions[offset:])
- }
-
- offset += header.AddTCPOptionPadding(synAckOptions, offset)
-
- // Build SYN-ACK.
- c.IRS = seqnum.Value(tcpSeg.SequenceNumber())
- iss := seqnum.Value(TestInitialSequenceNumber)
- if delay > 0 {
- // Sleep so that RTT is increased.
- time.Sleep(delay)
- }
- c.SendPacket(nil, &Headers{
- SrcPort: tcpSeg.DestinationPort(),
- DstPort: tcpSeg.SourcePort(),
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- TCPOpts: synAckOptions[:offset],
- })
-
- // Read ACK.
- var ackPacket []byte
- // Ignore retransimitted SYN packets.
- for {
- packet := c.GetPacket()
- if header.TCP(header.IPv4(packet).Payload()).Flags()&header.TCPFlagSyn != 0 {
- checker.IPv4(c.t, packet, synChecker)
- } else {
- ackPacket = packet
- break
- }
- }
-
- // Verify TCP header fields.
- tcpCheckers := []checker.TransportChecker{
- checker.DstPort(TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.TCPSeqNum(uint32(c.IRS) + 1),
- checker.TCPAckNum(uint32(iss) + 1),
- }
-
- // Verify that tsEcr of ACK packet is wantOptions.TSVal if the
- // timestamp option was enabled, if not then we verify that
- // there is no timestamp in the ACK packet.
- if wantOptions.TS {
- tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(true, 0, wantOptions.TSVal))
- } else {
- tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0))
- }
-
- checker.IPv4(c.t, ackPacket, checker.TCP(tcpCheckers...))
-
- ackSeg := header.TCP(header.IPv4(ackPacket).Payload())
- ackOptions := ackSeg.ParsedOptions()
-
- // Wait for connection to be established.
- select {
- case <-notifyCh:
- if err := c.EP.LastError(); err != nil {
- c.t.Fatalf("Unexpected error when connecting: %v", err)
- }
- case <-time.After(1 * time.Second):
- c.t.Fatalf("Timed out waiting for connection")
- }
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want {
- c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- // Store the source port in use by the endpoint.
- c.Port = tcpSeg.SourcePort()
-
- // Mark in context that timestamp option is enabled for this endpoint.
- c.TimeStampEnabled = true
- c.RcvdWindowScale = uint8(synOptions.WS)
- return &RawEndpoint{
- C: c,
- SrcPort: tcpSeg.DestinationPort(),
- DstPort: tcpSeg.SourcePort(),
- Flags: header.TCPFlagAck | header.TCPFlagPsh,
- NextSeqNum: iss + 1,
- AckNum: c.IRS.Add(1),
- WndSize: 30000,
- RecentTS: ackOptions.TSVal,
- TSVal: wantOptions.TSVal,
- SACKPermitted: wantOptions.SACKPermitted,
- }
-}
-
-// AcceptWithOptionsNoDelay delegates call to AcceptWithOptions without delay.
-func (c *Context) AcceptWithOptionsNoDelay(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
- return c.AcceptWithOptions(wndScale, synOptions, 0 /* delay */)
-}
-
-// AcceptWithOptions initializes a listening endpoint and connects to it with
-// the provided options enabled. It delays before the final ACK of the 3WHS is
-// sent. It also verifies that the SYN-ACK has the expected values for the
-// provided options.
-//
-// The function returns a RawEndpoint representing the other end of the accepted
-// endpoint.
-func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint {
- // Create EP and start listening.
- wq := &waiter.Queue{}
- ep, err := c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
- defer ep.Close()
-
- if err := ep.Bind(tcpip.FullAddress{Port: StackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
- }
- if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
- c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- if err := ep.Listen(10); err != nil {
- c.t.Fatalf("Listen failed: %v", err)
- }
- if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
- c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- rep := c.PassiveConnectWithOptions(100, wndScale, synOptions, delay)
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.ReadableEvents)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept(nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept(nil)
- if err != nil {
- c.t.Fatalf("Accept failed: %v", err)
- }
-
- case <-time.After(1 * time.Second):
- c.t.Fatalf("Timed out waiting for accept")
- }
- }
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want {
- c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- return rep
-}
-
-// PassiveConnect just disables WindowScaling and delegates the call to
-// PassiveConnectWithOptions.
-func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCPSynOptions) {
- synOptions.WS = -1
- c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions, 0 /* delay */)
-}
-
-// PassiveConnectWithOptions initiates a new connection (with the specified TCP
-// options enabled) to the port on which the Context.ep is listening for new
-// connections. It also validates that the SYN-ACK has the expected values for
-// the enabled options. The final ACK of the handshake is delayed by specified
-// duration.
-//
-// NOTE: MSS is not a negotiated option and it can be asymmetric
-// in each direction. This function uses the maxPayload to set the MSS to be
-// sent to the peer on a connect and validates that the MSS in the SYN-ACK
-// response is equal to the MTU - (tcphdr len + iphdr len).
-//
-// wndScale is the expected window scale in the SYN-ACK and synOptions.WS is the
-// value of the window scaling option to be sent in the SYN. If synOptions.WS >
-// 0 then we send the WindowScale option.
-func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint {
- c.t.Helper()
- opts := make([]byte, header.TCPOptionsMaximumSize)
- offset := 0
- offset += header.EncodeMSSOption(uint32(maxPayload), opts)
-
- if synOptions.WS >= 0 {
- offset += header.EncodeWSOption(3, opts[offset:])
- }
- if synOptions.TS {
- offset += header.EncodeTSOption(synOptions.TSVal, synOptions.TSEcr, opts[offset:])
- }
-
- if synOptions.SACKPermitted {
- offset += header.EncodeSACKPermittedOption(opts[offset:])
- }
-
- paddingToAdd := 4 - offset%4
- // Now add any padding bytes that might be required to quad align the
- // options.
- for i := offset; i < offset+paddingToAdd; i++ {
- opts[i] = header.TCPOptionNOP
- }
- offset += paddingToAdd
-
- // Send a SYN request.
- iss := seqnum.Value(TestInitialSequenceNumber)
- c.SendPacket(nil, &Headers{
- SrcPort: TestPort,
- DstPort: StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- RcvWnd: 30000,
- TCPOpts: opts[:offset],
- })
-
- // Receive the SYN-ACK reply. Make sure MSS and other expected options
- // are present.
- b := c.GetPacket()
- tcp := header.TCP(header.IPv4(b).Payload())
- rcvdSynOptions := header.ParseSynOptions(tcp.Options(), true /* isAck */)
- c.IRS = seqnum.Value(tcp.SequenceNumber())
-
- tcpCheckers := []checker.TransportChecker{
- checker.SrcPort(StackPort),
- checker.DstPort(TestPort),
- checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.TCPAckNum(uint32(iss) + 1),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}),
- }
-
- // If TS option was enabled in the original SYN then add a checker to
- // validate the Timestamp option in the SYN-ACK.
- if synOptions.TS {
- tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(synOptions.TS, 0, synOptions.TSVal))
- } else {
- tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0))
- }
-
- checker.IPv4(c.t, b, checker.TCP(tcpCheckers...))
- rcvWnd := seqnum.Size(30000)
- ackHeaders := &Headers{
- SrcPort: TestPort,
- DstPort: StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 1,
- RcvWnd: rcvWnd,
- }
-
- // If WS was expected to be in effect then scale the advertised window
- // correspondingly.
- if synOptions.WS > 0 {
- ackHeaders.RcvWnd = rcvWnd >> byte(synOptions.WS)
- }
-
- parsedOpts := tcp.ParsedOptions()
- if synOptions.TS {
- // Echo the tsVal back to the peer in the tsEcr field of the
- // timestamp option.
- // Increment TSVal by 1 from the value sent in the SYN and echo
- // the TSVal in the SYN-ACK in the TSEcr field.
- opts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
- header.EncodeTSOption(synOptions.TSVal+1, parsedOpts.TSVal, opts[2:])
- ackHeaders.TCPOpts = opts[:]
- }
-
- // Send ACK, delay if needed.
- if delay > 0 {
- time.Sleep(delay)
- }
- c.SendPacket(nil, ackHeaders)
-
- c.RcvdWindowScale = uint8(rcvdSynOptions.WS)
- c.Port = StackPort
-
- return &RawEndpoint{
- C: c,
- SrcPort: TestPort,
- DstPort: StackPort,
- Flags: header.TCPFlagPsh | header.TCPFlagAck,
- NextSeqNum: iss + 1,
- AckNum: c.IRS + 1,
- WndSize: rcvWnd,
- SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled(),
- RecentTS: parsedOpts.TSVal,
- TSVal: synOptions.TSVal + 1,
- }
-}
-
-// SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true
-// for the Stack in the context.
-func (c *Context) SACKEnabled() bool {
- var v tcpip.TCPSACKEnabled
- if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil {
- // Stack doesn't support SACK. So just return.
- return false
- }
- return bool(v)
-}
-
-// SetGSOEnabled enables or disables generic segmentation offload.
-func (c *Context) SetGSOEnabled(enable bool) {
- if enable {
- c.linkEP.SupportedGSOKind = stack.HWGSOSupported
- } else {
- c.linkEP.SupportedGSOKind = stack.GSONotSupported
- }
-}
-
-// MSSWithoutOptions returns the value for the MSS used by the stack when no
-// options are in use.
-func (c *Context) MSSWithoutOptions() uint16 {
- return uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize)
-}
-
-// MSSWithoutOptionsV6 returns the value for the MSS used by the stack when no
-// options are in use for IPv6 packets.
-func (c *Context) MSSWithoutOptionsV6() uint16 {
- return uint16(c.linkEP.MTU() - header.IPv6MinimumSize - header.TCPMinimumSize)
-}
diff --git a/pkg/tcpip/transport/tcp/timer_test.go b/pkg/tcpip/transport/tcp/timer_test.go
deleted file mode 100644
index 479752de7..000000000
--- a/pkg/tcpip/transport/tcp/timer_test.go
+++ /dev/null
@@ -1,50 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp
-
-import (
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/sleep"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
-)
-
-func TestCleanup(t *testing.T) {
- const (
- timerDurationSeconds = 2
- isAssertedTimeoutSeconds = timerDurationSeconds + 1
- )
-
- clock := faketime.NewManualClock()
-
- tmr := timer{}
- w := sleep.Waker{}
- tmr.init(clock, &w)
- tmr.enable(timerDurationSeconds * time.Second)
- tmr.cleanup()
-
- if want := (timer{}); tmr != want {
- t.Errorf("got tmr = %+v, want = %+v", tmr, want)
- }
-
- // The waker should not be asserted.
- for i := 0; i < isAssertedTimeoutSeconds; i++ {
- clock.Advance(time.Second)
- if w.IsAsserted() {
- t.Fatalf("waker asserted unexpectedly")
- }
- }
-}
diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD
deleted file mode 100644
index 3ad6994a7..000000000
--- a/pkg/tcpip/transport/tcpconntrack/BUILD
+++ /dev/null
@@ -1,23 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "tcpconntrack",
- srcs = ["tcp_conntrack.go"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip/header",
- "//pkg/tcpip/seqnum",
- ],
-)
-
-go_test(
- name = "tcpconntrack_test",
- size = "small",
- srcs = ["tcp_conntrack_test.go"],
- deps = [
- ":tcpconntrack",
- "//pkg/tcpip/header",
- ],
-)
diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go
deleted file mode 100644
index 6c5ddc3c7..000000000
--- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go
+++ /dev/null
@@ -1,511 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcpconntrack_test
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack"
-)
-
-// connected creates a connection tracker TCB and sets it to a connected state
-// by performing a 3-way handshake.
-func connected(t *testing.T, iss, irs uint32, isw, irw uint16) *tcpconntrack.TCB {
- // Send SYN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: iss,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: irw,
- })
-
- tcb := tcpconntrack.TCB{}
- tcb.Init(tcp)
-
- // Receive SYN-ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: irs,
- AckNum: iss + 1,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- WindowSize: isw,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Send ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: iss + 1,
- AckNum: irs + 1,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: irw,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- return &tcb
-}
-
-func TestConnectionRefused(t *testing.T) {
- // Send SYN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: 1234,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 30000,
- })
-
- tcb := tcpconntrack.TCB{}
- tcb.Init(tcp)
-
- // Receive RST.
- tcp.Encode(&header.TCPFields{
- SeqNum: 789,
- AckNum: 1235,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagRst | header.TCPFlagAck,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset)
- }
-}
-
-func TestConnectionRefusedInSynRcvd(t *testing.T) {
- // Send SYN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: 1234,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 30000,
- })
-
- tcb := tcpconntrack.TCB{}
- tcb.Init(tcp)
-
- // Receive SYN.
- tcp.Encode(&header.TCPFields{
- SeqNum: 789,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Receive RST with no ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 790,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagRst,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset)
- }
-}
-
-func TestConnectionResetInSynRcvd(t *testing.T) {
- // Send SYN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: 1234,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 30000,
- })
-
- tcb := tcpconntrack.TCB{}
- tcb.Init(tcp)
-
- // Receive SYN.
- tcp.Encode(&header.TCPFields{
- SeqNum: 789,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Send RST with no ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 1235,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagRst,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultReset {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset)
- }
-}
-
-func TestRetransmitOnSynSent(t *testing.T) {
- // Send initial SYN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: 1234,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 30000,
- })
-
- tcb := tcpconntrack.TCB{}
- tcb.Init(tcp)
-
- // Retransmit the same SYN.
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultConnecting {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultConnecting)
- }
-}
-
-func TestRetransmitOnSynRcvd(t *testing.T) {
- // Send initial SYN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: 1234,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 30000,
- })
-
- tcb := tcpconntrack.TCB{}
- tcb.Init(tcp)
-
- // Receive SYN. This will cause the state to go to SYN-RCVD.
- tcp.Encode(&header.TCPFields{
- SeqNum: 789,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Retransmit the original SYN.
- tcp.Encode(&header.TCPFields{
- SeqNum: 1234,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 30000,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Transmit a SYN-ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 1234,
- AckNum: 790,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- WindowSize: 30000,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-}
-
-func TestClosedBySelf(t *testing.T) {
- tcb := connected(t, 1234, 789, 30000, 50000)
-
- // Send FIN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: 1235,
- AckNum: 790,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- WindowSize: 30000,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Receive FIN/ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 790,
- AckNum: 1236,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Send ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 1236,
- AckNum: 791,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: 30000,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf)
- }
-}
-
-func TestClosedByPeer(t *testing.T) {
- tcb := connected(t, 1234, 789, 30000, 50000)
-
- // Receive FIN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: 790,
- AckNum: 1235,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Send FIN/ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 1235,
- AckNum: 791,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- WindowSize: 30000,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Receive ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 791,
- AckNum: 1236,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultClosedByPeer {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedByPeer)
- }
-}
-
-func TestSendAndReceiveDataClosedBySelf(t *testing.T) {
- sseq := uint32(1234)
- rseq := uint32(789)
- tcb := connected(t, sseq, rseq, 30000, 50000)
- sseq++
- rseq++
-
- // Send some data.
- tcp := make(header.TCP, header.TCPMinimumSize+1024)
-
- for i := uint32(0); i < 10; i++ {
- // Send some data.
- tcp.Encode(&header.TCPFields{
- SeqNum: sseq,
- AckNum: rseq,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: 30000,
- })
- sseq += uint32(len(tcp)) - header.TCPMinimumSize
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Receive ack for data.
- tcp.Encode(&header.TCPFields{
- SeqNum: rseq,
- AckNum: sseq,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
- }
-
- for i := uint32(0); i < 10; i++ {
- // Receive some data.
- tcp.Encode(&header.TCPFields{
- SeqNum: rseq,
- AckNum: sseq,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: 50000,
- })
- rseq += uint32(len(tcp)) - header.TCPMinimumSize
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Send ack for data.
- tcp.Encode(&header.TCPFields{
- SeqNum: sseq,
- AckNum: rseq,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: 30000,
- })
-
- if r := tcb.UpdateStateOutbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
- }
-
- // Send FIN.
- tcp = tcp[:header.TCPMinimumSize]
- tcp.Encode(&header.TCPFields{
- SeqNum: sseq,
- AckNum: rseq,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- WindowSize: 30000,
- })
- sseq++
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Receive FIN/ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: rseq,
- AckNum: sseq,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- WindowSize: 50000,
- })
- rseq++
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Send ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: sseq,
- AckNum: rseq,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: 30000,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf)
- }
-}
-
-func TestIgnoreBadResetOnSynSent(t *testing.T) {
- // Send SYN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: 1234,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 30000,
- })
-
- tcb := tcpconntrack.TCB{}
- tcb.Init(tcp)
-
- // Receive a RST with a bad ACK, it should not cause the connection to
- // be reset.
- acks := []uint32{1234, 1236, 1000, 5000}
- flags := []header.TCPFlags{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck}
- for _, a := range acks {
- for _, f := range flags {
- tcp.Encode(&header.TCPFields{
- SeqNum: 789,
- AckNum: a,
- DataOffset: header.TCPMinimumSize,
- Flags: f,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultConnecting {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
- }
- }
-
- // Complete the handshake.
- // Receive SYN-ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 789,
- AckNum: 1235,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Send ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 1235,
- AckNum: 790,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: 30000,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-}
diff --git a/pkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go b/pkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go
new file mode 100644
index 000000000..ff53204da
--- /dev/null
+++ b/pkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package tcpconntrack
diff --git a/pkg/tcpip/transport/transport_state_autogen.go b/pkg/tcpip/transport/transport_state_autogen.go
new file mode 100644
index 000000000..c023165ec
--- /dev/null
+++ b/pkg/tcpip/transport/transport_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package transport
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
deleted file mode 100644
index d2c0963b0..000000000
--- a/pkg/tcpip/transport/udp/BUILD
+++ /dev/null
@@ -1,68 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-package(licenses = ["notice"])
-
-go_template_instance(
- name = "udp_packet_list",
- out = "udp_packet_list.go",
- package = "udp",
- prefix = "udpPacket",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*udpPacket",
- "Linker": "*udpPacket",
- },
-)
-
-go_library(
- name = "udp",
- srcs = [
- "endpoint.go",
- "endpoint_state.go",
- "forwarder.go",
- "protocol.go",
- "udp_packet_list.go",
- ],
- imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/sleep",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/header/parse",
- "//pkg/tcpip/ports",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport",
- "//pkg/tcpip/transport/internal/network",
- "//pkg/tcpip/transport/raw",
- "//pkg/waiter",
- ],
-)
-
-go_test(
- name = "udp_x_test",
- size = "small",
- srcs = ["udp_test.go"],
- deps = [
- ":udp",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/faketime",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/loopback",
- "//pkg/tcpip/link/sniffer",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/testutil",
- "//pkg/tcpip/transport/icmp",
- "//pkg/waiter",
- "@com_github_google_go_cmp//cmp:go_default_library",
- "@org_golang_x_time//rate:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/transport/udp/udp_packet_list.go b/pkg/tcpip/transport/udp/udp_packet_list.go
new file mode 100644
index 000000000..c396f77c9
--- /dev/null
+++ b/pkg/tcpip/transport/udp/udp_packet_list.go
@@ -0,0 +1,221 @@
+package udp
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type udpPacketElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (udpPacketElementMapper) linkerFor(elem *udpPacket) *udpPacket { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type udpPacketList struct {
+ head *udpPacket
+ tail *udpPacket
+}
+
+// Reset resets list l to the empty state.
+func (l *udpPacketList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+//
+//go:nosplit
+func (l *udpPacketList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+//
+//go:nosplit
+func (l *udpPacketList) Front() *udpPacket {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+//
+//go:nosplit
+func (l *udpPacketList) Back() *udpPacket {
+ return l.tail
+}
+
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+//
+//go:nosplit
+func (l *udpPacketList) Len() (count int) {
+ for e := l.Front(); e != nil; e = (udpPacketElementMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
+// PushFront inserts the element e at the front of list l.
+//
+//go:nosplit
+func (l *udpPacketList) PushFront(e *udpPacket) {
+ linker := udpPacketElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
+ if l.head != nil {
+ udpPacketElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+//
+//go:nosplit
+func (l *udpPacketList) PushBack(e *udpPacket) {
+ linker := udpPacketElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
+ if l.tail != nil {
+ udpPacketElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+//
+//go:nosplit
+func (l *udpPacketList) PushBackList(m *udpPacketList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ udpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ udpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+//
+//go:nosplit
+func (l *udpPacketList) InsertAfter(b, e *udpPacket) {
+ bLinker := udpPacketElementMapper{}.linkerFor(b)
+ eLinker := udpPacketElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
+
+ if a != nil {
+ udpPacketElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+//
+//go:nosplit
+func (l *udpPacketList) InsertBefore(a, e *udpPacket) {
+ aLinker := udpPacketElementMapper{}.linkerFor(a)
+ eLinker := udpPacketElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
+
+ if b != nil {
+ udpPacketElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+//
+//go:nosplit
+func (l *udpPacketList) Remove(e *udpPacket) {
+ linker := udpPacketElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
+
+ if prev != nil {
+ udpPacketElementMapper{}.linkerFor(prev).SetNext(next)
+ } else if l.head == e {
+ l.head = next
+ }
+
+ if next != nil {
+ udpPacketElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else if l.tail == e {
+ l.tail = prev
+ }
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type udpPacketEntry struct {
+ next *udpPacket
+ prev *udpPacket
+}
+
+// Next returns the entry that follows e in the list.
+//
+//go:nosplit
+func (e *udpPacketEntry) Next() *udpPacket {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *udpPacketEntry) Prev() *udpPacket {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+//
+//go:nosplit
+func (e *udpPacketEntry) SetNext(elem *udpPacket) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+//
+//go:nosplit
+func (e *udpPacketEntry) SetPrev(elem *udpPacket) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go
new file mode 100644
index 000000000..e25607e3f
--- /dev/null
+++ b/pkg/tcpip/transport/udp/udp_state_autogen.go
@@ -0,0 +1,194 @@
+// automatically generated by stateify.
+
+package udp
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+func (p *udpPacket) StateTypeName() string {
+ return "pkg/tcpip/transport/udp.udpPacket"
+}
+
+func (p *udpPacket) StateFields() []string {
+ return []string{
+ "udpPacketEntry",
+ "netProto",
+ "senderAddress",
+ "destinationAddress",
+ "packetInfo",
+ "data",
+ "receivedAt",
+ "tos",
+ }
+}
+
+func (p *udpPacket) beforeSave() {}
+
+// +checklocksignore
+func (p *udpPacket) StateSave(stateSinkObject state.Sink) {
+ p.beforeSave()
+ var dataValue buffer.VectorisedView
+ dataValue = p.saveData()
+ stateSinkObject.SaveValue(5, dataValue)
+ var receivedAtValue int64
+ receivedAtValue = p.saveReceivedAt()
+ stateSinkObject.SaveValue(6, receivedAtValue)
+ stateSinkObject.Save(0, &p.udpPacketEntry)
+ stateSinkObject.Save(1, &p.netProto)
+ stateSinkObject.Save(2, &p.senderAddress)
+ stateSinkObject.Save(3, &p.destinationAddress)
+ stateSinkObject.Save(4, &p.packetInfo)
+ stateSinkObject.Save(7, &p.tos)
+}
+
+func (p *udpPacket) afterLoad() {}
+
+// +checklocksignore
+func (p *udpPacket) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &p.udpPacketEntry)
+ stateSourceObject.Load(1, &p.netProto)
+ stateSourceObject.Load(2, &p.senderAddress)
+ stateSourceObject.Load(3, &p.destinationAddress)
+ stateSourceObject.Load(4, &p.packetInfo)
+ stateSourceObject.Load(7, &p.tos)
+ stateSourceObject.LoadValue(5, new(buffer.VectorisedView), func(y interface{}) { p.loadData(y.(buffer.VectorisedView)) })
+ stateSourceObject.LoadValue(6, new(int64), func(y interface{}) { p.loadReceivedAt(y.(int64)) })
+}
+
+func (e *endpoint) StateTypeName() string {
+ return "pkg/tcpip/transport/udp.endpoint"
+}
+
+func (e *endpoint) StateFields() []string {
+ return []string{
+ "DefaultSocketOptionsHandler",
+ "waiterQueue",
+ "uniqueID",
+ "net",
+ "ops",
+ "rcvReady",
+ "rcvList",
+ "rcvBufSize",
+ "rcvClosed",
+ "lastError",
+ "portFlags",
+ "boundBindToDevice",
+ "boundPortFlags",
+ "readShutdown",
+ "effectiveNetProtos",
+ "frozen",
+ "localPort",
+ "remotePort",
+ }
+}
+
+// +checklocksignore
+func (e *endpoint) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.DefaultSocketOptionsHandler)
+ stateSinkObject.Save(1, &e.waiterQueue)
+ stateSinkObject.Save(2, &e.uniqueID)
+ stateSinkObject.Save(3, &e.net)
+ stateSinkObject.Save(4, &e.ops)
+ stateSinkObject.Save(5, &e.rcvReady)
+ stateSinkObject.Save(6, &e.rcvList)
+ stateSinkObject.Save(7, &e.rcvBufSize)
+ stateSinkObject.Save(8, &e.rcvClosed)
+ stateSinkObject.Save(9, &e.lastError)
+ stateSinkObject.Save(10, &e.portFlags)
+ stateSinkObject.Save(11, &e.boundBindToDevice)
+ stateSinkObject.Save(12, &e.boundPortFlags)
+ stateSinkObject.Save(13, &e.readShutdown)
+ stateSinkObject.Save(14, &e.effectiveNetProtos)
+ stateSinkObject.Save(15, &e.frozen)
+ stateSinkObject.Save(16, &e.localPort)
+ stateSinkObject.Save(17, &e.remotePort)
+}
+
+// +checklocksignore
+func (e *endpoint) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.DefaultSocketOptionsHandler)
+ stateSourceObject.Load(1, &e.waiterQueue)
+ stateSourceObject.Load(2, &e.uniqueID)
+ stateSourceObject.Load(3, &e.net)
+ stateSourceObject.Load(4, &e.ops)
+ stateSourceObject.Load(5, &e.rcvReady)
+ stateSourceObject.Load(6, &e.rcvList)
+ stateSourceObject.Load(7, &e.rcvBufSize)
+ stateSourceObject.Load(8, &e.rcvClosed)
+ stateSourceObject.Load(9, &e.lastError)
+ stateSourceObject.Load(10, &e.portFlags)
+ stateSourceObject.Load(11, &e.boundBindToDevice)
+ stateSourceObject.Load(12, &e.boundPortFlags)
+ stateSourceObject.Load(13, &e.readShutdown)
+ stateSourceObject.Load(14, &e.effectiveNetProtos)
+ stateSourceObject.Load(15, &e.frozen)
+ stateSourceObject.Load(16, &e.localPort)
+ stateSourceObject.Load(17, &e.remotePort)
+ stateSourceObject.AfterLoad(e.afterLoad)
+}
+
+func (l *udpPacketList) StateTypeName() string {
+ return "pkg/tcpip/transport/udp.udpPacketList"
+}
+
+func (l *udpPacketList) StateFields() []string {
+ return []string{
+ "head",
+ "tail",
+ }
+}
+
+func (l *udpPacketList) beforeSave() {}
+
+// +checklocksignore
+func (l *udpPacketList) StateSave(stateSinkObject state.Sink) {
+ l.beforeSave()
+ stateSinkObject.Save(0, &l.head)
+ stateSinkObject.Save(1, &l.tail)
+}
+
+func (l *udpPacketList) afterLoad() {}
+
+// +checklocksignore
+func (l *udpPacketList) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &l.head)
+ stateSourceObject.Load(1, &l.tail)
+}
+
+func (e *udpPacketEntry) StateTypeName() string {
+ return "pkg/tcpip/transport/udp.udpPacketEntry"
+}
+
+func (e *udpPacketEntry) StateFields() []string {
+ return []string{
+ "next",
+ "prev",
+ }
+}
+
+func (e *udpPacketEntry) beforeSave() {}
+
+// +checklocksignore
+func (e *udpPacketEntry) StateSave(stateSinkObject state.Sink) {
+ e.beforeSave()
+ stateSinkObject.Save(0, &e.next)
+ stateSinkObject.Save(1, &e.prev)
+}
+
+func (e *udpPacketEntry) afterLoad() {}
+
+// +checklocksignore
+func (e *udpPacketEntry) StateLoad(stateSourceObject state.Source) {
+ stateSourceObject.Load(0, &e.next)
+ stateSourceObject.Load(1, &e.prev)
+}
+
+func init() {
+ state.Register((*udpPacket)(nil))
+ state.Register((*endpoint)(nil))
+ state.Register((*udpPacketList)(nil))
+ state.Register((*udpPacketEntry)(nil))
+}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
deleted file mode 100644
index b3199489c..000000000
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ /dev/null
@@ -1,2602 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package udp_test
-
-import (
- "bytes"
- "fmt"
- "io/ioutil"
- "math/rand"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "golang.org/x/time/rate"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/testutil"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-// Addresses and ports used for testing. It is recommended that tests stick to
-// using these addresses as it allows using the testFlow helper.
-// Naming rules: 'stack*'' denotes local addresses and ports, while 'test*'
-// represents the remote endpoint.
-const (
- v4MappedAddrPrefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff"
- stackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- testV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- stackV4MappedAddr = v4MappedAddrPrefix + stackAddr
- testV4MappedAddr = v4MappedAddrPrefix + testAddr
- multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr
- broadcastV4MappedAddr = v4MappedAddrPrefix + broadcastAddr
- v4MappedWildcardAddr = v4MappedAddrPrefix + "\x00\x00\x00\x00"
-
- stackAddr = "\x0a\x00\x00\x01"
- stackPort = 1234
- testAddr = "\x0a\x00\x00\x02"
- testPort = 4096
- invalidPort = 8192
- multicastAddr = "\xe8\x2b\xd3\xea"
- multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
- broadcastAddr = header.IPv4Broadcast
- testTOS = 0x80
-
- // defaultMTU is the MTU, in bytes, used throughout the tests, except
- // where another value is explicitly used. It is chosen to match the MTU
- // of loopback interfaces on linux systems.
- defaultMTU = 65536
-)
-
-// header4Tuple stores the 4-tuple {src-IP, src-port, dst-IP, dst-port} used in
-// a packet header. These values are used to populate a header or verify one.
-// Note that because they are used in packet headers, the addresses are never in
-// a V4-mapped format.
-type header4Tuple struct {
- srcAddr tcpip.FullAddress
- dstAddr tcpip.FullAddress
-}
-
-// testFlow implements a helper type used for sending and receiving test
-// packets. A given test flow value defines 1) the socket endpoint used for the
-// test and 2) the type of packet send or received on the endpoint. E.g., a
-// multicastV6Only flow is a V6 multicast packet passing through a V6-only
-// endpoint. The type provides helper methods to characterize the flow (e.g.,
-// isV4) as well as return a proper header4Tuple for it.
-type testFlow int
-
-const (
- unicastV4 testFlow = iota // V4 unicast on a V4 socket
- unicastV4in6 // V4-mapped unicast on a V6-dual socket
- unicastV6 // V6 unicast on a V6 socket
- unicastV6Only // V6 unicast on a V6-only socket
- multicastV4 // V4 multicast on a V4 socket
- multicastV4in6 // V4-mapped multicast on a V6-dual socket
- multicastV6 // V6 multicast on a V6 socket
- multicastV6Only // V6 multicast on a V6-only socket
- broadcast // V4 broadcast on a V4 socket
- broadcastIn6 // V4-mapped broadcast on a V6-dual socket
- reverseMulticast4 // V4 multicast src. Must fail.
- reverseMulticast6 // V6 multicast src. Must fail.
-)
-
-func (flow testFlow) String() string {
- switch flow {
- case unicastV4:
- return "unicastV4"
- case unicastV6:
- return "unicastV6"
- case unicastV6Only:
- return "unicastV6Only"
- case unicastV4in6:
- return "unicastV4in6"
- case multicastV4:
- return "multicastV4"
- case multicastV6:
- return "multicastV6"
- case multicastV6Only:
- return "multicastV6Only"
- case multicastV4in6:
- return "multicastV4in6"
- case broadcast:
- return "broadcast"
- case broadcastIn6:
- return "broadcastIn6"
- case reverseMulticast4:
- return "reverseMulticast4"
- case reverseMulticast6:
- return "reverseMulticast6"
- default:
- return "unknown"
- }
-}
-
-// packetDirection explains if a flow is incoming (read) or outgoing (write).
-type packetDirection int
-
-const (
- incoming packetDirection = iota
- outgoing
-)
-
-// header4Tuple returns the header4Tuple for the given flow and direction. Note
-// that the tuple contains no mapped addresses as those only exist at the socket
-// level but not at the packet header level.
-func (flow testFlow) header4Tuple(d packetDirection) header4Tuple {
- var h header4Tuple
- if flow.isV4() {
- if d == outgoing {
- h = header4Tuple{
- srcAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort},
- dstAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort},
- }
- } else {
- h = header4Tuple{
- srcAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort},
- dstAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort},
- }
- }
- if flow.isMulticast() {
- h.dstAddr.Addr = multicastAddr
- } else if flow.isBroadcast() {
- h.dstAddr.Addr = broadcastAddr
- }
- } else { // IPv6
- if d == outgoing {
- h = header4Tuple{
- srcAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
- dstAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
- }
- } else {
- h = header4Tuple{
- srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
- dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
- }
- }
- if flow.isMulticast() {
- h.dstAddr.Addr = multicastV6Addr
- }
- }
- if flow.isReverseMulticast() {
- h.srcAddr.Addr = flow.getMcastAddr()
- }
- return h
-}
-
-func (flow testFlow) getMcastAddr() tcpip.Address {
- if flow.isV4() {
- return multicastAddr
- }
- return multicastV6Addr
-}
-
-// mapAddrIfApplicable converts the given V4 address into its V4-mapped version
-// if it is applicable to the flow.
-func (flow testFlow) mapAddrIfApplicable(v4Addr tcpip.Address) tcpip.Address {
- if flow.isMapped() {
- return v4MappedAddrPrefix + v4Addr
- }
- return v4Addr
-}
-
-// netProto returns the protocol number used for the network packet.
-func (flow testFlow) netProto() tcpip.NetworkProtocolNumber {
- if flow.isV4() {
- return ipv4.ProtocolNumber
- }
- return ipv6.ProtocolNumber
-}
-
-// sockProto returns the protocol number used when creating the socket
-// endpoint for this flow.
-func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber {
- switch flow {
- case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6, reverseMulticast6:
- return ipv6.ProtocolNumber
- case unicastV4, multicastV4, broadcast, reverseMulticast4:
- return ipv4.ProtocolNumber
- default:
- panic(fmt.Sprintf("invalid testFlow given: %d", flow))
- }
-}
-
-func (flow testFlow) checkerFn() func(*testing.T, []byte, ...checker.NetworkChecker) {
- if flow.isV4() {
- return checker.IPv4
- }
- return checker.IPv6
-}
-
-func (flow testFlow) isV6() bool { return !flow.isV4() }
-func (flow testFlow) isV4() bool {
- return flow.sockProto() == ipv4.ProtocolNumber || flow.isMapped()
-}
-
-func (flow testFlow) isV6Only() bool {
- switch flow {
- case unicastV6Only, multicastV6Only:
- return true
- case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6:
- return false
- default:
- panic(fmt.Sprintf("invalid testFlow given: %d", flow))
- }
-}
-
-func (flow testFlow) isMulticast() bool {
- switch flow {
- case multicastV4, multicastV4in6, multicastV6, multicastV6Only:
- return true
- case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6:
- return false
- default:
- panic(fmt.Sprintf("invalid testFlow given: %d", flow))
- }
-}
-
-func (flow testFlow) isBroadcast() bool {
- switch flow {
- case broadcast, broadcastIn6:
- return true
- case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only, reverseMulticast4, reverseMulticast6:
- return false
- default:
- panic(fmt.Sprintf("invalid testFlow given: %d", flow))
- }
-}
-
-func (flow testFlow) isMapped() bool {
- switch flow {
- case unicastV4in6, multicastV4in6, broadcastIn6:
- return true
- case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast, reverseMulticast4, reverseMulticast6:
- return false
- default:
- panic(fmt.Sprintf("invalid testFlow given: %d", flow))
- }
-}
-
-func (flow testFlow) isReverseMulticast() bool {
- switch flow {
- case reverseMulticast4, reverseMulticast6:
- return true
- default:
- return false
- }
-}
-
-type testContext struct {
- t *testing.T
- linkEP *channel.Endpoint
- s *stack.Stack
- nicID tcpip.NICID
-
- ep tcpip.Endpoint
- wq waiter.Queue
-}
-
-func newDualTestContext(t *testing.T, mtu uint32) *testContext {
- t.Helper()
- return newDualTestContextWithHandleLocal(t, mtu, true)
-}
-
-func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal bool) *testContext {
- const nicID = 1
-
- t.Helper()
-
- options := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
- HandleLocal: handleLocal,
- Clock: &faketime.NullClock{},
- }
- s := stack.New(options)
- // Disable ICMP rate limiter because we're using Null clock, which never advances time and thus
- // never allows ICMP messages.
- s.SetICMPLimit(rate.Inf)
- ep := channel.New(256, mtu, "")
- wep := stack.LinkEndpoint(ep)
-
- if testing.Verbose() {
- wep = sniffer.New(ep)
- }
- if err := s.CreateNIC(nicID, wep); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
-
- protocolAddrV4 := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.Address(stackAddr).WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err)
- }
-
- protocolAddrV6 := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: tcpip.Address(stackV6Addr).WithPrefix(),
- }
- if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: nicID,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: nicID,
- },
- })
-
- return &testContext{
- t: t,
- s: s,
- nicID: nicID,
- linkEP: ep,
- }
-}
-
-func (c *testContext) cleanup() {
- if c.ep != nil {
- c.ep.Close()
- }
-}
-
-func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) {
- c.t.Helper()
-
- var err tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq)
- if err != nil {
- c.t.Fatal("NewEndpoint failed: ", err)
- }
-}
-
-func (c *testContext) createEndpointForFlow(flow testFlow) {
- c.t.Helper()
-
- c.createEndpoint(flow.sockProto())
- if flow.isV6Only() {
- c.ep.SocketOptions().SetV6Only(true)
- } else if flow.isBroadcast() {
- c.ep.SocketOptions().SetBroadcast(true)
- }
-}
-
-// getPacketAndVerify reads a packet from the link endpoint and verifies the
-// header against expected values from the given test flow. In addition, it
-// calls any extra checker functions provided.
-func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte {
- c.t.Helper()
-
- p, ok := c.linkEP.Read()
- if !ok {
- c.t.Fatalf("Packet wasn't written out")
- return nil
- }
-
- if p.Proto != flow.netProto() {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
- }
-
- if got, want := p.Pkt.TransportProtocolNumber, header.UDPProtocolNumber; got != want {
- c.t.Errorf("got p.Pkt.TransportProtocolNumber = %d, want = %d", got, want)
- }
-
- vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
- b := vv.ToView()
-
- h := flow.header4Tuple(outgoing)
- checkers = append(
- checkers,
- checker.SrcAddr(h.srcAddr.Addr),
- checker.DstAddr(h.dstAddr.Addr),
- checker.UDP(checker.DstPort(h.dstAddr.Port)),
- )
- flow.checkerFn()(c.t, b, checkers...)
- return b
-}
-
-// injectPacket creates a packet of the given flow and with the given payload,
-// and injects it into the link endpoint. If badChecksum is true, the packet has
-// a bad checksum in the UDP header.
-func (c *testContext) injectPacket(flow testFlow, payload []byte, badChecksum bool) {
- c.t.Helper()
-
- h := flow.header4Tuple(incoming)
- if flow.isV4() {
- buf := c.buildV4Packet(payload, &h)
- if badChecksum {
- // Invalidate the UDP header checksum field, taking care to avoid
- // overflow to zero, which would disable checksum validation.
- for u := header.UDP(buf[header.IPv4MinimumSize:]); ; {
- u.SetChecksum(u.Checksum() + 1)
- if u.Checksum() != 0 {
- break
- }
- }
- }
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- } else {
- buf := c.buildV6Packet(payload, &h)
- if badChecksum {
- // Invalidate the UDP header checksum field (Unlike IPv4, zero is
- // a valid checksum value for IPv6 so no need to avoid it).
- u := header.UDP(buf[header.IPv6MinimumSize:])
- u.SetChecksum(u.Checksum() + 1)
- }
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- }
-}
-
-// buildV6Packet creates a V6 test packet with the given payload and header
-// values in a buffer.
-func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View {
- // Allocate a buffer for data and headers.
- buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
- payloadStart := len(buf) - len(payload)
- copy(buf[payloadStart:], payload)
-
- // Initialize the IP header.
- ip := header.IPv6(buf)
- ip.Encode(&header.IPv6Fields{
- TrafficClass: testTOS,
- PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
- TransportProtocol: udp.ProtocolNumber,
- HopLimit: 65,
- SrcAddr: h.srcAddr.Addr,
- DstAddr: h.dstAddr.Addr,
- })
-
- // Initialize the UDP header.
- u := header.UDP(buf[header.IPv6MinimumSize:])
- u.Encode(&header.UDPFields{
- SrcPort: h.srcAddr.Port,
- DstPort: h.dstAddr.Port,
- Length: uint16(header.UDPMinimumSize + len(payload)),
- })
-
- // Calculate the UDP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u)))
-
- // Calculate the UDP checksum and set it.
- xsum = header.Checksum(payload, xsum)
- u.SetChecksum(^u.CalculateChecksum(xsum))
-
- return buf
-}
-
-// buildV4Packet creates a V4 test packet with the given payload and header
-// values in a buffer.
-func (c *testContext) buildV4Packet(payload []byte, h *header4Tuple) buffer.View {
- // Allocate a buffer for data and headers.
- buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
- payloadStart := len(buf) - len(payload)
- copy(buf[payloadStart:], payload)
-
- // Initialize the IP header.
- ip := header.IPv4(buf)
- ip.Encode(&header.IPv4Fields{
- TOS: testTOS,
- TotalLength: uint16(len(buf)),
- TTL: 65,
- Protocol: uint8(udp.ProtocolNumber),
- SrcAddr: h.srcAddr.Addr,
- DstAddr: h.dstAddr.Addr,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- // Initialize the UDP header.
- u := header.UDP(buf[header.IPv4MinimumSize:])
- u.Encode(&header.UDPFields{
- SrcPort: h.srcAddr.Port,
- DstPort: h.dstAddr.Port,
- Length: uint16(header.UDPMinimumSize + len(payload)),
- })
-
- // Calculate the UDP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u)))
-
- // Calculate the UDP checksum and set it.
- xsum = header.Checksum(payload, xsum)
- u.SetChecksum(^u.CalculateChecksum(xsum))
-
- return buf
-}
-
-func newPayload() []byte {
- return newMinPayload(30)
-}
-
-func newMinPayload(minSize int) []byte {
- b := make([]byte, minSize+rand.Intn(100))
- for i := range b {
- b[i] = byte(rand.Intn(256))
- }
- return b
-}
-
-func TestBindToDeviceOption(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- Clock: &faketime.NullClock{},
- })
-
- ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %s", err)
- }
- defer ep.Close()
-
- opts := stack.NICOptions{Name: "my_device"}
- if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil {
- t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %s", opts, err)
- }
-
- // nicIDPtr is used instead of taking the address of NICID literals, which is
- // a compiler error.
- nicIDPtr := func(s tcpip.NICID) *tcpip.NICID {
- return &s
- }
-
- testActions := []struct {
- name string
- setBindToDevice *tcpip.NICID
- setBindToDeviceError tcpip.Error
- getBindToDevice int32
- }{
- {"GetDefaultValue", nil, nil, 0},
- {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0},
- {"BindToExistent", nicIDPtr(321), nil, 321},
- {"UnbindToDevice", nicIDPtr(0), nil, 0},
- }
- for _, testAction := range testActions {
- t.Run(testAction.name, func(t *testing.T) {
- if testAction.setBindToDevice != nil {
- bindToDevice := int32(*testAction.setBindToDevice)
- if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
- t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr)
- }
- }
- bindToDevice := ep.SocketOptions().GetBindToDevice()
- if bindToDevice != testAction.getBindToDevice {
- t.Errorf("got bindToDevice = %d, want = %d", bindToDevice, testAction.getBindToDevice)
- }
- })
- }
-}
-
-// testReadInternal sends a packet of the given test flow into the stack by
-// injecting it into the link endpoint. It then attempts to read it from the
-// UDP endpoint and depending on if this was expected to succeed verifies its
-// correctness including any additional checker functions provided.
-func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) {
- c.t.Helper()
-
- payload := newPayload()
- c.injectPacket(flow, payload, false)
-
- // Try to receive the data.
- we, ch := waiter.NewChannelEntry(nil)
- c.wq.EventRegister(&we, waiter.ReadableEvents)
- defer c.wq.EventUnregister(&we)
-
- // Take a snapshot of the stats to validate them at the end of the test.
- epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
-
- var buf bytes.Buffer
- res, err := c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true})
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
- // Wait for data to become available.
- select {
- case <-ch:
- res, err = c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true})
-
- default:
- if packetShouldBeDropped {
- return // expected to time out
- }
- c.t.Fatal("timed out waiting for data")
- }
- }
-
- if expectReadError && err != nil {
- c.checkEndpointReadStats(1, epstats, err)
- return
- }
-
- if err != nil {
- c.t.Fatal("Read failed:", err)
- }
-
- if packetShouldBeDropped {
- c.t.Fatalf("Read unexpectedly received data from %s", res.RemoteAddr.Addr)
- }
-
- // Check the read result.
- h := flow.header4Tuple(incoming)
- if diff := cmp.Diff(tcpip.ReadResult{
- Count: buf.Len(),
- Total: buf.Len(),
- RemoteAddr: tcpip.FullAddress{Addr: h.srcAddr.Addr},
- }, res, checker.IgnoreCmpPath(
- "ControlMessages", // ControlMessages will be checked later.
- "RemoteAddr.NIC",
- "RemoteAddr.Port",
- )); diff != "" {
- c.t.Fatalf("Read: unexpected result (-want +got):\n%s", diff)
- }
-
- // Check the payload.
- v := buf.Bytes()
- if !bytes.Equal(payload, v) {
- c.t.Fatalf("got payload = %x, want = %x", v, payload)
- }
-
- // Run any checkers against the ControlMessages.
- for _, f := range checkers {
- f(c.t, res.ControlMessages)
- }
-
- c.checkEndpointReadStats(1, epstats, err)
-}
-
-// testRead sends a packet of the given test flow into the stack by injecting it
-// into the link endpoint. It then reads it from the UDP endpoint and verifies
-// its correctness including any additional checker functions provided.
-func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) {
- c.t.Helper()
- testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...)
-}
-
-// testFailingRead sends a packet of the given test flow into the stack by
-// injecting it into the link endpoint. It then tries to read it from the UDP
-// endpoint and expects this to fail.
-func testFailingRead(c *testContext, flow testFlow, expectReadError bool) {
- c.t.Helper()
- testReadInternal(c, flow, true /* packetShouldBeDropped */, expectReadError)
-}
-
-func TestBindEphemeralPort(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- if err := c.ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %s", err)
- }
-}
-
-func TestBindReservedPort(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %s", err)
- }
-
- addr, err := c.ep.GetLocalAddress()
- if err != nil {
- t.Fatalf("GetLocalAddress failed: %s", err)
- }
-
- // We can't bind the address reserved by the connected endpoint above.
- {
- ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- defer ep.Close()
- {
- err := ep.Bind(addr)
- if _, ok := err.(*tcpip.ErrPortInUse); !ok {
- t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{})
- }
- }
- }
-
- func() {
- ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- defer ep.Close()
- // We can't bind ipv4-any on the port reserved by the connected endpoint
- // above, since the endpoint is dual-stack.
- {
- err := ep.Bind(tcpip.FullAddress{Port: addr.Port})
- if _, ok := err.(*tcpip.ErrPortInUse); !ok {
- t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{})
- }
- }
- // We can bind an ipv4 address on this port, though.
- if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %s", err)
- }
- }()
-
- // Once the connected endpoint releases its port reservation, we are able to
- // bind ipv4-any once again.
- c.ep.Close()
- func() {
- ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- defer ep.Close()
- if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %s", err)
- }
- }()
-}
-
-func TestV4ReadOnV6(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV4in6)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- // Test acceptance.
- testRead(c, unicastV4in6)
-}
-
-func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV4in6)
-
- // Bind to v4 mapped wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- // Test acceptance.
- testRead(c, unicastV4in6)
-}
-
-func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV4in6)
-
- // Bind to local address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- // Test acceptance.
- testRead(c, unicastV4in6)
-}
-
-func TestV6ReadOnV6(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV6)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- // Test acceptance.
- testRead(c, unicastV6)
-}
-
-// TestV4ReadSelfSource checks that packets coming from a local IP address are
-// correctly dropped when handleLocal is true and not otherwise.
-func TestV4ReadSelfSource(t *testing.T) {
- for _, tt := range []struct {
- name string
- handleLocal bool
- wantErr tcpip.Error
- wantInvalidSource uint64
- }{
- {"HandleLocal", false, nil, 0},
- {"NoHandleLocal", true, &tcpip.ErrWouldBlock{}, 1},
- } {
- t.Run(tt.name, func(t *testing.T) {
- c := newDualTestContextWithHandleLocal(t, defaultMTU, tt.handleLocal)
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV4)
-
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- payload := newPayload()
- h := unicastV4.header4Tuple(incoming)
- h.srcAddr = h.dstAddr
-
- buf := c.buildV4Packet(payload, &h)
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource {
- t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource)
- }
-
- if _, err := c.ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tt.wantErr {
- t.Errorf("got c.ep.Read = %s, want = %s", err, tt.wantErr)
- }
- })
- }
-}
-
-func TestV4ReadOnV4(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV4)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- // Test acceptance.
- testRead(c, unicastV4)
-}
-
-// TestReadOnBoundToMulticast checks that an endpoint can bind to a multicast
-// address and receive data sent to that address.
-func TestReadOnBoundToMulticast(t *testing.T) {
- // FIXME(b/128189410): multicastV4in6 currently doesn't work as
- // AddMembershipOption doesn't handle V4in6 addresses.
- for _, flow := range []testFlow{multicastV4, multicastV6, multicastV6Only} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to multicast address.
- mcastAddr := flow.mapAddrIfApplicable(flow.getMcastAddr())
- if err := c.ep.Bind(tcpip.FullAddress{Addr: mcastAddr, Port: stackPort}); err != nil {
- c.t.Fatal("Bind failed:", err)
- }
-
- // Join multicast group.
- ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr}
- if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err)
- }
-
- // Check that we receive multicast packets but not unicast or broadcast
- // ones.
- testRead(c, flow)
- testFailingRead(c, broadcast, false /* expectReadError */)
- testFailingRead(c, unicastV4, false /* expectReadError */)
- })
- }
-}
-
-// TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast
-// address and can receive only broadcast data.
-func TestV4ReadOnBoundToBroadcast(t *testing.T) {
- for _, flow := range []testFlow{broadcast, broadcastIn6} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to broadcast address.
- bcastAddr := flow.mapAddrIfApplicable(broadcastAddr)
- if err := c.ep.Bind(tcpip.FullAddress{Addr: bcastAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- // Check that we receive broadcast packets but not unicast ones.
- testRead(c, flow)
- testFailingRead(c, unicastV4, false /* expectReadError */)
- })
- }
-}
-
-// TestReadFromMulticast checks that an endpoint will NOT receive a packet
-// that was sent with multicast SOURCE address.
-func TestReadFromMulticast(t *testing.T) {
- for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
- testFailingRead(c, flow, false /* expectReadError */)
- })
- }
-}
-
-// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY
-// and receive broadcast and unicast data.
-func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) {
- for _, flow := range []testFlow{broadcast, broadcastIn6} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s (", err)
- }
-
- // Check that we receive both broadcast and unicast packets.
- testRead(c, flow)
- testRead(c, unicastV4)
- })
- }
-}
-
-// testFailingWrite sends a packet of the given test flow into the UDP endpoint
-// and verifies it fails with the provided error code.
-func testFailingWrite(c *testContext, flow testFlow, wantErr tcpip.Error) {
- c.t.Helper()
- // Take a snapshot of the stats to validate them at the end of the test.
- epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
- h := flow.header4Tuple(outgoing)
- writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
-
- var r bytes.Reader
- r.Reset(newPayload())
- _, gotErr := c.ep.Write(&r, tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
- })
- c.checkEndpointWriteStats(1, epstats, gotErr)
- if gotErr != wantErr {
- c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr)
- }
-}
-
-// testWrite sends a packet of the given test flow from the UDP endpoint to the
-// flow's destination address:port. It then receives it from the link endpoint
-// and verifies its correctness including any additional checker functions
-// provided.
-func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
- c.t.Helper()
- return testWriteAndVerifyInternal(c, flow, true, checkers...)
-}
-
-// testWriteWithoutDestination sends a packet of the given test flow from the
-// UDP endpoint without giving a destination address:port. It then receives it
-// from the link endpoint and verifies its correctness including any additional
-// checker functions provided.
-func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
- c.t.Helper()
- return testWriteAndVerifyInternal(c, flow, false, checkers...)
-}
-
-func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View {
- c.t.Helper()
- // Take a snapshot of the stats to validate them at the end of the test.
- epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
-
- writeOpts := tcpip.WriteOptions{}
- if setDest {
- h := flow.header4Tuple(outgoing)
- writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
- writeOpts = tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
- }
- }
- var r bytes.Reader
- payload := newPayload()
- r.Reset(payload)
- n, err := c.ep.Write(&r, writeOpts)
- if err != nil {
- c.t.Fatalf("Write failed: %s", err)
- }
- if n != int64(len(payload)) {
- c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
- }
- c.checkEndpointWriteStats(1, epstats, err)
- return payload
-}
-
-func testWriteAndVerifyInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 {
- c.t.Helper()
- payload := testWriteNoVerify(c, flow, setDest)
- // Received the packet and check the payload.
- b := c.getPacketAndVerify(flow, checkers...)
- var udpH header.UDP
- if flow.isV4() {
- udpH = header.IPv4(b).Payload()
- } else {
- udpH = header.IPv6(b).Payload()
- }
- if !bytes.Equal(payload, udpH.Payload()) {
- c.t.Fatalf("Bad payload: got %x, want %x", udpH.Payload(), payload)
- }
-
- return udpH.SourcePort()
-}
-
-func testDualWrite(c *testContext) uint16 {
- c.t.Helper()
-
- v4Port := testWrite(c, unicastV4in6)
- v6Port := testWrite(c, unicastV6)
- if v4Port != v6Port {
- c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port)
- }
-
- return v4Port
-}
-
-func TestDualWriteUnbound(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- testDualWrite(c)
-}
-
-func TestDualWriteBoundToWildcard(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- p := testDualWrite(c)
- if p != stackPort {
- c.t.Fatalf("Bad port: got %v, want %v", p, stackPort)
- }
-}
-
-func TestDualWriteConnectedToV6(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Connect to v6 address.
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- testWrite(c, unicastV6)
-
- // Write to V4 mapped address.
- testFailingWrite(c, unicastV4in6, &tcpip.ErrNetworkUnreachable{})
- const want = 1
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want {
- c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want)
- }
-}
-
-func TestDualWriteConnectedToV4Mapped(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Connect to v4 mapped address.
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- testWrite(c, unicastV4in6)
-
- // Write to v6 address.
- testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{})
-}
-
-func TestV4WriteOnV6Only(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV6Only)
-
- // Write to V4 mapped address.
- testFailingWrite(c, unicastV4in6, &tcpip.ErrNoRoute{})
-}
-
-func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Bind to v4 mapped address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- // Write to v6 address.
- testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{})
-}
-
-func TestV6WriteOnConnected(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Connect to v6 address.
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %s", err)
- }
-
- testWriteWithoutDestination(c, unicastV6)
-}
-
-func TestV4WriteOnConnected(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Connect to v4 mapped address.
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %s", err)
- }
-
- testWriteWithoutDestination(c, unicastV4)
-}
-
-func TestWriteOnConnectedInvalidPort(t *testing.T) {
- protocols := map[string]tcpip.NetworkProtocolNumber{
- "ipv4": ipv4.ProtocolNumber,
- "ipv6": ipv6.ProtocolNumber,
- }
- for name, pn := range protocols {
- t.Run(name, func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(pn)
- if err := c.ep.Connect(tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}); err != nil {
- c.t.Fatalf("Connect failed: %s", err)
- }
- writeOpts := tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort},
- }
- var r bytes.Reader
- payload := newPayload()
- r.Reset(payload)
- n, err := c.ep.Write(&r, writeOpts)
- if err != nil {
- c.t.Fatalf("c.ep.Write(...) = %s, want nil", err)
- }
- if got, want := n, int64(len(payload)); got != want {
- c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want)
- }
-
- {
- err := c.ep.LastError()
- if _, ok := err.(*tcpip.ErrConnectionRefused); !ok {
- c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err)
- }
- }
- })
- }
-}
-
-// TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket
-// that is bound to a V4 multicast address.
-func TestWriteOnBoundToV4Multicast(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
- t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to V4 mcast address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastAddr, Port: stackPort}); err != nil {
- c.t.Fatal("Bind failed:", err)
- }
-
- testWrite(c, flow)
- })
- }
-}
-
-// TestWriteOnBoundToV4MappedMulticast checks that we can send packets out of a
-// socket that is bound to a V4-mapped multicast address.
-func TestWriteOnBoundToV4MappedMulticast(t *testing.T) {
- for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} {
- t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to V4Mapped mcast address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- testWrite(c, flow)
- })
- }
-}
-
-// TestWriteOnBoundToV6Multicast checks that we can send packets out of a
-// socket that is bound to a V6 multicast address.
-func TestWriteOnBoundToV6Multicast(t *testing.T) {
- for _, flow := range []testFlow{unicastV6, multicastV6} {
- t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to V6 mcast address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- testWrite(c, flow)
- })
- }
-}
-
-// TestWriteOnBoundToV6Multicast checks that we can send packets out of a
-// V6-only socket that is bound to a V6 multicast address.
-func TestWriteOnBoundToV6OnlyMulticast(t *testing.T) {
- for _, flow := range []testFlow{unicastV6Only, multicastV6Only} {
- t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to V6 mcast address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- testWrite(c, flow)
- })
- }
-}
-
-// TestWriteOnBoundToBroadcast checks that we can send packets out of a
-// socket that is bound to the broadcast address.
-func TestWriteOnBoundToBroadcast(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
- t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to V4 broadcast address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastAddr, Port: stackPort}); err != nil {
- c.t.Fatal("Bind failed:", err)
- }
-
- testWrite(c, flow)
- })
- }
-}
-
-// TestWriteOnBoundToV4MappedBroadcast checks that we can send packets out of a
-// socket that is bound to the V4-mapped broadcast address.
-func TestWriteOnBoundToV4MappedBroadcast(t *testing.T) {
- for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} {
- t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to V4Mapped mcast address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- testWrite(c, flow)
- })
- }
-}
-
-func TestReadIncrementsPacketsReceived(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- // Create IPv4 UDP endpoint
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- testRead(c, unicastV4)
-
- var want uint64 = 1
- if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want {
- c.t.Fatalf("Read did not increment PacketsReceived: got %v, want %v", got, want)
- }
-}
-
-func TestReadIPPacketInfo(t *testing.T) {
- tests := []struct {
- name string
- proto tcpip.NetworkProtocolNumber
- flow testFlow
- checker func(tcpip.NICID) checker.ControlMessagesChecker
- }{
- {
- name: "IPv4 unicast",
- proto: header.IPv4ProtocolNumber,
- flow: unicastV4,
- checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
- return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
- NIC: id,
- LocalAddr: stackAddr,
- DestinationAddr: stackAddr,
- })
- },
- },
- {
- name: "IPv4 multicast",
- proto: header.IPv4ProtocolNumber,
- flow: multicastV4,
- checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
- return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
- NIC: id,
- // TODO(gvisor.dev/issue/3556): Check for a unicast address.
- LocalAddr: multicastAddr,
- DestinationAddr: multicastAddr,
- })
- },
- },
- {
- name: "IPv4 broadcast",
- proto: header.IPv4ProtocolNumber,
- flow: broadcast,
- checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
- return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
- NIC: id,
- // TODO(gvisor.dev/issue/3556): Check for a unicast address.
- LocalAddr: broadcastAddr,
- DestinationAddr: broadcastAddr,
- })
- },
- },
- {
- name: "IPv6 unicast",
- proto: header.IPv6ProtocolNumber,
- flow: unicastV6,
- checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
- return checker.ReceiveIPv6PacketInfo(tcpip.IPv6PacketInfo{
- NIC: id,
- Addr: stackV6Addr,
- })
- },
- },
- {
- name: "IPv6 multicast",
- proto: header.IPv6ProtocolNumber,
- flow: multicastV6,
- checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
- return checker.ReceiveIPv6PacketInfo(tcpip.IPv6PacketInfo{
- NIC: id,
- Addr: multicastV6Addr,
- })
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(test.proto)
-
- bindAddr := tcpip.FullAddress{Port: stackPort}
- if err := c.ep.Bind(bindAddr); err != nil {
- t.Fatalf("Bind(%+v): %s", bindAddr, err)
- }
-
- if test.flow.isMulticast() {
- ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()}
- if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err)
- }
- }
-
- switch f := test.flow.netProto(); f {
- case header.IPv4ProtocolNumber:
- c.ep.SocketOptions().SetReceivePacketInfo(true)
- case header.IPv6ProtocolNumber:
- c.ep.SocketOptions().SetIPv6ReceivePacketInfo(true)
- default:
- t.Fatalf("unhandled protocol number = %d", f)
- }
-
- testRead(c, test.flow, test.checker(c.nicID))
-
- if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 {
- t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got)
- }
- })
- }
-}
-
-func TestReadRecvOriginalDstAddr(t *testing.T) {
- tests := []struct {
- name string
- proto tcpip.NetworkProtocolNumber
- flow testFlow
- expectedOriginalDstAddr tcpip.FullAddress
- }{
- {
- name: "IPv4 unicast",
- proto: header.IPv4ProtocolNumber,
- flow: unicastV4,
- expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackAddr, Port: stackPort},
- },
- {
- name: "IPv4 multicast",
- proto: header.IPv4ProtocolNumber,
- flow: multicastV4,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastAddr, Port: stackPort},
- },
- {
- name: "IPv4 broadcast",
- proto: header.IPv4ProtocolNumber,
- flow: broadcast,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: broadcastAddr, Port: stackPort},
- },
- {
- name: "IPv6 unicast",
- proto: header.IPv6ProtocolNumber,
- flow: unicastV6,
- expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackV6Addr, Port: stackPort},
- },
- {
- name: "IPv6 multicast",
- proto: header.IPv6ProtocolNumber,
- flow: multicastV6,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastV6Addr, Port: stackPort},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(test.proto)
-
- bindAddr := tcpip.FullAddress{Port: stackPort}
- if err := c.ep.Bind(bindAddr); err != nil {
- t.Fatalf("Bind(%#v): %s", bindAddr, err)
- }
-
- if test.flow.isMulticast() {
- ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()}
- if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err)
- }
- }
-
- c.ep.SocketOptions().SetReceiveOriginalDstAddress(true)
-
- testRead(c, test.flow, checker.ReceiveOriginalDstAddr(test.expectedOriginalDstAddr))
-
- if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 {
- t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got)
- }
- })
- }
-}
-
-func TestWriteIncrementsPacketsSent(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- testDualWrite(c)
-
- var want uint64 = 2
- if got := c.s.Stats().UDP.PacketsSent.Value(); got != want {
- c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want)
- }
-}
-
-func TestNoChecksum(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, unicastV6} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Disable the checksum generation.
- c.ep.SocketOptions().SetNoChecksum(true)
- // This option is effective on IPv4 only.
- testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4())))
-
- // Enable the checksum generation.
- c.ep.SocketOptions().SetNoChecksum(false)
- testWrite(c, flow, checker.UDP(checker.NoChecksum(false)))
- })
- }
-}
-
-var _ stack.NetworkInterface = (*testInterface)(nil)
-
-type testInterface struct {
- stack.NetworkInterface
-}
-
-func (*testInterface) ID() tcpip.NICID {
- return 0
-}
-
-func (*testInterface) Enabled() bool {
- return true
-}
-
-func TestTTL(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- const multicastTTL = 42
- if err := c.ep.SetSockOptInt(tcpip.MulticastTTLOption, multicastTTL); err != nil {
- c.t.Fatalf("SetSockOptInt failed: %s", err)
- }
-
- var wantTTL uint8
- if flow.isMulticast() {
- wantTTL = multicastTTL
- } else {
- var p stack.NetworkProtocolFactory
- var n tcpip.NetworkProtocolNumber
- if flow.isV4() {
- p = ipv4.NewProtocol
- n = ipv4.ProtocolNumber
- } else {
- p = ipv6.NewProtocol
- n = ipv6.ProtocolNumber
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{p},
- Clock: &faketime.NullClock{},
- })
- ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil)
- wantTTL = ep.DefaultTTL()
- ep.Close()
- }
-
- testWrite(c, flow, checker.TTL(wantTTL))
- })
- }
-}
-
-func TestSetTTL(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} {
- t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- if err := c.ep.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil {
- c.t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err)
- }
-
- testWrite(c, flow, checker.TTL(wantTTL))
- })
- }
- })
- }
-}
-
-var v4PacketFlows = [...]testFlow{unicastV4, multicastV4, broadcast, unicastV4in6, multicastV4in6, broadcastIn6}
-
-func TestSetTOS(t *testing.T) {
- for _, flow := range v4PacketFlows {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- const tos = testTOS
- v, err := c.ep.GetSockOptInt(tcpip.IPv4TOSOption)
- if err != nil {
- c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err)
- }
- // Test for expected default value.
- if v != 0 {
- c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0)
- }
-
- if err := c.ep.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil {
- c.t.Errorf("SetSockOptInt(IPv4TOSOption, 0x%x) failed: %s", tos, err)
- }
-
- v, err = c.ep.GetSockOptInt(tcpip.IPv4TOSOption)
- if err != nil {
- c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err)
- }
-
- if v != tos {
- c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, tos)
- }
-
- testWrite(c, flow, checker.TOS(tos, 0))
- })
- }
-}
-
-var v6PacketFlows = [...]testFlow{unicastV6, unicastV6Only, multicastV6}
-
-func TestSetTClass(t *testing.T) {
- for _, flow := range v6PacketFlows {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- const tClass = testTOS
- v, err := c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
- if err != nil {
- c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err)
- }
- // Test for expected default value.
- if v != 0 {
- c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, 0)
- }
-
- if err := c.ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, tClass); err != nil {
- c.t.Errorf("SetSockOptInt(IPv6TrafficClassOption, 0x%x) failed: %s", tClass, err)
- }
-
- v, err = c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
- if err != nil {
- c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err)
- }
-
- if v != tClass {
- c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, tClass)
- }
-
- // The header getter for TClass is called TOS, so use that checker.
- testWrite(c, flow, checker.TOS(tClass, 0))
- })
- }
-}
-
-func TestReceiveTosTClass(t *testing.T) {
- const RcvTOSOpt = "ReceiveTosOption"
- const RcvTClassOpt = "ReceiveTClassOption"
-
- testCases := []struct {
- name string
- tests []testFlow
- }{
- {
- name: RcvTOSOpt,
- tests: v4PacketFlows[:],
- },
- {
- name: RcvTClassOpt,
- tests: v6PacketFlows[:],
- },
- }
- for _, testCase := range testCases {
- for _, flow := range testCase.tests {
- t.Run(fmt.Sprintf("%s:flow:%s", testCase.name, flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
- name := testCase.name
-
- if flow.isMulticast() {
- netProto := flow.netProto()
- addr := flow.getMcastAddr()
- if err := c.s.JoinGroup(netProto, c.nicID, addr); err != nil {
- c.t.Fatalf("JoinGroup(%d, %d, %s): %s", netProto, c.nicID, addr, err)
- }
- }
-
- var optionGetter func() bool
- var optionSetter func(bool)
- switch name {
- case RcvTOSOpt:
- optionGetter = c.ep.SocketOptions().GetReceiveTOS
- optionSetter = c.ep.SocketOptions().SetReceiveTOS
- case RcvTClassOpt:
- optionGetter = c.ep.SocketOptions().GetReceiveTClass
- optionSetter = c.ep.SocketOptions().SetReceiveTClass
- default:
- t.Fatalf("unkown test variant: %s", name)
- }
-
- // Verify that setting and reading the option works.
- v := optionGetter()
- // Test for expected default value.
- if v != false {
- c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false)
- }
-
- const want = true
- optionSetter(want)
-
- got := optionGetter()
- if got != want {
- c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want)
- }
-
- // Verify that the correct received TOS or TClass is handed through as
- // ancillary data to the ControlMessages struct.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
- switch name {
- case RcvTClassOpt:
- testRead(c, flow, checker.ReceiveTClass(testTOS))
- case RcvTOSOpt:
- testRead(c, flow, checker.ReceiveTOS(testTOS))
- default:
- t.Fatalf("unknown test variant: %s", name)
- }
- })
- }
- }
-}
-
-func TestMulticastInterfaceOption(t *testing.T) {
- for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- for _, bindTyp := range []string{"bound", "unbound"} {
- t.Run(bindTyp, func(t *testing.T) {
- for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} {
- t.Run(optTyp, func(t *testing.T) {
- h := flow.header4Tuple(outgoing)
- mcastAddr := h.dstAddr.Addr
- localIfAddr := h.srcAddr.Addr
-
- var ifoptSet tcpip.MulticastInterfaceOption
- switch optTyp {
- case "use local-addr":
- ifoptSet.InterfaceAddr = localIfAddr
- case "use NICID":
- ifoptSet.NIC = 1
- case "use local-addr and NIC":
- ifoptSet.InterfaceAddr = localIfAddr
- ifoptSet.NIC = 1
- default:
- t.Fatal("unknown test variant")
- }
-
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(flow.sockProto())
-
- if bindTyp == "bound" {
- // Bind the socket by connecting to the multicast address.
- // This may have an influence on how the multicast interface
- // is set.
- addr := tcpip.FullAddress{
- Addr: flow.mapAddrIfApplicable(mcastAddr),
- Port: stackPort,
- }
- if err := c.ep.Connect(addr); err != nil {
- c.t.Fatalf("Connect failed: %s", err)
- }
- }
-
- if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err)
- }
-
- // Verify multicast interface addr and NIC were set correctly.
- // Note that NIC must be 1 since this is our outgoing interface.
- var ifoptGot tcpip.MulticastInterfaceOption
- if err := c.ep.GetSockOpt(&ifoptGot); err != nil {
- c.t.Fatalf("GetSockOpt(&%T): %s", ifoptGot, err)
- } else if ifoptWant := (tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}); ifoptGot != ifoptWant {
- c.t.Errorf("got multicast interface option = %#v, want = %#v", ifoptGot, ifoptWant)
- }
- })
- }
- })
- }
- })
- }
-}
-
-// TestV4UnknownDestination verifies that we generate an ICMPv4 Destination
-// Unreachable message when a udp datagram is received on ports for which there
-// is no bound udp socket.
-func TestV4UnknownDestination(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- testCases := []struct {
- flow testFlow
- icmpRequired bool
- // largePayload if true, will result in a payload large enough
- // so that the final generated IPv4 packet is larger than
- // header.IPv4MinimumProcessableDatagramSize.
- largePayload bool
- // badChecksum if true, will set an invalid checksum in the
- // header.
- badChecksum bool
- }{
- {unicastV4, true, false, false},
- {unicastV4, true, true, false},
- {unicastV4, false, false, true},
- {unicastV4, false, true, true},
- {multicastV4, false, false, false},
- {multicastV4, false, true, false},
- {broadcast, false, false, false},
- {broadcast, false, true, false},
- }
- checksumErrors := uint64(0)
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) {
- payload := newPayload()
- if tc.largePayload {
- payload = newMinPayload(576)
- }
- c.injectPacket(tc.flow, payload, tc.badChecksum)
- if tc.badChecksum {
- checksumErrors++
- if got, want := c.s.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want {
- t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- }
- if !tc.icmpRequired {
- if p, ok := c.linkEP.Read(); ok {
- t.Fatalf("unexpected packet received: %+v", p)
- }
- return
- }
-
- // ICMP required.
- p, ok := c.linkEP.Read()
- if !ok {
- t.Fatalf("packet wasn't written out")
- return
- }
-
- vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
- pkt := vv.ToView()
- if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
- t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
- }
-
- hdr := header.IPv4(pkt)
- checker.IPv4(t, hdr, checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4DstUnreachable),
- checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
-
- // We need to compare the included data part of the UDP packet that is in
- // the ICMP packet with the matching original data.
- icmpPkt := header.ICMPv4(hdr.Payload())
- payloadIPHeader := header.IPv4(icmpPkt.Payload())
- incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize
- wantLen := len(payload)
- if tc.largePayload {
- // To work out the data size we need to simulate what the sender would
- // have done. The wanted size is the total available minus the sum of
- // the headers in the UDP AND ICMP packets, given that we know the test
- // had only a minimal IP header but the ICMP sender will have allowed
- // for a maximally sized packet header.
- wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength
- }
-
- // In the case of large payloads the IP packet may be truncated. Update
- // the length field before retrieving the udp datagram payload.
- // Add back the two headers within the payload.
- payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength))
-
- origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram.Payload()), wantLen; got != want {
- t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
- }
- if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
- t.Fatalf("unexpected payload got: %d, want: %d", got, want)
- }
- })
- }
-}
-
-// TestV6UnknownDestination verifies that we generate an ICMPv6 Destination
-// Unreachable message when a udp datagram is received on ports for which there
-// is no bound udp socket.
-func TestV6UnknownDestination(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- testCases := []struct {
- flow testFlow
- icmpRequired bool
- // largePayload if true will result in a payload large enough to
- // create an IPv6 packet > header.IPv6MinimumMTU bytes.
- largePayload bool
- // badChecksum if true, will set an invalid checksum in the
- // header.
- badChecksum bool
- }{
- {unicastV6, true, false, false},
- {unicastV6, true, true, false},
- {unicastV6, false, false, true},
- {unicastV6, false, true, true},
- {multicastV6, false, false, false},
- {multicastV6, false, true, false},
- }
- checksumErrors := uint64(0)
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) {
- payload := newPayload()
- if tc.largePayload {
- payload = newMinPayload(1280)
- }
- c.injectPacket(tc.flow, payload, tc.badChecksum)
- if tc.badChecksum {
- checksumErrors++
- if got, want := c.s.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want {
- t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- }
- if !tc.icmpRequired {
- if p, ok := c.linkEP.Read(); ok {
- t.Fatalf("unexpected packet received: %+v", p)
- }
- return
- }
-
- // ICMP required.
- p, ok := c.linkEP.Read()
- if !ok {
- t.Fatalf("packet wasn't written out")
- return
- }
-
- vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
- pkt := vv.ToView()
- if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
- t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
- }
-
- hdr := header.IPv6(pkt)
- checker.IPv6(t, hdr, checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6DstUnreachable),
- checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
-
- icmpPkt := header.ICMPv6(hdr.Payload())
- payloadIPHeader := header.IPv6(icmpPkt.Payload())
- wantLen := len(payload)
- if tc.largePayload {
- wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
- }
- // In case of large payloads the IP packet may be truncated. Update
- // the length field before retrieving the udp datagram payload.
- payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
-
- origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram.Payload()), wantLen; got != want {
- t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
- }
- if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
- t.Fatalf("unexpected payload got: %v, want: %v", got, want)
- }
- })
- }
-}
-
-// TestIncrementMalformedPacketsReceived verifies if the malformed received
-// global and endpoint stats are incremented.
-func TestIncrementMalformedPacketsReceived(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- payload := newPayload()
- h := unicastV6.header4Tuple(incoming)
- buf := c.buildV6Packet(payload, &h)
-
- // Invalidate the UDP header length field.
- u := header.UDP(buf[header.IPv6MinimumSize:])
- u.SetLength(u.Length() + 1)
-
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- const want = 1
- if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %d, want = %d", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
- }
-}
-
-// TestShortHeader verifies that when a packet with a too-short UDP header is
-// received, the malformed received global stat gets incremented.
-func TestShortHeader(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- h := unicastV6.header4Tuple(incoming)
-
- // Allocate a buffer for an IPv6 and too-short UDP header.
- const udpSize = header.UDPMinimumSize - 1
- buf := buffer.NewView(header.IPv6MinimumSize + udpSize)
- // Initialize the IP header.
- ip := header.IPv6(buf)
- ip.Encode(&header.IPv6Fields{
- TrafficClass: testTOS,
- PayloadLength: uint16(udpSize),
- TransportProtocol: udp.ProtocolNumber,
- HopLimit: 65,
- SrcAddr: h.srcAddr.Addr,
- DstAddr: h.dstAddr.Addr,
- })
-
- // Initialize the UDP header.
- udpHdr := header.UDP(buffer.NewView(header.UDPMinimumSize))
- udpHdr.Encode(&header.UDPFields{
- SrcPort: h.srcAddr.Port,
- DstPort: h.dstAddr.Port,
- Length: header.UDPMinimumSize,
- })
- // Calculate the UDP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(udpHdr)))
- udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
- // Copy all but the last byte of the UDP header into the packet.
- copy(buf[header.IPv6MinimumSize:], udpHdr)
-
- // Inject packet.
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- if got, want := c.s.Stats().NICs.MalformedL4RcvdPackets.Value(), uint64(1); got != want {
- t.Errorf("got c.s.Stats().NIC.MalformedL4RcvdPackets.Value() = %d, want = %d", got, want)
- }
-}
-
-// TestBadChecksumErrors verifies if a checksum error is detected,
-// global and endpoint stats are incremented.
-func TestBadChecksumErrors(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, unicastV6} {
- t.Run(flow.String(), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(flow.sockProto())
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- payload := newPayload()
- c.injectPacket(flow, payload, true /* badChecksum */)
-
- const want = 1
- if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
- t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
- }
- })
- }
-}
-
-// TestPayloadModifiedV4 verifies if a checksum error is detected,
-// global and endpoint stats are incremented.
-func TestPayloadModifiedV4(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv4.ProtocolNumber)
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- payload := newPayload()
- h := unicastV4.header4Tuple(incoming)
- buf := c.buildV4Packet(payload, &h)
- // Modify the payload so that the checksum value in the UDP header will be
- // incorrect.
- buf[len(buf)-1]++
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- const want = 1
- if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
- t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
- }
-}
-
-// TestPayloadModifiedV6 verifies if a checksum error is detected,
-// global and endpoint stats are incremented.
-func TestPayloadModifiedV6(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- payload := newPayload()
- h := unicastV6.header4Tuple(incoming)
- buf := c.buildV6Packet(payload, &h)
- // Modify the payload so that the checksum value in the UDP header will be
- // incorrect.
- buf[len(buf)-1]++
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- const want = 1
- if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
- t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
- }
-}
-
-// TestChecksumZeroV4 verifies if the checksum value is zero, global and
-// endpoint states are *not* incremented (UDP checksum is optional on IPv4).
-func TestChecksumZeroV4(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv4.ProtocolNumber)
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- payload := newPayload()
- h := unicastV4.header4Tuple(incoming)
- buf := c.buildV4Packet(payload, &h)
- // Set the checksum field in the UDP header to zero.
- u := header.UDP(buf[header.IPv4MinimumSize:])
- u.SetChecksum(0)
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- const want = 0
- if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
- t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
- }
-}
-
-// TestChecksumZeroV6 verifies if the checksum value is zero, global and
-// endpoint states are incremented (UDP checksum is *not* optional on IPv6).
-func TestChecksumZeroV6(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- payload := newPayload()
- h := unicastV6.header4Tuple(incoming)
- buf := c.buildV6Packet(payload, &h)
- // Set the checksum field in the UDP header to zero.
- u := header.UDP(buf[header.IPv6MinimumSize:])
- u.SetChecksum(0)
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- const want = 1
- if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
- t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
- }
-}
-
-// TestShutdownRead verifies endpoint read shutdown and error
-// stats increment on packet receive.
-func TestShutdownRead(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %s", err)
- }
-
- if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- testFailingRead(c, unicastV6, true /* expectReadError */)
-
- var want uint64 = 1
- if got := c.s.Stats().UDP.ReceiveBufferErrors.Value(); got != want {
- t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want)
- }
-}
-
-// TestShutdownWrite verifies endpoint write shutdown and error
-// stats increment on packet write.
-func TestShutdownWrite(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %s", err)
- }
-
- if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- testFailingWrite(c, unicastV6, &tcpip.ErrClosedForSend{})
-}
-
-func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) {
- got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
- switch err.(type) {
- case nil:
- want.PacketsSent.IncrementBy(incr)
- case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue:
- want.WriteErrors.InvalidArgs.IncrementBy(incr)
- case *tcpip.ErrClosedForSend:
- want.WriteErrors.WriteClosed.IncrementBy(incr)
- case *tcpip.ErrInvalidEndpointState:
- want.WriteErrors.InvalidEndpointState.IncrementBy(incr)
- case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable:
- want.SendErrors.NoRoute.IncrementBy(incr)
- default:
- want.SendErrors.SendToNetworkFailed.IncrementBy(incr)
- }
- if got != want {
- c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
- }
-}
-
-func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) {
- got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
- switch err.(type) {
- case nil, *tcpip.ErrWouldBlock:
- case *tcpip.ErrClosedForReceive:
- want.ReadErrors.ReadClosed.IncrementBy(incr)
- default:
- c.t.Errorf("Endpoint error missing stats update err %v", err)
- }
- if got != want {
- c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
- }
-}
-
-func TestOutgoingSubnetBroadcast(t *testing.T) {
- const nicID1 = 1
-
- ipv4Addr := tcpip.AddressWithPrefix{
- Address: "\xc0\xa8\x01\x3a",
- PrefixLen: 24,
- }
- ipv4Subnet := ipv4Addr.Subnet()
- ipv4SubnetBcast := ipv4Subnet.Broadcast()
- ipv4Gateway := testutil.MustParse4("192.168.1.1")
- ipv4AddrPrefix31 := tcpip.AddressWithPrefix{
- Address: "\xc0\xa8\x01\x3a",
- PrefixLen: 31,
- }
- ipv4Subnet31 := ipv4AddrPrefix31.Subnet()
- ipv4Subnet31Bcast := ipv4Subnet31.Broadcast()
- ipv4AddrPrefix32 := tcpip.AddressWithPrefix{
- Address: "\xc0\xa8\x01\x3a",
- PrefixLen: 32,
- }
- ipv4Subnet32 := ipv4AddrPrefix32.Subnet()
- ipv4Subnet32Bcast := ipv4Subnet32.Broadcast()
- ipv6Addr := tcpip.AddressWithPrefix{
- Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- PrefixLen: 64,
- }
- ipv6Subnet := ipv6Addr.Subnet()
- ipv6SubnetBcast := ipv6Subnet.Broadcast()
- remNetAddr := tcpip.AddressWithPrefix{
- Address: "\x64\x0a\x7b\x18",
- PrefixLen: 24,
- }
- remNetSubnet := remNetAddr.Subnet()
- remNetSubnetBcast := remNetSubnet.Broadcast()
-
- tests := []struct {
- name string
- nicAddr tcpip.ProtocolAddress
- routes []tcpip.Route
- remoteAddr tcpip.Address
- requiresBroadcastOpt bool
- }{
- {
- name: "IPv4 Broadcast to local subnet",
- nicAddr: tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: ipv4Addr,
- },
- routes: []tcpip.Route{
- {
- Destination: ipv4Subnet,
- NIC: nicID1,
- },
- },
- remoteAddr: ipv4SubnetBcast,
- requiresBroadcastOpt: true,
- },
- {
- name: "IPv4 Broadcast to local /31 subnet",
- nicAddr: tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: ipv4AddrPrefix31,
- },
- routes: []tcpip.Route{
- {
- Destination: ipv4Subnet31,
- NIC: nicID1,
- },
- },
- remoteAddr: ipv4Subnet31Bcast,
- requiresBroadcastOpt: false,
- },
- {
- name: "IPv4 Broadcast to local /32 subnet",
- nicAddr: tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: ipv4AddrPrefix32,
- },
- routes: []tcpip.Route{
- {
- Destination: ipv4Subnet32,
- NIC: nicID1,
- },
- },
- remoteAddr: ipv4Subnet32Bcast,
- requiresBroadcastOpt: false,
- },
- // IPv6 has no notion of a broadcast.
- {
- name: "IPv6 'Broadcast' to local subnet",
- nicAddr: tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: ipv6Addr,
- },
- routes: []tcpip.Route{
- {
- Destination: ipv6Subnet,
- NIC: nicID1,
- },
- },
- remoteAddr: ipv6SubnetBcast,
- requiresBroadcastOpt: false,
- },
- {
- name: "IPv4 Broadcast to remote subnet",
- nicAddr: tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: ipv4Addr,
- },
- routes: []tcpip.Route{
- {
- Destination: remNetSubnet,
- Gateway: ipv4Gateway,
- NIC: nicID1,
- },
- },
- remoteAddr: remNetSubnetBcast,
- // TODO(gvisor.dev/issue/3938): Once we support marking a route as
- // broadcast, this test should require the broadcast option to be set.
- requiresBroadcastOpt: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- Clock: &faketime.NullClock{},
- })
- e := channel.New(0, defaultMTU, "")
- if err := s.CreateNIC(nicID1, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
- }
- if err := s.AddProtocolAddress(nicID1, test.nicAddr, stack.AddressProperties{}); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, test.nicAddr, err)
- }
-
- s.SetRouteTable(test.routes)
-
- var netProto tcpip.NetworkProtocolNumber
- switch l := len(test.remoteAddr); l {
- case header.IPv4AddressSize:
- netProto = header.IPv4ProtocolNumber
- case header.IPv6AddressSize:
- netProto = header.IPv6ProtocolNumber
- default:
- t.Fatalf("got unexpected address length = %d bytes", l)
- }
-
- wq := waiter.Queue{}
- ep, err := s.NewEndpoint(udp.ProtocolNumber, netProto, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netProto, err)
- }
- defer ep.Close()
-
- var r bytes.Reader
- data := []byte{1, 2, 3, 4}
- to := tcpip.FullAddress{
- Addr: test.remoteAddr,
- Port: 80,
- }
- opts := tcpip.WriteOptions{To: &to}
- expectedErrWithoutBcastOpt := func(err tcpip.Error) tcpip.Error {
- if _, ok := err.(*tcpip.ErrBroadcastDisabled); ok {
- return nil
- }
- return &tcpip.ErrBroadcastDisabled{}
- }
- if !test.requiresBroadcastOpt {
- expectedErrWithoutBcastOpt = nil
- }
-
- r.Reset(data)
- {
- n, err := ep.Write(&r, opts)
- if expectedErrWithoutBcastOpt != nil {
- if want := expectedErrWithoutBcastOpt(err); want != nil {
- t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want)
- }
- } else if err != nil {
- t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err)
- }
- }
-
- ep.SocketOptions().SetBroadcast(true)
-
- r.Reset(data)
- if n, err := ep.Write(&r, opts); err != nil {
- t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err)
- }
-
- ep.SocketOptions().SetBroadcast(false)
-
- r.Reset(data)
- {
- n, err := ep.Write(&r, opts)
- if expectedErrWithoutBcastOpt != nil {
- if want := expectedErrWithoutBcastOpt(err); want != nil {
- t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want)
- }
- } else if err != nil {
- t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err)
- }
- }
- })
- }
-}