summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/BUILD34
-rw-r--r--pkg/tcpip/adapters/gonet/BUILD37
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/adapters/gonet/gonet.go0
-rwxr-xr-xpkg/tcpip/adapters/gonet/gonet_state_autogen.go3
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go683
-rw-r--r--pkg/tcpip/buffer/BUILD19
-rwxr-xr-xpkg/tcpip/buffer/buffer_state_autogen.go24
-rw-r--r--pkg/tcpip/buffer/view_test.go235
-rw-r--r--pkg/tcpip/checker/BUILD16
-rw-r--r--pkg/tcpip/checker/checker.go842
-rw-r--r--pkg/tcpip/hash/jenkins/BUILD18
-rwxr-xr-xpkg/tcpip/hash/jenkins/jenkins_state_autogen.go3
-rw-r--r--pkg/tcpip/hash/jenkins/jenkins_test.go176
-rw-r--r--pkg/tcpip/header/BUILD65
-rw-r--r--pkg/tcpip/header/checksum_test.go171
-rw-r--r--pkg/tcpip/header/eth_test.go102
-rwxr-xr-xpkg/tcpip/header/header_state_autogen.go42
-rw-r--r--pkg/tcpip/header/ipv6_test.go331
-rw-r--r--pkg/tcpip/header/ipversion_test.go67
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/header/ndp_neighbor_advert.go0
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/header/ndp_neighbor_solicit.go0
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/header/ndp_options.go0
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/header/ndp_router_advert.go0
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/header/ndp_router_solicit.go0
-rw-r--r--pkg/tcpip/header/ndp_test.go937
-rw-r--r--pkg/tcpip/header/tcp_test.go148
-rw-r--r--pkg/tcpip/iptables/BUILD18
-rwxr-xr-xpkg/tcpip/iptables/iptables_state_autogen.go3
-rw-r--r--pkg/tcpip/link/channel/BUILD14
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/link/channel/channel.go0
-rwxr-xr-xpkg/tcpip/link/channel/channel_state_autogen.go3
-rw-r--r--pkg/tcpip/link/fdbased/BUILD39
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go468
-rwxr-xr-xpkg/tcpip/link/fdbased/fdbased_state_autogen.go10
-rw-r--r--pkg/tcpip/link/loopback/BUILD15
-rwxr-xr-xpkg/tcpip/link/loopback/loopback_state_autogen.go3
-rw-r--r--pkg/tcpip/link/muxed/BUILD28
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/link/muxed/injectable.go0
-rw-r--r--pkg/tcpip/link/muxed/injectable_test.go98
-rwxr-xr-xpkg/tcpip/link/muxed/muxed_state_autogen.go3
-rw-r--r--pkg/tcpip/link/rawfile/BUILD20
-rwxr-xr-xpkg/tcpip/link/rawfile/rawfile_state_autogen.go10
-rw-r--r--pkg/tcpip/link/sharedmem/BUILD41
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/BUILD23
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/link/sharedmem/pipe/pipe.go0
-rwxr-xr-xpkg/tcpip/link/sharedmem/pipe/pipe_state_autogen.go3
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/pipe_test.go518
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go0
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/link/sharedmem/pipe/rx.go0
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/link/sharedmem/pipe/tx.go0
-rw-r--r--pkg/tcpip/link/sharedmem/queue/BUILD27
-rwxr-xr-xpkg/tcpip/link/sharedmem/queue/queue_state_autogen.go3
-rw-r--r--pkg/tcpip/link/sharedmem/queue/queue_test.go517
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/link/sharedmem/queue/rx.go0
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/link/sharedmem/queue/tx.go0
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/link/sharedmem/rx.go0
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/link/sharedmem/sharedmem.go0
-rwxr-xr-xpkg/tcpip/link/sharedmem/sharedmem_state_autogen.go6
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go812
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/link/sharedmem/sharedmem_unsafe.go0
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/link/sharedmem/tx.go0
-rw-r--r--pkg/tcpip/link/sniffer/BUILD19
-rwxr-xr-xpkg/tcpip/link/sniffer/sniffer_state_autogen.go3
-rw-r--r--pkg/tcpip/link/tun/BUILD9
-rwxr-xr-xpkg/tcpip/link/tun/tun_state_autogen.go5
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/link/tun/tun_unsafe.go0
-rw-r--r--pkg/tcpip/link/waitable/BUILD30
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/link/waitable/waitable.go0
-rwxr-xr-xpkg/tcpip/link/waitable/waitable_state_autogen.go3
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go173
-rw-r--r--pkg/tcpip/network/BUILD22
-rw-r--r--pkg/tcpip/network/arp/BUILD32
-rwxr-xr-xpkg/tcpip/network/arp/arp_state_autogen.go3
-rw-r--r--pkg/tcpip/network/arp/arp_test.go145
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD45
-rw-r--r--pkg/tcpip/network/fragmentation/frag_heap_test.go126
-rwxr-xr-xpkg/tcpip/network/fragmentation/fragmentation_state_autogen.go38
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go165
-rwxr-xr-xpkg/tcpip/network/fragmentation/reassembler_list.go173
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler_test.go105
-rw-r--r--pkg/tcpip/network/hash/BUILD13
-rwxr-xr-xpkg/tcpip/network/hash/hash_state_autogen.go3
-rw-r--r--pkg/tcpip/network/ip_test.go655
-rw-r--r--pkg/tcpip/network/ipv4/BUILD39
-rwxr-xr-xpkg/tcpip/network/ipv4/ipv4_state_autogen.go3
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go475
-rw-r--r--pkg/tcpip/network/ipv6/BUILD40
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go958
-rwxr-xr-xpkg/tcpip/network/ipv6/ipv6_state_autogen.go3
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go270
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go613
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/packet_buffer.go0
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/packet_buffer_state.go0
-rw-r--r--pkg/tcpip/ports/BUILD22
-rwxr-xr-xpkg/tcpip/ports/ports_state_autogen.go24
-rw-r--r--pkg/tcpip/ports/ports_test.go402
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/BUILD22
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go225
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/BUILD21
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go203
-rw-r--r--pkg/tcpip/seqnum/BUILD9
-rwxr-xr-xpkg/tcpip/seqnum/seqnum_state_autogen.go3
-rw-r--r--pkg/tcpip/stack/BUILD92
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go277
-rwxr-xr-xpkg/tcpip/stack/linkaddrentry_list.go173
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/stack/ndp.go0
-rw-r--r--pkg/tcpip/stack/ndp_test.go3612
-rw-r--r--pkg/tcpip/stack/nic_test.go62
-rwxr-xr-xpkg/tcpip/stack/stack_state_autogen.go131
-rw-r--r--pkg/tcpip/stack/stack_test.go3089
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go348
-rw-r--r--pkg/tcpip/stack/transport_test.go637
-rwxr-xr-xpkg/tcpip/tcpip_state_autogen.go108
-rw-r--r--pkg/tcpip/tcpip_test.go228
-rw-r--r--pkg/tcpip/time.s15
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/timer.go0
-rw-r--r--pkg/tcpip/timer_test.go236
-rw-r--r--pkg/tcpip/transport/icmp/BUILD40
-rwxr-xr-xpkg/tcpip/transport/icmp/icmp_packet_list.go173
-rwxr-xr-xpkg/tcpip/transport/icmp/icmp_state_autogen.go92
-rw-r--r--pkg/tcpip/transport/packet/BUILD38
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/transport/packet/endpoint.go0
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/transport/packet/endpoint_state.go0
-rwxr-xr-xpkg/tcpip/transport/packet/packet_list.go173
-rwxr-xr-xpkg/tcpip/transport/packet/packet_state_autogen.go88
-rw-r--r--pkg/tcpip/transport/raw/BUILD40
-rwxr-xr-xpkg/tcpip/transport/raw/raw_packet_list.go173
-rwxr-xr-xpkg/tcpip/transport/raw/raw_state_autogen.go90
-rw-r--r--pkg/tcpip/transport/tcp/BUILD109
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/transport/tcp/dispatcher.go0
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go652
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/transport/tcp/rcv_state.go0
-rw-r--r--pkg/tcpip/transport/tcp/sack_scoreboard_test.go249
-rwxr-xr-xpkg/tcpip/transport/tcp/tcp_endpoint_list.go173
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go527
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go572
-rwxr-xr-xpkg/tcpip/transport/tcp/tcp_segment_list.go173
-rwxr-xr-xpkg/tcpip/transport/tcp/tcp_state_autogen.go531
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go6969
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go295
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/BUILD26
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go1102
-rw-r--r--pkg/tcpip/transport/tcpconntrack/BUILD23
-rwxr-xr-x[-rw-r--r--]pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go0
-rw-r--r--pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go511
-rwxr-xr-xpkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go3
-rw-r--r--pkg/tcpip/transport/udp/BUILD61
-rwxr-xr-xpkg/tcpip/transport/udp/udp_packet_list.go173
-rwxr-xr-xpkg/tcpip/transport/udp/udp_state_autogen.go144
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go1802
150 files changed, 2775 insertions, 32969 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
deleted file mode 100644
index 26f7ba86b..000000000
--- a/pkg/tcpip/BUILD
+++ /dev/null
@@ -1,34 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "tcpip",
- srcs = [
- "packet_buffer.go",
- "packet_buffer_state.go",
- "tcpip.go",
- "time_unsafe.go",
- "timer.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/sync",
- "//pkg/tcpip/buffer",
- "//pkg/waiter",
- ],
-)
-
-go_test(
- name = "tcpip_test",
- size = "small",
- srcs = ["tcpip_test.go"],
- library = ":tcpip",
-)
-
-go_test(
- name = "tcpip_x_test",
- size = "small",
- srcs = ["timer_test.go"],
- deps = [":tcpip"],
-)
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD
deleted file mode 100644
index a984f1712..000000000
--- a/pkg/tcpip/adapters/gonet/BUILD
+++ /dev/null
@@ -1,37 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "gonet",
- srcs = ["gonet.go"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/tcp",
- "//pkg/tcpip/transport/udp",
- "//pkg/waiter",
- ],
-)
-
-go_test(
- name = "gonet_test",
- size = "small",
- srcs = ["gonet_test.go"],
- library = ":gonet",
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/loopback",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/tcp",
- "//pkg/tcpip/transport/udp",
- "//pkg/waiter",
- "@org_golang_x_net//nettest:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 6e0db2741..6e0db2741 100644..100755
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
diff --git a/pkg/tcpip/adapters/gonet/gonet_state_autogen.go b/pkg/tcpip/adapters/gonet/gonet_state_autogen.go
new file mode 100755
index 000000000..7a5c5419e
--- /dev/null
+++ b/pkg/tcpip/adapters/gonet/gonet_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package gonet
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
deleted file mode 100644
index ea0a0409a..000000000
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ /dev/null
@@ -1,683 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package gonet
-
-import (
- "context"
- "fmt"
- "io"
- "net"
- "reflect"
- "strings"
- "testing"
- "time"
-
- "golang.org/x/net/nettest"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- NICID = 1
-)
-
-func TestTimeouts(t *testing.T) {
- nc := NewTCPConn(nil, nil)
- dlfs := []struct {
- name string
- f func(time.Time) error
- }{
- {"SetDeadline", nc.SetDeadline},
- {"SetReadDeadline", nc.SetReadDeadline},
- {"SetWriteDeadline", nc.SetWriteDeadline},
- }
-
- for _, dlf := range dlfs {
- if err := dlf.f(time.Time{}); err != nil {
- t.Errorf("got %s(time.Time{}) = %v, want = %v", dlf.name, err, nil)
- }
- }
-}
-
-func newLoopbackStack() (*stack.Stack, *tcpip.Error) {
- // Create the stack and add a NIC.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()},
- })
-
- if err := s.CreateNIC(NICID, loopback.New()); err != nil {
- return nil, err
- }
-
- // Add default route.
- s.SetRouteTable([]tcpip.Route{
- // IPv4
- {
- Destination: header.IPv4EmptySubnet,
- NIC: NICID,
- },
-
- // IPv6
- {
- Destination: header.IPv6EmptySubnet,
- NIC: NICID,
- },
- })
-
- return s, nil
-}
-
-type testConnection struct {
- wq *waiter.Queue
- e *waiter.Entry
- ch chan struct{}
- ep tcpip.Endpoint
-}
-
-func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Error) {
- wq := &waiter.Queue{}
- ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
-
- entry, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&entry, waiter.EventOut)
-
- err = ep.Connect(addr)
- if err == tcpip.ErrConnectStarted {
- <-ch
- err = ep.GetSockOpt(tcpip.ErrorOption{})
- }
- if err != nil {
- return nil, err
- }
-
- wq.EventUnregister(&entry)
- wq.EventRegister(&entry, waiter.EventIn)
-
- return &testConnection{wq, &entry, ch, ep}, nil
-}
-
-func (c *testConnection) close() {
- c.wq.EventUnregister(c.e)
- c.ep.Close()
-}
-
-// TestCloseReader tests that Conn.Close() causes Conn.Read() to unblock.
-func TestCloseReader(t *testing.T) {
- s, err := newLoopbackStack()
- if err != nil {
- t.Fatalf("newLoopbackStack() = %v", err)
- }
-
- addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
-
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
-
- l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
- if e != nil {
- t.Fatalf("NewListener() = %v", e)
- }
- done := make(chan struct{})
- go func() {
- defer close(done)
- c, err := l.Accept()
- if err != nil {
- t.Fatalf("l.Accept() = %v", err)
- }
-
- // Give c.Read() a chance to block before closing the connection.
- time.AfterFunc(time.Millisecond*50, func() {
- c.Close()
- })
-
- buf := make([]byte, 256)
- n, err := c.Read(buf)
- if n != 0 || err != io.EOF {
- t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, err)
- }
- }()
- sender, err := connect(s, addr)
- if err != nil {
- t.Fatalf("connect() = %v", err)
- }
-
- select {
- case <-done:
- case <-time.After(5 * time.Second):
- t.Errorf("c.Read() didn't unblock")
- }
- sender.close()
-}
-
-// TestCloseReaderWithForwarder tests that TCPConn.Close wakes TCPConn.Read when
-// using tcp.Forwarder.
-func TestCloseReaderWithForwarder(t *testing.T) {
- s, err := newLoopbackStack()
- if err != nil {
- t.Fatalf("newLoopbackStack() = %v", err)
- }
-
- addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
-
- done := make(chan struct{})
-
- fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
- defer close(done)
-
- var wq waiter.Queue
- ep, err := r.CreateEndpoint(&wq)
- if err != nil {
- t.Fatalf("r.CreateEndpoint() = %v", err)
- }
- defer ep.Close()
- r.Complete(false)
-
- c := NewTCPConn(&wq, ep)
-
- // Give c.Read() a chance to block before closing the connection.
- time.AfterFunc(time.Millisecond*50, func() {
- c.Close()
- })
-
- buf := make([]byte, 256)
- n, e := c.Read(buf)
- if n != 0 || e != io.EOF {
- t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, e)
- }
- })
- s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
-
- sender, err := connect(s, addr)
- if err != nil {
- t.Fatalf("connect() = %v", err)
- }
-
- select {
- case <-done:
- case <-time.After(5 * time.Second):
- t.Errorf("c.Read() didn't unblock")
- }
- sender.close()
-}
-
-func TestCloseRead(t *testing.T) {
- s, terr := newLoopbackStack()
- if terr != nil {
- t.Fatalf("newLoopbackStack() = %v", terr)
- }
-
- addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
-
- fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
- var wq waiter.Queue
- ep, err := r.CreateEndpoint(&wq)
- if err != nil {
- t.Fatalf("r.CreateEndpoint() = %v", err)
- }
- defer ep.Close()
- r.Complete(false)
-
- c := NewTCPConn(&wq, ep)
-
- buf := make([]byte, 256)
- n, e := c.Read(buf)
- if e != nil || string(buf[:n]) != "abc123" {
- t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, e)
- }
-
- if n, e = c.Write([]byte("abc123")); e != nil {
- t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e)
- }
- })
-
- s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
-
- tc, terr := connect(s, addr)
- if terr != nil {
- t.Fatalf("connect() = %v", terr)
- }
- c := NewTCPConn(tc.wq, tc.ep)
-
- if err := c.CloseRead(); err != nil {
- t.Errorf("c.CloseRead() = %v", err)
- }
-
- buf := make([]byte, 256)
- if n, err := c.Read(buf); err != io.EOF {
- t.Errorf("c.Read() = (%d, %v), want (0, io.EOF)", n, err)
- }
-
- if n, err := c.Write([]byte("abc123")); n != 6 || err != nil {
- t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, err)
- }
-}
-
-func TestCloseWrite(t *testing.T) {
- s, terr := newLoopbackStack()
- if terr != nil {
- t.Fatalf("newLoopbackStack() = %v", terr)
- }
-
- addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
-
- fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
- var wq waiter.Queue
- ep, err := r.CreateEndpoint(&wq)
- if err != nil {
- t.Fatalf("r.CreateEndpoint() = %v", err)
- }
- defer ep.Close()
- r.Complete(false)
-
- c := NewTCPConn(&wq, ep)
-
- n, e := c.Read(make([]byte, 256))
- if n != 0 || e != io.EOF {
- t.Errorf("c.Read() = (%d, %v), want (0, io.EOF)", n, e)
- }
-
- if n, e = c.Write([]byte("abc123")); n != 6 || e != nil {
- t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e)
- }
- })
-
- s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
-
- tc, terr := connect(s, addr)
- if terr != nil {
- t.Fatalf("connect() = %v", terr)
- }
- c := NewTCPConn(tc.wq, tc.ep)
-
- if err := c.CloseWrite(); err != nil {
- t.Errorf("c.CloseWrite() = %v", err)
- }
-
- buf := make([]byte, 256)
- n, err := c.Read(buf)
- if err != nil || string(buf[:n]) != "abc123" {
- t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, err)
- }
-
- n, err = c.Write([]byte("abc123"))
- got, ok := err.(*net.OpError)
- want := "endpoint is closed for send"
- if n != 0 || !ok || got.Op != "write" || got.Err == nil || !strings.HasSuffix(got.Err.Error(), want) {
- t.Errorf("c.Write() = (%d, %v), want (0, OpError(Op: write, Err: %s))", n, err, want)
- }
-}
-
-func TestUDPForwarder(t *testing.T) {
- s, terr := newLoopbackStack()
- if terr != nil {
- t.Fatalf("newLoopbackStack() = %v", terr)
- }
-
- ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
- addr1 := tcpip.FullAddress{NICID, ip1, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip1)
- ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4())
- addr2 := tcpip.FullAddress{NICID, ip2, 11311}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip2)
-
- done := make(chan struct{})
- fwd := udp.NewForwarder(s, func(r *udp.ForwarderRequest) {
- defer close(done)
-
- var wq waiter.Queue
- ep, err := r.CreateEndpoint(&wq)
- if err != nil {
- t.Fatalf("r.CreateEndpoint() = %v", err)
- }
- defer ep.Close()
-
- c := NewTCPConn(&wq, ep)
-
- buf := make([]byte, 256)
- n, e := c.Read(buf)
- if e != nil {
- t.Errorf("c.Read() = %v", e)
- }
-
- if _, e := c.Write(buf[:n]); e != nil {
- t.Errorf("c.Write() = %v", e)
- }
- })
- s.SetTransportProtocolHandler(udp.ProtocolNumber, fwd.HandlePacket)
-
- c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber)
- if err != nil {
- t.Fatal("DialUDP(bind port 5):", err)
- }
-
- sent := "abc123"
- sendAddr := fullToUDPAddr(addr1)
- if n, err := c2.WriteTo([]byte(sent), sendAddr); err != nil || n != len(sent) {
- t.Errorf("c1.WriteTo(%q, %v) = %d, %v, want = %d, %v", sent, sendAddr, n, err, len(sent), nil)
- }
-
- buf := make([]byte, 256)
- n, recvAddr, err := c2.ReadFrom(buf)
- if err != nil || recvAddr.String() != sendAddr.String() {
- t.Errorf("c1.ReadFrom() = %d, %v, %v, want = %d, %v, %v", n, recvAddr, err, len(sent), sendAddr, nil)
- }
-}
-
-// TestDeadlineChange tests that changing the deadline affects currently blocked reads.
-func TestDeadlineChange(t *testing.T) {
- s, err := newLoopbackStack()
- if err != nil {
- t.Fatalf("newLoopbackStack() = %v", err)
- }
-
- addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
-
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
-
- l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
- if e != nil {
- t.Fatalf("NewListener() = %v", e)
- }
- done := make(chan struct{})
- go func() {
- defer close(done)
- c, err := l.Accept()
- if err != nil {
- t.Fatalf("l.Accept() = %v", err)
- }
-
- c.SetDeadline(time.Now().Add(time.Minute))
- // Give c.Read() a chance to block before closing the connection.
- time.AfterFunc(time.Millisecond*50, func() {
- c.SetDeadline(time.Now().Add(time.Millisecond * 10))
- })
-
- buf := make([]byte, 256)
- n, err := c.Read(buf)
- got, ok := err.(*net.OpError)
- want := "i/o timeout"
- if n != 0 || !ok || got.Err == nil || got.Err.Error() != want {
- t.Errorf("c.Read() = (%d, %v), want (0, OpError(%s))", n, err, want)
- }
- }()
- sender, err := connect(s, addr)
- if err != nil {
- t.Fatalf("connect() = %v", err)
- }
-
- select {
- case <-done:
- case <-time.After(time.Millisecond * 500):
- t.Errorf("c.Read() didn't unblock")
- }
- sender.close()
-}
-
-func TestPacketConnTransfer(t *testing.T) {
- s, e := newLoopbackStack()
- if e != nil {
- t.Fatalf("newLoopbackStack() = %v", e)
- }
-
- ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
- addr1 := tcpip.FullAddress{NICID, ip1, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip1)
- ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4())
- addr2 := tcpip.FullAddress{NICID, ip2, 11311}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip2)
-
- c1, err := DialUDP(s, &addr1, nil, ipv4.ProtocolNumber)
- if err != nil {
- t.Fatal("DialUDP(bind port 4):", err)
- }
- c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber)
- if err != nil {
- t.Fatal("DialUDP(bind port 5):", err)
- }
-
- c1.SetDeadline(time.Now().Add(time.Second))
- c2.SetDeadline(time.Now().Add(time.Second))
-
- sent := "abc123"
- sendAddr := fullToUDPAddr(addr2)
- if n, err := c1.WriteTo([]byte(sent), sendAddr); err != nil || n != len(sent) {
- t.Errorf("got c1.WriteTo(%q, %v) = %d, %v, want = %d, %v", sent, sendAddr, n, err, len(sent), nil)
- }
- recv := make([]byte, len(sent))
- n, recvAddr, err := c2.ReadFrom(recv)
- if err != nil || n != len(recv) {
- t.Errorf("got c2.ReadFrom() = %d, %v, want = %d, %v", n, err, len(recv), nil)
- }
-
- if recv := string(recv); recv != sent {
- t.Errorf("got recv = %q, want = %q", recv, sent)
- }
-
- if want := fullToUDPAddr(addr1); !reflect.DeepEqual(recvAddr, want) {
- t.Errorf("got recvAddr = %v, want = %v", recvAddr, want)
- }
-
- if err := c1.Close(); err != nil {
- t.Error("c1.Close():", err)
- }
- if err := c2.Close(); err != nil {
- t.Error("c2.Close():", err)
- }
-}
-
-func TestConnectedPacketConnTransfer(t *testing.T) {
- s, e := newLoopbackStack()
- if e != nil {
- t.Fatalf("newLoopbackStack() = %v", e)
- }
-
- ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
- addr := tcpip.FullAddress{NICID, ip, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
-
- c1, err := DialUDP(s, &addr, nil, ipv4.ProtocolNumber)
- if err != nil {
- t.Fatal("DialUDP(bind port 4):", err)
- }
- c2, err := DialUDP(s, nil, &addr, ipv4.ProtocolNumber)
- if err != nil {
- t.Fatal("DialUDP(bind port 5):", err)
- }
-
- c1.SetDeadline(time.Now().Add(time.Second))
- c2.SetDeadline(time.Now().Add(time.Second))
-
- sent := "abc123"
- if n, err := c2.Write([]byte(sent)); err != nil || n != len(sent) {
- t.Errorf("got c2.Write(%q) = %d, %v, want = %d, %v", sent, n, err, len(sent), nil)
- }
- recv := make([]byte, len(sent))
- n, err := c1.Read(recv)
- if err != nil || n != len(recv) {
- t.Errorf("got c1.Read() = %d, %v, want = %d, %v", n, err, len(recv), nil)
- }
-
- if recv := string(recv); recv != sent {
- t.Errorf("got recv = %q, want = %q", recv, sent)
- }
-
- if err := c1.Close(); err != nil {
- t.Error("c1.Close():", err)
- }
- if err := c2.Close(); err != nil {
- t.Error("c2.Close():", err)
- }
-}
-
-func makePipe() (c1, c2 net.Conn, stop func(), err error) {
- s, e := newLoopbackStack()
- if e != nil {
- return nil, nil, nil, fmt.Errorf("newLoopbackStack() = %v", e)
- }
-
- ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
- addr := tcpip.FullAddress{NICID, ip, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
-
- l, err := ListenTCP(s, addr, ipv4.ProtocolNumber)
- if err != nil {
- return nil, nil, nil, fmt.Errorf("NewListener: %v", err)
- }
-
- c1, err = DialTCP(s, addr, ipv4.ProtocolNumber)
- if err != nil {
- l.Close()
- return nil, nil, nil, fmt.Errorf("DialTCP: %v", err)
- }
-
- c2, err = l.Accept()
- if err != nil {
- l.Close()
- c1.Close()
- return nil, nil, nil, fmt.Errorf("l.Accept: %v", err)
- }
-
- stop = func() {
- c1.Close()
- c2.Close()
- }
-
- if err := l.Close(); err != nil {
- stop()
- return nil, nil, nil, fmt.Errorf("l.Close(): %v", err)
- }
-
- return c1, c2, stop, nil
-}
-
-func TestTCPConnTransfer(t *testing.T) {
- c1, c2, _, err := makePipe()
- if err != nil {
- t.Fatal(err)
- }
- defer func() {
- if err := c1.Close(); err != nil {
- t.Error("c1.Close():", err)
- }
- if err := c2.Close(); err != nil {
- t.Error("c2.Close():", err)
- }
- }()
-
- c1.SetDeadline(time.Now().Add(time.Second))
- c2.SetDeadline(time.Now().Add(time.Second))
-
- const sent = "abc123"
-
- tests := []struct {
- name string
- c1 net.Conn
- c2 net.Conn
- }{
- {"connected to accepted", c1, c2},
- {"accepted to connected", c2, c1},
- }
-
- for _, test := range tests {
- if n, err := test.c1.Write([]byte(sent)); err != nil || n != len(sent) {
- t.Errorf("%s: got test.c1.Write(%q) = %d, %v, want = %d, %v", test.name, sent, n, err, len(sent), nil)
- continue
- }
-
- recv := make([]byte, len(sent))
- n, err := test.c2.Read(recv)
- if err != nil || n != len(recv) {
- t.Errorf("%s: got test.c2.Read() = %d, %v, want = %d, %v", test.name, n, err, len(recv), nil)
- continue
- }
-
- if recv := string(recv); recv != sent {
- t.Errorf("%s: got recv = %q, want = %q", test.name, recv, sent)
- }
- }
-}
-
-func TestTCPDialError(t *testing.T) {
- s, e := newLoopbackStack()
- if e != nil {
- t.Fatalf("newLoopbackStack() = %v", e)
- }
-
- ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
- addr := tcpip.FullAddress{NICID, ip, 11211}
-
- _, err := DialTCP(s, addr, ipv4.ProtocolNumber)
- got, ok := err.(*net.OpError)
- want := tcpip.ErrNoRoute
- if !ok || got.Err.Error() != want.String() {
- t.Errorf("Got DialTCP() = %v, want = %v", err, tcpip.ErrNoRoute)
- }
-}
-
-func TestDialContextTCPCanceled(t *testing.T) {
- s, err := newLoopbackStack()
- if err != nil {
- t.Fatalf("newLoopbackStack() = %v", err)
- }
-
- addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
-
- ctx := context.Background()
- ctx, cancel := context.WithCancel(ctx)
- cancel()
-
- if _, err := DialContextTCP(ctx, s, addr, ipv4.ProtocolNumber); err != context.Canceled {
- t.Errorf("got DialContextTCP(...) = %v, want = %v", err, context.Canceled)
- }
-}
-
-func TestDialContextTCPTimeout(t *testing.T) {
- s, err := newLoopbackStack()
- if err != nil {
- t.Fatalf("newLoopbackStack() = %v", err)
- }
-
- addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
-
- fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
- time.Sleep(time.Second)
- r.Complete(true)
- })
- s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
-
- ctx := context.Background()
- ctx, cancel := context.WithDeadline(ctx, time.Now().Add(100*time.Millisecond))
- defer cancel()
-
- if _, err := DialContextTCP(ctx, s, addr, ipv4.ProtocolNumber); err != context.DeadlineExceeded {
- t.Errorf("got DialContextTCP(...) = %v, want = %v", err, context.DeadlineExceeded)
- }
-}
-
-func TestNetTest(t *testing.T) {
- nettest.TestConn(t, makePipe)
-}
diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD
deleted file mode 100644
index 563bc78ea..000000000
--- a/pkg/tcpip/buffer/BUILD
+++ /dev/null
@@ -1,19 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "buffer",
- srcs = [
- "prependable.go",
- "view.go",
- ],
- visibility = ["//visibility:public"],
-)
-
-go_test(
- name = "buffer_test",
- size = "small",
- srcs = ["view_test.go"],
- library = ":buffer",
-)
diff --git a/pkg/tcpip/buffer/buffer_state_autogen.go b/pkg/tcpip/buffer/buffer_state_autogen.go
new file mode 100755
index 000000000..954487771
--- /dev/null
+++ b/pkg/tcpip/buffer/buffer_state_autogen.go
@@ -0,0 +1,24 @@
+// automatically generated by stateify.
+
+package buffer
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (x *VectorisedView) beforeSave() {}
+func (x *VectorisedView) save(m state.Map) {
+ x.beforeSave()
+ m.Save("views", &x.views)
+ m.Save("size", &x.size)
+}
+
+func (x *VectorisedView) afterLoad() {}
+func (x *VectorisedView) load(m state.Map) {
+ m.Load("views", &x.views)
+ m.Load("size", &x.size)
+}
+
+func init() {
+ state.Register("pkg/tcpip/buffer.VectorisedView", (*VectorisedView)(nil), state.Fns{Save: (*VectorisedView).save, Load: (*VectorisedView).load})
+}
diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go
deleted file mode 100644
index ebc3a17b7..000000000
--- a/pkg/tcpip/buffer/view_test.go
+++ /dev/null
@@ -1,235 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package buffer_test contains tests for the VectorisedView type.
-package buffer
-
-import (
- "reflect"
- "testing"
-)
-
-// copy returns a deep-copy of the vectorised view.
-func (vv VectorisedView) copy() VectorisedView {
- uu := VectorisedView{
- views: make([]View, 0, len(vv.views)),
- size: vv.size,
- }
- for _, v := range vv.views {
- uu.views = append(uu.views, append(View(nil), v...))
- }
- return uu
-}
-
-// vv is an helper to build VectorisedView from different strings.
-func vv(size int, pieces ...string) VectorisedView {
- views := make([]View, len(pieces))
- for i, p := range pieces {
- views[i] = []byte(p)
- }
-
- return NewVectorisedView(size, views)
-}
-
-var capLengthTestCases = []struct {
- comment string
- in VectorisedView
- length int
- want VectorisedView
-}{
- {
- comment: "Simple case",
- in: vv(2, "12"),
- length: 1,
- want: vv(1, "1"),
- },
- {
- comment: "Case spanning across two Views",
- in: vv(4, "123", "4"),
- length: 2,
- want: vv(2, "12"),
- },
- {
- comment: "Corner case with negative length",
- in: vv(1, "1"),
- length: -1,
- want: vv(0),
- },
- {
- comment: "Corner case with length = 0",
- in: vv(3, "12", "3"),
- length: 0,
- want: vv(0),
- },
- {
- comment: "Corner case with length = size",
- in: vv(1, "1"),
- length: 1,
- want: vv(1, "1"),
- },
- {
- comment: "Corner case with length > size",
- in: vv(1, "1"),
- length: 2,
- want: vv(1, "1"),
- },
-}
-
-func TestCapLength(t *testing.T) {
- for _, c := range capLengthTestCases {
- orig := c.in.copy()
- c.in.CapLength(c.length)
- if !reflect.DeepEqual(c.in, c.want) {
- t.Errorf("Test \"%s\" failed when calling CapLength(%d) on %v. Got %v. Want %v",
- c.comment, c.length, orig, c.in, c.want)
- }
- }
-}
-
-var trimFrontTestCases = []struct {
- comment string
- in VectorisedView
- count int
- want VectorisedView
-}{
- {
- comment: "Simple case",
- in: vv(2, "12"),
- count: 1,
- want: vv(1, "2"),
- },
- {
- comment: "Case where we trim an entire View",
- in: vv(2, "1", "2"),
- count: 1,
- want: vv(1, "2"),
- },
- {
- comment: "Case spanning across two Views",
- in: vv(3, "1", "23"),
- count: 2,
- want: vv(1, "3"),
- },
- {
- comment: "Corner case with negative count",
- in: vv(1, "1"),
- count: -1,
- want: vv(1, "1"),
- },
- {
- comment: " Corner case with count = 0",
- in: vv(1, "1"),
- count: 0,
- want: vv(1, "1"),
- },
- {
- comment: "Corner case with count = size",
- in: vv(1, "1"),
- count: 1,
- want: vv(0),
- },
- {
- comment: "Corner case with count > size",
- in: vv(1, "1"),
- count: 2,
- want: vv(0),
- },
-}
-
-func TestTrimFront(t *testing.T) {
- for _, c := range trimFrontTestCases {
- orig := c.in.copy()
- c.in.TrimFront(c.count)
- if !reflect.DeepEqual(c.in, c.want) {
- t.Errorf("Test \"%s\" failed when calling TrimFront(%d) on %v. Got %v. Want %v",
- c.comment, c.count, orig, c.in, c.want)
- }
- }
-}
-
-var toViewCases = []struct {
- comment string
- in VectorisedView
- want View
-}{
- {
- comment: "Simple case",
- in: vv(2, "12"),
- want: []byte("12"),
- },
- {
- comment: "Case with multiple views",
- in: vv(2, "1", "2"),
- want: []byte("12"),
- },
- {
- comment: "Empty case",
- in: vv(0),
- want: []byte(""),
- },
-}
-
-func TestToView(t *testing.T) {
- for _, c := range toViewCases {
- got := c.in.ToView()
- if !reflect.DeepEqual(got, c.want) {
- t.Errorf("Test \"%s\" failed when calling ToView() on %v. Got %v. Want %v",
- c.comment, c.in, got, c.want)
- }
- }
-}
-
-var toCloneCases = []struct {
- comment string
- inView VectorisedView
- inBuffer []View
-}{
- {
- comment: "Simple case",
- inView: vv(1, "1"),
- inBuffer: make([]View, 1),
- },
- {
- comment: "Case with multiple views",
- inView: vv(2, "1", "2"),
- inBuffer: make([]View, 2),
- },
- {
- comment: "Case with buffer too small",
- inView: vv(2, "1", "2"),
- inBuffer: make([]View, 1),
- },
- {
- comment: "Case with buffer larger than needed",
- inView: vv(1, "1"),
- inBuffer: make([]View, 2),
- },
- {
- comment: "Case with nil buffer",
- inView: vv(1, "1"),
- inBuffer: nil,
- },
-}
-
-func TestToClone(t *testing.T) {
- for _, c := range toCloneCases {
- t.Run(c.comment, func(t *testing.T) {
- got := c.inView.Clone(c.inBuffer)
- if !reflect.DeepEqual(got, c.inView) {
- t.Fatalf("got (%+v).Clone(%+v) = %+v, want = %+v",
- c.inView, c.inBuffer, got, c.inView)
- }
- })
- }
-}
diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD
deleted file mode 100644
index ed434807f..000000000
--- a/pkg/tcpip/checker/BUILD
+++ /dev/null
@@ -1,16 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "checker",
- testonly = 1,
- srcs = ["checker.go"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/seqnum",
- ],
-)
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
deleted file mode 100644
index c6c160dfc..000000000
--- a/pkg/tcpip/checker/checker.go
+++ /dev/null
@@ -1,842 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package checker provides helper functions to check networking packets for
-// validity.
-package checker
-
-import (
- "encoding/binary"
- "reflect"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
-)
-
-// NetworkChecker is a function to check a property of a network packet.
-type NetworkChecker func(*testing.T, []header.Network)
-
-// TransportChecker is a function to check a property of a transport packet.
-type TransportChecker func(*testing.T, header.Transport)
-
-// ControlMessagesChecker is a function to check a property of ancillary data.
-type ControlMessagesChecker func(*testing.T, tcpip.ControlMessages)
-
-// IPv4 checks the validity and properties of the given IPv4 packet. It is
-// expected to be used in conjunction with other network checkers for specific
-// properties. For example, to check the source and destination address, one
-// would call:
-//
-// checker.IPv4(t, b, checker.SrcAddr(x), checker.DstAddr(y))
-func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) {
- t.Helper()
-
- ipv4 := header.IPv4(b)
-
- if !ipv4.IsValid(len(b)) {
- t.Error("Not a valid IPv4 packet")
- }
-
- xsum := ipv4.CalculateChecksum()
- if xsum != 0 && xsum != 0xffff {
- t.Errorf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum())
- }
-
- for _, f := range checkers {
- f(t, []header.Network{ipv4})
- }
- if t.Failed() {
- t.FailNow()
- }
-}
-
-// IPv6 checks the validity and properties of the given IPv6 packet. The usage
-// is similar to IPv4.
-func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) {
- t.Helper()
-
- ipv6 := header.IPv6(b)
- if !ipv6.IsValid(len(b)) {
- t.Error("Not a valid IPv6 packet")
- }
-
- for _, f := range checkers {
- f(t, []header.Network{ipv6})
- }
- if t.Failed() {
- t.FailNow()
- }
-}
-
-// SrcAddr creates a checker that checks the source address.
-func SrcAddr(addr tcpip.Address) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- if a := h[0].SourceAddress(); a != addr {
- t.Errorf("Bad source address, got %v, want %v", a, addr)
- }
- }
-}
-
-// DstAddr creates a checker that checks the destination address.
-func DstAddr(addr tcpip.Address) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- if a := h[0].DestinationAddress(); a != addr {
- t.Errorf("Bad destination address, got %v, want %v", a, addr)
- }
- }
-}
-
-// TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6).
-func TTL(ttl uint8) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- var v uint8
- switch ip := h[0].(type) {
- case header.IPv4:
- v = ip.TTL()
- case header.IPv6:
- v = ip.HopLimit()
- }
- if v != ttl {
- t.Fatalf("Bad TTL, got %v, want %v", v, ttl)
- }
- }
-}
-
-// PayloadLen creates a checker that checks the payload length.
-func PayloadLen(plen int) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- if l := len(h[0].Payload()); l != plen {
- t.Errorf("Bad payload length, got %v, want %v", l, plen)
- }
- }
-}
-
-// FragmentOffset creates a checker that checks the FragmentOffset field.
-func FragmentOffset(offset uint16) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- // We only do this of IPv4 for now.
- switch ip := h[0].(type) {
- case header.IPv4:
- if v := ip.FragmentOffset(); v != offset {
- t.Errorf("Bad fragment offset, got %v, want %v", v, offset)
- }
- }
- }
-}
-
-// FragmentFlags creates a checker that checks the fragment flags field.
-func FragmentFlags(flags uint8) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- // We only do this of IPv4 for now.
- switch ip := h[0].(type) {
- case header.IPv4:
- if v := ip.Flags(); v != flags {
- t.Errorf("Bad fragment offset, got %v, want %v", v, flags)
- }
- }
- }
-}
-
-// ReceiveTClass creates a checker that checks the TCLASS field in
-// ControlMessages.
-func ReceiveTClass(want uint32) ControlMessagesChecker {
- return func(t *testing.T, cm tcpip.ControlMessages) {
- t.Helper()
- if !cm.HasTClass {
- t.Fatalf("got cm.HasTClass = %t, want cm.TClass = %d", cm.HasTClass, want)
- }
- if got := cm.TClass; got != want {
- t.Fatalf("got cm.TClass = %d, want %d", got, want)
- }
- }
-}
-
-// ReceiveTOS creates a checker that checks the TOS field in ControlMessages.
-func ReceiveTOS(want uint8) ControlMessagesChecker {
- return func(t *testing.T, cm tcpip.ControlMessages) {
- t.Helper()
- if !cm.HasTOS {
- t.Fatalf("got cm.HasTOS = %t, want cm.TOS = %d", cm.HasTOS, want)
- }
- if got := cm.TOS; got != want {
- t.Fatalf("got cm.TOS = %d, want %d", got, want)
- }
- }
-}
-
-// TOS creates a checker that checks the TOS field.
-func TOS(tos uint8, label uint32) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- if v, l := h[0].TOS(); v != tos || l != label {
- t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label)
- }
- }
-}
-
-// Raw creates a checker that checks the bytes of payload.
-// The checker always checks the payload of the last network header.
-// For instance, in case of IPv6 fragments, the payload that will be checked
-// is the one containing the actual data that the packet is carrying, without
-// the bytes added by the IPv6 fragmentation.
-func Raw(want []byte) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- if got := h[len(h)-1].Payload(); !reflect.DeepEqual(got, want) {
- t.Errorf("Wrong payload, got %v, want %v", got, want)
- }
- }
-}
-
-// IPv6Fragment creates a checker that validates an IPv6 fragment.
-func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader {
- t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
- }
-
- ipv6Frag := header.IPv6Fragment(h[0].Payload())
- if !ipv6Frag.IsValid() {
- t.Error("Not a valid IPv6 fragment")
- }
-
- for _, f := range checkers {
- f(t, []header.Network{h[0], ipv6Frag})
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// TCP creates a checker that checks that the transport protocol is TCP and
-// potentially additional transport header fields.
-func TCP(checkers ...TransportChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- first := h[0]
- last := h[len(h)-1]
-
- if p := last.TransportProtocol(); p != header.TCPProtocolNumber {
- t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber)
- }
-
- // Verify the checksum.
- tcp := header.TCP(last.Payload())
- l := uint16(len(tcp))
-
- xsum := header.Checksum([]byte(first.SourceAddress()), 0)
- xsum = header.Checksum([]byte(first.DestinationAddress()), xsum)
- xsum = header.Checksum([]byte{0, byte(last.TransportProtocol())}, xsum)
- xsum = header.Checksum([]byte{byte(l >> 8), byte(l)}, xsum)
- xsum = header.Checksum(tcp, xsum)
-
- if xsum != 0 && xsum != 0xffff {
- t.Errorf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum())
- }
-
- // Run the transport checkers.
- for _, f := range checkers {
- f(t, tcp)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// UDP creates a checker that checks that the transport protocol is UDP and
-// potentially additional transport header fields.
-func UDP(checkers ...TransportChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- last := h[len(h)-1]
-
- if p := last.TransportProtocol(); p != header.UDPProtocolNumber {
- t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
- }
-
- udp := header.UDP(last.Payload())
- for _, f := range checkers {
- f(t, udp)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// SrcPort creates a checker that checks the source port.
-func SrcPort(port uint16) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- if p := h.SourcePort(); p != port {
- t.Errorf("Bad source port, got %v, want %v", p, port)
- }
- }
-}
-
-// DstPort creates a checker that checks the destination port.
-func DstPort(port uint16) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- if p := h.DestinationPort(); p != port {
- t.Errorf("Bad destination port, got %v, want %v", p, port)
- }
- }
-}
-
-// SeqNum creates a checker that checks the sequence number.
-func SeqNum(seq uint32) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- tcp, ok := h.(header.TCP)
- if !ok {
- return
- }
-
- if s := tcp.SequenceNumber(); s != seq {
- t.Errorf("Bad sequence number, got %v, want %v", s, seq)
- }
- }
-}
-
-// AckNum creates a checker that checks the ack number.
-func AckNum(seq uint32) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
- tcp, ok := h.(header.TCP)
- if !ok {
- return
- }
-
- if s := tcp.AckNumber(); s != seq {
- t.Errorf("Bad ack number, got %v, want %v", s, seq)
- }
- }
-}
-
-// Window creates a checker that checks the tcp window.
-func Window(window uint16) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- tcp, ok := h.(header.TCP)
- if !ok {
- return
- }
-
- if w := tcp.WindowSize(); w != window {
- t.Errorf("Bad window, got 0x%x, want 0x%x", w, window)
- }
- }
-}
-
-// TCPFlags creates a checker that checks the tcp flags.
-func TCPFlags(flags uint8) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- tcp, ok := h.(header.TCP)
- if !ok {
- return
- }
-
- if f := tcp.Flags(); f != flags {
- t.Errorf("Bad flags, got 0x%x, want 0x%x", f, flags)
- }
- }
-}
-
-// TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the
-// given mask, match the supplied flags.
-func TCPFlagsMatch(flags, mask uint8) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- tcp, ok := h.(header.TCP)
- if !ok {
- return
- }
-
- if f := tcp.Flags(); (f & mask) != (flags & mask) {
- t.Errorf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask)
- }
- }
-}
-
-// TCPSynOptions creates a checker that checks the presence of TCP options in
-// SYN segments.
-//
-// If wndscale is negative, the window scale option must not be present.
-func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- tcp, ok := h.(header.TCP)
- if !ok {
- return
- }
- opts := tcp.Options()
- limit := len(opts)
- foundMSS := false
- foundWS := false
- foundTS := false
- foundSACKPermitted := false
- tsVal := uint32(0)
- tsEcr := uint32(0)
- for i := 0; i < limit; {
- switch opts[i] {
- case header.TCPOptionEOL:
- i = limit
- case header.TCPOptionNOP:
- i++
- case header.TCPOptionMSS:
- v := uint16(opts[i+2])<<8 | uint16(opts[i+3])
- if wantOpts.MSS != v {
- t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS)
- }
- foundMSS = true
- i += 4
- case header.TCPOptionWS:
- if wantOpts.WS < 0 {
- t.Error("WS present when it shouldn't be")
- }
- v := int(opts[i+2])
- if v != wantOpts.WS {
- t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS)
- }
- foundWS = true
- i += 3
- case header.TCPOptionTS:
- if i+9 >= limit {
- t.Errorf("TS Option truncated , option is only: %d bytes, want 10", limit-i)
- }
- if opts[i+1] != 10 {
- t.Errorf("Bad length %d for TS option, limit: %d", opts[i+1], limit)
- }
- tsVal = binary.BigEndian.Uint32(opts[i+2:])
- tsEcr = uint32(0)
- if tcp.Flags()&header.TCPFlagAck != 0 {
- // If the syn is an SYN-ACK then read
- // the tsEcr value as well.
- tsEcr = binary.BigEndian.Uint32(opts[i+6:])
- }
- foundTS = true
- i += 10
- case header.TCPOptionSACKPermitted:
- if i+1 >= limit {
- t.Errorf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i)
- }
- if opts[i+1] != 2 {
- t.Errorf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit)
- }
- foundSACKPermitted = true
- i += 2
-
- default:
- i += int(opts[i+1])
- }
- }
-
- if !foundMSS {
- t.Errorf("MSS option not found. Options: %x", opts)
- }
-
- if !foundWS && wantOpts.WS >= 0 {
- t.Errorf("WS option not found. Options: %x", opts)
- }
- if wantOpts.TS && !foundTS {
- t.Errorf("TS option not found. Options: %x", opts)
- }
- if foundTS && tsVal == 0 {
- t.Error("TS option specified but the timestamp value is zero")
- }
- if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 {
- t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr)
- }
- if wantOpts.SACKPermitted && !foundSACKPermitted {
- t.Errorf("SACKPermitted option not found. Options: %x", opts)
- }
- }
-}
-
-// TCPTimestampChecker creates a checker that validates that a TCP segment has a
-// TCP Timestamp option if wantTS is true, it also compares the wantTSVal and
-// wantTSEcr values with those in the TCP segment (if present).
-//
-// If wantTSVal or wantTSEcr is zero then the corresponding comparison is
-// skipped.
-func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- tcp, ok := h.(header.TCP)
- if !ok {
- return
- }
- opts := []byte(tcp.Options())
- limit := len(opts)
- foundTS := false
- tsVal := uint32(0)
- tsEcr := uint32(0)
- for i := 0; i < limit; {
- switch opts[i] {
- case header.TCPOptionEOL:
- i = limit
- case header.TCPOptionNOP:
- i++
- case header.TCPOptionTS:
- if i+9 >= limit {
- t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i)
- }
- if opts[i+1] != 10 {
- t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1])
- }
- tsVal = binary.BigEndian.Uint32(opts[i+2:])
- tsEcr = binary.BigEndian.Uint32(opts[i+6:])
- foundTS = true
- i += 10
- default:
- // We don't recognize this option, just skip over it.
- if i+2 > limit {
- return
- }
- l := int(opts[i+1])
- if i < 2 || i+l > limit {
- return
- }
- i += l
- }
- }
-
- if wantTS != foundTS {
- t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS)
- }
- if wantTS && wantTSVal != 0 && wantTSVal != tsVal {
- t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal)
- }
- if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr {
- t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr)
- }
- }
-}
-
-// TCPNoSACKBlockChecker creates a checker that verifies that the segment does not
-// contain any SACK blocks in the TCP options.
-func TCPNoSACKBlockChecker() TransportChecker {
- return TCPSACKBlockChecker(nil)
-}
-
-// TCPSACKBlockChecker creates a checker that verifies that the segment does
-// contain the specified SACK blocks in the TCP options.
-func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
- tcp, ok := h.(header.TCP)
- if !ok {
- return
- }
- var gotSACKBlocks []header.SACKBlock
-
- opts := []byte(tcp.Options())
- limit := len(opts)
- for i := 0; i < limit; {
- switch opts[i] {
- case header.TCPOptionEOL:
- i = limit
- case header.TCPOptionNOP:
- i++
- case header.TCPOptionSACK:
- if i+2 > limit {
- // Malformed SACK block.
- t.Errorf("malformed SACK option in options: %v", opts)
- }
- sackOptionLen := int(opts[i+1])
- if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 {
- // Malformed SACK block.
- t.Errorf("malformed SACK option length in options: %v", opts)
- }
- numBlocks := sackOptionLen / 8
- for j := 0; j < numBlocks; j++ {
- start := binary.BigEndian.Uint32(opts[i+2+j*8:])
- end := binary.BigEndian.Uint32(opts[i+2+j*8+4:])
- gotSACKBlocks = append(gotSACKBlocks, header.SACKBlock{
- Start: seqnum.Value(start),
- End: seqnum.Value(end),
- })
- }
- i += sackOptionLen
- default:
- // We don't recognize this option, just skip over it.
- if i+2 > limit {
- break
- }
- l := int(opts[i+1])
- if l < 2 || i+l > limit {
- break
- }
- i += l
- }
- }
-
- if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) {
- t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks)
- }
- }
-}
-
-// Payload creates a checker that checks the payload.
-func Payload(want []byte) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- if got := h.Payload(); !reflect.DeepEqual(got, want) {
- t.Errorf("Wrong payload, got %v, want %v", got, want)
- }
- }
-}
-
-// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 and
-// potentially additional ICMPv4 header fields.
-func ICMPv4(checkers ...TransportChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- last := h[len(h)-1]
-
- if p := last.TransportProtocol(); p != header.ICMPv4ProtocolNumber {
- t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv4ProtocolNumber)
- }
-
- icmp := header.ICMPv4(last.Payload())
- for _, f := range checkers {
- f(t, icmp)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// ICMPv4Type creates a checker that checks the ICMPv4 Type field.
-func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
- icmpv4, ok := h.(header.ICMPv4)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
- }
- if got := icmpv4.Type(); got != want {
- t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
- }
- }
-}
-
-// ICMPv4Code creates a checker that checks the ICMPv4 Code field.
-func ICMPv4Code(want byte) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
- icmpv4, ok := h.(header.ICMPv4)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
- }
- if got := icmpv4.Code(); got != want {
- t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
- }
- }
-}
-
-// ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and
-// potentially additional ICMPv6 header fields.
-//
-// ICMPv6 will validate the checksum field before calling checkers.
-func ICMPv6(checkers ...TransportChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- last := h[len(h)-1]
-
- if p := last.TransportProtocol(); p != header.ICMPv6ProtocolNumber {
- t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv6ProtocolNumber)
- }
-
- icmp := header.ICMPv6(last.Payload())
- if got, want := icmp.Checksum(), header.ICMPv6Checksum(icmp, last.SourceAddress(), last.DestinationAddress(), buffer.VectorisedView{}); got != want {
- t.Fatalf("Bad ICMPv6 checksum; got %d, want %d", got, want)
- }
-
- for _, f := range checkers {
- f(t, icmp)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// ICMPv6Type creates a checker that checks the ICMPv6 Type field.
-func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
- icmpv6, ok := h.(header.ICMPv6)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
- }
- if got := icmpv6.Type(); got != want {
- t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
- }
- }
-}
-
-// ICMPv6Code creates a checker that checks the ICMPv6 Code field.
-func ICMPv6Code(want byte) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
- icmpv6, ok := h.(header.ICMPv6)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
- }
- if got := icmpv6.Code(); got != want {
- t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
- }
- }
-}
-
-// NDP creates a checker that checks that the packet contains a valid NDP
-// message for type of ty, with potentially additional checks specified by
-// checkers.
-//
-// checkers may assume that a valid ICMPv6 is passed to it containing a valid
-// NDP message as far as the size of the message (minSize) is concerned. The
-// values within the message are up to checkers to validate.
-func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- // Check normal ICMPv6 first.
- ICMPv6(
- ICMPv6Type(msgType),
- ICMPv6Code(0))(t, h)
-
- last := h[len(h)-1]
-
- icmp := header.ICMPv6(last.Payload())
- if got := len(icmp.NDPPayload()); got < minSize {
- t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize)
- }
-
- for _, f := range checkers {
- f(t, icmp)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// NDPNS creates a checker that checks that the packet contains a valid NDP
-// Neighbor Solicitation message (as per the raw wire format), with potentially
-// additional checks specified by checkers.
-//
-// checkers may assume that a valid ICMPv6 is passed to it containing a valid
-// NDPNS message as far as the size of the messages concerned. The values within
-// the message are up to checkers to validate.
-func NDPNS(checkers ...TransportChecker) NetworkChecker {
- return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...)
-}
-
-// NDPNSTargetAddress creates a checker that checks the Target Address field of
-// a header.NDPNeighborSolicit.
-//
-// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
-// containing a valid NDPNS message as far as the size is concerned.
-func NDPNSTargetAddress(want tcpip.Address) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmp := h.(header.ICMPv6)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
-
- if got := ns.TargetAddress(); got != want {
- t.Fatalf("got %T.TargetAddress = %s, want = %s", ns, got, want)
- }
- }
-}
-
-// NDPNSOptions creates a checker that checks that the packet contains the
-// provided NDP options within an NDP Neighbor Solicitation message.
-//
-// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
-// containing a valid NDPNS message as far as the size is concerned.
-func NDPNSOptions(opts []header.NDPOption) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmp := h.(header.ICMPv6)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
- it, err := ns.Options().Iter(true)
- if err != nil {
- t.Errorf("opts.Iter(true): %s", err)
- return
- }
-
- i := 0
- for {
- opt, done, _ := it.Next()
- if done {
- break
- }
-
- if i >= len(opts) {
- t.Errorf("got unexpected option: %s", opt)
- continue
- }
-
- switch wantOpt := opts[i].(type) {
- case header.NDPSourceLinkLayerAddressOption:
- gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption)
- if !ok {
- t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
- } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
- t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
- }
- default:
- panic("not implemented")
- }
-
- i++
- }
-
- if missing := opts[i:]; len(missing) > 0 {
- t.Errorf("missing options: %s", missing)
- }
- }
-}
-
-// NDPRS creates a checker that checks that the packet contains a valid NDP
-// Router Solicitation message (as per the raw wire format).
-func NDPRS() NetworkChecker {
- return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize)
-}
diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD
deleted file mode 100644
index ff2719291..000000000
--- a/pkg/tcpip/hash/jenkins/BUILD
+++ /dev/null
@@ -1,18 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "jenkins",
- srcs = ["jenkins.go"],
- visibility = ["//visibility:public"],
-)
-
-go_test(
- name = "jenkins_test",
- size = "small",
- srcs = [
- "jenkins_test.go",
- ],
- library = ":jenkins",
-)
diff --git a/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go b/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go
new file mode 100755
index 000000000..216cc5a2e
--- /dev/null
+++ b/pkg/tcpip/hash/jenkins/jenkins_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package jenkins
diff --git a/pkg/tcpip/hash/jenkins/jenkins_test.go b/pkg/tcpip/hash/jenkins/jenkins_test.go
deleted file mode 100644
index 4c78b5808..000000000
--- a/pkg/tcpip/hash/jenkins/jenkins_test.go
+++ /dev/null
@@ -1,176 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-package jenkins
-
-import (
- "bytes"
- "encoding/binary"
- "hash"
- "hash/fnv"
- "math"
- "testing"
-)
-
-func TestGolden32(t *testing.T) {
- var golden32 = []struct {
- out []byte
- in string
- }{
- {[]byte{0x00, 0x00, 0x00, 0x00}, ""},
- {[]byte{0xca, 0x2e, 0x94, 0x42}, "a"},
- {[]byte{0x45, 0xe6, 0x1e, 0x58}, "ab"},
- {[]byte{0xed, 0x13, 0x1f, 0x5b}, "abc"},
- }
-
- hash := New32()
-
- for _, g := range golden32 {
- hash.Reset()
- done, error := hash.Write([]byte(g.in))
- if error != nil {
- t.Fatalf("write error: %s", error)
- }
- if done != len(g.in) {
- t.Fatalf("wrote only %d out of %d bytes", done, len(g.in))
- }
- if actual := hash.Sum(nil); !bytes.Equal(g.out, actual) {
- t.Errorf("hash(%q) = 0x%x want 0x%x", g.in, actual, g.out)
- }
- }
-}
-
-func TestIntegrity32(t *testing.T) {
- data := []byte{'1', '2', 3, 4, 5}
-
- h := New32()
- h.Write(data)
- sum := h.Sum(nil)
-
- if size := h.Size(); size != len(sum) {
- t.Fatalf("Size()=%d but len(Sum())=%d", size, len(sum))
- }
-
- if a := h.Sum(nil); !bytes.Equal(sum, a) {
- t.Fatalf("first Sum()=0x%x, second Sum()=0x%x", sum, a)
- }
-
- h.Reset()
- h.Write(data)
- if a := h.Sum(nil); !bytes.Equal(sum, a) {
- t.Fatalf("Sum()=0x%x, but after Reset() Sum()=0x%x", sum, a)
- }
-
- h.Reset()
- h.Write(data[:2])
- h.Write(data[2:])
- if a := h.Sum(nil); !bytes.Equal(sum, a) {
- t.Fatalf("Sum()=0x%x, but with partial writes, Sum()=0x%x", sum, a)
- }
-
- sum32 := h.(hash.Hash32).Sum32()
- if sum32 != binary.BigEndian.Uint32(sum) {
- t.Fatalf("Sum()=0x%x, but Sum32()=0x%x", sum, sum32)
- }
-}
-
-func BenchmarkJenkins32KB(b *testing.B) {
- h := New32()
-
- b.SetBytes(1024)
- data := make([]byte, 1024)
- for i := range data {
- data[i] = byte(i)
- }
- in := make([]byte, 0, h.Size())
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- h.Reset()
- h.Write(data)
- h.Sum(in)
- }
-}
-
-func BenchmarkFnv32(b *testing.B) {
- arr := make([]int64, 1000)
- for i := 0; i < b.N; i++ {
- var payload [8]byte
- binary.BigEndian.PutUint32(payload[:4], uint32(i))
- binary.BigEndian.PutUint32(payload[4:], uint32(i))
-
- h := fnv.New32()
- h.Write(payload[:])
- idx := int(h.Sum32()) % len(arr)
- arr[idx]++
- }
- b.StopTimer()
- c := 0
- if b.N > 1000000 {
- for i := 0; i < len(arr)-1; i++ {
- if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) {
- if c == 0 {
- b.Logf("i %d val[i] %d val[i+1] %d b.N %b\n", i, arr[i], arr[i+1], b.N)
- }
- c++
- }
- }
- if c > 0 {
- b.Logf("Unbalanced buckets: %d", c)
- }
- }
-}
-
-func BenchmarkSum32(b *testing.B) {
- arr := make([]int64, 1000)
- for i := 0; i < b.N; i++ {
- var payload [8]byte
- binary.BigEndian.PutUint32(payload[:4], uint32(i))
- binary.BigEndian.PutUint32(payload[4:], uint32(i))
- h := Sum32(0)
- h.Write(payload[:])
- idx := int(h.Sum32()) % len(arr)
- arr[idx]++
- }
- b.StopTimer()
- if b.N > 1000000 {
- for i := 0; i < len(arr)-1; i++ {
- if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) {
- b.Logf("val[%3d]=%8d\tval[%3d]=%8d\tb.N=%b\n", i, arr[i], i+1, arr[i+1], b.N)
- break
- }
- }
- }
-}
-
-func BenchmarkNew32(b *testing.B) {
- arr := make([]int64, 1000)
- for i := 0; i < b.N; i++ {
- var payload [8]byte
- binary.BigEndian.PutUint32(payload[:4], uint32(i))
- binary.BigEndian.PutUint32(payload[4:], uint32(i))
- h := New32()
- h.Write(payload[:])
- idx := int(h.Sum32()) % len(arr)
- arr[idx]++
- }
- b.StopTimer()
- if b.N > 1000000 {
- for i := 0; i < len(arr)-1; i++ {
- if math.Abs(float64(arr[i]-arr[i+1]))/float64(arr[i]) > float64(0.1) {
- b.Logf("val[%3d]=%8d\tval[%3d]=%8d\tb.N=%b\n", i, arr[i], i+1, arr[i+1], b.N)
- break
- }
- }
- }
-}
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
deleted file mode 100644
index 9da0d71f8..000000000
--- a/pkg/tcpip/header/BUILD
+++ /dev/null
@@ -1,65 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "header",
- srcs = [
- "arp.go",
- "checksum.go",
- "eth.go",
- "gue.go",
- "icmpv4.go",
- "icmpv6.go",
- "interfaces.go",
- "ipv4.go",
- "ipv6.go",
- "ipv6_fragment.go",
- "ndp_neighbor_advert.go",
- "ndp_neighbor_solicit.go",
- "ndp_options.go",
- "ndp_router_advert.go",
- "ndp_router_solicit.go",
- "tcp.go",
- "udp.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/seqnum",
- "@com_github_google_btree//:go_default_library",
- ],
-)
-
-go_test(
- name = "header_x_test",
- size = "small",
- srcs = [
- "checksum_test.go",
- "ipv6_test.go",
- "ipversion_test.go",
- "tcp_test.go",
- ],
- deps = [
- ":header",
- "//pkg/rand",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "@com_github_google_go-cmp//cmp:go_default_library",
- ],
-)
-
-go_test(
- name = "header_test",
- size = "small",
- srcs = [
- "eth_test.go",
- "ndp_test.go",
- ],
- library = ":header",
- deps = [
- "//pkg/tcpip",
- "@com_github_google_go-cmp//cmp:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go
deleted file mode 100644
index 309403482..000000000
--- a/pkg/tcpip/header/checksum_test.go
+++ /dev/null
@@ -1,171 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package header provides the implementation of the encoding and decoding of
-// network protocol headers.
-package header_test
-
-import (
- "fmt"
- "math/rand"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
-)
-
-func TestChecksumVVWithOffset(t *testing.T) {
- testCases := []struct {
- name string
- vv buffer.VectorisedView
- off, size int
- initial uint16
- want uint16
- }{
- {
- name: "empty",
- vv: buffer.NewVectorisedView(0, []buffer.View{
- buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}),
- }),
- off: 0,
- size: 0,
- want: 0,
- },
- {
- name: "OneView",
- vv: buffer.NewVectorisedView(0, []buffer.View{
- buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}),
- }),
- off: 0,
- size: 5,
- want: 1294,
- },
- {
- name: "TwoViews",
- vv: buffer.NewVectorisedView(0, []buffer.View{
- buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}),
- buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}),
- }),
- off: 0,
- size: 11,
- want: 33819,
- },
- {
- name: "TwoViewsWithOffset",
- vv: buffer.NewVectorisedView(0, []buffer.View{
- buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
- buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}),
- }),
- off: 1,
- size: 11,
- want: 33819,
- },
- {
- name: "ThreeViewsWithOffset",
- vv: buffer.NewVectorisedView(0, []buffer.View{
- buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
- buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
- buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}),
- }),
- off: 7,
- size: 11,
- want: 33819,
- },
- {
- name: "ThreeViewsWithInitial",
- vv: buffer.NewVectorisedView(0, []buffer.View{
- buffer.NewViewFromBytes([]byte{77, 11, 33, 0, 55, 44}),
- buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
- buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123, 99}),
- }),
- initial: 77,
- off: 7,
- size: 11,
- want: 33896,
- },
- }
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- if got, want := header.ChecksumVVWithOffset(tc.vv, tc.initial, tc.off, tc.size), tc.want; got != want {
- t.Errorf("header.ChecksumVVWithOffset(%v) = %v, want: %v", tc, got, tc.want)
- }
- v := tc.vv.ToView()
- v.TrimFront(tc.off)
- v.CapLength(tc.size)
- if got, want := header.Checksum(v, tc.initial), tc.want; got != want {
- t.Errorf("header.Checksum(%v) = %v, want: %v", tc, got, tc.want)
- }
- })
- }
-}
-
-func TestChecksum(t *testing.T) {
- var bufSizes = []int{0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256, 257, 1023, 1024}
- type testCase struct {
- buf []byte
- initial uint16
- csumOrig uint16
- csumNew uint16
- }
- testCases := make([]testCase, 100000)
- // Ensure same buffer generation for test consistency.
- rnd := rand.New(rand.NewSource(42))
- for i := range testCases {
- testCases[i].buf = make([]byte, bufSizes[i%len(bufSizes)])
- testCases[i].initial = uint16(rnd.Intn(65536))
- rnd.Read(testCases[i].buf)
- }
-
- for i := range testCases {
- testCases[i].csumOrig = header.ChecksumOld(testCases[i].buf, testCases[i].initial)
- testCases[i].csumNew = header.Checksum(testCases[i].buf, testCases[i].initial)
- if got, want := testCases[i].csumNew, testCases[i].csumOrig; got != want {
- t.Fatalf("new checksum for (buf = %x, initial = %d) does not match old got: %d, want: %d", testCases[i].buf, testCases[i].initial, got, want)
- }
- }
-}
-
-func BenchmarkChecksum(b *testing.B) {
- var bufSizes = []int{64, 128, 256, 512, 1024, 1500, 2048, 4096, 8192, 16384, 32767, 32768, 65535, 65536}
-
- checkSumImpls := []struct {
- fn func([]byte, uint16) uint16
- name string
- }{
- {header.ChecksumOld, fmt.Sprintf("checksum_old")},
- {header.Checksum, fmt.Sprintf("checksum")},
- }
-
- for _, csumImpl := range checkSumImpls {
- // Ensure same buffer generation for test consistency.
- rnd := rand.New(rand.NewSource(42))
- for _, bufSz := range bufSizes {
- b.Run(fmt.Sprintf("%s_%d", csumImpl.name, bufSz), func(b *testing.B) {
- tc := struct {
- buf []byte
- initial uint16
- csum uint16
- }{
- buf: make([]byte, bufSz),
- initial: uint16(rnd.Intn(65536)),
- }
- rnd.Read(tc.buf)
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- tc.csum = csumImpl.fn(tc.buf, tc.initial)
- }
- })
- }
- }
-}
diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go
deleted file mode 100644
index 7a0014ad9..000000000
--- a/pkg/tcpip/header/eth_test.go
+++ /dev/null
@@ -1,102 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package header
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
-)
-
-func TestIsValidUnicastEthernetAddress(t *testing.T) {
- tests := []struct {
- name string
- addr tcpip.LinkAddress
- expected bool
- }{
- {
- "Nil",
- tcpip.LinkAddress([]byte(nil)),
- false,
- },
- {
- "Empty",
- tcpip.LinkAddress(""),
- false,
- },
- {
- "InvalidLength",
- tcpip.LinkAddress("\x01\x02\x03"),
- false,
- },
- {
- "Unspecified",
- unspecifiedEthernetAddress,
- false,
- },
- {
- "Multicast",
- tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
- false,
- },
- {
- "Valid",
- tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06"),
- true,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- if got := IsValidUnicastEthernetAddress(test.addr); got != test.expected {
- t.Fatalf("got IsValidUnicastEthernetAddress = %t, want = %t", got, test.expected)
- }
- })
- }
-}
-
-func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) {
- tests := []struct {
- name string
- addr tcpip.Address
- expectedLinkAddr tcpip.LinkAddress
- }{
- {
- name: "IPv4 Multicast without 24th bit set",
- addr: "\xe0\x7e\xdc\xba",
- expectedLinkAddr: "\x01\x00\x5e\x7e\xdc\xba",
- },
- {
- name: "IPv4 Multicast with 24th bit set",
- addr: "\xe0\xfe\xdc\xba",
- expectedLinkAddr: "\x01\x00\x5e\x7e\xdc\xba",
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- if got := EthernetAddressFromMulticastIPv4Address(test.addr); got != test.expectedLinkAddr {
- t.Fatalf("got EthernetAddressFromMulticastIPv4Address(%s) = %s, want = %s", got, test.expectedLinkAddr)
- }
- })
- }
-}
-
-func TestEthernetAddressFromMulticastIPv6Address(t *testing.T) {
- addr := tcpip.Address("\xff\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x1a")
- if got, want := EthernetAddressFromMulticastIPv6Address(addr), tcpip.LinkAddress("\x33\x33\x0d\x0e\x0f\x1a"); got != want {
- t.Fatalf("got EthernetAddressFromMulticastIPv6Address(%s) = %s, want = %s", addr, got, want)
- }
-}
diff --git a/pkg/tcpip/header/header_state_autogen.go b/pkg/tcpip/header/header_state_autogen.go
new file mode 100755
index 000000000..015d7e12a
--- /dev/null
+++ b/pkg/tcpip/header/header_state_autogen.go
@@ -0,0 +1,42 @@
+// automatically generated by stateify.
+
+package header
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (x *SACKBlock) beforeSave() {}
+func (x *SACKBlock) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Start", &x.Start)
+ m.Save("End", &x.End)
+}
+
+func (x *SACKBlock) afterLoad() {}
+func (x *SACKBlock) load(m state.Map) {
+ m.Load("Start", &x.Start)
+ m.Load("End", &x.End)
+}
+
+func (x *TCPOptions) beforeSave() {}
+func (x *TCPOptions) save(m state.Map) {
+ x.beforeSave()
+ m.Save("TS", &x.TS)
+ m.Save("TSVal", &x.TSVal)
+ m.Save("TSEcr", &x.TSEcr)
+ m.Save("SACKBlocks", &x.SACKBlocks)
+}
+
+func (x *TCPOptions) afterLoad() {}
+func (x *TCPOptions) load(m state.Map) {
+ m.Load("TS", &x.TS)
+ m.Load("TSVal", &x.TSVal)
+ m.Load("TSEcr", &x.TSEcr)
+ m.Load("SACKBlocks", &x.SACKBlocks)
+}
+
+func init() {
+ state.Register("pkg/tcpip/header.SACKBlock", (*SACKBlock)(nil), state.Fns{Save: (*SACKBlock).save, Load: (*SACKBlock).load})
+ state.Register("pkg/tcpip/header.TCPOptions", (*TCPOptions)(nil), state.Fns{Save: (*TCPOptions).save, Load: (*TCPOptions).load})
+}
diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go
deleted file mode 100644
index c3ad503aa..000000000
--- a/pkg/tcpip/header/ipv6_test.go
+++ /dev/null
@@ -1,331 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package header_test
-
-import (
- "bytes"
- "crypto/sha256"
- "fmt"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/rand"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/header"
-)
-
-const (
- linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- linkLocalAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- globalAddr = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
-)
-
-func TestEthernetAdddressToModifiedEUI64(t *testing.T) {
- expectedIID := [header.IIDSize]byte{0, 2, 3, 255, 254, 4, 5, 6}
-
- if diff := cmp.Diff(expectedIID, header.EthernetAddressToModifiedEUI64(linkAddr)); diff != "" {
- t.Errorf("EthernetAddressToModifiedEUI64(%s) mismatch (-want +got):\n%s", linkAddr, diff)
- }
-
- var buf [header.IIDSize]byte
- header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:])
- if diff := cmp.Diff(expectedIID, buf); diff != "" {
- t.Errorf("EthernetAddressToModifiedEUI64IntoBuf(%s, _) mismatch (-want +got):\n%s", linkAddr, diff)
- }
-}
-
-func TestLinkLocalAddr(t *testing.T) {
- if got, want := header.LinkLocalAddr(linkAddr), tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x02\x03\xff\xfe\x04\x05\x06"); got != want {
- t.Errorf("got LinkLocalAddr(%s) = %s, want = %s", linkAddr, got, want)
- }
-}
-
-func TestAppendOpaqueInterfaceIdentifier(t *testing.T) {
- var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte
- if n, err := rand.Read(secretKeyBuf[:]); err != nil {
- t.Fatalf("rand.Read(_): %s", err)
- } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want {
- t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n)
- }
-
- tests := []struct {
- name string
- prefix tcpip.Subnet
- nicName string
- dadCounter uint8
- secretKey []byte
- }{
- {
- name: "SecretKey of minimum size",
- prefix: header.IPv6LinkLocalPrefix.Subnet(),
- nicName: "eth0",
- dadCounter: 0,
- secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes],
- },
- {
- name: "SecretKey of less than minimum size",
- prefix: func() tcpip.Subnet {
- addrWithPrefix := tcpip.AddressWithPrefix{
- Address: "\x01\x02\x03\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
- PrefixLen: header.IIDOffsetInIPv6Address * 8,
- }
- return addrWithPrefix.Subnet()
- }(),
- nicName: "eth10",
- dadCounter: 1,
- secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2],
- },
- {
- name: "SecretKey of more than minimum size",
- prefix: func() tcpip.Subnet {
- addrWithPrefix := tcpip.AddressWithPrefix{
- Address: "\x01\x02\x03\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
- PrefixLen: header.IIDOffsetInIPv6Address * 8,
- }
- return addrWithPrefix.Subnet()
- }(),
- nicName: "eth11",
- dadCounter: 2,
- secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2],
- },
- {
- name: "Nil SecretKey and empty nicName",
- prefix: func() tcpip.Subnet {
- addrWithPrefix := tcpip.AddressWithPrefix{
- Address: "\x01\x02\x03\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
- PrefixLen: header.IIDOffsetInIPv6Address * 8,
- }
- return addrWithPrefix.Subnet()
- }(),
- nicName: "",
- dadCounter: 3,
- secretKey: nil,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- h := sha256.New()
- h.Write([]byte(test.prefix.ID()[:header.IIDOffsetInIPv6Address]))
- h.Write([]byte(test.nicName))
- h.Write([]byte{test.dadCounter})
- if k := test.secretKey; k != nil {
- h.Write(k)
- }
- var hashSum [sha256.Size]byte
- h.Sum(hashSum[:0])
- want := hashSum[:header.IIDSize]
-
- // Passing a nil buffer should result in a new buffer returned with the
- // IID.
- if got := header.AppendOpaqueInterfaceIdentifier(nil, test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) {
- t.Errorf("got AppendOpaqueInterfaceIdentifier(nil, %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want)
- }
-
- // Passing a buffer with sufficient capacity for the IID should populate
- // the buffer provided.
- var iidBuf [header.IIDSize]byte
- if got := header.AppendOpaqueInterfaceIdentifier(iidBuf[:0], test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) {
- t.Errorf("got AppendOpaqueInterfaceIdentifier(iidBuf[:0], %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want)
- }
- if got := iidBuf[:]; !bytes.Equal(got, want) {
- t.Errorf("got iidBuf = %x, want = %x", got, want)
- }
- })
- }
-}
-
-func TestLinkLocalAddrWithOpaqueIID(t *testing.T) {
- var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte
- if n, err := rand.Read(secretKeyBuf[:]); err != nil {
- t.Fatalf("rand.Read(_): %s", err)
- } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want {
- t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n)
- }
-
- prefix := header.IPv6LinkLocalPrefix.Subnet()
-
- tests := []struct {
- name string
- prefix tcpip.Subnet
- nicName string
- dadCounter uint8
- secretKey []byte
- }{
- {
- name: "SecretKey of minimum size",
- nicName: "eth0",
- dadCounter: 0,
- secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes],
- },
- {
- name: "SecretKey of less than minimum size",
- nicName: "eth10",
- dadCounter: 1,
- secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2],
- },
- {
- name: "SecretKey of more than minimum size",
- nicName: "eth11",
- dadCounter: 2,
- secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2],
- },
- {
- name: "Nil SecretKey and empty nicName",
- nicName: "",
- dadCounter: 3,
- secretKey: nil,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- addrBytes := [header.IPv6AddressSize]byte{
- 0: 0xFE,
- 1: 0x80,
- }
-
- want := tcpip.Address(header.AppendOpaqueInterfaceIdentifier(
- addrBytes[:header.IIDOffsetInIPv6Address],
- prefix,
- test.nicName,
- test.dadCounter,
- test.secretKey,
- ))
-
- if got := header.LinkLocalAddrWithOpaqueIID(test.nicName, test.dadCounter, test.secretKey); got != want {
- t.Errorf("got LinkLocalAddrWithOpaqueIID(%s, %d, %x) = %s, want = %s", test.nicName, test.dadCounter, test.secretKey, got, want)
- }
- })
- }
-}
-
-func TestIsV6UniqueLocalAddress(t *testing.T) {
- tests := []struct {
- name string
- addr tcpip.Address
- expected bool
- }{
- {
- name: "Valid Unique 1",
- addr: uniqueLocalAddr1,
- expected: true,
- },
- {
- name: "Valid Unique 2",
- addr: uniqueLocalAddr1,
- expected: true,
- },
- {
- name: "Link Local",
- addr: linkLocalAddr,
- expected: false,
- },
- {
- name: "Global",
- addr: globalAddr,
- expected: false,
- },
- {
- name: "IPv4",
- addr: "\x01\x02\x03\x04",
- expected: false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- if got := header.IsV6UniqueLocalAddress(test.addr); got != test.expected {
- t.Errorf("got header.IsV6UniqueLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected)
- }
- })
- }
-}
-
-func TestScopeForIPv6Address(t *testing.T) {
- tests := []struct {
- name string
- addr tcpip.Address
- scope header.IPv6AddressScope
- err *tcpip.Error
- }{
- {
- name: "Unique Local",
- addr: uniqueLocalAddr1,
- scope: header.UniqueLocalScope,
- err: nil,
- },
- {
- name: "Link Local",
- addr: linkLocalAddr,
- scope: header.LinkLocalScope,
- err: nil,
- },
- {
- name: "Global",
- addr: globalAddr,
- scope: header.GlobalScope,
- err: nil,
- },
- {
- name: "IPv4",
- addr: "\x01\x02\x03\x04",
- scope: header.GlobalScope,
- err: tcpip.ErrBadAddress,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- got, err := header.ScopeForIPv6Address(test.addr)
- if err != test.err {
- t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (_, %v), want = (_, %v)", test.addr, err, test.err)
- }
- if got != test.scope {
- t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (%d, _), want = (%d, _)", test.addr, got, test.scope)
- }
- })
- }
-}
-
-func TestSolicitedNodeAddr(t *testing.T) {
- tests := []struct {
- addr tcpip.Address
- want tcpip.Address
- }{
- {
- addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\xa0",
- want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x0e\x0f\xa0",
- },
- {
- addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\xdd\x0e\x0f\xa0",
- want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x0e\x0f\xa0",
- },
- {
- addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\xdd\x01\x02\x03",
- want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x01\x02\x03",
- },
- }
-
- for _, test := range tests {
- t.Run(fmt.Sprintf("%s", test.addr), func(t *testing.T) {
- if got := header.SolicitedNodeAddr(test.addr); got != test.want {
- t.Fatalf("got header.SolicitedNodeAddr(%s) = %s, want = %s", test.addr, got, test.want)
- }
- })
- }
-}
diff --git a/pkg/tcpip/header/ipversion_test.go b/pkg/tcpip/header/ipversion_test.go
deleted file mode 100644
index b5540bf66..000000000
--- a/pkg/tcpip/header/ipversion_test.go
+++ /dev/null
@@ -1,67 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package header_test
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip/header"
-)
-
-func TestIPv4(t *testing.T) {
- b := header.IPv4(make([]byte, header.IPv4MinimumSize))
- b.Encode(&header.IPv4Fields{})
-
- const want = header.IPv4Version
- if v := header.IPVersion(b); v != want {
- t.Fatalf("Bad version, want %v, got %v", want, v)
- }
-}
-
-func TestIPv6(t *testing.T) {
- b := header.IPv6(make([]byte, header.IPv6MinimumSize))
- b.Encode(&header.IPv6Fields{})
-
- const want = header.IPv6Version
- if v := header.IPVersion(b); v != want {
- t.Fatalf("Bad version, want %v, got %v", want, v)
- }
-}
-
-func TestOtherVersion(t *testing.T) {
- const want = header.IPv4Version + header.IPv6Version
- b := make([]byte, 1)
- b[0] = want << 4
-
- if v := header.IPVersion(b); v != want {
- t.Fatalf("Bad version, want %v, got %v", want, v)
- }
-}
-
-func TestTooShort(t *testing.T) {
- b := make([]byte, 1)
- b[0] = (header.IPv4Version + header.IPv6Version) << 4
-
- // Get the version of a zero-length slice.
- const want = -1
- if v := header.IPVersion(b[:0]); v != want {
- t.Fatalf("Bad version, want %v, got %v", want, v)
- }
-
- // Get the version of a nil slice.
- if v := header.IPVersion(nil); v != want {
- t.Fatalf("Bad version, want %v, got %v", want, v)
- }
-}
diff --git a/pkg/tcpip/header/ndp_neighbor_advert.go b/pkg/tcpip/header/ndp_neighbor_advert.go
index 505c92668..505c92668 100644..100755
--- a/pkg/tcpip/header/ndp_neighbor_advert.go
+++ b/pkg/tcpip/header/ndp_neighbor_advert.go
diff --git a/pkg/tcpip/header/ndp_neighbor_solicit.go b/pkg/tcpip/header/ndp_neighbor_solicit.go
index 3a1b8e139..3a1b8e139 100644..100755
--- a/pkg/tcpip/header/ndp_neighbor_solicit.go
+++ b/pkg/tcpip/header/ndp_neighbor_solicit.go
diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go
index e6a6ad39b..e6a6ad39b 100644..100755
--- a/pkg/tcpip/header/ndp_options.go
+++ b/pkg/tcpip/header/ndp_options.go
diff --git a/pkg/tcpip/header/ndp_router_advert.go b/pkg/tcpip/header/ndp_router_advert.go
index bf7610863..bf7610863 100644..100755
--- a/pkg/tcpip/header/ndp_router_advert.go
+++ b/pkg/tcpip/header/ndp_router_advert.go
diff --git a/pkg/tcpip/header/ndp_router_solicit.go b/pkg/tcpip/header/ndp_router_solicit.go
index 9e67ba95d..9e67ba95d 100644..100755
--- a/pkg/tcpip/header/ndp_router_solicit.go
+++ b/pkg/tcpip/header/ndp_router_solicit.go
diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go
deleted file mode 100644
index 1cb9f5dc8..000000000
--- a/pkg/tcpip/header/ndp_test.go
+++ /dev/null
@@ -1,937 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package header
-
-import (
- "bytes"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
-)
-
-// TestNDPNeighborSolicit tests the functions of NDPNeighborSolicit.
-func TestNDPNeighborSolicit(t *testing.T) {
- b := []byte{
- 0, 0, 0, 0,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- }
-
- // Test getting the Target Address.
- ns := NDPNeighborSolicit(b)
- addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10")
- if got := ns.TargetAddress(); got != addr {
- t.Errorf("got ns.TargetAddress = %s, want %s", got, addr)
- }
-
- // Test updating the Target Address.
- addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11")
- ns.SetTargetAddress(addr2)
- if got := ns.TargetAddress(); got != addr2 {
- t.Errorf("got ns.TargetAddress = %s, want %s", got, addr2)
- }
- // Make sure the address got updated in the backing buffer.
- if got := tcpip.Address(b[ndpNSTargetAddessOffset:][:IPv6AddressSize]); got != addr2 {
- t.Errorf("got targetaddress buffer = %s, want %s", got, addr2)
- }
-}
-
-// TestNDPNeighborAdvert tests the functions of NDPNeighborAdvert.
-func TestNDPNeighborAdvert(t *testing.T) {
- b := []byte{
- 160, 0, 0, 0,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- }
-
- // Test getting the Target Address.
- na := NDPNeighborAdvert(b)
- addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10")
- if got := na.TargetAddress(); got != addr {
- t.Errorf("got TargetAddress = %s, want %s", got, addr)
- }
-
- // Test getting the Router Flag.
- if got := na.RouterFlag(); !got {
- t.Errorf("got RouterFlag = false, want = true")
- }
-
- // Test getting the Solicited Flag.
- if got := na.SolicitedFlag(); got {
- t.Errorf("got SolicitedFlag = true, want = false")
- }
-
- // Test getting the Override Flag.
- if got := na.OverrideFlag(); !got {
- t.Errorf("got OverrideFlag = false, want = true")
- }
-
- // Test updating the Target Address.
- addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11")
- na.SetTargetAddress(addr2)
- if got := na.TargetAddress(); got != addr2 {
- t.Errorf("got TargetAddress = %s, want %s", got, addr2)
- }
- // Make sure the address got updated in the backing buffer.
- if got := tcpip.Address(b[ndpNATargetAddressOffset:][:IPv6AddressSize]); got != addr2 {
- t.Errorf("got targetaddress buffer = %s, want %s", got, addr2)
- }
-
- // Test updating the Router Flag.
- na.SetRouterFlag(false)
- if got := na.RouterFlag(); got {
- t.Errorf("got RouterFlag = true, want = false")
- }
-
- // Test updating the Solicited Flag.
- na.SetSolicitedFlag(true)
- if got := na.SolicitedFlag(); !got {
- t.Errorf("got SolicitedFlag = false, want = true")
- }
-
- // Test updating the Override Flag.
- na.SetOverrideFlag(false)
- if got := na.OverrideFlag(); got {
- t.Errorf("got OverrideFlag = true, want = false")
- }
-
- // Make sure flags got updated in the backing buffer.
- if got := b[ndpNAFlagsOffset]; got != 64 {
- t.Errorf("got flags byte = %d, want = 64")
- }
-}
-
-func TestNDPRouterAdvert(t *testing.T) {
- b := []byte{
- 64, 128, 1, 2,
- 3, 4, 5, 6,
- 7, 8, 9, 10,
- }
-
- ra := NDPRouterAdvert(b)
-
- if got := ra.CurrHopLimit(); got != 64 {
- t.Errorf("got ra.CurrHopLimit = %d, want = 64", got)
- }
-
- if got := ra.ManagedAddrConfFlag(); !got {
- t.Errorf("got ManagedAddrConfFlag = false, want = true")
- }
-
- if got := ra.OtherConfFlag(); got {
- t.Errorf("got OtherConfFlag = true, want = false")
- }
-
- if got, want := ra.RouterLifetime(), time.Second*258; got != want {
- t.Errorf("got ra.RouterLifetime = %d, want = %d", got, want)
- }
-
- if got, want := ra.ReachableTime(), time.Millisecond*50595078; got != want {
- t.Errorf("got ra.ReachableTime = %d, want = %d", got, want)
- }
-
- if got, want := ra.RetransTimer(), time.Millisecond*117967114; got != want {
- t.Errorf("got ra.RetransTimer = %d, want = %d", got, want)
- }
-}
-
-// TestNDPSourceLinkLayerAddressOptionEthernetAddress tests getting the
-// Ethernet address from an NDPSourceLinkLayerAddressOption.
-func TestNDPSourceLinkLayerAddressOptionEthernetAddress(t *testing.T) {
- tests := []struct {
- name string
- buf []byte
- expected tcpip.LinkAddress
- }{
- {
- "ValidMAC",
- []byte{1, 2, 3, 4, 5, 6},
- tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
- },
- {
- "SLLBodyTooShort",
- []byte{1, 2, 3, 4, 5},
- tcpip.LinkAddress([]byte(nil)),
- },
- {
- "SLLBodyLargerThanNeeded",
- []byte{1, 2, 3, 4, 5, 6, 7, 8},
- tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- sll := NDPSourceLinkLayerAddressOption(test.buf)
- if got := sll.EthernetAddress(); got != test.expected {
- t.Errorf("got sll.EthernetAddress = %s, want = %s", got, test.expected)
- }
- })
- }
-}
-
-// TestNDPSourceLinkLayerAddressOptionSerialize tests serializing a
-// NDPSourceLinkLayerAddressOption.
-func TestNDPSourceLinkLayerAddressOptionSerialize(t *testing.T) {
- tests := []struct {
- name string
- buf []byte
- expectedBuf []byte
- addr tcpip.LinkAddress
- }{
- {
- "Ethernet",
- make([]byte, 8),
- []byte{1, 1, 1, 2, 3, 4, 5, 6},
- "\x01\x02\x03\x04\x05\x06",
- },
- {
- "Padding",
- []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
- []byte{1, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0},
- "\x01\x02\x03\x04\x05\x06\x07\x08",
- },
- {
- "Empty",
- nil,
- nil,
- "",
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- opts := NDPOptions(test.buf)
- serializer := NDPOptionsSerializer{
- NDPSourceLinkLayerAddressOption(test.addr),
- }
- if got, want := int(serializer.Length()), len(test.expectedBuf); got != want {
- t.Fatalf("got Length = %d, want = %d", got, want)
- }
- opts.Serialize(serializer)
- if !bytes.Equal(test.buf, test.expectedBuf) {
- t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf)
- }
-
- it, err := opts.Iter(true)
- if err != nil {
- t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
- }
-
- if len(test.expectedBuf) > 0 {
- next, done, err := it.Next()
- if err != nil {
- t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if done {
- t.Fatal("got Next = (_, true, _), want = (_, false, _)")
- }
- if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType {
- t.Fatalf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType)
- }
- sll := next.(NDPSourceLinkLayerAddressOption)
- if got, want := []byte(sll), test.expectedBuf[2:]; !bytes.Equal(got, want) {
- t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
- }
-
- if got, want := sll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want {
- t.Errorf("got sll.EthernetAddress = %s, want = %s", got, want)
- }
- }
-
- // Iterator should not return anything else.
- next, done, err := it.Next()
- if err != nil {
- t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if !done {
- t.Error("got Next = (_, false, _), want = (_, true, _)")
- }
- if next != nil {
- t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
- }
- })
- }
-}
-
-// TestNDPTargetLinkLayerAddressOptionEthernetAddress tests getting the
-// Ethernet address from an NDPTargetLinkLayerAddressOption.
-func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) {
- tests := []struct {
- name string
- buf []byte
- expected tcpip.LinkAddress
- }{
- {
- "ValidMAC",
- []byte{1, 2, 3, 4, 5, 6},
- tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
- },
- {
- "TLLBodyTooShort",
- []byte{1, 2, 3, 4, 5},
- tcpip.LinkAddress([]byte(nil)),
- },
- {
- "TLLBodyLargerThanNeeded",
- []byte{1, 2, 3, 4, 5, 6, 7, 8},
- tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- tll := NDPTargetLinkLayerAddressOption(test.buf)
- if got := tll.EthernetAddress(); got != test.expected {
- t.Errorf("got tll.EthernetAddress = %s, want = %s", got, test.expected)
- }
- })
- }
-}
-
-// TestNDPTargetLinkLayerAddressOptionSerialize tests serializing a
-// NDPTargetLinkLayerAddressOption.
-func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) {
- tests := []struct {
- name string
- buf []byte
- expectedBuf []byte
- addr tcpip.LinkAddress
- }{
- {
- "Ethernet",
- make([]byte, 8),
- []byte{2, 1, 1, 2, 3, 4, 5, 6},
- "\x01\x02\x03\x04\x05\x06",
- },
- {
- "Padding",
- []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
- []byte{2, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0},
- "\x01\x02\x03\x04\x05\x06\x07\x08",
- },
- {
- "Empty",
- nil,
- nil,
- "",
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- opts := NDPOptions(test.buf)
- serializer := NDPOptionsSerializer{
- NDPTargetLinkLayerAddressOption(test.addr),
- }
- if got, want := int(serializer.Length()), len(test.expectedBuf); got != want {
- t.Fatalf("got Length = %d, want = %d", got, want)
- }
- opts.Serialize(serializer)
- if !bytes.Equal(test.buf, test.expectedBuf) {
- t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf)
- }
-
- it, err := opts.Iter(true)
- if err != nil {
- t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
- }
-
- if len(test.expectedBuf) > 0 {
- next, done, err := it.Next()
- if err != nil {
- t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if done {
- t.Fatal("got Next = (_, true, _), want = (_, false, _)")
- }
- if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType {
- t.Fatalf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType)
- }
- tll := next.(NDPTargetLinkLayerAddressOption)
- if got, want := []byte(tll), test.expectedBuf[2:]; !bytes.Equal(got, want) {
- t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
- }
-
- if got, want := tll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want {
- t.Errorf("got tll.EthernetAddress = %s, want = %s", got, want)
- }
- }
-
- // Iterator should not return anything else.
- next, done, err := it.Next()
- if err != nil {
- t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if !done {
- t.Error("got Next = (_, false, _), want = (_, true, _)")
- }
- if next != nil {
- t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
- }
- })
- }
-}
-
-// TestNDPPrefixInformationOption tests the field getters and serialization of a
-// NDPPrefixInformation.
-func TestNDPPrefixInformationOption(t *testing.T) {
- b := []byte{
- 43, 127,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 5, 5, 5, 5,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 17, 18, 19, 20,
- 21, 22, 23, 24,
- }
-
- targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
- opts := NDPOptions(targetBuf)
- serializer := NDPOptionsSerializer{
- NDPPrefixInformation(b),
- }
- opts.Serialize(serializer)
- expectedBuf := []byte{
- 3, 4, 43, 64,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 0, 0, 0, 0,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 17, 18, 19, 20,
- 21, 22, 23, 24,
- }
- if !bytes.Equal(targetBuf, expectedBuf) {
- t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expectedBuf)
- }
-
- it, err := opts.Iter(true)
- if err != nil {
- t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
- }
-
- next, done, err := it.Next()
- if err != nil {
- t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if done {
- t.Fatal("got Next = (_, true, _), want = (_, false, _)")
- }
- if got := next.Type(); got != NDPPrefixInformationType {
- t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType)
- }
-
- pi := next.(NDPPrefixInformation)
-
- if got := pi.Type(); got != 3 {
- t.Errorf("got Type = %d, want = 3", got)
- }
-
- if got := pi.Length(); got != 30 {
- t.Errorf("got Length = %d, want = 30", got)
- }
-
- if got := pi.PrefixLength(); got != 43 {
- t.Errorf("got PrefixLength = %d, want = 43", got)
- }
-
- if pi.OnLinkFlag() {
- t.Error("got OnLinkFlag = true, want = false")
- }
-
- if !pi.AutonomousAddressConfigurationFlag() {
- t.Error("got AutonomousAddressConfigurationFlag = false, want = true")
- }
-
- if got, want := pi.ValidLifetime(), 16909060*time.Second; got != want {
- t.Errorf("got ValidLifetime = %d, want = %d", got, want)
- }
-
- if got, want := pi.PreferredLifetime(), 84281096*time.Second; got != want {
- t.Errorf("got PreferredLifetime = %d, want = %d", got, want)
- }
-
- if got, want := pi.Prefix(), tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18"); got != want {
- t.Errorf("got Prefix = %s, want = %s", got, want)
- }
-
- // Iterator should not return anything else.
- next, done, err = it.Next()
- if err != nil {
- t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if !done {
- t.Error("got Next = (_, false, _), want = (_, true, _)")
- }
- if next != nil {
- t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
- }
-}
-
-func TestNDPRecursiveDNSServerOptionSerialize(t *testing.T) {
- b := []byte{
- 9, 8,
- 1, 2, 4, 8,
- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
- }
- targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
- expected := []byte{
- 25, 3, 0, 0,
- 1, 2, 4, 8,
- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
- }
- opts := NDPOptions(targetBuf)
- serializer := NDPOptionsSerializer{
- NDPRecursiveDNSServer(b),
- }
- if got, want := opts.Serialize(serializer), len(expected); got != want {
- t.Errorf("got Serialize = %d, want = %d", got, want)
- }
- if !bytes.Equal(targetBuf, expected) {
- t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected)
- }
-
- it, err := opts.Iter(true)
- if err != nil {
- t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
- }
-
- next, done, err := it.Next()
- if err != nil {
- t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if done {
- t.Fatal("got Next = (_, true, _), want = (_, false, _)")
- }
- if got := next.Type(); got != NDPRecursiveDNSServerOptionType {
- t.Errorf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType)
- }
-
- opt, ok := next.(NDPRecursiveDNSServer)
- if !ok {
- t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next)
- }
- if got := opt.Type(); got != 25 {
- t.Errorf("got Type = %d, want = 31", got)
- }
- if got := opt.Length(); got != 22 {
- t.Errorf("got Length = %d, want = 22", got)
- }
- if got, want := opt.Lifetime(), 16909320*time.Second; got != want {
- t.Errorf("got Lifetime = %s, want = %s", got, want)
- }
- want := []tcpip.Address{
- "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
- }
- if got := opt.Addresses(); !cmp.Equal(got, want) {
- t.Errorf("got Addresses = %v, want = %v", got, want)
- }
-
- // Iterator should not return anything else.
- next, done, err = it.Next()
- if err != nil {
- t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if !done {
- t.Error("got Next = (_, false, _), want = (_, true, _)")
- }
- if next != nil {
- t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
- }
-}
-
-func TestNDPRecursiveDNSServerOption(t *testing.T) {
- tests := []struct {
- name string
- buf []byte
- lifetime time.Duration
- addrs []tcpip.Address
- }{
- {
- "Valid1Addr",
- []byte{
- 25, 3, 0, 0,
- 0, 0, 0, 0,
- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
- },
- 0,
- []tcpip.Address{
- "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
- },
- },
- {
- "Valid2Addr",
- []byte{
- 25, 5, 0, 0,
- 0, 0, 0, 0,
- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
- 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16,
- },
- 0,
- []tcpip.Address{
- "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
- "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10",
- },
- },
- {
- "Valid3Addr",
- []byte{
- 25, 7, 0, 0,
- 0, 0, 0, 0,
- 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
- 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16,
- 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 17,
- },
- 0,
- []tcpip.Address{
- "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
- "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10",
- "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x11",
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- opts := NDPOptions(test.buf)
- it, err := opts.Iter(true)
- if err != nil {
- t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
- }
-
- // Iterator should get our option.
- next, done, err := it.Next()
- if err != nil {
- t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if done {
- t.Fatal("got Next = (_, true, _), want = (_, false, _)")
- }
- if got := next.Type(); got != NDPRecursiveDNSServerOptionType {
- t.Fatalf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType)
- }
-
- opt, ok := next.(NDPRecursiveDNSServer)
- if !ok {
- t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next)
- }
- if got := opt.Lifetime(); got != test.lifetime {
- t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime)
- }
- if got := opt.Addresses(); !cmp.Equal(got, test.addrs) {
- t.Errorf("got Addresses = %v, want = %v", got, test.addrs)
- }
-
- // Iterator should not return anything else.
- next, done, err = it.Next()
- if err != nil {
- t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if !done {
- t.Error("got Next = (_, false, _), want = (_, true, _)")
- }
- if next != nil {
- t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
- }
- })
- }
-}
-
-// TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions
-// the iterator was returned for is malformed.
-func TestNDPOptionsIterCheck(t *testing.T) {
- tests := []struct {
- name string
- buf []byte
- expected error
- }{
- {
- "ZeroLengthField",
- []byte{0, 0, 0, 0, 0, 0, 0, 0},
- ErrNDPOptZeroLength,
- },
- {
- "ValidSourceLinkLayerAddressOption",
- []byte{1, 1, 1, 2, 3, 4, 5, 6},
- nil,
- },
- {
- "TooSmallSourceLinkLayerAddressOption",
- []byte{1, 1, 1, 2, 3, 4, 5},
- ErrNDPOptBufExhausted,
- },
- {
- "ValidTargetLinkLayerAddressOption",
- []byte{2, 1, 1, 2, 3, 4, 5, 6},
- nil,
- },
- {
- "TooSmallTargetLinkLayerAddressOption",
- []byte{2, 1, 1, 2, 3, 4, 5},
- ErrNDPOptBufExhausted,
- },
- {
- "ValidPrefixInformation",
- []byte{
- 3, 4, 43, 64,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 0, 0, 0, 0,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 17, 18, 19, 20,
- 21, 22, 23, 24,
- },
- nil,
- },
- {
- "TooSmallPrefixInformation",
- []byte{
- 3, 4, 43, 64,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 0, 0, 0, 0,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 17, 18, 19, 20,
- 21, 22, 23,
- },
- ErrNDPOptBufExhausted,
- },
- {
- "InvalidPrefixInformationLength",
- []byte{
- 3, 3, 43, 64,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 0, 0, 0, 0,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- },
- ErrNDPOptMalformedBody,
- },
- {
- "ValidSourceAndTargetLinkLayerAddressWithPrefixInformation",
- []byte{
- // Source Link-Layer Address.
- 1, 1, 1, 2, 3, 4, 5, 6,
-
- // Target Link-Layer Address.
- 2, 1, 7, 8, 9, 10, 11, 12,
-
- // Prefix information.
- 3, 4, 43, 64,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 0, 0, 0, 0,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 17, 18, 19, 20,
- 21, 22, 23, 24,
- },
- nil,
- },
- {
- "ValidSourceAndTargetLinkLayerAddressWithPrefixInformationWithUnrecognized",
- []byte{
- // Source Link-Layer Address.
- 1, 1, 1, 2, 3, 4, 5, 6,
-
- // Target Link-Layer Address.
- 2, 1, 7, 8, 9, 10, 11, 12,
-
- // 255 is an unrecognized type. If 255 ends up
- // being the type for some recognized type,
- // update 255 to some other unrecognized value.
- 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8,
-
- // Prefix information.
- 3, 4, 43, 64,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 0, 0, 0, 0,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 17, 18, 19, 20,
- 21, 22, 23, 24,
- },
- nil,
- },
- {
- "InvalidRecursiveDNSServerCutsOffAddress",
- []byte{
- 25, 4, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
- 0, 1, 2, 3, 4, 5, 6, 7,
- },
- ErrNDPOptMalformedBody,
- },
- {
- "InvalidRecursiveDNSServerInvalidLengthField",
- []byte{
- 25, 2, 0, 0,
- 0, 0, 0, 0,
- 0, 1, 2, 3, 4, 5, 6, 7, 8,
- },
- ErrNDPInvalidLength,
- },
- {
- "RecursiveDNSServerTooSmall",
- []byte{
- 25, 1, 0, 0,
- 0, 0, 0,
- },
- ErrNDPOptBufExhausted,
- },
- {
- "RecursiveDNSServerMulticast",
- []byte{
- 25, 3, 0, 0,
- 0, 0, 0, 0,
- 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
- },
- ErrNDPOptMalformedBody,
- },
- {
- "RecursiveDNSServerUnspecified",
- []byte{
- 25, 3, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- },
- ErrNDPOptMalformedBody,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- opts := NDPOptions(test.buf)
-
- if _, err := opts.Iter(true); err != test.expected {
- t.Fatalf("got Iter(true) = (_, %v), want = (_, %v)", err, test.expected)
- }
-
- // test.buf may be malformed but we chose not to check
- // the iterator so it must return true.
- if _, err := opts.Iter(false); err != nil {
- t.Fatalf("got Iter(false) = (_, %s), want = (_, nil)", err)
- }
- })
- }
-}
-
-// TestNDPOptionsIter tests that we can iterator over a valid NDPOptions. Note,
-// this test does not actually check any of the option's getters, it simply
-// checks the option Type and Body. We have other tests that tests the option
-// field gettings given an option body and don't need to duplicate those tests
-// here.
-func TestNDPOptionsIter(t *testing.T) {
- buf := []byte{
- // Source Link-Layer Address.
- 1, 1, 1, 2, 3, 4, 5, 6,
-
- // Target Link-Layer Address.
- 2, 1, 7, 8, 9, 10, 11, 12,
-
- // 255 is an unrecognized type. If 255 ends up being the type
- // for some recognized type, update 255 to some other
- // unrecognized value. Note, this option should be skipped when
- // iterating.
- 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8,
-
- // Prefix information.
- 3, 4, 43, 64,
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- 0, 0, 0, 0,
- 9, 10, 11, 12,
- 13, 14, 15, 16,
- 17, 18, 19, 20,
- 21, 22, 23, 24,
- }
-
- opts := NDPOptions(buf)
- it, err := opts.Iter(true)
- if err != nil {
- t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
- }
-
- // Test the first (Source Link-Layer) option.
- next, done, err := it.Next()
- if err != nil {
- t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if done {
- t.Fatal("got Next = (_, true, _), want = (_, false, _)")
- }
- if got, want := []byte(next.(NDPSourceLinkLayerAddressOption)), buf[2:][:6]; !bytes.Equal(got, want) {
- t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
- }
- if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType {
- t.Errorf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType)
- }
-
- // Test the next (Target Link-Layer) option.
- next, done, err = it.Next()
- if err != nil {
- t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if done {
- t.Fatal("got Next = (_, true, _), want = (_, false, _)")
- }
- if got, want := []byte(next.(NDPTargetLinkLayerAddressOption)), buf[10:][:6]; !bytes.Equal(got, want) {
- t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
- }
- if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType {
- t.Errorf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType)
- }
-
- // Test the next (Prefix Information) option.
- // Note, the unrecognized option should be skipped.
- next, done, err = it.Next()
- if err != nil {
- t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if done {
- t.Fatal("got Next = (_, true, _), want = (_, false, _)")
- }
- if got, want := next.(NDPPrefixInformation), buf[34:][:30]; !bytes.Equal(got, want) {
- t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
- }
- if got := next.Type(); got != NDPPrefixInformationType {
- t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType)
- }
-
- // Iterator should not return anything else.
- next, done, err = it.Next()
- if err != nil {
- t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
- }
- if !done {
- t.Error("got Next = (_, false, _), want = (_, true, _)")
- }
- if next != nil {
- t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
- }
-}
diff --git a/pkg/tcpip/header/tcp_test.go b/pkg/tcpip/header/tcp_test.go
deleted file mode 100644
index 72563837b..000000000
--- a/pkg/tcpip/header/tcp_test.go
+++ /dev/null
@@ -1,148 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package header_test
-
-import (
- "reflect"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip/header"
-)
-
-func TestEncodeSACKBlocks(t *testing.T) {
- testCases := []struct {
- sackBlocks []header.SACKBlock
- want []header.SACKBlock
- bufSize int
- }{
- {
- []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
- []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}},
- 40,
- },
- {
- []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
- []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}},
- 30,
- },
- {
- []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
- []header.SACKBlock{{10, 20}, {22, 30}},
- 20,
- },
- {
- []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
- []header.SACKBlock{{10, 20}},
- 10,
- },
- {
- []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
- nil,
- 8,
- },
- {
- []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
- []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}},
- 60,
- },
- }
- for _, tc := range testCases {
- b := make([]byte, tc.bufSize)
- t.Logf("testing: %v", tc)
- header.EncodeSACKBlocks(tc.sackBlocks, b)
- opts := header.ParseTCPOptions(b)
- if got, want := opts.SACKBlocks, tc.want; !reflect.DeepEqual(got, want) {
- t.Errorf("header.EncodeSACKBlocks(%v, %v), encoded blocks got: %v, want: %v", tc.sackBlocks, b, got, want)
- }
- }
-}
-
-func TestTCPParseOptions(t *testing.T) {
- type tsOption struct {
- tsVal uint32
- tsEcr uint32
- }
-
- generateOptions := func(tsOpt *tsOption, sackBlocks []header.SACKBlock) []byte {
- l := 0
- if tsOpt != nil {
- l += 10
- }
- if len(sackBlocks) != 0 {
- l += len(sackBlocks)*8 + 2
- }
- b := make([]byte, l)
- offset := 0
- if tsOpt != nil {
- offset = header.EncodeTSOption(tsOpt.tsVal, tsOpt.tsEcr, b)
- }
- header.EncodeSACKBlocks(sackBlocks, b[offset:])
- return b
- }
-
- testCases := []struct {
- b []byte
- want header.TCPOptions
- }{
- // Trivial cases.
- {nil, header.TCPOptions{false, 0, 0, nil}},
- {[]byte{header.TCPOptionNOP}, header.TCPOptions{false, 0, 0, nil}},
- {[]byte{header.TCPOptionNOP, header.TCPOptionNOP}, header.TCPOptions{false, 0, 0, nil}},
- {[]byte{header.TCPOptionEOL}, header.TCPOptions{false, 0, 0, nil}},
- {[]byte{header.TCPOptionNOP, header.TCPOptionEOL, header.TCPOptionTS, 10, 1, 1}, header.TCPOptions{false, 0, 0, nil}},
-
- // Test timestamp parsing.
- {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}},
- {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}},
-
- // Test malformed timestamp option.
- {[]byte{header.TCPOptionTS, 8, 1, 1}, header.TCPOptions{false, 0, 0, nil}},
- {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 8, 1, 1}, header.TCPOptions{false, 0, 0, nil}},
- {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 8, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}},
-
- // Test SACKBlock parsing.
- {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, 1, 0, 0, 0, 10}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{1, 10}}}},
- {[]byte{header.TCPOptionSACK, 18, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{1, 10}, {11, 12}}}},
-
- // Test malformed SACK option.
- {[]byte{header.TCPOptionSACK, 0}, header.TCPOptions{false, 0, 0, nil}},
- {[]byte{header.TCPOptionSACK, 8, 0, 0, 0, 1, 0, 0, 0, 10}, header.TCPOptions{false, 0, 0, nil}},
- {[]byte{header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, nil}},
- {[]byte{header.TCPOptionSACK, 17, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, nil}},
- {[]byte{header.TCPOptionSACK}, header.TCPOptions{false, 0, 0, nil}},
- {[]byte{header.TCPOptionSACK, 10}, header.TCPOptions{false, 0, 0, nil}},
- {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, 1, 0, 0, 0}, header.TCPOptions{false, 0, 0, nil}},
-
- // Test Timestamp + SACK block parsing.
- {generateOptions(&tsOption{1, 1}, []header.SACKBlock{{1, 10}, {11, 12}}), header.TCPOptions{true, 1, 1, []header.SACKBlock{{1, 10}, {11, 12}}}},
- {generateOptions(&tsOption{1, 2}, []header.SACKBlock{{1, 10}, {11, 12}}), header.TCPOptions{true, 1, 2, []header.SACKBlock{{1, 10}, {11, 12}}}},
- {generateOptions(&tsOption{1, 3}, []header.SACKBlock{{1, 10}, {11, 12}, {13, 14}, {14, 15}, {15, 16}}), header.TCPOptions{true, 1, 3, []header.SACKBlock{{1, 10}, {11, 12}, {13, 14}, {14, 15}}}},
-
- // Test valid timestamp + malformed SACK block parsing.
- {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK}, header.TCPOptions{true, 1, 1, nil}},
- {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 10}, header.TCPOptions{true, 1, 1, nil}},
- {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 10, 0, 0, 0}, header.TCPOptions{true, 1, 1, nil}},
- {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}},
- {[]byte{header.TCPOptionSACK, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}},
- {[]byte{header.TCPOptionSACK, 10, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{134873088, 65536}}}},
- {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{8, 167772160}}}},
- {[]byte{header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}},
- }
- for _, tc := range testCases {
- if got, want := header.ParseTCPOptions(tc.b), tc.want; !reflect.DeepEqual(got, want) {
- t.Errorf("ParseTCPOptions(%v) = %v, want: %v", tc.b, got, tc.want)
- }
- }
-}
diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD
deleted file mode 100644
index d1b73cfdf..000000000
--- a/pkg/tcpip/iptables/BUILD
+++ /dev/null
@@ -1,18 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "iptables",
- srcs = [
- "iptables.go",
- "targets.go",
- "types.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/log",
- "//pkg/tcpip",
- "//pkg/tcpip/header",
- ],
-)
diff --git a/pkg/tcpip/iptables/iptables_state_autogen.go b/pkg/tcpip/iptables/iptables_state_autogen.go
new file mode 100755
index 000000000..e75169fa7
--- /dev/null
+++ b/pkg/tcpip/iptables/iptables_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package iptables
diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD
deleted file mode 100644
index 3974c464e..000000000
--- a/pkg/tcpip/link/channel/BUILD
+++ /dev/null
@@ -1,14 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "channel",
- srcs = ["channel.go"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/stack",
- ],
-)
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index 78d447acd..78d447acd 100644..100755
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
diff --git a/pkg/tcpip/link/channel/channel_state_autogen.go b/pkg/tcpip/link/channel/channel_state_autogen.go
new file mode 100755
index 000000000..24066fb83
--- /dev/null
+++ b/pkg/tcpip/link/channel/channel_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package channel
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
deleted file mode 100644
index abe725548..000000000
--- a/pkg/tcpip/link/fdbased/BUILD
+++ /dev/null
@@ -1,39 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "fdbased",
- srcs = [
- "endpoint.go",
- "endpoint_unsafe.go",
- "mmap.go",
- "mmap_stub.go",
- "mmap_unsafe.go",
- "packet_dispatchers.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/rawfile",
- "//pkg/tcpip/stack",
- "@org_golang_x_sys//unix:go_default_library",
- ],
-)
-
-go_test(
- name = "fdbased_test",
- size = "small",
- srcs = ["endpoint_test.go"],
- library = ":fdbased",
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/rawfile",
- "//pkg/tcpip/stack",
- ],
-)
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
deleted file mode 100644
index 2066987eb..000000000
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ /dev/null
@@ -1,468 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// +build linux
-
-package fdbased
-
-import (
- "bytes"
- "fmt"
- "math/rand"
- "reflect"
- "syscall"
- "testing"
- "time"
- "unsafe"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-const (
- mtu = 1500
- laddr = tcpip.LinkAddress("\x11\x22\x33\x44\x55\x66")
- raddr = tcpip.LinkAddress("\x77\x88\x99\xaa\xbb\xcc")
- proto = 10
- csumOffset = 48
- gsoMSS = 500
-)
-
-type packetInfo struct {
- raddr tcpip.LinkAddress
- proto tcpip.NetworkProtocolNumber
- contents tcpip.PacketBuffer
-}
-
-type context struct {
- t *testing.T
- fds [2]int
- ep stack.LinkEndpoint
- ch chan packetInfo
- done chan struct{}
-}
-
-func newContext(t *testing.T, opt *Options) *context {
- fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0)
- if err != nil {
- t.Fatalf("Socketpair failed: %v", err)
- }
-
- done := make(chan struct{}, 1)
- opt.ClosedFunc = func(*tcpip.Error) {
- done <- struct{}{}
- }
-
- opt.FDs = []int{fds[1]}
- ep, err := New(opt)
- if err != nil {
- t.Fatalf("Failed to create FD endpoint: %v", err)
- }
-
- c := &context{
- t: t,
- fds: fds,
- ep: ep,
- ch: make(chan packetInfo, 100),
- done: done,
- }
-
- ep.Attach(c)
-
- return c
-}
-
-func (c *context) cleanup() {
- syscall.Close(c.fds[0])
- <-c.done
- syscall.Close(c.fds[1])
-}
-
-func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
- c.ch <- packetInfo{remote, protocol, pkt}
-}
-
-func TestNoEthernetProperties(t *testing.T) {
- c := newContext(t, &Options{MTU: mtu})
- defer c.cleanup()
-
- if want, v := uint16(0), c.ep.MaxHeaderLength(); want != v {
- t.Fatalf("MaxHeaderLength() = %v, want %v", v, want)
- }
-
- if want, v := uint32(mtu), c.ep.MTU(); want != v {
- t.Fatalf("MTU() = %v, want %v", v, want)
- }
-}
-
-func TestEthernetProperties(t *testing.T) {
- c := newContext(t, &Options{EthernetHeader: true, MTU: mtu})
- defer c.cleanup()
-
- if want, v := uint16(header.EthernetMinimumSize), c.ep.MaxHeaderLength(); want != v {
- t.Fatalf("MaxHeaderLength() = %v, want %v", v, want)
- }
-
- if want, v := uint32(mtu), c.ep.MTU(); want != v {
- t.Fatalf("MTU() = %v, want %v", v, want)
- }
-}
-
-func TestAddress(t *testing.T) {
- addrs := []tcpip.LinkAddress{"", "abc", "def"}
- for _, a := range addrs {
- t.Run(fmt.Sprintf("Address: %q", a), func(t *testing.T) {
- c := newContext(t, &Options{Address: a, MTU: mtu})
- defer c.cleanup()
-
- if want, v := a, c.ep.LinkAddress(); want != v {
- t.Fatalf("LinkAddress() = %v, want %v", v, want)
- }
- })
- }
-}
-
-func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) {
- c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth, GSOMaxSize: gsoMaxSize})
- defer c.cleanup()
-
- r := &stack.Route{
- RemoteLinkAddress: raddr,
- }
-
- // Build header.
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100)
- b := hdr.Prepend(100)
- for i := range b {
- b[i] = uint8(rand.Intn(256))
- }
-
- // Build payload and write.
- payload := make(buffer.View, plen)
- for i := range payload {
- payload[i] = uint8(rand.Intn(256))
- }
- want := append(hdr.View(), payload...)
- var gso *stack.GSO
- if gsoMaxSize != 0 {
- gso = &stack.GSO{
- Type: stack.GSOTCPv6,
- NeedsCsum: true,
- CsumOffset: csumOffset,
- MSS: gsoMSS,
- MaxSize: gsoMaxSize,
- L3HdrLen: header.IPv4MaximumHeaderSize,
- }
- }
- if err := c.ep.WritePacket(r, gso, proto, tcpip.PacketBuffer{
- Header: hdr,
- Data: payload.ToVectorisedView(),
- }); err != nil {
- t.Fatalf("WritePacket failed: %v", err)
- }
-
- // Read from fd, then compare with what we wrote.
- b = make([]byte, mtu)
- n, err := syscall.Read(c.fds[0], b)
- if err != nil {
- t.Fatalf("Read failed: %v", err)
- }
- b = b[:n]
- if gsoMaxSize != 0 {
- vnetHdr := *(*virtioNetHdr)(unsafe.Pointer(&b[0]))
- if vnetHdr.flags&_VIRTIO_NET_HDR_F_NEEDS_CSUM == 0 {
- t.Fatalf("virtioNetHdr.flags %v doesn't contain %v", vnetHdr.flags, _VIRTIO_NET_HDR_F_NEEDS_CSUM)
- }
- csumStart := header.EthernetMinimumSize + gso.L3HdrLen
- if vnetHdr.csumStart != csumStart {
- t.Fatalf("vnetHdr.csumStart = %v, want %v", vnetHdr.csumStart, csumStart)
- }
- if vnetHdr.csumOffset != csumOffset {
- t.Fatalf("vnetHdr.csumOffset = %v, want %v", vnetHdr.csumOffset, csumOffset)
- }
- gsoType := uint8(0)
- if int(gso.MSS) < plen {
- gsoType = _VIRTIO_NET_HDR_GSO_TCPV6
- }
- if vnetHdr.gsoType != gsoType {
- t.Fatalf("vnetHdr.gsoType = %v, want %v", vnetHdr.gsoType, gsoType)
- }
- b = b[virtioNetHdrSize:]
- }
- if eth {
- h := header.Ethernet(b)
- b = b[header.EthernetMinimumSize:]
-
- if a := h.SourceAddress(); a != laddr {
- t.Fatalf("SourceAddress() = %v, want %v", a, laddr)
- }
-
- if a := h.DestinationAddress(); a != raddr {
- t.Fatalf("DestinationAddress() = %v, want %v", a, raddr)
- }
-
- if et := h.Type(); et != proto {
- t.Fatalf("Type() = %v, want %v", et, proto)
- }
- }
- if len(b) != len(want) {
- t.Fatalf("Read returned %v bytes, want %v", len(b), len(want))
- }
- if !bytes.Equal(b, want) {
- t.Fatalf("Read returned %x, want %x", b, want)
- }
-}
-
-func TestWritePacket(t *testing.T) {
- lengths := []int{0, 100, 1000}
- eths := []bool{true, false}
- gsos := []uint32{0, 32768}
-
- for _, eth := range eths {
- for _, plen := range lengths {
- for _, gso := range gsos {
- t.Run(
- fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v", eth, plen, gso),
- func(t *testing.T) {
- testWritePacket(t, plen, eth, gso)
- },
- )
- }
- }
- }
-}
-
-func TestPreserveSrcAddress(t *testing.T) {
- baddr := tcpip.LinkAddress("\xcc\xbb\xaa\x77\x88\x99")
-
- c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: true})
- defer c.cleanup()
-
- // Set LocalLinkAddress in route to the value of the bridged address.
- r := &stack.Route{
- RemoteLinkAddress: raddr,
- LocalLinkAddress: baddr,
- }
-
- // WritePacket panics given a prependable with anything less than
- // the minimum size of the ethernet header.
- hdr := buffer.NewPrependable(header.EthernetMinimumSize)
- if err := c.ep.WritePacket(r, nil /* gso */, proto, tcpip.PacketBuffer{
- Header: hdr,
- Data: buffer.VectorisedView{},
- }); err != nil {
- t.Fatalf("WritePacket failed: %v", err)
- }
-
- // Read from the FD, then compare with what we wrote.
- b := make([]byte, mtu)
- n, err := syscall.Read(c.fds[0], b)
- if err != nil {
- t.Fatalf("Read failed: %v", err)
- }
- b = b[:n]
- h := header.Ethernet(b)
-
- if a := h.SourceAddress(); a != baddr {
- t.Fatalf("SourceAddress() = %v, want %v", a, baddr)
- }
-}
-
-func TestDeliverPacket(t *testing.T) {
- lengths := []int{100, 1000}
- eths := []bool{true, false}
-
- for _, eth := range eths {
- for _, plen := range lengths {
- t.Run(fmt.Sprintf("Eth=%v,PayloadLen=%v", eth, plen), func(t *testing.T) {
- c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth})
- defer c.cleanup()
-
- // Build packet.
- b := make([]byte, plen)
- all := b
- for i := range b {
- b[i] = uint8(rand.Intn(256))
- }
-
- var hdr header.Ethernet
- if !eth {
- // So that it looks like an IPv4 packet.
- b[0] = 0x40
- } else {
- hdr = make(header.Ethernet, header.EthernetMinimumSize)
- hdr.Encode(&header.EthernetFields{
- SrcAddr: raddr,
- DstAddr: laddr,
- Type: proto,
- })
- all = append(hdr, b...)
- }
-
- // Write packet via the file descriptor.
- if _, err := syscall.Write(c.fds[0], all); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- // Receive packet through the endpoint.
- select {
- case pi := <-c.ch:
- want := packetInfo{
- raddr: raddr,
- proto: proto,
- contents: tcpip.PacketBuffer{
- Data: buffer.View(b).ToVectorisedView(),
- LinkHeader: buffer.View(hdr),
- },
- }
- if !eth {
- want.proto = header.IPv4ProtocolNumber
- want.raddr = ""
- }
- // want.contents.Data will be a single
- // view, so make pi do the same for the
- // DeepEqual check.
- pi.contents.Data = pi.contents.Data.ToView().ToVectorisedView()
- if !reflect.DeepEqual(want, pi) {
- t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want)
- }
- case <-time.After(10 * time.Second):
- t.Fatalf("Timed out waiting for packet")
- }
- })
- }
- }
-}
-
-func TestBufConfigMaxLength(t *testing.T) {
- got := 0
- for _, i := range BufConfig {
- got += i
- }
- want := header.MaxIPPacketSize // maximum TCP packet size
- if got < want {
- t.Errorf("total buffer size is invalid: got %d, want >= %d", got, want)
- }
-}
-
-func TestBufConfigFirst(t *testing.T) {
- // The stack assumes that the TCP/IP header is enterily contained in the first view.
- // Therefore, the first view needs to be large enough to contain the maximum TCP/IP
- // header, which is 120 bytes (60 bytes for IP + 60 bytes for TCP).
- want := 120
- got := BufConfig[0]
- if got < want {
- t.Errorf("first view has an invalid size: got %d, want >= %d", got, want)
- }
-}
-
-var capLengthTestCases = []struct {
- comment string
- config []int
- n int
- wantUsed int
- wantLengths []int
-}{
- {
- comment: "Single slice",
- config: []int{2},
- n: 1,
- wantUsed: 1,
- wantLengths: []int{1},
- },
- {
- comment: "Multiple slices",
- config: []int{1, 2},
- n: 2,
- wantUsed: 2,
- wantLengths: []int{1, 1},
- },
- {
- comment: "Entire buffer",
- config: []int{1, 2},
- n: 3,
- wantUsed: 2,
- wantLengths: []int{1, 2},
- },
- {
- comment: "Entire buffer but not on the last slice",
- config: []int{1, 2, 3},
- n: 3,
- wantUsed: 2,
- wantLengths: []int{1, 2, 3},
- },
-}
-
-func TestReadVDispatcherCapLength(t *testing.T) {
- for _, c := range capLengthTestCases {
- // fd does not matter for this test.
- d := readVDispatcher{fd: -1, e: &endpoint{}}
- d.views = make([]buffer.View, len(c.config))
- d.iovecs = make([]syscall.Iovec, len(c.config))
- d.allocateViews(c.config)
-
- used := d.capViews(c.n, c.config)
- if used != c.wantUsed {
- t.Errorf("Test %q failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed)
- }
- lengths := make([]int, len(d.views))
- for i, v := range d.views {
- lengths[i] = len(v)
- }
- if !reflect.DeepEqual(lengths, c.wantLengths) {
- t.Errorf("Test %q failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths)
- }
- }
-}
-
-func TestRecvMMsgDispatcherCapLength(t *testing.T) {
- for _, c := range capLengthTestCases {
- d := recvMMsgDispatcher{
- fd: -1, // fd does not matter for this test.
- e: &endpoint{},
- views: make([][]buffer.View, 1),
- iovecs: make([][]syscall.Iovec, 1),
- msgHdrs: make([]rawfile.MMsgHdr, 1),
- }
-
- for i, _ := range d.views {
- d.views[i] = make([]buffer.View, len(c.config))
- }
- for i := range d.iovecs {
- d.iovecs[i] = make([]syscall.Iovec, len(c.config))
- }
- for k, msgHdr := range d.msgHdrs {
- msgHdr.Msg.Iov = &d.iovecs[k][0]
- msgHdr.Msg.Iovlen = uint64(len(c.config))
- }
-
- d.allocateViews(c.config)
-
- used := d.capViews(0, c.n, c.config)
- if used != c.wantUsed {
- t.Errorf("Test %q failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed)
- }
- lengths := make([]int, len(d.views[0]))
- for i, v := range d.views[0] {
- lengths[i] = len(v)
- }
- if !reflect.DeepEqual(lengths, c.wantLengths) {
- t.Errorf("Test %q failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths)
- }
-
- }
-}
diff --git a/pkg/tcpip/link/fdbased/fdbased_state_autogen.go b/pkg/tcpip/link/fdbased/fdbased_state_autogen.go
new file mode 100755
index 000000000..97cb3958e
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/fdbased_state_autogen.go
@@ -0,0 +1,10 @@
+// automatically generated by stateify.
+
+// +build linux
+// +build linux
+// +build linux,amd64 linux,arm64
+// +build !linux !amd64,!arm64
+// +build linux,amd64 linux,arm64
+// +build linux
+
+package fdbased
diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD
deleted file mode 100644
index 6bf3805b7..000000000
--- a/pkg/tcpip/link/loopback/BUILD
+++ /dev/null
@@ -1,15 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "loopback",
- srcs = ["loopback.go"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/stack",
- ],
-)
diff --git a/pkg/tcpip/link/loopback/loopback_state_autogen.go b/pkg/tcpip/link/loopback/loopback_state_autogen.go
new file mode 100755
index 000000000..c00fd9f19
--- /dev/null
+++ b/pkg/tcpip/link/loopback/loopback_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package loopback
diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD
deleted file mode 100644
index 82b441b79..000000000
--- a/pkg/tcpip/link/muxed/BUILD
+++ /dev/null
@@ -1,28 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "muxed",
- srcs = ["injectable.go"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/stack",
- ],
-)
-
-go_test(
- name = "muxed_test",
- size = "small",
- srcs = ["injectable_test.go"],
- library = ":muxed",
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/link/fdbased",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/stack",
- ],
-)
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index 445b22c17..445b22c17 100644..100755
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go
deleted file mode 100644
index 63b249837..000000000
--- a/pkg/tcpip/link/muxed/injectable_test.go
+++ /dev/null
@@ -1,98 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package muxed
-
-import (
- "bytes"
- "net"
- "os"
- "syscall"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-func TestInjectableEndpointRawDispatch(t *testing.T) {
- endpoint, sock, dstIP := makeTestInjectableEndpoint(t)
-
- endpoint.InjectOutbound(dstIP, []byte{0xFA})
-
- buf := make([]byte, ipv4.MaxTotalSize)
- bytesRead, err := sock.Read(buf)
- if err != nil {
- t.Fatalf("Unable to read from socketpair: %v", err)
- }
- if got, want := buf[:bytesRead], []byte{0xFA}; !bytes.Equal(got, want) {
- t.Fatalf("Read %v from the socketpair, wanted %v", got, want)
- }
-}
-
-func TestInjectableEndpointDispatch(t *testing.T) {
- endpoint, sock, dstIP := makeTestInjectableEndpoint(t)
-
- hdr := buffer.NewPrependable(1)
- hdr.Prepend(1)[0] = 0xFA
- packetRoute := stack.Route{RemoteAddress: dstIP}
-
- endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, tcpip.PacketBuffer{
- Header: hdr,
- Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(),
- })
-
- buf := make([]byte, 6500)
- bytesRead, err := sock.Read(buf)
- if err != nil {
- t.Fatalf("Unable to read from socketpair: %v", err)
- }
- if got, want := buf[:bytesRead], []byte{0xFA, 0xFB}; !bytes.Equal(got, want) {
- t.Fatalf("Read %v from the socketpair, wanted %v", got, want)
- }
-}
-
-func TestInjectableEndpointDispatchHdrOnly(t *testing.T) {
- endpoint, sock, dstIP := makeTestInjectableEndpoint(t)
- hdr := buffer.NewPrependable(1)
- hdr.Prepend(1)[0] = 0xFA
- packetRoute := stack.Route{RemoteAddress: dstIP}
- endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, tcpip.PacketBuffer{
- Header: hdr,
- Data: buffer.NewView(0).ToVectorisedView(),
- })
- buf := make([]byte, 6500)
- bytesRead, err := sock.Read(buf)
- if err != nil {
- t.Fatalf("Unable to read from socketpair: %v", err)
- }
- if got, want := buf[:bytesRead], []byte{0xFA}; !bytes.Equal(got, want) {
- t.Fatalf("Read %v from the socketpair, wanted %v", got, want)
- }
-}
-
-func makeTestInjectableEndpoint(t *testing.T) (*InjectableEndpoint, *os.File, tcpip.Address) {
- dstIP := tcpip.Address(net.ParseIP("1.2.3.4").To4())
- pair, err := syscall.Socketpair(syscall.AF_UNIX,
- syscall.SOCK_SEQPACKET|syscall.SOCK_CLOEXEC|syscall.SOCK_NONBLOCK, 0)
- if err != nil {
- t.Fatal("Failed to create socket pair:", err)
- }
- underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone)
- routes := map[tcpip.Address]stack.InjectableLinkEndpoint{dstIP: underlyingEndpoint}
- endpoint := NewInjectableEndpoint(routes)
- return endpoint, os.NewFile(uintptr(pair[0]), "test route end"), dstIP
-}
diff --git a/pkg/tcpip/link/muxed/muxed_state_autogen.go b/pkg/tcpip/link/muxed/muxed_state_autogen.go
new file mode 100755
index 000000000..56330e2a5
--- /dev/null
+++ b/pkg/tcpip/link/muxed/muxed_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package muxed
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
deleted file mode 100644
index 14b527bc2..000000000
--- a/pkg/tcpip/link/rawfile/BUILD
+++ /dev/null
@@ -1,20 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "rawfile",
- srcs = [
- "blockingpoll_amd64.s",
- "blockingpoll_arm64.s",
- "blockingpoll_noyield_unsafe.go",
- "blockingpoll_yield_unsafe.go",
- "errors.go",
- "rawfile_unsafe.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip",
- "@org_golang_x_sys//unix:go_default_library",
- ],
-)
diff --git a/pkg/tcpip/link/rawfile/rawfile_state_autogen.go b/pkg/tcpip/link/rawfile/rawfile_state_autogen.go
new file mode 100755
index 000000000..6b6816bae
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/rawfile_state_autogen.go
@@ -0,0 +1,10 @@
+// automatically generated by stateify.
+
+// +build linux,!amd64,!arm64
+// +build linux,amd64 linux,arm64
+// +build go1.12
+// +build !go1.15
+// +build linux
+// +build linux
+
+package rawfile
diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD
deleted file mode 100644
index 13243ebbb..000000000
--- a/pkg/tcpip/link/sharedmem/BUILD
+++ /dev/null
@@ -1,41 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "sharedmem",
- srcs = [
- "rx.go",
- "sharedmem.go",
- "sharedmem_unsafe.go",
- "tx.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/log",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/rawfile",
- "//pkg/tcpip/link/sharedmem/queue",
- "//pkg/tcpip/stack",
- ],
-)
-
-go_test(
- name = "sharedmem_test",
- srcs = [
- "sharedmem_test.go",
- ],
- library = ":sharedmem",
- deps = [
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/sharedmem/pipe",
- "//pkg/tcpip/link/sharedmem/queue",
- "//pkg/tcpip/stack",
- ],
-)
diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD
deleted file mode 100644
index 87020ec08..000000000
--- a/pkg/tcpip/link/sharedmem/pipe/BUILD
+++ /dev/null
@@ -1,23 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "pipe",
- srcs = [
- "pipe.go",
- "pipe_unsafe.go",
- "rx.go",
- "tx.go",
- ],
- visibility = ["//visibility:public"],
-)
-
-go_test(
- name = "pipe_test",
- srcs = [
- "pipe_test.go",
- ],
- library = ":pipe",
- deps = ["//pkg/sync"],
-)
diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe.go b/pkg/tcpip/link/sharedmem/pipe/pipe.go
index 74c9f0311..74c9f0311 100644..100755
--- a/pkg/tcpip/link/sharedmem/pipe/pipe.go
+++ b/pkg/tcpip/link/sharedmem/pipe/pipe.go
diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_state_autogen.go b/pkg/tcpip/link/sharedmem/pipe/pipe_state_autogen.go
new file mode 100755
index 000000000..d3b40feb4
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/pipe/pipe_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package pipe
diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
deleted file mode 100644
index dc239a0d0..000000000
--- a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
+++ /dev/null
@@ -1,518 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package pipe
-
-import (
- "math/rand"
- "reflect"
- "runtime"
- "testing"
-
- "gvisor.dev/gvisor/pkg/sync"
-)
-
-func TestSimpleReadWrite(t *testing.T) {
- // Check that a simple write can be properly read from the rx side.
- tr := rand.New(rand.NewSource(99))
- rr := rand.New(rand.NewSource(99))
-
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- wb := tx.Push(10)
- if wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
- for i := range wb {
- wb[i] = byte(tr.Intn(256))
- }
- tx.Flush()
-
- var rx Rx
- rx.Init(b)
- rb := rx.Pull()
- if len(rb) != 10 {
- t.Fatalf("Bad buffer size returned: got %v, want %v", len(rb), 10)
- }
-
- for i := range rb {
- if v := byte(rr.Intn(256)); v != rb[i] {
- t.Fatalf("Bad read buffer at index %v: got %v, want %v", i, rb[i], v)
- }
- }
- rx.Flush()
-}
-
-func TestEmptyRead(t *testing.T) {
- // Check that pulling from an empty pipe fails.
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- var rx Rx
- rx.Init(b)
- if rb := rx.Pull(); rb != nil {
- t.Fatalf("Pull succeeded on empty pipe")
- }
-}
-
-func TestTooLargeWrite(t *testing.T) {
- // Check that writes that are too large are properly rejected.
- b := make([]byte, 96)
- var tx Tx
- tx.Init(b)
-
- if wb := tx.Push(96); wb != nil {
- t.Fatalf("Write of 96 bytes succeeded on 96-byte pipe")
- }
-
- if wb := tx.Push(88); wb != nil {
- t.Fatalf("Write of 88 bytes succeeded on 96-byte pipe")
- }
-
- if wb := tx.Push(80); wb == nil {
- t.Fatalf("Write of 80 bytes failed on 96-byte pipe")
- }
-}
-
-func TestFullWrite(t *testing.T) {
- // Check that writes fail when the pipe is full.
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- if wb := tx.Push(80); wb == nil {
- t.Fatalf("Write of 80 bytes failed on 96-byte pipe")
- }
-
- if wb := tx.Push(1); wb != nil {
- t.Fatalf("Write succeeded on full pipe")
- }
-}
-
-func TestFullAndFlushedWrite(t *testing.T) {
- // Check that writes fail when the pipe is full and has already been
- // flushed.
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- if wb := tx.Push(80); wb == nil {
- t.Fatalf("Write of 80 bytes failed on 96-byte pipe")
- }
-
- tx.Flush()
-
- if wb := tx.Push(1); wb != nil {
- t.Fatalf("Write succeeded on full pipe")
- }
-}
-
-func TestTxFlushTwice(t *testing.T) {
- // Checks that a second consecutive tx flush is a no-op.
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- if wb := tx.Push(50); wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
- tx.Flush()
-
- // Make copy of original tx queue, flush it, then check that it didn't
- // change.
- orig := tx
- tx.Flush()
-
- if !reflect.DeepEqual(orig, tx) {
- t.Fatalf("Flush mutated tx pipe: got %v, want %v", tx, orig)
- }
-}
-
-func TestRxFlushTwice(t *testing.T) {
- // Checks that a second consecutive rx flush is a no-op.
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- if wb := tx.Push(50); wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
- tx.Flush()
-
- var rx Rx
- rx.Init(b)
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull failed on non-empty pipe")
- }
- rx.Flush()
-
- // Make copy of original rx queue, flush it, then check that it didn't
- // change.
- orig := rx
- rx.Flush()
-
- if !reflect.DeepEqual(orig, rx) {
- t.Fatalf("Flush mutated rx pipe: got %v, want %v", rx, orig)
- }
-}
-
-func TestWrapInMiddleOfTransaction(t *testing.T) {
- // Check that writes are not flushed when we need to wrap the buffer
- // around.
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- if wb := tx.Push(50); wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
- tx.Flush()
-
- var rx Rx
- rx.Init(b)
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull failed on non-empty pipe")
- }
- rx.Flush()
-
- // At this point the ring buffer is empty, but the write is at offset
- // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment).
- if wb := tx.Push(10); wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
-
- if wb := tx.Push(50); wb == nil {
- t.Fatalf("Push failed on non-full pipe")
- }
-
- // We haven't flushed yet, so pull must return nil.
- if rb := rx.Pull(); rb != nil {
- t.Fatalf("Pull succeeded on non-flushed pipe")
- }
-
- tx.Flush()
-
- // The two buffers must be available now.
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull failed on non-empty pipe")
- }
-
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull failed on non-empty pipe")
- }
-}
-
-func TestWriteAbort(t *testing.T) {
- // Check that a read fails on a pipe that has had data pushed to it but
- // has aborted the push.
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- if wb := tx.Push(10); wb == nil {
- t.Fatalf("Write failed on empty pipe")
- }
-
- var rx Rx
- rx.Init(b)
- if rb := rx.Pull(); rb != nil {
- t.Fatalf("Pull succeeded on empty pipe")
- }
-
- tx.Abort()
- if rb := rx.Pull(); rb != nil {
- t.Fatalf("Pull succeeded on empty pipe")
- }
-}
-
-func TestWrappedWriteAbort(t *testing.T) {
- // Check that writes are properly aborted even if the writes wrap
- // around.
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- if wb := tx.Push(50); wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
- tx.Flush()
-
- var rx Rx
- rx.Init(b)
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull failed on non-empty pipe")
- }
- rx.Flush()
-
- // At this point the ring buffer is empty, but the write is at offset
- // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment).
- if wb := tx.Push(10); wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
-
- if wb := tx.Push(50); wb == nil {
- t.Fatalf("Push failed on non-full pipe")
- }
-
- // We haven't flushed yet, so pull must return nil.
- if rb := rx.Pull(); rb != nil {
- t.Fatalf("Pull succeeded on non-flushed pipe")
- }
-
- tx.Abort()
-
- // The pushes were aborted, so no data should be readable.
- if rb := rx.Pull(); rb != nil {
- t.Fatalf("Pull succeeded on non-flushed pipe")
- }
-
- // Try the same transactions again, but flush this time.
- if wb := tx.Push(10); wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
-
- if wb := tx.Push(50); wb == nil {
- t.Fatalf("Push failed on non-full pipe")
- }
-
- tx.Flush()
-
- // The two buffers must be available now.
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull failed on non-empty pipe")
- }
-
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull failed on non-empty pipe")
- }
-}
-
-func TestEmptyReadOnNonFlushedWrite(t *testing.T) {
- // Check that a read fails on a pipe that has had data pushed to it
- // but not yet flushed.
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- if wb := tx.Push(10); wb == nil {
- t.Fatalf("Write failed on empty pipe")
- }
-
- var rx Rx
- rx.Init(b)
- if rb := rx.Pull(); rb != nil {
- t.Fatalf("Pull succeeded on empty pipe")
- }
-
- tx.Flush()
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull on failed on non-empty pipe")
- }
-}
-
-func TestPullAfterPullingEntirePipe(t *testing.T) {
- // Check that Pull fails when the pipe is full, but all of it has
- // already been pulled but not yet flushed.
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- if wb := tx.Push(50); wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
- tx.Flush()
-
- var rx Rx
- rx.Init(b)
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull failed on non-empty pipe")
- }
- rx.Flush()
-
- // At this point the ring buffer is empty, but the write is at offset
- // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). Write 3
- // buffers that will fill the pipe.
- if wb := tx.Push(10); wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
-
- if wb := tx.Push(20); wb == nil {
- t.Fatalf("Push failed on non-full pipe")
- }
-
- if wb := tx.Push(24); wb == nil {
- t.Fatalf("Push failed on non-full pipe")
- }
-
- tx.Flush()
-
- // The three buffers must be available now.
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull failed on non-empty pipe")
- }
-
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull failed on non-empty pipe")
- }
-
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull failed on non-empty pipe")
- }
-
- // Fourth pull must fail.
- if rb := rx.Pull(); rb != nil {
- t.Fatalf("Pull succeeded on empty pipe")
- }
-}
-
-func TestNoRoomToWrapOnPush(t *testing.T) {
- // Check that Push fails when it tries to allocate room to add a wrap
- // message.
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- if wb := tx.Push(50); wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
- tx.Flush()
-
- var rx Rx
- rx.Init(b)
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull failed on non-empty pipe")
- }
- rx.Flush()
-
- // At this point the ring buffer is empty, but the write is at offset
- // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). Write 20,
- // which won't fit (64+20+8+padding = 96, which wouldn't leave room for
- // the padding), so it wraps around.
- if wb := tx.Push(20); wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
-
- tx.Flush()
-
- // Buffer offset is at 28. Try to write 70, which would require a wrap
- // slot which cannot be created now.
- if wb := tx.Push(70); wb != nil {
- t.Fatalf("Push succeeded on pipe with no room for wrap message")
- }
-}
-
-func TestRxImplicitFlushOfWrapMessage(t *testing.T) {
- // Check if the first read is that of a wrapping message, that it gets
- // immediately flushed.
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- if wb := tx.Push(50); wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
- tx.Flush()
-
- // This will cause a wrapping message to written.
- if wb := tx.Push(60); wb != nil {
- t.Fatalf("Push succeeded when there is no room in pipe")
- }
-
- var rx Rx
- rx.Init(b)
-
- // Read the first message.
- if rb := rx.Pull(); rb == nil {
- t.Fatalf("Pull failed on non-empty pipe")
- }
- rx.Flush()
-
- // This should fail because of the wrapping message is taking up space.
- if wb := tx.Push(60); wb != nil {
- t.Fatalf("Push succeeded when there is no room in pipe")
- }
-
- // Try to read the next one. This should consume the wrapping message.
- rx.Pull()
-
- // This must now succeed.
- if wb := tx.Push(60); wb == nil {
- t.Fatalf("Push failed on empty pipe")
- }
-}
-
-func TestConcurrentReaderWriter(t *testing.T) {
- // Push a million buffers of random sizes and random contents. Check
- // that buffers read match what was written.
- tr := rand.New(rand.NewSource(99))
- rr := rand.New(rand.NewSource(99))
-
- b := make([]byte, 100)
- var tx Tx
- tx.Init(b)
-
- var rx Rx
- rx.Init(b)
-
- const count = 1000000
- var wg sync.WaitGroup
- wg.Add(1)
- go func() {
- defer wg.Done()
- runtime.Gosched()
- for i := 0; i < count; i++ {
- n := 1 + tr.Intn(80)
- wb := tx.Push(uint64(n))
- for wb == nil {
- wb = tx.Push(uint64(n))
- }
-
- for j := range wb {
- wb[j] = byte(tr.Intn(256))
- }
-
- tx.Flush()
- }
- }()
-
- wg.Add(1)
- go func() {
- defer wg.Done()
- runtime.Gosched()
- for i := 0; i < count; i++ {
- n := 1 + rr.Intn(80)
- rb := rx.Pull()
- for rb == nil {
- rb = rx.Pull()
- }
-
- if n != len(rb) {
- t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n)
- }
-
- for j := range rb {
- if v := byte(rr.Intn(256)); v != rb[j] {
- t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v)
- }
- }
-
- rx.Flush()
- }
- }()
-
- wg.Wait()
-}
diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go b/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go
index 62d17029e..62d17029e 100644..100755
--- a/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go
+++ b/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go
diff --git a/pkg/tcpip/link/sharedmem/pipe/rx.go b/pkg/tcpip/link/sharedmem/pipe/rx.go
index f22e533ac..f22e533ac 100644..100755
--- a/pkg/tcpip/link/sharedmem/pipe/rx.go
+++ b/pkg/tcpip/link/sharedmem/pipe/rx.go
diff --git a/pkg/tcpip/link/sharedmem/pipe/tx.go b/pkg/tcpip/link/sharedmem/pipe/tx.go
index 9841eb231..9841eb231 100644..100755
--- a/pkg/tcpip/link/sharedmem/pipe/tx.go
+++ b/pkg/tcpip/link/sharedmem/pipe/tx.go
diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD
deleted file mode 100644
index 3ba06af73..000000000
--- a/pkg/tcpip/link/sharedmem/queue/BUILD
+++ /dev/null
@@ -1,27 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "queue",
- srcs = [
- "rx.go",
- "tx.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/log",
- "//pkg/tcpip/link/sharedmem/pipe",
- ],
-)
-
-go_test(
- name = "queue_test",
- srcs = [
- "queue_test.go",
- ],
- library = ":queue",
- deps = [
- "//pkg/tcpip/link/sharedmem/pipe",
- ],
-)
diff --git a/pkg/tcpip/link/sharedmem/queue/queue_state_autogen.go b/pkg/tcpip/link/sharedmem/queue/queue_state_autogen.go
new file mode 100755
index 000000000..563d4fbb4
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/queue/queue_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package queue
diff --git a/pkg/tcpip/link/sharedmem/queue/queue_test.go b/pkg/tcpip/link/sharedmem/queue/queue_test.go
deleted file mode 100644
index 9a0aad5d7..000000000
--- a/pkg/tcpip/link/sharedmem/queue/queue_test.go
+++ /dev/null
@@ -1,517 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package queue
-
-import (
- "encoding/binary"
- "reflect"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe"
-)
-
-func TestBasicTxQueue(t *testing.T) {
- // Tests that a basic transmit on a queue works, and that completion
- // gets properly reported as well.
- pb1 := make([]byte, 100)
- pb2 := make([]byte, 100)
-
- var rxp pipe.Rx
- rxp.Init(pb1)
-
- var txp pipe.Tx
- txp.Init(pb2)
-
- var q Tx
- q.Init(pb1, pb2)
-
- // Enqueue two buffers.
- b := []TxBuffer{
- {nil, 100, 60},
- {nil, 200, 40},
- }
-
- b[0].Next = &b[1]
-
- const usedID = 1002
- const usedTotalSize = 100
- if !q.Enqueue(usedID, usedTotalSize, 2, &b[0]) {
- t.Fatalf("Enqueue failed on empty queue")
- }
-
- // Check the contents of the pipe.
- d := rxp.Pull()
- if d == nil {
- t.Fatalf("Tx pipe is empty after Enqueue")
- }
-
- want := []byte{
- 234, 3, 0, 0, 0, 0, 0, 0, // id
- 100, 0, 0, 0, // total size
- 0, 0, 0, 0, // reserved
- 100, 0, 0, 0, 0, 0, 0, 0, // offset 1
- 60, 0, 0, 0, // size 1
- 200, 0, 0, 0, 0, 0, 0, 0, // offset 2
- 40, 0, 0, 0, // size 2
- }
-
- if !reflect.DeepEqual(want, d) {
- t.Fatalf("Bad posted packet: got %v, want %v", d, want)
- }
-
- rxp.Flush()
-
- // Check that there are no completions yet.
- if _, ok := q.CompletedPacket(); ok {
- t.Fatalf("Packet reported as completed too soon")
- }
-
- // Post a completion.
- d = txp.Push(8)
- if d == nil {
- t.Fatalf("Unable to push to rx pipe")
- }
- binary.LittleEndian.PutUint64(d, usedID)
- txp.Flush()
-
- // Check that completion is properly reported.
- id, ok := q.CompletedPacket()
- if !ok {
- t.Fatalf("Completion not reported")
- }
-
- if id != usedID {
- t.Fatalf("Bad completion id: got %v, want %v", id, usedID)
- }
-}
-
-func TestBasicRxQueue(t *testing.T) {
- // Tests that a basic receive on a queue works.
- pb1 := make([]byte, 100)
- pb2 := make([]byte, 100)
-
- var rxp pipe.Rx
- rxp.Init(pb1)
-
- var txp pipe.Tx
- txp.Init(pb2)
-
- var q Rx
- q.Init(pb1, pb2, nil)
-
- // Post two buffers.
- b := []RxBuffer{
- {100, 60, 1077, 0},
- {200, 40, 2123, 0},
- }
-
- if !q.PostBuffers(b) {
- t.Fatalf("PostBuffers failed on empty queue")
- }
-
- // Check the contents of the pipe.
- want := [][]byte{
- {
- 100, 0, 0, 0, 0, 0, 0, 0, // Offset1
- 60, 0, 0, 0, // Size1
- 0, 0, 0, 0, // Remaining in group 1
- 0, 0, 0, 0, 0, 0, 0, 0, // User data 1
- 53, 4, 0, 0, 0, 0, 0, 0, // ID 1
- },
- {
- 200, 0, 0, 0, 0, 0, 0, 0, // Offset2
- 40, 0, 0, 0, // Size2
- 0, 0, 0, 0, // Remaining in group 2
- 0, 0, 0, 0, 0, 0, 0, 0, // User data 2
- 75, 8, 0, 0, 0, 0, 0, 0, // ID 2
- },
- }
-
- for i := range b {
- d := rxp.Pull()
- if d == nil {
- t.Fatalf("Tx pipe is empty after PostBuffers")
- }
-
- if !reflect.DeepEqual(want[i], d) {
- t.Fatalf("Bad posted packet: got %v, want %v", d, want[i])
- }
-
- rxp.Flush()
- }
-
- // Check that there are no completions.
- if _, n := q.Dequeue(nil); n != 0 {
- t.Fatalf("Packet reported as received too soon")
- }
-
- // Post a completion.
- d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer)
- if d == nil {
- t.Fatalf("Unable to push to rx pipe")
- }
-
- copy(d, []byte{
- 100, 0, 0, 0, // packet size
- 0, 0, 0, 0, // reserved
-
- 100, 0, 0, 0, 0, 0, 0, 0, // offset 1
- 60, 0, 0, 0, // size 1
- 0, 0, 0, 0, 0, 0, 0, 0, // user data 1
- 53, 4, 0, 0, 0, 0, 0, 0, // ID 1
-
- 200, 0, 0, 0, 0, 0, 0, 0, // offset 2
- 40, 0, 0, 0, // size 2
- 0, 0, 0, 0, 0, 0, 0, 0, // user data 2
- 75, 8, 0, 0, 0, 0, 0, 0, // ID 2
- })
-
- txp.Flush()
-
- // Check that completion is properly reported.
- bufs, n := q.Dequeue(nil)
- if n != 100 {
- t.Fatalf("Bad packet size: got %v, want %v", n, 100)
- }
-
- if !reflect.DeepEqual(bufs, b) {
- t.Fatalf("Bad returned buffers: got %v, want %v", bufs, b)
- }
-}
-
-func TestBadTxCompletion(t *testing.T) {
- // Check that tx completions with bad sizes are properly ignored.
- pb1 := make([]byte, 100)
- pb2 := make([]byte, 100)
-
- var rxp pipe.Rx
- rxp.Init(pb1)
-
- var txp pipe.Tx
- txp.Init(pb2)
-
- var q Tx
- q.Init(pb1, pb2)
-
- // Post a completion that is too short, and check that it is ignored.
- if d := txp.Push(7); d == nil {
- t.Fatalf("Unable to push to rx pipe")
- }
- txp.Flush()
-
- if _, ok := q.CompletedPacket(); ok {
- t.Fatalf("Bad completion not ignored")
- }
-
- // Post a completion that is too long, and check that it is ignored.
- if d := txp.Push(10); d == nil {
- t.Fatalf("Unable to push to rx pipe")
- }
- txp.Flush()
-
- if _, ok := q.CompletedPacket(); ok {
- t.Fatalf("Bad completion not ignored")
- }
-}
-
-func TestBadRxCompletion(t *testing.T) {
- // Check that bad rx completions are properly ignored.
- pb1 := make([]byte, 100)
- pb2 := make([]byte, 100)
-
- var rxp pipe.Rx
- rxp.Init(pb1)
-
- var txp pipe.Tx
- txp.Init(pb2)
-
- var q Rx
- q.Init(pb1, pb2, nil)
-
- // Post a completion that is too short, and check that it is ignored.
- if d := txp.Push(7); d == nil {
- t.Fatalf("Unable to push to rx pipe")
- }
- txp.Flush()
-
- if b, _ := q.Dequeue(nil); b != nil {
- t.Fatalf("Bad completion not ignored")
- }
-
- // Post a completion whose buffer sizes add up to less than the total
- // size.
- d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer)
- if d == nil {
- t.Fatalf("Unable to push to rx pipe")
- }
-
- copy(d, []byte{
- 100, 0, 0, 0, // packet size
- 0, 0, 0, 0, // reserved
-
- 100, 0, 0, 0, 0, 0, 0, 0, // offset 1
- 10, 0, 0, 0, // size 1
- 0, 0, 0, 0, 0, 0, 0, 0, // user data 1
- 53, 4, 0, 0, 0, 0, 0, 0, // ID 1
-
- 200, 0, 0, 0, 0, 0, 0, 0, // offset 2
- 10, 0, 0, 0, // size 2
- 0, 0, 0, 0, 0, 0, 0, 0, // user data 2
- 75, 8, 0, 0, 0, 0, 0, 0, // ID 2
- })
-
- txp.Flush()
- if b, _ := q.Dequeue(nil); b != nil {
- t.Fatalf("Bad completion not ignored")
- }
-
- // Post a completion whose buffer sizes will cause a 32-bit overflow,
- // but adds up to the right number.
- d = txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer)
- if d == nil {
- t.Fatalf("Unable to push to rx pipe")
- }
-
- copy(d, []byte{
- 100, 0, 0, 0, // packet size
- 0, 0, 0, 0, // reserved
-
- 100, 0, 0, 0, 0, 0, 0, 0, // offset 1
- 255, 255, 255, 255, // size 1
- 0, 0, 0, 0, 0, 0, 0, 0, // user data 1
- 53, 4, 0, 0, 0, 0, 0, 0, // ID 1
-
- 200, 0, 0, 0, 0, 0, 0, 0, // offset 2
- 101, 0, 0, 0, // size 2
- 0, 0, 0, 0, 0, 0, 0, 0, // user data 2
- 75, 8, 0, 0, 0, 0, 0, 0, // ID 2
- })
-
- txp.Flush()
- if b, _ := q.Dequeue(nil); b != nil {
- t.Fatalf("Bad completion not ignored")
- }
-}
-
-func TestFillTxPipe(t *testing.T) {
- // Check that transmitting a new buffer when the buffer pipe is full
- // fails gracefully.
- pb1 := make([]byte, 104)
- pb2 := make([]byte, 104)
-
- var rxp pipe.Rx
- rxp.Init(pb1)
-
- var txp pipe.Tx
- txp.Init(pb2)
-
- var q Tx
- q.Init(pb1, pb2)
-
- // Transmit twice, which should fill the tx pipe.
- b := []TxBuffer{
- {nil, 100, 60},
- {nil, 200, 40},
- }
-
- b[0].Next = &b[1]
-
- const usedID = 1002
- const usedTotalSize = 100
- for i := uint64(0); i < 2; i++ {
- if !q.Enqueue(usedID+i, usedTotalSize, 2, &b[0]) {
- t.Fatalf("Failed to transmit buffer")
- }
- }
-
- // Transmit another packet now that the tx pipe is full.
- if q.Enqueue(usedID+2, usedTotalSize, 2, &b[0]) {
- t.Fatalf("Enqueue succeeded when tx pipe is full")
- }
-}
-
-func TestFillRxPipe(t *testing.T) {
- // Check that posting a new buffer when the buffer pipe is full fails
- // gracefully.
- pb1 := make([]byte, 100)
- pb2 := make([]byte, 100)
-
- var rxp pipe.Rx
- rxp.Init(pb1)
-
- var txp pipe.Tx
- txp.Init(pb2)
-
- var q Rx
- q.Init(pb1, pb2, nil)
-
- // Post a buffer twice, it should fill the tx pipe.
- b := []RxBuffer{
- {100, 60, 1077, 0},
- }
-
- for i := 0; i < 2; i++ {
- if !q.PostBuffers(b) {
- t.Fatalf("PostBuffers failed on non-full queue")
- }
- }
-
- // Post another buffer now that the tx pipe is full.
- if q.PostBuffers(b) {
- t.Fatalf("PostBuffers succeeded on full queue")
- }
-}
-
-func TestLotsOfTransmissions(t *testing.T) {
- // Make sure pipes are being properly flushed when transmitting packets.
- pb1 := make([]byte, 100)
- pb2 := make([]byte, 100)
-
- var rxp pipe.Rx
- rxp.Init(pb1)
-
- var txp pipe.Tx
- txp.Init(pb2)
-
- var q Tx
- q.Init(pb1, pb2)
-
- // Prepare packet with two buffers.
- b := []TxBuffer{
- {nil, 100, 60},
- {nil, 200, 40},
- }
-
- b[0].Next = &b[1]
-
- const usedID = 1002
- const usedTotalSize = 100
-
- // Post 100000 packets and completions.
- for i := 100000; i > 0; i-- {
- if !q.Enqueue(usedID, usedTotalSize, 2, &b[0]) {
- t.Fatalf("Enqueue failed on non-full queue")
- }
-
- if d := rxp.Pull(); d == nil {
- t.Fatalf("Tx pipe is empty after Enqueue")
- }
- rxp.Flush()
-
- d := txp.Push(8)
- if d == nil {
- t.Fatalf("Unable to write to rx pipe")
- }
- binary.LittleEndian.PutUint64(d, usedID)
- txp.Flush()
- if _, ok := q.CompletedPacket(); !ok {
- t.Fatalf("Completion not returned")
- }
- }
-}
-
-func TestLotsOfReceptions(t *testing.T) {
- // Make sure pipes are being properly flushed when receiving packets.
- pb1 := make([]byte, 100)
- pb2 := make([]byte, 100)
-
- var rxp pipe.Rx
- rxp.Init(pb1)
-
- var txp pipe.Tx
- txp.Init(pb2)
-
- var q Rx
- q.Init(pb1, pb2, nil)
-
- // Prepare for posting two buffers.
- b := []RxBuffer{
- {100, 60, 1077, 0},
- {200, 40, 2123, 0},
- }
-
- // Post 100000 buffers and completions.
- for i := 100000; i > 0; i-- {
- if !q.PostBuffers(b) {
- t.Fatalf("PostBuffers failed on non-full queue")
- }
-
- if d := rxp.Pull(); d == nil {
- t.Fatalf("Tx pipe is empty after PostBuffers")
- }
- rxp.Flush()
-
- if d := rxp.Pull(); d == nil {
- t.Fatalf("Tx pipe is empty after PostBuffers")
- }
- rxp.Flush()
-
- d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer)
- if d == nil {
- t.Fatalf("Unable to push to rx pipe")
- }
-
- copy(d, []byte{
- 100, 0, 0, 0, // packet size
- 0, 0, 0, 0, // reserved
-
- 100, 0, 0, 0, 0, 0, 0, 0, // offset 1
- 60, 0, 0, 0, // size 1
- 0, 0, 0, 0, 0, 0, 0, 0, // user data 1
- 53, 4, 0, 0, 0, 0, 0, 0, // ID 1
-
- 200, 0, 0, 0, 0, 0, 0, 0, // offset 2
- 40, 0, 0, 0, // size 2
- 0, 0, 0, 0, 0, 0, 0, 0, // user data 2
- 75, 8, 0, 0, 0, 0, 0, 0, // ID 2
- })
-
- txp.Flush()
-
- if _, n := q.Dequeue(nil); n == 0 {
- t.Fatalf("Dequeue failed when there is a completion")
- }
- }
-}
-
-func TestRxEnableNotification(t *testing.T) {
- // Check that enabling nofifications results in properly updated state.
- pb1 := make([]byte, 100)
- pb2 := make([]byte, 100)
-
- var state uint32
- var q Rx
- q.Init(pb1, pb2, &state)
-
- q.EnableNotification()
- if state != eventFDEnabled {
- t.Fatalf("Bad value in shared state: got %v, want %v", state, eventFDEnabled)
- }
-}
-
-func TestRxDisableNotification(t *testing.T) {
- // Check that disabling nofifications results in properly updated state.
- pb1 := make([]byte, 100)
- pb2 := make([]byte, 100)
-
- var state uint32
- var q Rx
- q.Init(pb1, pb2, &state)
-
- q.DisableNotification()
- if state != eventFDDisabled {
- t.Fatalf("Bad value in shared state: got %v, want %v", state, eventFDDisabled)
- }
-}
diff --git a/pkg/tcpip/link/sharedmem/queue/rx.go b/pkg/tcpip/link/sharedmem/queue/rx.go
index 696e6c9e5..696e6c9e5 100644..100755
--- a/pkg/tcpip/link/sharedmem/queue/rx.go
+++ b/pkg/tcpip/link/sharedmem/queue/rx.go
diff --git a/pkg/tcpip/link/sharedmem/queue/tx.go b/pkg/tcpip/link/sharedmem/queue/tx.go
index beffe807b..beffe807b 100644..100755
--- a/pkg/tcpip/link/sharedmem/queue/tx.go
+++ b/pkg/tcpip/link/sharedmem/queue/tx.go
diff --git a/pkg/tcpip/link/sharedmem/rx.go b/pkg/tcpip/link/sharedmem/rx.go
index eec11e4cb..eec11e4cb 100644..100755
--- a/pkg/tcpip/link/sharedmem/rx.go
+++ b/pkg/tcpip/link/sharedmem/rx.go
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 655e537c4..655e537c4 100644..100755
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_state_autogen.go b/pkg/tcpip/link/sharedmem/sharedmem_state_autogen.go
new file mode 100755
index 000000000..bc12017b2
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/sharedmem_state_autogen.go
@@ -0,0 +1,6 @@
+// automatically generated by stateify.
+
+// +build linux
+// +build linux
+
+package sharedmem
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
deleted file mode 100644
index 5c729a439..000000000
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ /dev/null
@@ -1,812 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// +build linux
-
-package sharedmem
-
-import (
- "bytes"
- "io/ioutil"
- "math/rand"
- "os"
- "strings"
- "syscall"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe"
- "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-const (
- localLinkAddr = "\xde\xad\xbe\xef\x56\x78"
- remoteLinkAddr = "\xde\xad\xbe\xef\x12\x34"
-
- queueDataSize = 1024 * 1024
- queuePipeSize = 4096
-)
-
-type queueBuffers struct {
- data []byte
- rx pipe.Tx
- tx pipe.Rx
-}
-
-func initQueue(t *testing.T, q *queueBuffers, c *QueueConfig) {
- // Prepare tx pipe.
- b, err := getBuffer(c.TxPipeFD)
- if err != nil {
- t.Fatalf("getBuffer failed: %v", err)
- }
- q.tx.Init(b)
-
- // Prepare rx pipe.
- b, err = getBuffer(c.RxPipeFD)
- if err != nil {
- t.Fatalf("getBuffer failed: %v", err)
- }
- q.rx.Init(b)
-
- // Get data slice.
- q.data, err = getBuffer(c.DataFD)
- if err != nil {
- t.Fatalf("getBuffer failed: %v", err)
- }
-}
-
-func (q *queueBuffers) cleanup() {
- syscall.Munmap(q.tx.Bytes())
- syscall.Munmap(q.rx.Bytes())
- syscall.Munmap(q.data)
-}
-
-type packetInfo struct {
- addr tcpip.LinkAddress
- proto tcpip.NetworkProtocolNumber
- vv buffer.VectorisedView
- linkHeader buffer.View
-}
-
-type testContext struct {
- t *testing.T
- ep *endpoint
- txCfg QueueConfig
- rxCfg QueueConfig
- txq queueBuffers
- rxq queueBuffers
-
- packetCh chan struct{}
- mu sync.Mutex
- packets []packetInfo
-}
-
-func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress) *testContext {
- var err error
- c := &testContext{
- t: t,
- packetCh: make(chan struct{}, 1000000),
- }
- c.txCfg = createQueueFDs(t, queueSizes{
- dataSize: queueDataSize,
- txPipeSize: queuePipeSize,
- rxPipeSize: queuePipeSize,
- sharedDataSize: 4096,
- })
-
- c.rxCfg = createQueueFDs(t, queueSizes{
- dataSize: queueDataSize,
- txPipeSize: queuePipeSize,
- rxPipeSize: queuePipeSize,
- sharedDataSize: 4096,
- })
-
- initQueue(t, &c.txq, &c.txCfg)
- initQueue(t, &c.rxq, &c.rxCfg)
-
- ep, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg)
- if err != nil {
- t.Fatalf("New failed: %v", err)
- }
-
- c.ep = ep.(*endpoint)
- c.ep.Attach(c)
-
- return c
-}
-
-func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
- c.mu.Lock()
- c.packets = append(c.packets, packetInfo{
- addr: remoteLinkAddr,
- proto: proto,
- vv: pkt.Data.Clone(nil),
- })
- c.mu.Unlock()
-
- c.packetCh <- struct{}{}
-}
-
-func (c *testContext) cleanup() {
- c.ep.Close()
- closeFDs(&c.txCfg)
- closeFDs(&c.rxCfg)
- c.txq.cleanup()
- c.rxq.cleanup()
-}
-
-func (c *testContext) waitForPackets(n int, to <-chan time.Time, errorStr string) {
- for i := 0; i < n; i++ {
- select {
- case <-c.packetCh:
- case <-to:
- c.t.Fatalf(errorStr)
- }
- }
-}
-
-func (c *testContext) pushRxCompletion(size uint32, bs []queue.RxBuffer) {
- b := c.rxq.rx.Push(queue.RxCompletionSize(len(bs)))
- queue.EncodeRxCompletion(b, size, 0)
- for i := range bs {
- queue.EncodeRxCompletionBuffer(b, i, queue.RxBuffer{
- Offset: bs[i].Offset,
- Size: bs[i].Size,
- ID: bs[i].ID,
- })
- }
-}
-
-func randomFill(b []byte) {
- for i := range b {
- b[i] = byte(rand.Intn(256))
- }
-}
-
-func shuffle(b []int) {
- for i := len(b) - 1; i >= 0; i-- {
- j := rand.Intn(i + 1)
- b[i], b[j] = b[j], b[i]
- }
-}
-
-func createFile(t *testing.T, size int64, initQueue bool) int {
- tmpDir := os.Getenv("TEST_TMPDIR")
- if tmpDir == "" {
- tmpDir = os.Getenv("TMPDIR")
- }
- f, err := ioutil.TempFile(tmpDir, "sharedmem_test")
- if err != nil {
- t.Fatalf("TempFile failed: %v", err)
- }
- defer f.Close()
- syscall.Unlink(f.Name())
-
- if initQueue {
- // Write the "slot-free" flag in the initial queue.
- _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0)
- if err != nil {
- t.Fatalf("WriteAt failed: %v", err)
- }
- }
-
- fd, err := syscall.Dup(int(f.Fd()))
- if err != nil {
- t.Fatalf("Dup failed: %v", err)
- }
-
- if err := syscall.Ftruncate(fd, size); err != nil {
- syscall.Close(fd)
- t.Fatalf("Ftruncate failed: %v", err)
- }
-
- return fd
-}
-
-func closeFDs(c *QueueConfig) {
- syscall.Close(c.DataFD)
- syscall.Close(c.EventFD)
- syscall.Close(c.TxPipeFD)
- syscall.Close(c.RxPipeFD)
- syscall.Close(c.SharedDataFD)
-}
-
-type queueSizes struct {
- dataSize int64
- txPipeSize int64
- rxPipeSize int64
- sharedDataSize int64
-}
-
-func createQueueFDs(t *testing.T, s queueSizes) QueueConfig {
- fd, _, err := syscall.RawSyscall(syscall.SYS_EVENTFD2, 0, 0, 0)
- if err != 0 {
- t.Fatalf("eventfd failed: %v", error(err))
- }
-
- return QueueConfig{
- EventFD: int(fd),
- DataFD: createFile(t, s.dataSize, false),
- TxPipeFD: createFile(t, s.txPipeSize, true),
- RxPipeFD: createFile(t, s.rxPipeSize, true),
- SharedDataFD: createFile(t, s.sharedDataSize, false),
- }
-}
-
-// TestSimpleSend sends 1000 packets with random header and payload sizes,
-// then checks that the right payload is received on the shared memory queues.
-func TestSimpleSend(t *testing.T) {
- c := newTestContext(t, 20000, 1500, localLinkAddr)
- defer c.cleanup()
-
- // Prepare route.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- }
-
- for iters := 1000; iters > 0; iters-- {
- func() {
- // Prepare and send packet.
- n := rand.Intn(10000)
- hdr := buffer.NewPrependable(n + int(c.ep.MaxHeaderLength()))
- hdrBuf := hdr.Prepend(n)
- randomFill(hdrBuf)
-
- n = rand.Intn(10000)
- buf := buffer.NewView(n)
- randomFill(buf)
-
- proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
- if err := c.ep.WritePacket(&r, nil /* gso */, proto, tcpip.PacketBuffer{
- Header: hdr,
- Data: buf.ToVectorisedView(),
- }); err != nil {
- t.Fatalf("WritePacket failed: %v", err)
- }
-
- // Receive packet.
- desc := c.txq.tx.Pull()
- pi := queue.DecodeTxPacketHeader(desc)
- if pi.Reserved != 0 {
- t.Fatalf("Reserved value is non-zero: 0x%x", pi.Reserved)
- }
- contents := make([]byte, 0, pi.Size)
- for i := 0; i < pi.BufferCount; i++ {
- bi := queue.DecodeTxBufferHeader(desc, i)
- contents = append(contents, c.txq.data[bi.Offset:][:bi.Size]...)
- }
- c.txq.tx.Flush()
-
- defer func() {
- // Tell the endpoint about the completion of the write.
- b := c.txq.rx.Push(8)
- queue.EncodeTxCompletion(b, pi.ID)
- c.txq.rx.Flush()
- }()
-
- // Check the ethernet header.
- ethTemplate := make(header.Ethernet, header.EthernetMinimumSize)
- ethTemplate.Encode(&header.EthernetFields{
- SrcAddr: localLinkAddr,
- DstAddr: remoteLinkAddr,
- Type: proto,
- })
- if got := contents[:header.EthernetMinimumSize]; !bytes.Equal(got, []byte(ethTemplate)) {
- t.Fatalf("Bad ethernet header in packet: got %x, want %x", got, ethTemplate)
- }
-
- // Compare contents skipping the ethernet header added by the
- // endpoint.
- merged := append(hdrBuf, buf...)
- if uint32(len(contents)) < pi.Size {
- t.Fatalf("Sum of buffers is less than packet size: %v < %v", len(contents), pi.Size)
- }
- contents = contents[:pi.Size][header.EthernetMinimumSize:]
-
- if !bytes.Equal(contents, merged) {
- t.Fatalf("Buffers are different: got %x (%v bytes), want %x (%v bytes)", contents, len(contents), merged, len(merged))
- }
- }()
- }
-}
-
-// TestPreserveSrcAddressInSend calls WritePacket once with LocalLinkAddress
-// set in Route (using much of the same code as TestSimpleSend), then checks
-// that the encoded ethernet header received includes the correct SrcAddr.
-func TestPreserveSrcAddressInSend(t *testing.T) {
- c := newTestContext(t, 20000, 1500, localLinkAddr)
- defer c.cleanup()
-
- newLocalLinkAddress := tcpip.LinkAddress(strings.Repeat("0xFE", 6))
- // Set both remote and local link address in route.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- LocalLinkAddress: newLocalLinkAddress,
- }
-
- // WritePacket panics given a prependable with anything less than
- // the minimum size of the ethernet header.
- hdr := buffer.NewPrependable(header.EthernetMinimumSize)
-
- proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
- if err := c.ep.WritePacket(&r, nil /* gso */, proto, tcpip.PacketBuffer{
- Header: hdr,
- }); err != nil {
- t.Fatalf("WritePacket failed: %v", err)
- }
-
- // Receive packet.
- desc := c.txq.tx.Pull()
- pi := queue.DecodeTxPacketHeader(desc)
- if pi.Reserved != 0 {
- t.Fatalf("Reserved value is non-zero: 0x%x", pi.Reserved)
- }
- contents := make([]byte, 0, pi.Size)
- for i := 0; i < pi.BufferCount; i++ {
- bi := queue.DecodeTxBufferHeader(desc, i)
- contents = append(contents, c.txq.data[bi.Offset:][:bi.Size]...)
- }
- c.txq.tx.Flush()
-
- defer func() {
- // Tell the endpoint about the completion of the write.
- b := c.txq.rx.Push(8)
- queue.EncodeTxCompletion(b, pi.ID)
- c.txq.rx.Flush()
- }()
-
- // Check that the ethernet header contains the expected SrcAddr.
- ethTemplate := make(header.Ethernet, header.EthernetMinimumSize)
- ethTemplate.Encode(&header.EthernetFields{
- SrcAddr: newLocalLinkAddress,
- DstAddr: remoteLinkAddr,
- Type: proto,
- })
- if got := contents[:header.EthernetMinimumSize]; !bytes.Equal(got, []byte(ethTemplate)) {
- t.Fatalf("Bad ethernet header in packet: got %x, want %x", got, ethTemplate)
- }
-}
-
-// TestFillTxQueue sends packets until the queue is full.
-func TestFillTxQueue(t *testing.T) {
- c := newTestContext(t, 20000, 1500, localLinkAddr)
- defer c.cleanup()
-
- // Prepare to send a packet.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- }
-
- buf := buffer.NewView(100)
-
- // Each packet is uses no more than 40 bytes, so write that many packets
- // until the tx queue if full.
- ids := make(map[uint64]struct{})
- for i := queuePipeSize / 40; i > 0; i-- {
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
-
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
- Header: hdr,
- Data: buf.ToVectorisedView(),
- }); err != nil {
- t.Fatalf("WritePacket failed unexpectedly: %v", err)
- }
-
- // Check that they have different IDs.
- desc := c.txq.tx.Pull()
- pi := queue.DecodeTxPacketHeader(desc)
- if _, ok := ids[pi.ID]; ok {
- t.Fatalf("ID (%v) reused", pi.ID)
- }
- ids[pi.ID] = struct{}{}
- }
-
- // Next attempt to write must fail.
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
- Header: hdr,
- Data: buf.ToVectorisedView(),
- }); err != want {
- t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
- }
-}
-
-// TestFillTxQueueAfterBadCompletion sends a bad completion, then sends packets
-// until the queue is full.
-func TestFillTxQueueAfterBadCompletion(t *testing.T) {
- c := newTestContext(t, 20000, 1500, localLinkAddr)
- defer c.cleanup()
-
- // Send a bad completion.
- queue.EncodeTxCompletion(c.txq.rx.Push(8), 1)
- c.txq.rx.Flush()
-
- // Prepare to send a packet.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- }
-
- buf := buffer.NewView(100)
-
- // Send two packets so that the id slice has at least two slots.
- for i := 2; i > 0; i-- {
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
- Header: hdr,
- Data: buf.ToVectorisedView(),
- }); err != nil {
- t.Fatalf("WritePacket failed unexpectedly: %v", err)
- }
- }
-
- // Complete the two writes twice.
- for i := 2; i > 0; i-- {
- pi := queue.DecodeTxPacketHeader(c.txq.tx.Pull())
-
- queue.EncodeTxCompletion(c.txq.rx.Push(8), pi.ID)
- queue.EncodeTxCompletion(c.txq.rx.Push(8), pi.ID)
- c.txq.rx.Flush()
- }
- c.txq.tx.Flush()
-
- // Each packet is uses no more than 40 bytes, so write that many packets
- // until the tx queue if full.
- ids := make(map[uint64]struct{})
- for i := queuePipeSize / 40; i > 0; i-- {
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
- Header: hdr,
- Data: buf.ToVectorisedView(),
- }); err != nil {
- t.Fatalf("WritePacket failed unexpectedly: %v", err)
- }
-
- // Check that they have different IDs.
- desc := c.txq.tx.Pull()
- pi := queue.DecodeTxPacketHeader(desc)
- if _, ok := ids[pi.ID]; ok {
- t.Fatalf("ID (%v) reused", pi.ID)
- }
- ids[pi.ID] = struct{}{}
- }
-
- // Next attempt to write must fail.
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
- Header: hdr,
- Data: buf.ToVectorisedView(),
- }); err != want {
- t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
- }
-}
-
-// TestFillTxMemory sends packets until the we run out of shared memory.
-func TestFillTxMemory(t *testing.T) {
- const bufferSize = 1500
- c := newTestContext(t, 20000, bufferSize, localLinkAddr)
- defer c.cleanup()
-
- // Prepare to send a packet.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- }
-
- buf := buffer.NewView(100)
-
- // Each packet is uses up one buffer, so write as many as possible until
- // we fill the memory.
- ids := make(map[uint64]struct{})
- for i := queueDataSize / bufferSize; i > 0; i-- {
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
- Header: hdr,
- Data: buf.ToVectorisedView(),
- }); err != nil {
- t.Fatalf("WritePacket failed unexpectedly: %v", err)
- }
-
- // Check that they have different IDs.
- desc := c.txq.tx.Pull()
- pi := queue.DecodeTxPacketHeader(desc)
- if _, ok := ids[pi.ID]; ok {
- t.Fatalf("ID (%v) reused", pi.ID)
- }
- ids[pi.ID] = struct{}{}
- c.txq.tx.Flush()
- }
-
- // Next attempt to write must fail.
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
- Header: hdr,
- Data: buf.ToVectorisedView(),
- })
- if want := tcpip.ErrWouldBlock; err != want {
- t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
- }
-}
-
-// TestFillTxMemoryWithMultiBuffer sends packets until the we run out of
-// shared memory for a 2-buffer packet, but still with room for a 1-buffer
-// packet.
-func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
- const bufferSize = 1500
- c := newTestContext(t, 20000, bufferSize, localLinkAddr)
- defer c.cleanup()
-
- // Prepare to send a packet.
- r := stack.Route{
- RemoteLinkAddress: remoteLinkAddr,
- }
-
- buf := buffer.NewView(100)
-
- // Each packet is uses up one buffer, so write as many as possible
- // until there is only one buffer left.
- for i := queueDataSize/bufferSize - 1; i > 0; i-- {
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
- Header: hdr,
- Data: buf.ToVectorisedView(),
- }); err != nil {
- t.Fatalf("WritePacket failed unexpectedly: %v", err)
- }
-
- // Pull the posted buffer.
- c.txq.tx.Pull()
- c.txq.tx.Flush()
- }
-
- // Attempt to write a two-buffer packet. It must fail.
- {
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- uu := buffer.NewView(bufferSize).ToVectorisedView()
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
- Header: hdr,
- Data: uu,
- }); err != want {
- t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
- }
- }
-
- // Attempt to write the one-buffer packet again. It must succeed.
- {
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
- Header: hdr,
- Data: buf.ToVectorisedView(),
- }); err != nil {
- t.Fatalf("WritePacket failed unexpectedly: %v", err)
- }
- }
-}
-
-func pollPull(t *testing.T, p *pipe.Rx, to <-chan time.Time, errStr string) []byte {
- t.Helper()
-
- for {
- b := p.Pull()
- if b != nil {
- return b
- }
-
- select {
- case <-time.After(10 * time.Millisecond):
- case <-to:
- t.Fatal(errStr)
- }
- }
-}
-
-// TestSimpleReceive completes 1000 different receives with random payload and
-// random number of buffers. It checks that the contents match the expected
-// values.
-func TestSimpleReceive(t *testing.T) {
- const bufferSize = 1500
- c := newTestContext(t, 20000, bufferSize, localLinkAddr)
- defer c.cleanup()
-
- // Check that buffers have been posted.
- limit := c.ep.rx.q.PostedBuffersLimit()
- for i := uint64(0); i < limit; i++ {
- timeout := time.After(2 * time.Second)
- bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for all buffers to be posted"))
-
- if want := i * bufferSize; want != bi.Offset {
- t.Fatalf("Bad posted offset: got %v, want %v", bi.Offset, want)
- }
-
- if want := i; want != bi.ID {
- t.Fatalf("Bad posted ID: got %v, want %v", bi.ID, want)
- }
-
- if bufferSize != bi.Size {
- t.Fatalf("Bad posted bufferSize: got %v, want %v", bi.Size, bufferSize)
- }
- }
- c.rxq.tx.Flush()
-
- // Create a slice with the indices 0..limit-1.
- idx := make([]int, limit)
- for i := range idx {
- idx[i] = i
- }
-
- // Complete random packets 1000 times.
- for iters := 1000; iters > 0; iters-- {
- timeout := time.After(2 * time.Second)
- // Prepare a random packet.
- shuffle(idx)
- n := 1 + rand.Intn(10)
- bufs := make([]queue.RxBuffer, n)
- contents := make([]byte, bufferSize*n-rand.Intn(500))
- randomFill(contents)
- for i := range bufs {
- j := idx[i]
- bufs[i].Size = bufferSize
- bufs[i].Offset = uint64(bufferSize * j)
- bufs[i].ID = uint64(j)
-
- copy(c.rxq.data[bufs[i].Offset:][:bufferSize], contents[i*bufferSize:])
- }
-
- // Push completion.
- c.pushRxCompletion(uint32(len(contents)), bufs)
- c.rxq.rx.Flush()
- syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
-
- // Wait for packet to be received, then check it.
- c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet")
- c.mu.Lock()
- rcvd := []byte(c.packets[0].vv.First())
- c.packets = c.packets[:0]
- c.mu.Unlock()
-
- if contents := contents[header.EthernetMinimumSize:]; !bytes.Equal(contents, rcvd) {
- t.Fatalf("Unexpected buffer contents: got %x, want %x", rcvd, contents)
- }
-
- // Check that buffers have been reposted.
- for i := range bufs {
- bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffers to be reposted"))
- if bi != bufs[i] {
- t.Fatalf("Unexpected buffer reposted: got %x, want %x", bi, bufs[i])
- }
- }
- c.rxq.tx.Flush()
- }
-}
-
-// TestRxBuffersReposted tests that rx buffers get reposted after they have been
-// completed.
-func TestRxBuffersReposted(t *testing.T) {
- const bufferSize = 1500
- c := newTestContext(t, 20000, bufferSize, localLinkAddr)
- defer c.cleanup()
-
- // Receive all posted buffers.
- limit := c.ep.rx.q.PostedBuffersLimit()
- buffers := make([]queue.RxBuffer, 0, limit)
- for i := limit; i > 0; i-- {
- timeout := time.After(2 * time.Second)
- buffers = append(buffers, queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for all buffers")))
- }
- c.rxq.tx.Flush()
-
- // Check that all buffers are reposted when individually completed.
- for i := range buffers {
- timeout := time.After(2 * time.Second)
- // Complete the buffer.
- c.pushRxCompletion(buffers[i].Size, buffers[i:][:1])
- c.rxq.rx.Flush()
- syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
-
- // Wait for it to be reposted.
- bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted"))
- if bi != buffers[i] {
- t.Fatalf("Different buffer posted: got %v, want %v", bi, buffers[i])
- }
- }
- c.rxq.tx.Flush()
-
- // Check that all buffers are reposted when completed in pairs.
- for i := 0; i < len(buffers)/2; i++ {
- timeout := time.After(2 * time.Second)
- // Complete with two buffers.
- c.pushRxCompletion(2*bufferSize, buffers[2*i:][:2])
- c.rxq.rx.Flush()
- syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
-
- // Wait for them to be reposted.
- for j := 0; j < 2; j++ {
- bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted"))
- if bi != buffers[2*i+j] {
- t.Fatalf("Different buffer posted: got %v, want %v", bi, buffers[2*i+j])
- }
- }
- }
- c.rxq.tx.Flush()
-}
-
-// TestReceivePostingIsFull checks that the endpoint will properly handle the
-// case when a received buffer cannot be immediately reposted because it hasn't
-// been pulled from the tx pipe yet.
-func TestReceivePostingIsFull(t *testing.T) {
- const bufferSize = 1500
- c := newTestContext(t, 20000, bufferSize, localLinkAddr)
- defer c.cleanup()
-
- // Complete first posted buffer before flushing it from the tx pipe.
- first := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for first buffer to be posted"))
- c.pushRxCompletion(first.Size, []queue.RxBuffer{first})
- c.rxq.rx.Flush()
- syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
-
- // Check that packet is received.
- c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet")
-
- // Complete another buffer.
- second := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for second buffer to be posted"))
- c.pushRxCompletion(second.Size, []queue.RxBuffer{second})
- c.rxq.rx.Flush()
- syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
-
- // Check that no packet is received yet, as the worker is blocked trying
- // to repost.
- select {
- case <-time.After(500 * time.Millisecond):
- case <-c.packetCh:
- t.Fatalf("Unexpected packet received")
- }
-
- // Flush tx queue, which will allow the first buffer to be reposted,
- // and the second completion to be pulled.
- c.rxq.tx.Flush()
- syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
-
- // Check that second packet completes.
- c.waitForPackets(1, time.After(time.Second), "Timeout waiting for second completed packet")
-}
-
-// TestCloseWhileWaitingToPost closes the endpoint while it is waiting to
-// repost a buffer. Make sure it backs out.
-func TestCloseWhileWaitingToPost(t *testing.T) {
- const bufferSize = 1500
- c := newTestContext(t, 20000, bufferSize, localLinkAddr)
- cleaned := false
- defer func() {
- if !cleaned {
- c.cleanup()
- }
- }()
-
- // Complete first posted buffer before flushing it from the tx pipe.
- bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for initial buffer to be posted"))
- c.pushRxCompletion(bi.Size, []queue.RxBuffer{bi})
- c.rxq.rx.Flush()
- syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
-
- // Wait for packet to be indicated.
- c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet")
-
- // Cleanup and wait for worker to complete.
- c.cleanup()
- cleaned = true
- c.ep.Wait()
-}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go b/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go
index f7e816a41..f7e816a41 100644..100755
--- a/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go
diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go
index 6b8d7859d..6b8d7859d 100644..100755
--- a/pkg/tcpip/link/sharedmem/tx.go
+++ b/pkg/tcpip/link/sharedmem/tx.go
diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD
deleted file mode 100644
index 230a8d53a..000000000
--- a/pkg/tcpip/link/sniffer/BUILD
+++ /dev/null
@@ -1,19 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "sniffer",
- srcs = [
- "pcap.go",
- "sniffer.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/log",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/stack",
- ],
-)
diff --git a/pkg/tcpip/link/sniffer/sniffer_state_autogen.go b/pkg/tcpip/link/sniffer/sniffer_state_autogen.go
new file mode 100755
index 000000000..8d79defea
--- /dev/null
+++ b/pkg/tcpip/link/sniffer/sniffer_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package sniffer
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
deleted file mode 100644
index e5096ea38..000000000
--- a/pkg/tcpip/link/tun/BUILD
+++ /dev/null
@@ -1,9 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "tun",
- srcs = ["tun_unsafe.go"],
- visibility = ["//visibility:public"],
-)
diff --git a/pkg/tcpip/link/tun/tun_state_autogen.go b/pkg/tcpip/link/tun/tun_state_autogen.go
new file mode 100755
index 000000000..149299ea3
--- /dev/null
+++ b/pkg/tcpip/link/tun/tun_state_autogen.go
@@ -0,0 +1,5 @@
+// automatically generated by stateify.
+
+// +build linux
+
+package tun
diff --git a/pkg/tcpip/link/tun/tun_unsafe.go b/pkg/tcpip/link/tun/tun_unsafe.go
index 09ca9b527..09ca9b527 100644..100755
--- a/pkg/tcpip/link/tun/tun_unsafe.go
+++ b/pkg/tcpip/link/tun/tun_unsafe.go
diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD
deleted file mode 100644
index 0956d2c65..000000000
--- a/pkg/tcpip/link/waitable/BUILD
+++ /dev/null
@@ -1,30 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "waitable",
- srcs = [
- "waitable.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/gate",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/stack",
- ],
-)
-
-go_test(
- name = "waitable_test",
- srcs = [
- "waitable_test.go",
- ],
- library = ":waitable",
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/stack",
- ],
-)
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index a8de38979..a8de38979 100644..100755
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
diff --git a/pkg/tcpip/link/waitable/waitable_state_autogen.go b/pkg/tcpip/link/waitable/waitable_state_autogen.go
new file mode 100755
index 000000000..059424fa0
--- /dev/null
+++ b/pkg/tcpip/link/waitable/waitable_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package waitable
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
deleted file mode 100644
index 31b11a27a..000000000
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ /dev/null
@@ -1,173 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package waitable
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-type countedEndpoint struct {
- dispatchCount int
- writeCount int
- attachCount int
-
- mtu uint32
- capabilities stack.LinkEndpointCapabilities
- hdrLen uint16
- linkAddr tcpip.LinkAddress
-
- dispatcher stack.NetworkDispatcher
-}
-
-func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
- e.dispatchCount++
-}
-
-func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
- e.attachCount++
- e.dispatcher = dispatcher
-}
-
-// IsAttached implements stack.LinkEndpoint.IsAttached.
-func (e *countedEndpoint) IsAttached() bool {
- return e.dispatcher != nil
-}
-
-func (e *countedEndpoint) MTU() uint32 {
- return e.mtu
-}
-
-func (e *countedEndpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.capabilities
-}
-
-func (e *countedEndpoint) MaxHeaderLength() uint16 {
- return e.hdrLen
-}
-
-func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress {
- return e.linkAddr
-}
-
-func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
- e.writeCount++
- return nil
-}
-
-// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- e.writeCount += len(pkts)
- return len(pkts), nil
-}
-
-func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
- e.writeCount++
- return nil
-}
-
-// Wait implements stack.LinkEndpoint.Wait.
-func (*countedEndpoint) Wait() {}
-
-func TestWaitWrite(t *testing.T) {
- ep := &countedEndpoint{}
- wep := New(ep)
-
- // Write and check that it goes through.
- wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{})
- if want := 1; ep.writeCount != want {
- t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
- }
-
- // Wait on dispatches, then try to write. It must go through.
- wep.WaitDispatch()
- wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{})
- if want := 2; ep.writeCount != want {
- t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
- }
-
- // Wait on writes, then try to write. It must not go through.
- wep.WaitWrite()
- wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{})
- if want := 2; ep.writeCount != want {
- t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
- }
-}
-
-func TestWaitDispatch(t *testing.T) {
- ep := &countedEndpoint{}
- wep := New(ep)
-
- // Check that attach happens.
- wep.Attach(ep)
- if want := 1; ep.attachCount != want {
- t.Fatalf("Unexpected attachCount: got=%v, want=%v", ep.attachCount, want)
- }
-
- // Dispatch and check that it goes through.
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{})
- if want := 1; ep.dispatchCount != want {
- t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
- }
-
- // Wait on writes, then try to dispatch. It must go through.
- wep.WaitWrite()
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{})
- if want := 2; ep.dispatchCount != want {
- t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
- }
-
- // Wait on dispatches, then try to dispatch. It must not go through.
- wep.WaitDispatch()
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{})
- if want := 2; ep.dispatchCount != want {
- t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
- }
-}
-
-func TestOtherMethods(t *testing.T) {
- const (
- mtu = 0xdead
- capabilities = 0xbeef
- hdrLen = 0x1234
- linkAddr = "test address"
- )
- ep := &countedEndpoint{
- mtu: mtu,
- capabilities: capabilities,
- hdrLen: hdrLen,
- linkAddr: linkAddr,
- }
- wep := New(ep)
-
- if v := wep.MTU(); v != mtu {
- t.Fatalf("Unexpected mtu: got=%v, want=%v", v, mtu)
- }
-
- if v := wep.Capabilities(); v != capabilities {
- t.Fatalf("Unexpected capabilities: got=%v, want=%v", v, capabilities)
- }
-
- if v := wep.MaxHeaderLength(); v != hdrLen {
- t.Fatalf("Unexpected MaxHeaderLength: got=%v, want=%v", v, hdrLen)
- }
-
- if v := wep.LinkAddress(); v != linkAddr {
- t.Fatalf("Unexpected LinkAddress: got=%q, want=%q", v, linkAddr)
- }
-}
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
deleted file mode 100644
index 6a4839fb8..000000000
--- a/pkg/tcpip/network/BUILD
+++ /dev/null
@@ -1,22 +0,0 @@
-load("//tools:defs.bzl", "go_test")
-
-package(licenses = ["notice"])
-
-go_test(
- name = "ip_test",
- size = "small",
- srcs = [
- "ip_test.go",
- ],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/loopback",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/tcp",
- "//pkg/tcpip/transport/udp",
- ],
-)
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
deleted file mode 100644
index eddf7b725..000000000
--- a/pkg/tcpip/network/arp/BUILD
+++ /dev/null
@@ -1,32 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "arp",
- srcs = ["arp.go"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/stack",
- ],
-)
-
-go_test(
- name = "arp_test",
- size = "small",
- srcs = ["arp_test.go"],
- deps = [
- ":arp",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/sniffer",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/icmp",
- ],
-)
diff --git a/pkg/tcpip/network/arp/arp_state_autogen.go b/pkg/tcpip/network/arp/arp_state_autogen.go
new file mode 100755
index 000000000..5cd8535e3
--- /dev/null
+++ b/pkg/tcpip/network/arp/arp_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package arp
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
deleted file mode 100644
index 03cf03b6d..000000000
--- a/pkg/tcpip/network/arp/arp_test.go
+++ /dev/null
@@ -1,145 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package arp_test
-
-import (
- "context"
- "strconv"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.dev/gvisor/pkg/tcpip/network/arp"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
-)
-
-const (
- stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
- stackAddr1 = tcpip.Address("\x0a\x00\x00\x01")
- stackAddr2 = tcpip.Address("\x0a\x00\x00\x02")
- stackAddrBad = tcpip.Address("\x0a\x00\x00\x03")
-)
-
-type testContext struct {
- t *testing.T
- linkEP *channel.Endpoint
- s *stack.Stack
-}
-
-func newTestContext(t *testing.T) *testContext {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()},
- })
-
- const defaultMTU = 65536
- ep := channel.New(256, defaultMTU, stackLinkAddr)
- wep := stack.LinkEndpoint(ep)
-
- if testing.Verbose() {
- wep = sniffer.New(ep)
- }
- if err := s.CreateNIC(1, wep); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil {
- t.Fatalf("AddAddress for ipv4 failed: %v", err)
- }
- if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr2); err != nil {
- t.Fatalf("AddAddress for ipv4 failed: %v", err)
- }
- if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
- t.Fatalf("AddAddress for arp failed: %v", err)
- }
-
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv4EmptySubnet,
- NIC: 1,
- }})
-
- return &testContext{
- t: t,
- s: s,
- linkEP: ep,
- }
-}
-
-func (c *testContext) cleanup() {
- c.linkEP.Close()
-}
-
-func TestDirectRequest(t *testing.T) {
- c := newTestContext(t)
- defer c.cleanup()
-
- const senderMAC = "\x01\x02\x03\x04\x05\x06"
- const senderIPv4 = "\x0a\x00\x00\x02"
-
- v := make(buffer.View, header.ARPSize)
- h := header.ARP(v)
- h.SetIPv4OverEthernet()
- h.SetOp(header.ARPRequest)
- copy(h.HardwareAddressSender(), senderMAC)
- copy(h.ProtocolAddressSender(), senderIPv4)
-
- inject := func(addr tcpip.Address) {
- copy(h.ProtocolAddressTarget(), addr)
- c.linkEP.InjectInbound(arp.ProtocolNumber, tcpip.PacketBuffer{
- Data: v.ToVectorisedView(),
- })
- }
-
- for i, address := range []tcpip.Address{stackAddr1, stackAddr2} {
- t.Run(strconv.Itoa(i), func(t *testing.T) {
- inject(address)
- pi, _ := c.linkEP.ReadContext(context.Background())
- if pi.Proto != arp.ProtocolNumber {
- t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto)
- }
- rep := header.ARP(pi.Pkt.Header.View())
- if !rep.IsValid() {
- t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength())
- }
- if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want {
- t.Errorf("got HardwareAddressSender = %s, want = %s", got, want)
- }
- if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want {
- t.Errorf("got ProtocolAddressSender = %s, want = %s", got, want)
- }
- if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress(h.HardwareAddressSender()); got != want {
- t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want)
- }
- if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want {
- t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, want)
- }
- })
- }
-
- inject(stackAddrBad)
- // Sleep tests are gross, but this will only potentially flake
- // if there's a bug. If there is no bug this will reliably
- // succeed.
- ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
- if pkt, ok := c.linkEP.ReadContext(ctx); ok {
- t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto)
- }
-}
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
deleted file mode 100644
index d1c728ccf..000000000
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ /dev/null
@@ -1,45 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-package(licenses = ["notice"])
-
-go_template_instance(
- name = "reassembler_list",
- out = "reassembler_list.go",
- package = "fragmentation",
- prefix = "reassembler",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*reassembler",
- "Linker": "*reassembler",
- },
-)
-
-go_library(
- name = "fragmentation",
- srcs = [
- "frag_heap.go",
- "fragmentation.go",
- "reassembler.go",
- "reassembler_list.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/log",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- ],
-)
-
-go_test(
- name = "fragmentation_test",
- size = "small",
- srcs = [
- "frag_heap_test.go",
- "fragmentation_test.go",
- "reassembler_test.go",
- ],
- library = ":fragmentation",
- deps = ["//pkg/tcpip/buffer"],
-)
diff --git a/pkg/tcpip/network/fragmentation/frag_heap_test.go b/pkg/tcpip/network/fragmentation/frag_heap_test.go
deleted file mode 100644
index 9ececcb9f..000000000
--- a/pkg/tcpip/network/fragmentation/frag_heap_test.go
+++ /dev/null
@@ -1,126 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package fragmentation
-
-import (
- "container/heap"
- "reflect"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
-)
-
-var reassambleTestCases = []struct {
- comment string
- in []fragment
- want buffer.VectorisedView
-}{
- {
- comment: "Non-overlapping in-order",
- in: []fragment{
- {offset: 0, vv: vv(1, "0")},
- {offset: 1, vv: vv(1, "1")},
- },
- want: vv(2, "0", "1"),
- },
- {
- comment: "Non-overlapping out-of-order",
- in: []fragment{
- {offset: 1, vv: vv(1, "1")},
- {offset: 0, vv: vv(1, "0")},
- },
- want: vv(2, "0", "1"),
- },
- {
- comment: "Duplicated packets",
- in: []fragment{
- {offset: 0, vv: vv(1, "0")},
- {offset: 0, vv: vv(1, "0")},
- },
- want: vv(1, "0"),
- },
- {
- comment: "Overlapping in-order",
- in: []fragment{
- {offset: 0, vv: vv(2, "01")},
- {offset: 1, vv: vv(2, "12")},
- },
- want: vv(3, "01", "2"),
- },
- {
- comment: "Overlapping out-of-order",
- in: []fragment{
- {offset: 1, vv: vv(2, "12")},
- {offset: 0, vv: vv(2, "01")},
- },
- want: vv(3, "01", "2"),
- },
- {
- comment: "Overlapping subset in-order",
- in: []fragment{
- {offset: 0, vv: vv(3, "012")},
- {offset: 1, vv: vv(1, "1")},
- },
- want: vv(3, "012"),
- },
- {
- comment: "Overlapping subset out-of-order",
- in: []fragment{
- {offset: 1, vv: vv(1, "1")},
- {offset: 0, vv: vv(3, "012")},
- },
- want: vv(3, "012"),
- },
-}
-
-func TestReassamble(t *testing.T) {
- for _, c := range reassambleTestCases {
- t.Run(c.comment, func(t *testing.T) {
- h := make(fragHeap, 0, 8)
- heap.Init(&h)
- for _, f := range c.in {
- heap.Push(&h, f)
- }
- got, err := h.reassemble()
- if err != nil {
- t.Fatal(err)
- }
- if !reflect.DeepEqual(got, c.want) {
- t.Errorf("got reassemble(%+v) = %v, want = %v", c.in, got, c.want)
- }
- })
- }
-}
-
-func TestReassambleFailsForNonZeroOffset(t *testing.T) {
- h := make(fragHeap, 0, 8)
- heap.Init(&h)
- heap.Push(&h, fragment{offset: 1, vv: vv(1, "0")})
- _, err := h.reassemble()
- if err == nil {
- t.Errorf("reassemble() did not fail when the first packet had offset != 0")
- }
-}
-
-func TestReassambleFailsForHoles(t *testing.T) {
- h := make(fragHeap, 0, 8)
- heap.Init(&h)
- heap.Push(&h, fragment{offset: 0, vv: vv(1, "0")})
- heap.Push(&h, fragment{offset: 2, vv: vv(1, "1")})
- _, err := h.reassemble()
- if err == nil {
- t.Errorf("reassemble() did not fail when there was a hole in the packet")
- }
-}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_state_autogen.go b/pkg/tcpip/network/fragmentation/fragmentation_state_autogen.go
new file mode 100755
index 000000000..cbaecdaa7
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/fragmentation_state_autogen.go
@@ -0,0 +1,38 @@
+// automatically generated by stateify.
+
+package fragmentation
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (x *reassemblerList) beforeSave() {}
+func (x *reassemblerList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *reassemblerList) afterLoad() {}
+func (x *reassemblerList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *reassemblerEntry) beforeSave() {}
+func (x *reassemblerEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *reassemblerEntry) afterLoad() {}
+func (x *reassemblerEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func init() {
+ state.Register("pkg/tcpip/network/fragmentation.reassemblerList", (*reassemblerList)(nil), state.Fns{Save: (*reassemblerList).save, Load: (*reassemblerList).load})
+ state.Register("pkg/tcpip/network/fragmentation.reassemblerEntry", (*reassemblerEntry)(nil), state.Fns{Save: (*reassemblerEntry).save, Load: (*reassemblerEntry).load})
+}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
deleted file mode 100644
index 72c0f53be..000000000
--- a/pkg/tcpip/network/fragmentation/fragmentation_test.go
+++ /dev/null
@@ -1,165 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package fragmentation
-
-import (
- "reflect"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
-)
-
-// vv is a helper to build VectorisedView from different strings.
-func vv(size int, pieces ...string) buffer.VectorisedView {
- views := make([]buffer.View, len(pieces))
- for i, p := range pieces {
- views[i] = []byte(p)
- }
-
- return buffer.NewVectorisedView(size, views)
-}
-
-type processInput struct {
- id uint32
- first uint16
- last uint16
- more bool
- vv buffer.VectorisedView
-}
-
-type processOutput struct {
- vv buffer.VectorisedView
- done bool
-}
-
-var processTestCases = []struct {
- comment string
- in []processInput
- out []processOutput
-}{
- {
- comment: "One ID",
- in: []processInput{
- {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")},
- {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")},
- },
- out: []processOutput{
- {vv: buffer.VectorisedView{}, done: false},
- {vv: vv(4, "01", "23"), done: true},
- },
- },
- {
- comment: "Two IDs",
- in: []processInput{
- {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")},
- {id: 1, first: 0, last: 1, more: true, vv: vv(2, "ab")},
- {id: 1, first: 2, last: 3, more: false, vv: vv(2, "cd")},
- {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")},
- },
- out: []processOutput{
- {vv: buffer.VectorisedView{}, done: false},
- {vv: buffer.VectorisedView{}, done: false},
- {vv: vv(4, "ab", "cd"), done: true},
- {vv: vv(4, "01", "23"), done: true},
- },
- },
-}
-
-func TestFragmentationProcess(t *testing.T) {
- for _, c := range processTestCases {
- t.Run(c.comment, func(t *testing.T) {
- f := NewFragmentation(1024, 512, DefaultReassembleTimeout)
- for i, in := range c.in {
- vv, done, err := f.Process(in.id, in.first, in.last, in.more, in.vv)
- if err != nil {
- t.Fatalf("f.Process(%+v, %+d, %+d, %t, %+v) failed: %v", in.id, in.first, in.last, in.more, in.vv, err)
- }
- if !reflect.DeepEqual(vv, c.out[i].vv) {
- t.Errorf("got Process(%d) = %+v, want = %+v", i, vv, c.out[i].vv)
- }
- if done != c.out[i].done {
- t.Errorf("got Process(%d) = %+v, want = %+v", i, done, c.out[i].done)
- }
- if c.out[i].done {
- if _, ok := f.reassemblers[in.id]; ok {
- t.Errorf("Process(%d) did not remove buffer from reassemblers", i)
- }
- for n := f.rList.Front(); n != nil; n = n.Next() {
- if n.id == in.id {
- t.Errorf("Process(%d) did not remove buffer from rList", i)
- }
- }
- }
- }
- })
- }
-}
-
-func TestReassemblingTimeout(t *testing.T) {
- timeout := time.Millisecond
- f := NewFragmentation(1024, 512, timeout)
- // Send first fragment with id = 0, first = 0, last = 0, and more = true.
- f.Process(0, 0, 0, true, vv(1, "0"))
- // Sleep more than the timeout.
- time.Sleep(2 * timeout)
- // Send another fragment that completes a packet.
- // However, no packet should be reassembled because the fragment arrived after the timeout.
- _, done, err := f.Process(0, 1, 1, false, vv(1, "1"))
- if err != nil {
- t.Fatalf("f.Process(0, 1, 1, false, vv(1, \"1\")) failed: %v", err)
- }
- if done {
- t.Errorf("Fragmentation does not respect the reassembling timeout.")
- }
-}
-
-func TestMemoryLimits(t *testing.T) {
- f := NewFragmentation(3, 1, DefaultReassembleTimeout)
- // Send first fragment with id = 0.
- f.Process(0, 0, 0, true, vv(1, "0"))
- // Send first fragment with id = 1.
- f.Process(1, 0, 0, true, vv(1, "1"))
- // Send first fragment with id = 2.
- f.Process(2, 0, 0, true, vv(1, "2"))
-
- // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be
- // evicted.
- f.Process(3, 0, 0, true, vv(1, "3"))
-
- if _, ok := f.reassemblers[0]; ok {
- t.Errorf("Memory limits are not respected: id=0 has not been evicted.")
- }
- if _, ok := f.reassemblers[1]; ok {
- t.Errorf("Memory limits are not respected: id=1 has not been evicted.")
- }
- if _, ok := f.reassemblers[3]; !ok {
- t.Errorf("Implementation of memory limits is wrong: id=3 is not present.")
- }
-}
-
-func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
- f := NewFragmentation(1, 0, DefaultReassembleTimeout)
- // Send first fragment with id = 0.
- f.Process(0, 0, 0, true, vv(1, "0"))
- // Send the same packet again.
- f.Process(0, 0, 0, true, vv(1, "0"))
-
- got := f.size
- want := 1
- if got != want {
- t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want)
- }
-}
diff --git a/pkg/tcpip/network/fragmentation/reassembler_list.go b/pkg/tcpip/network/fragmentation/reassembler_list.go
new file mode 100755
index 000000000..3189cae29
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/reassembler_list.go
@@ -0,0 +1,173 @@
+package fragmentation
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type reassemblerElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (reassemblerElementMapper) linkerFor(elem *reassembler) *reassembler { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type reassemblerList struct {
+ head *reassembler
+ tail *reassembler
+}
+
+// Reset resets list l to the empty state.
+func (l *reassemblerList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *reassemblerList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *reassemblerList) Front() *reassembler {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *reassemblerList) Back() *reassembler {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *reassemblerList) PushFront(e *reassembler) {
+ reassemblerElementMapper{}.linkerFor(e).SetNext(l.head)
+ reassemblerElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ reassemblerElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *reassemblerList) PushBack(e *reassembler) {
+ reassemblerElementMapper{}.linkerFor(e).SetNext(nil)
+ reassemblerElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ reassemblerElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *reassemblerList) PushBackList(m *reassemblerList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ reassemblerElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ reassemblerElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *reassemblerList) InsertAfter(b, e *reassembler) {
+ a := reassemblerElementMapper{}.linkerFor(b).Next()
+ reassemblerElementMapper{}.linkerFor(e).SetNext(a)
+ reassemblerElementMapper{}.linkerFor(e).SetPrev(b)
+ reassemblerElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ reassemblerElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *reassemblerList) InsertBefore(a, e *reassembler) {
+ b := reassemblerElementMapper{}.linkerFor(a).Prev()
+ reassemblerElementMapper{}.linkerFor(e).SetNext(a)
+ reassemblerElementMapper{}.linkerFor(e).SetPrev(b)
+ reassemblerElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ reassemblerElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *reassemblerList) Remove(e *reassembler) {
+ prev := reassemblerElementMapper{}.linkerFor(e).Prev()
+ next := reassemblerElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ reassemblerElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ reassemblerElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type reassemblerEntry struct {
+ next *reassembler
+ prev *reassembler
+}
+
+// Next returns the entry that follows e in the list.
+func (e *reassemblerEntry) Next() *reassembler {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *reassemblerEntry) Prev() *reassembler {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *reassemblerEntry) SetNext(elem *reassembler) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *reassemblerEntry) SetPrev(elem *reassembler) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go
deleted file mode 100644
index 7eee0710d..000000000
--- a/pkg/tcpip/network/fragmentation/reassembler_test.go
+++ /dev/null
@@ -1,105 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package fragmentation
-
-import (
- "math"
- "reflect"
- "testing"
-)
-
-type updateHolesInput struct {
- first uint16
- last uint16
- more bool
-}
-
-var holesTestCases = []struct {
- comment string
- in []updateHolesInput
- want []hole
-}{
- {
- comment: "No fragments. Expected holes: {[0 -> inf]}.",
- in: []updateHolesInput{},
- want: []hole{{first: 0, last: math.MaxUint16, deleted: false}},
- },
- {
- comment: "One fragment at beginning. Expected holes: {[2, inf]}.",
- in: []updateHolesInput{{first: 0, last: 1, more: true}},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 2, last: math.MaxUint16, deleted: false},
- },
- },
- {
- comment: "One fragment in the middle. Expected holes: {[0, 0], [3, inf]}.",
- in: []updateHolesInput{{first: 1, last: 2, more: true}},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 0, last: 0, deleted: false},
- {first: 3, last: math.MaxUint16, deleted: false},
- },
- },
- {
- comment: "One fragment at the end. Expected holes: {[0, 0]}.",
- in: []updateHolesInput{{first: 1, last: 2, more: false}},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 0, last: 0, deleted: false},
- },
- },
- {
- comment: "One fragment completing a packet. Expected holes: {}.",
- in: []updateHolesInput{{first: 0, last: 1, more: false}},
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- },
- },
- {
- comment: "Two non-overlapping fragments completing a packet. Expected holes: {}.",
- in: []updateHolesInput{
- {first: 0, last: 1, more: true},
- {first: 2, last: 3, more: false},
- },
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 2, last: math.MaxUint16, deleted: true},
- },
- },
- {
- comment: "Two overlapping fragments completing a packet. Expected holes: {}.",
- in: []updateHolesInput{
- {first: 0, last: 2, more: true},
- {first: 2, last: 3, more: false},
- },
- want: []hole{
- {first: 0, last: math.MaxUint16, deleted: true},
- {first: 3, last: math.MaxUint16, deleted: true},
- },
- },
-}
-
-func TestUpdateHoles(t *testing.T) {
- for _, c := range holesTestCases {
- r := newReassembler(0)
- for _, i := range c.in {
- r.updateHoles(i.first, i.last, i.more)
- }
- if !reflect.DeepEqual(r.holes, c.want) {
- t.Errorf("Test \"%s\" produced unexepetced holes. Got %v. Want %v", c.comment, r.holes, c.want)
- }
- }
-}
diff --git a/pkg/tcpip/network/hash/BUILD b/pkg/tcpip/network/hash/BUILD
deleted file mode 100644
index 872165866..000000000
--- a/pkg/tcpip/network/hash/BUILD
+++ /dev/null
@@ -1,13 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "hash",
- srcs = ["hash.go"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/rand",
- "//pkg/tcpip/header",
- ],
-)
diff --git a/pkg/tcpip/network/hash/hash_state_autogen.go b/pkg/tcpip/network/hash/hash_state_autogen.go
new file mode 100755
index 000000000..9467fe298
--- /dev/null
+++ b/pkg/tcpip/network/hash/hash_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package hash
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
deleted file mode 100644
index f4d78f8c6..000000000
--- a/pkg/tcpip/network/ip_test.go
+++ /dev/null
@@ -1,655 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ip_test
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
-)
-
-const (
- localIpv4Addr = "\x0a\x00\x00\x01"
- localIpv4PrefixLen = 24
- remoteIpv4Addr = "\x0a\x00\x00\x02"
- ipv4SubnetAddr = "\x0a\x00\x00\x00"
- ipv4SubnetMask = "\xff\xff\xff\x00"
- ipv4Gateway = "\x0a\x00\x00\x03"
- localIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- localIpv6PrefixLen = 120
- remoteIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
- ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00"
- ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
-)
-
-// testObject implements two interfaces: LinkEndpoint and TransportDispatcher.
-// The former is used to pretend that it's a link endpoint so that we can
-// inspect packets written by the network endpoints. The latter is used to
-// pretend that it's the network stack so that it can inspect incoming packets
-// that have been handled by the network endpoints.
-//
-// Packets are checked by comparing their fields/values against the expected
-// values stored in the test object itself.
-type testObject struct {
- t *testing.T
- protocol tcpip.TransportProtocolNumber
- contents []byte
- srcAddr tcpip.Address
- dstAddr tcpip.Address
- v4 bool
- typ stack.ControlType
- extra uint32
-
- dataCalls int
- controlCalls int
-}
-
-// checkValues verifies that the transport protocol, data contents, src & dst
-// addresses of a packet match what's expected. If any field doesn't match, the
-// test fails.
-func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, srcAddr, dstAddr tcpip.Address) {
- v := vv.ToView()
- if protocol != t.protocol {
- t.t.Errorf("protocol = %v, want %v", protocol, t.protocol)
- }
-
- if srcAddr != t.srcAddr {
- t.t.Errorf("srcAddr = %v, want %v", srcAddr, t.srcAddr)
- }
-
- if dstAddr != t.dstAddr {
- t.t.Errorf("dstAddr = %v, want %v", dstAddr, t.dstAddr)
- }
-
- if len(v) != len(t.contents) {
- t.t.Fatalf("len(payload) = %v, want %v", len(v), len(t.contents))
- }
-
- for i := range t.contents {
- if t.contents[i] != v[i] {
- t.t.Fatalf("payload[%v] = %v, want %v", i, v[i], t.contents[i])
- }
- }
-}
-
-// DeliverTransportPacket is called by network endpoints after parsing incoming
-// packets. This is used by the test object to verify that the results of the
-// parsing are expected.
-func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) {
- t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress)
- t.dataCalls++
-}
-
-// DeliverTransportControlPacket is called by network endpoints after parsing
-// incoming control (ICMP) packets. This is used by the test object to verify
-// that the results of the parsing are expected.
-func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
- t.checkValues(trans, pkt.Data, remote, local)
- if typ != t.typ {
- t.t.Errorf("typ = %v, want %v", typ, t.typ)
- }
- if extra != t.extra {
- t.t.Errorf("extra = %v, want %v", extra, t.extra)
- }
- t.controlCalls++
-}
-
-// Attach is only implemented to satisfy the LinkEndpoint interface.
-func (*testObject) Attach(stack.NetworkDispatcher) {}
-
-// IsAttached implements stack.LinkEndpoint.IsAttached.
-func (*testObject) IsAttached() bool {
- return true
-}
-
-// MTU implements stack.LinkEndpoint.MTU. It just returns a constant that
-// matches the linux loopback MTU.
-func (*testObject) MTU() uint32 {
- return 65536
-}
-
-// Capabilities implements stack.LinkEndpoint.Capabilities.
-func (*testObject) Capabilities() stack.LinkEndpointCapabilities {
- return 0
-}
-
-// MaxHeaderLength is only implemented to satisfy the LinkEndpoint interface.
-func (*testObject) MaxHeaderLength() uint16 {
- return 0
-}
-
-// LinkAddress returns the link address of this endpoint.
-func (*testObject) LinkAddress() tcpip.LinkAddress {
- return ""
-}
-
-// Wait implements stack.LinkEndpoint.Wait.
-func (*testObject) Wait() {}
-
-// WritePacket is called by network endpoints after producing a packet and
-// writing it to the link endpoint. This is used by the test object to verify
-// that the produced packet is as expected.
-func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
- var prot tcpip.TransportProtocolNumber
- var srcAddr tcpip.Address
- var dstAddr tcpip.Address
-
- if t.v4 {
- h := header.IPv4(pkt.Header.View())
- prot = tcpip.TransportProtocolNumber(h.Protocol())
- srcAddr = h.SourceAddress()
- dstAddr = h.DestinationAddress()
-
- } else {
- h := header.IPv6(pkt.Header.View())
- prot = tcpip.TransportProtocolNumber(h.NextHeader())
- srcAddr = h.SourceAddress()
- dstAddr = h.DestinationAddress()
- }
- t.checkValues(prot, pkt.Data, srcAddr, dstAddr)
- return nil
-}
-
-// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- panic("not implemented")
-}
-
-func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
- return tcpip.ErrNotSupported
-}
-
-func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
- })
- s.CreateNIC(1, loopback.New())
- s.AddAddress(1, ipv4.ProtocolNumber, local)
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv4EmptySubnet,
- Gateway: ipv4Gateway,
- NIC: 1,
- }})
-
- return s.FindRoute(1, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */)
-}
-
-func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
- })
- s.CreateNIC(1, loopback.New())
- s.AddAddress(1, ipv6.ProtocolNumber, local)
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv6EmptySubnet,
- Gateway: ipv6Gateway,
- NIC: 1,
- }})
-
- return s.FindRoute(1, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */)
-}
-
-func buildDummyStack() *stack.Stack {
- return stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
- })
-}
-
-func TestIPv4Send(t *testing.T) {
- o := testObject{t: t, v4: true}
- proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, nil, &o, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- // Allocate and initialize the payload view.
- payload := buffer.NewView(100)
- for i := 0; i < len(payload); i++ {
- payload[i] = uint8(i)
- }
-
- // Allocate the header buffer.
- hdr := buffer.NewPrependable(int(ep.MaxHeaderLength()))
-
- // Issue the write.
- o.protocol = 123
- o.srcAddr = localIpv4Addr
- o.dstAddr = remoteIpv4Addr
- o.contents = payload
-
- r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
- if err != nil {
- t.Fatalf("could not find route: %v", err)
- }
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
- Header: hdr,
- Data: payload.ToVectorisedView(),
- }); err != nil {
- t.Fatalf("WritePacket failed: %v", err)
- }
-}
-
-func TestIPv4Receive(t *testing.T) {
- o := testObject{t: t, v4: true}
- proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- totalLen := header.IPv4MinimumSize + 30
- view := buffer.NewView(totalLen)
- ip := header.IPv4(view)
- ip.Encode(&header.IPv4Fields{
- IHL: header.IPv4MinimumSize,
- TotalLength: uint16(totalLen),
- TTL: 20,
- Protocol: 10,
- SrcAddr: remoteIpv4Addr,
- DstAddr: localIpv4Addr,
- })
-
- // Make payload be non-zero.
- for i := header.IPv4MinimumSize; i < totalLen; i++ {
- view[i] = uint8(i)
- }
-
- // Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv4Addr
- o.dstAddr = localIpv4Addr
- o.contents = view[header.IPv4MinimumSize:totalLen]
-
- r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
- if err != nil {
- t.Fatalf("could not find route: %v", err)
- }
- ep.HandlePacket(&r, tcpip.PacketBuffer{
- Data: view.ToVectorisedView(),
- })
- if o.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
- }
-}
-
-func TestIPv4ReceiveControl(t *testing.T) {
- const mtu = 0xbeef - header.IPv4MinimumSize
- cases := []struct {
- name string
- expectedCount int
- fragmentOffset uint16
- code uint8
- expectedTyp stack.ControlType
- expectedExtra uint32
- trunc int
- }{
- {"FragmentationNeeded", 1, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 0},
- {"Truncated (10 bytes missing)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 10},
- {"Truncated (missing IPv4 header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.IPv4MinimumSize + 8},
- {"Truncated (missing 'extra info')", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 4 + header.IPv4MinimumSize + 8},
- {"Truncated (missing ICMP header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.ICMPv4MinimumSize + header.IPv4MinimumSize + 8},
- {"Port unreachable", 1, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
- {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
- {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8},
- }
- r, err := buildIPv4Route(localIpv4Addr, "\x0a\x00\x00\xbb")
- if err != nil {
- t.Fatal(err)
- }
- for _, c := range cases {
- t.Run(c.name, func(t *testing.T) {
- o := testObject{t: t}
- proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
- defer ep.Close()
-
- const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize
- view := buffer.NewView(dataOffset + 8)
-
- // Create the outer IPv4 header.
- ip := header.IPv4(view)
- ip.Encode(&header.IPv4Fields{
- IHL: header.IPv4MinimumSize,
- TotalLength: uint16(len(view) - c.trunc),
- TTL: 20,
- Protocol: uint8(header.ICMPv4ProtocolNumber),
- SrcAddr: "\x0a\x00\x00\xbb",
- DstAddr: localIpv4Addr,
- })
-
- // Create the ICMP header.
- icmp := header.ICMPv4(view[header.IPv4MinimumSize:])
- icmp.SetType(header.ICMPv4DstUnreachable)
- icmp.SetCode(c.code)
- icmp.SetIdent(0xdead)
- icmp.SetSequence(0xbeef)
-
- // Create the inner IPv4 header.
- ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:])
- ip.Encode(&header.IPv4Fields{
- IHL: header.IPv4MinimumSize,
- TotalLength: 100,
- TTL: 20,
- Protocol: 10,
- FragmentOffset: c.fragmentOffset,
- SrcAddr: localIpv4Addr,
- DstAddr: remoteIpv4Addr,
- })
-
- // Make payload be non-zero.
- for i := dataOffset; i < len(view); i++ {
- view[i] = uint8(i)
- }
-
- // Give packet to IPv4 endpoint, dispatcher will validate that
- // it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv4Addr
- o.dstAddr = localIpv4Addr
- o.contents = view[dataOffset:]
- o.typ = c.expectedTyp
- o.extra = c.expectedExtra
-
- vv := view[:len(view)-c.trunc].ToVectorisedView()
- ep.HandlePacket(&r, tcpip.PacketBuffer{
- Data: vv,
- })
- if want := c.expectedCount; o.controlCalls != want {
- t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
- }
- })
- }
-}
-
-func TestIPv4FragmentationReceive(t *testing.T) {
- o := testObject{t: t, v4: true}
- proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- totalLen := header.IPv4MinimumSize + 24
-
- frag1 := buffer.NewView(totalLen)
- ip1 := header.IPv4(frag1)
- ip1.Encode(&header.IPv4Fields{
- IHL: header.IPv4MinimumSize,
- TotalLength: uint16(totalLen),
- TTL: 20,
- Protocol: 10,
- FragmentOffset: 0,
- Flags: header.IPv4FlagMoreFragments,
- SrcAddr: remoteIpv4Addr,
- DstAddr: localIpv4Addr,
- })
- // Make payload be non-zero.
- for i := header.IPv4MinimumSize; i < totalLen; i++ {
- frag1[i] = uint8(i)
- }
-
- frag2 := buffer.NewView(totalLen)
- ip2 := header.IPv4(frag2)
- ip2.Encode(&header.IPv4Fields{
- IHL: header.IPv4MinimumSize,
- TotalLength: uint16(totalLen),
- TTL: 20,
- Protocol: 10,
- FragmentOffset: 24,
- SrcAddr: remoteIpv4Addr,
- DstAddr: localIpv4Addr,
- })
- // Make payload be non-zero.
- for i := header.IPv4MinimumSize; i < totalLen; i++ {
- frag2[i] = uint8(i)
- }
-
- // Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv4Addr
- o.dstAddr = localIpv4Addr
- o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
-
- r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
- if err != nil {
- t.Fatalf("could not find route: %v", err)
- }
-
- // Send first segment.
- ep.HandlePacket(&r, tcpip.PacketBuffer{
- Data: frag1.ToVectorisedView(),
- })
- if o.dataCalls != 0 {
- t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls)
- }
-
- // Send second segment.
- ep.HandlePacket(&r, tcpip.PacketBuffer{
- Data: frag2.ToVectorisedView(),
- })
- if o.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
- }
-}
-
-func TestIPv6Send(t *testing.T) {
- o := testObject{t: t}
- proto := ipv6.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, nil, &o, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- // Allocate and initialize the payload view.
- payload := buffer.NewView(100)
- for i := 0; i < len(payload); i++ {
- payload[i] = uint8(i)
- }
-
- // Allocate the header buffer.
- hdr := buffer.NewPrependable(int(ep.MaxHeaderLength()))
-
- // Issue the write.
- o.protocol = 123
- o.srcAddr = localIpv6Addr
- o.dstAddr = remoteIpv6Addr
- o.contents = payload
-
- r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr)
- if err != nil {
- t.Fatalf("could not find route: %v", err)
- }
- if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
- Header: hdr,
- Data: payload.ToVectorisedView(),
- }); err != nil {
- t.Fatalf("WritePacket failed: %v", err)
- }
-}
-
-func TestIPv6Receive(t *testing.T) {
- o := testObject{t: t}
- proto := ipv6.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- totalLen := header.IPv6MinimumSize + 30
- view := buffer.NewView(totalLen)
- ip := header.IPv6(view)
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(totalLen - header.IPv6MinimumSize),
- NextHeader: 10,
- HopLimit: 20,
- SrcAddr: remoteIpv6Addr,
- DstAddr: localIpv6Addr,
- })
-
- // Make payload be non-zero.
- for i := header.IPv6MinimumSize; i < totalLen; i++ {
- view[i] = uint8(i)
- }
-
- // Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv6Addr
- o.dstAddr = localIpv6Addr
- o.contents = view[header.IPv6MinimumSize:totalLen]
-
- r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr)
- if err != nil {
- t.Fatalf("could not find route: %v", err)
- }
-
- ep.HandlePacket(&r, tcpip.PacketBuffer{
- Data: view.ToVectorisedView(),
- })
- if o.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
- }
-}
-
-func TestIPv6ReceiveControl(t *testing.T) {
- newUint16 := func(v uint16) *uint16 { return &v }
-
- const mtu = 0xffff
- const outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa"
- cases := []struct {
- name string
- expectedCount int
- fragmentOffset *uint16
- typ header.ICMPv6Type
- code uint8
- expectedTyp stack.ControlType
- expectedExtra uint32
- trunc int
- }{
- {"PacketTooBig", 1, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 0},
- {"Truncated (10 bytes missing)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 10},
- {"Truncated (missing IPv6 header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.IPv6MinimumSize + 8},
- {"Truncated PacketTooBig (missing 'extra info')", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 4 + header.IPv6MinimumSize + 8},
- {"Truncated (missing ICMP header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + 8},
- {"Port unreachable", 1, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
- {"Truncated DstUnreachable (missing 'extra info')", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 4 + header.IPv6MinimumSize + 8},
- {"Fragmented, zero offset", 1, newUint16(0), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
- {"Non-zero fragment offset", 0, newUint16(100), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
- {"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8},
- }
- r, err := buildIPv6Route(
- localIpv6Addr,
- "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa",
- )
- if err != nil {
- t.Fatal(err)
- }
- for _, c := range cases {
- t.Run(c.name, func(t *testing.T) {
- o := testObject{t: t}
- proto := ipv6.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- defer ep.Close()
-
- dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize
- if c.fragmentOffset != nil {
- dataOffset += header.IPv6FragmentHeaderSize
- }
- view := buffer.NewView(dataOffset + 8)
-
- // Create the outer IPv6 header.
- ip := header.IPv6(view)
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 20,
- SrcAddr: outerSrcAddr,
- DstAddr: localIpv6Addr,
- })
-
- // Create the ICMP header.
- icmp := header.ICMPv6(view[header.IPv6MinimumSize:])
- icmp.SetType(c.typ)
- icmp.SetCode(c.code)
- icmp.SetIdent(0xdead)
- icmp.SetSequence(0xbeef)
-
- // Create the inner IPv6 header.
- ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:])
- ip.Encode(&header.IPv6Fields{
- PayloadLength: 100,
- NextHeader: 10,
- HopLimit: 20,
- SrcAddr: localIpv6Addr,
- DstAddr: remoteIpv6Addr,
- })
-
- // Build the fragmentation header if needed.
- if c.fragmentOffset != nil {
- ip.SetNextHeader(header.IPv6FragmentHeader)
- frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize:])
- frag.Encode(&header.IPv6FragmentFields{
- NextHeader: 10,
- FragmentOffset: *c.fragmentOffset,
- M: true,
- Identification: 0x12345678,
- })
- }
-
- // Make payload be non-zero.
- for i := dataOffset; i < len(view); i++ {
- view[i] = uint8(i)
- }
-
- // Give packet to IPv6 endpoint, dispatcher will validate that
- // it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv6Addr
- o.dstAddr = localIpv6Addr
- o.contents = view[dataOffset:]
- o.typ = c.expectedTyp
- o.extra = c.expectedExtra
-
- // Set ICMPv6 checksum.
- icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{}))
-
- ep.HandlePacket(&r, tcpip.PacketBuffer{
- Data: view[:len(view)-c.trunc].ToVectorisedView(),
- })
- if want := c.expectedCount; o.controlCalls != want {
- t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
- }
- })
- }
-}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
deleted file mode 100644
index 0fef2b1f1..000000000
--- a/pkg/tcpip/network/ipv4/BUILD
+++ /dev/null
@@ -1,39 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "ipv4",
- srcs = [
- "icmp.go",
- "ipv4.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
- "//pkg/tcpip/network/fragmentation",
- "//pkg/tcpip/network/hash",
- "//pkg/tcpip/stack",
- ],
-)
-
-go_test(
- name = "ipv4_test",
- size = "small",
- srcs = ["ipv4_test.go"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/sniffer",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/tcp",
- "//pkg/tcpip/transport/udp",
- "//pkg/waiter",
- ],
-)
diff --git a/pkg/tcpip/network/ipv4/ipv4_state_autogen.go b/pkg/tcpip/network/ipv4/ipv4_state_autogen.go
new file mode 100755
index 000000000..250b2128e
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/ipv4_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package ipv4
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
deleted file mode 100644
index e900f1b45..000000000
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ /dev/null
@@ -1,475 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ipv4_test
-
-import (
- "bytes"
- "encoding/hex"
- "math/rand"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-func TestExcludeBroadcast(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
- })
-
- const defaultMTU = 65536
- ep := stack.LinkEndpoint(channel.New(256, defaultMTU, ""))
- if testing.Verbose() {
- ep = sniffer.New(ep)
- }
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv4EmptySubnet,
- NIC: 1,
- }})
-
- randomAddr := tcpip.FullAddress{NIC: 1, Addr: "\x0a\x00\x00\x01", Port: 53}
-
- var wq waiter.Queue
- t.Run("WithoutPrimaryAddress", func(t *testing.T) {
- ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatal(err)
- }
- defer ep.Close()
-
- // Cannot connect using a broadcast address as the source.
- if err := ep.Connect(randomAddr); err != tcpip.ErrNoRoute {
- t.Errorf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute)
- }
-
- // However, we can bind to a broadcast address to listen.
- if err := ep.Bind(tcpip.FullAddress{Addr: header.IPv4Broadcast, Port: 53, NIC: 1}); err != nil {
- t.Errorf("Bind failed: %v", err)
- }
- })
-
- t.Run("WithPrimaryAddress", func(t *testing.T) {
- ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatal(err)
- }
- defer ep.Close()
-
- // Add a valid primary endpoint address, now we can connect.
- if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
- if err := ep.Connect(randomAddr); err != nil {
- t.Errorf("Connect failed: %v", err)
- }
- })
-}
-
-// makeHdrAndPayload generates a randomize packet. hdrLength indicates how much
-// data should already be in the header before WritePacket. extraLength
-// indicates how much extra space should be in the header. The payload is made
-// from many Views of the sizes listed in viewSizes.
-func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.Prependable, buffer.VectorisedView) {
- hdr := buffer.NewPrependable(hdrLength + extraLength)
- hdr.Prepend(hdrLength)
- rand.Read(hdr.View())
-
- var views []buffer.View
- totalLength := 0
- for _, s := range viewSizes {
- newView := buffer.NewView(s)
- rand.Read(newView)
- views = append(views, newView)
- totalLength += s
- }
- payload := buffer.NewVectorisedView(totalLength, views)
- return hdr, payload
-}
-
-// comparePayloads compared the contents of all the packets against the contents
-// of the source packet.
-func compareFragments(t *testing.T, packets []tcpip.PacketBuffer, sourcePacketInfo tcpip.PacketBuffer, mtu uint32) {
- t.Helper()
- // Make a complete array of the sourcePacketInfo packet.
- source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize])
- source = append(source, sourcePacketInfo.Header.View()...)
- source = append(source, sourcePacketInfo.Data.ToView()...)
-
- // Make a copy of the IP header, which will be modified in some fields to make
- // an expected header.
- sourceCopy := header.IPv4(append(buffer.View(nil), source[:source.HeaderLength()]...))
- sourceCopy.SetChecksum(0)
- sourceCopy.SetFlagsFragmentOffset(0, 0)
- sourceCopy.SetTotalLength(0)
- var offset uint16
- // Build up an array of the bytes sent.
- var reassembledPayload []byte
- for i, packet := range packets {
- // Confirm that the packet is valid.
- allBytes := packet.Header.View().ToVectorisedView()
- allBytes.Append(packet.Data)
- ip := header.IPv4(allBytes.ToView())
- if !ip.IsValid(len(ip)) {
- t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip))
- }
- if got, want := ip.CalculateChecksum(), uint16(0xffff); got != want {
- t.Errorf("ip.CalculateChecksum() got %#x, want %#x", got, want)
- }
- if got, want := len(ip), int(mtu); got > want {
- t.Errorf("fragment is too large, got %d want %d", got, want)
- }
- if got, want := packet.Header.UsedLength(), sourcePacketInfo.Header.UsedLength()+header.IPv4MinimumSize; i == 0 && want < int(mtu) && got != want {
- t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want)
- }
- if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want {
- t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want)
- }
- if i < len(packets)-1 {
- sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset)
- } else {
- sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, offset)
- }
- reassembledPayload = append(reassembledPayload, ip.Payload()...)
- offset += ip.TotalLength() - uint16(ip.HeaderLength())
- // Clear out the checksum and length from the ip because we can't compare
- // it.
- sourceCopy.SetTotalLength(uint16(len(ip)))
- sourceCopy.SetChecksum(0)
- sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum())
- if !bytes.Equal(ip[:ip.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]) {
- t.Errorf("ip[:ip.HeaderLength()] got:\n%s\nwant:\n%s", hex.Dump(ip[:ip.HeaderLength()]), hex.Dump(sourceCopy[:sourceCopy.HeaderLength()]))
- }
- }
- expected := source[source.HeaderLength():]
- if !bytes.Equal(reassembledPayload, expected) {
- t.Errorf("reassembledPayload got:\n%s\nwant:\n%s", hex.Dump(reassembledPayload), hex.Dump(expected))
- }
-}
-
-type errorChannel struct {
- *channel.Endpoint
- Ch chan tcpip.PacketBuffer
- packetCollectorErrors []*tcpip.Error
-}
-
-// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket
-// will return successive errors from packetCollectorErrors until the list is
-// empty and then return nil each time.
-func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel {
- return &errorChannel{
- Endpoint: channel.New(size, mtu, linkAddr),
- Ch: make(chan tcpip.PacketBuffer, size),
- packetCollectorErrors: packetCollectorErrors,
- }
-}
-
-// Drain removes all outbound packets from the channel and counts them.
-func (e *errorChannel) Drain() int {
- c := 0
- for {
- select {
- case <-e.Ch:
- c++
- default:
- return c
- }
- }
-}
-
-// WritePacket stores outbound packets into the channel.
-func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
- select {
- case e.Ch <- pkt:
- default:
- }
-
- nextError := (*tcpip.Error)(nil)
- if len(e.packetCollectorErrors) > 0 {
- nextError = e.packetCollectorErrors[0]
- e.packetCollectorErrors = e.packetCollectorErrors[1:]
- }
- return nextError
-}
-
-type context struct {
- stack.Route
- linkEP *errorChannel
-}
-
-func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context {
- // Make the packet and write it.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- })
- ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
- s.CreateNIC(1, ep)
- const (
- src = "\x10\x00\x00\x01"
- dst = "\x10\x00\x00\x02"
- )
- s.AddAddress(1, ipv4.ProtocolNumber, src)
- {
- subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{
- Destination: subnet,
- NIC: 1,
- }})
- }
- r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("s.FindRoute got %v, want %v", err, nil)
- }
- return context{
- Route: r,
- linkEP: ep,
- }
-}
-
-func TestFragmentation(t *testing.T) {
- var manyPayloadViewsSizes [1000]int
- for i := range manyPayloadViewsSizes {
- manyPayloadViewsSizes[i] = 7
- }
- fragTests := []struct {
- description string
- mtu uint32
- gso *stack.GSO
- hdrLength int
- extraLength int
- payloadViewsSizes []int
- expectedFrags int
- }{
- {"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1},
- {"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1},
- {"Fragmented", 800, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 2},
- {"FragmentedWithGsoNil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2},
- {"FragmentedWithManyViews", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25},
- {"FragmentedWithManyViewsAndPrependableBytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25},
- {"FragmentedWithBigHeader", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2},
- {"FragmentedWithBigHeaderAndPrependableBytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2},
- {"FragmentedWithMTUSmallerThanHeaderAndPrependableBytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6},
- }
-
- for _, ft := range fragTests {
- t.Run(ft.description, func(t *testing.T) {
- hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes)
- source := tcpip.PacketBuffer{
- Header: hdr,
- // Save the source payload because WritePacket will modify it.
- Data: payload.Clone(nil),
- }
- c := buildContext(t, nil, ft.mtu)
- err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
- Header: hdr,
- Data: payload,
- })
- if err != nil {
- t.Errorf("err got %v, want %v", err, nil)
- }
-
- var results []tcpip.PacketBuffer
- L:
- for {
- select {
- case pi := <-c.linkEP.Ch:
- results = append(results, pi)
- default:
- break L
- }
- }
-
- if got, want := len(results), ft.expectedFrags; got != want {
- t.Errorf("len(result) got %d, want %d", got, want)
- }
- if got, want := len(results), int(c.Route.Stats().IP.PacketsSent.Value()); got != want {
- t.Errorf("no errors yet len(result) got %d, want %d", got, want)
- }
- compareFragments(t, results, source, ft.mtu)
- })
- }
-}
-
-// TestFragmentationErrors checks that errors are returned from write packet
-// correctly.
-func TestFragmentationErrors(t *testing.T) {
- fragTests := []struct {
- description string
- mtu uint32
- hdrLength int
- payloadViewsSizes []int
- packetCollectorErrors []*tcpip.Error
- }{
- {"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
- {"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
- {"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}},
- {"ErrorOnFirstFragMTUSmallerThanHdr", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}},
- }
-
- for _, ft := range fragTests {
- t.Run(ft.description, func(t *testing.T) {
- hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes)
- c := buildContext(t, ft.packetCollectorErrors, ft.mtu)
- err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
- Header: hdr,
- Data: payload,
- })
- for i := 0; i < len(ft.packetCollectorErrors)-1; i++ {
- if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want {
- t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want)
- }
- }
- // We only need to check that last error because all the ones before are
- // nil.
- if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want {
- t.Errorf("err got %v, want %v", got, want)
- }
- if got, want := c.linkEP.Drain(), int(c.Route.Stats().IP.PacketsSent.Value())+1; err != nil && got != want {
- t.Errorf("after linkEP error len(result) got %d, want %d", got, want)
- }
- })
- }
-}
-
-func TestInvalidFragments(t *testing.T) {
- // These packets have both IHL and TotalLength set to 0.
- testCases := []struct {
- name string
- packets [][]byte
- wantMalformedIPPackets uint64
- wantMalformedFragments uint64
- }{
- {
- "ihl_totallen_zero_valid_frag_offset",
- [][]byte{
- {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x7d, 0x30, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- },
- 1,
- 0,
- },
- {
- "ihl_totallen_zero_invalid_frag_offset",
- [][]byte{
- {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x20, 0x00, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- },
- 1,
- 0,
- },
- {
- // Total Length of 37(20 bytes IP header + 17 bytes of
- // payload)
- // Frag Offset of 0x1ffe = 8190*8 = 65520
- // Leading to the fragment end to be past 65535.
- "ihl_totallen_valid_invalid_frag_offset_1",
- [][]byte{
- {0x45, 0x30, 0x00, 0x25, 0x6c, 0x74, 0x1f, 0xfe, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- },
- 1,
- 1,
- },
- // The following 3 tests were found by running a fuzzer and were
- // triggering a panic in the IPv4 reassembler code.
- {
- "ihl_less_than_ipv4_minimum_size_1",
- [][]byte{
- {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0x0, 0xf3, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- },
- 2,
- 0,
- },
- {
- "ihl_less_than_ipv4_minimum_size_2",
- [][]byte{
- {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x12, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- },
- 2,
- 0,
- },
- {
- "ihl_less_than_ipv4_minimum_size_3",
- [][]byte{
- {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x30, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- },
- 2,
- 0,
- },
- {
- "fragment_with_short_total_len_extra_payload",
- [][]byte{
- {0x46, 0x30, 0x00, 0x30, 0x30, 0x40, 0x0e, 0x12, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- {0x46, 0x30, 0x00, 0x18, 0x30, 0x40, 0x20, 0x00, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- },
- 1,
- 1,
- },
- {
- "multiple_fragments_with_more_fragments_set_to_false",
- [][]byte{
- {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x10, 0x00, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
- {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x01, 0x61, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
- {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x20, 0x00, 0x00, 0x06, 0x34, 0x1e, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
- },
- 1,
- 1,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- const nicID tcpip.NICID = 42
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{
- ipv4.NewProtocol(),
- },
- })
-
- var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30})
- var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31})
- ep := channel.New(10, 1500, linkAddr)
- s.CreateNIC(nicID, sniffer.New(ep))
-
- for _, pkt := range tc.packets {
- ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, tcpip.PacketBuffer{
- Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}),
- })
- }
-
- if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want {
- t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want)
- }
- if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), tc.wantMalformedFragments; got != want {
- t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want)
- }
- })
- }
-}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
deleted file mode 100644
index fb11874c6..000000000
--- a/pkg/tcpip/network/ipv6/BUILD
+++ /dev/null
@@ -1,40 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "ipv6",
- srcs = [
- "icmp.go",
- "ipv6.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/stack",
- ],
-)
-
-go_test(
- name = "ipv6_test",
- size = "small",
- srcs = [
- "icmp_test.go",
- "ipv6_test.go",
- "ndp_test.go",
- ],
- library = ":ipv6",
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/sniffer",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/icmp",
- "//pkg/tcpip/transport/udp",
- "//pkg/waiter",
- ],
-)
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
deleted file mode 100644
index 50c4b6474..000000000
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ /dev/null
@@ -1,958 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ipv6
-
-import (
- "context"
- "reflect"
- "strings"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
-)
-
-var (
- lladdr0 = header.LinkLocalAddr(linkAddr0)
- lladdr1 = header.LinkLocalAddr(linkAddr1)
-)
-
-type stubLinkEndpoint struct {
- stack.LinkEndpoint
-}
-
-func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
- return 0
-}
-
-func (*stubLinkEndpoint) MaxHeaderLength() uint16 {
- return 0
-}
-
-func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress {
- return ""
-}
-
-func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, tcpip.PacketBuffer) *tcpip.Error {
- return nil
-}
-
-func (*stubLinkEndpoint) Attach(stack.NetworkDispatcher) {}
-
-type stubDispatcher struct {
- stack.TransportDispatcher
-}
-
-func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, tcpip.PacketBuffer) {
-}
-
-type stubLinkAddressCache struct {
- stack.LinkAddressCache
-}
-
-func (*stubLinkAddressCache) CheckLocalAddress(tcpip.NICID, tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.NICID {
- return 0
-}
-
-func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) {
-}
-
-func TestICMPCounts(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
- })
- {
- if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
- if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
- }
- }
- {
- subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet,
- NIC: 1,
- }},
- )
- }
-
- netProto := s.NetworkProtocolInstance(ProtocolNumber)
- if netProto == nil {
- t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
- }
- ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{lladdr1, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s)
- if err != nil {
- t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
- }
-
- r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
- }
- defer r.Release()
-
- var tllData [header.NDPLinkLayerAddressSize]byte
- header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(linkAddr1),
- })
-
- types := []struct {
- typ header.ICMPv6Type
- size int
- extraData []byte
- }{
- {
- typ: header.ICMPv6DstUnreachable,
- size: header.ICMPv6DstUnreachableMinimumSize,
- },
- {
- typ: header.ICMPv6PacketTooBig,
- size: header.ICMPv6PacketTooBigMinimumSize,
- },
- {
- typ: header.ICMPv6TimeExceeded,
- size: header.ICMPv6MinimumSize,
- },
- {
- typ: header.ICMPv6ParamProblem,
- size: header.ICMPv6MinimumSize,
- },
- {
- typ: header.ICMPv6EchoRequest,
- size: header.ICMPv6EchoMinimumSize,
- },
- {
- typ: header.ICMPv6EchoReply,
- size: header.ICMPv6EchoMinimumSize,
- },
- {
- typ: header.ICMPv6RouterSolicit,
- size: header.ICMPv6MinimumSize,
- },
- {
- typ: header.ICMPv6RouterAdvert,
- size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
- },
- {
- typ: header.ICMPv6NeighborSolicit,
- size: header.ICMPv6NeighborSolicitMinimumSize},
- {
- typ: header.ICMPv6NeighborAdvert,
- size: header.ICMPv6NeighborAdvertMinimumSize,
- extraData: tllData[:],
- },
- {
- typ: header.ICMPv6RedirectMsg,
- size: header.ICMPv6MinimumSize,
- },
- }
-
- handleIPv6Payload := func(hdr buffer.Prependable) {
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
- })
- ep.HandlePacket(&r, tcpip.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
- })
- }
-
- for _, typ := range types {
- extraDataLen := len(typ.extraData)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen)
- extraData := buffer.View(hdr.Prepend(extraDataLen))
- copy(extraData, typ.extraData)
- pkt := header.ICMPv6(hdr.Prepend(typ.size))
- pkt.SetType(typ.typ)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView()))
-
- handleIPv6Payload(hdr)
- }
-
- // Construct an empty ICMP packet so that
- // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
- handleIPv6Payload(buffer.NewPrependable(header.IPv6MinimumSize))
-
- icmpv6Stats := s.Stats().ICMP.V6PacketsReceived
- visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
- if got, want := s.Value(), uint64(1); got != want {
- t.Errorf("got %s = %d, want = %d", name, got, want)
- }
- })
- if t.Failed() {
- t.Logf("stats:\n%+v", s.Stats())
- }
-}
-
-func visitStats(v reflect.Value, f func(string, *tcpip.StatCounter)) {
- t := v.Type()
- for i := 0; i < v.NumField(); i++ {
- v := v.Field(i)
- if s, ok := v.Interface().(*tcpip.StatCounter); ok {
- f(t.Field(i).Name, s)
- } else {
- visitStats(v, f)
- }
- }
-}
-
-type testContext struct {
- s0 *stack.Stack
- s1 *stack.Stack
-
- linkEP0 *channel.Endpoint
- linkEP1 *channel.Endpoint
-}
-
-type endpointWithResolutionCapability struct {
- stack.LinkEndpoint
-}
-
-func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapabilities {
- return e.LinkEndpoint.Capabilities() | stack.CapabilityResolutionRequired
-}
-
-func newTestContext(t *testing.T) *testContext {
- c := &testContext{
- s0: stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
- }),
- s1: stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
- }),
- }
-
- const defaultMTU = 65536
- c.linkEP0 = channel.New(256, defaultMTU, linkAddr0)
-
- wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0})
- if testing.Verbose() {
- wrappedEP0 = sniffer.New(wrappedEP0)
- }
- if err := c.s0.CreateNIC(1, wrappedEP0); err != nil {
- t.Fatalf("CreateNIC s0: %v", err)
- }
- if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress lladdr0: %v", err)
- }
-
- c.linkEP1 = channel.New(256, defaultMTU, linkAddr1)
- wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1})
- if err := c.s1.CreateNIC(1, wrappedEP1); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
- if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil {
- t.Fatalf("AddAddress lladdr1: %v", err)
- }
-
- subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
- if err != nil {
- t.Fatal(err)
- }
- c.s0.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet0,
- NIC: 1,
- }},
- )
- subnet1, err := tcpip.NewSubnet(lladdr0, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0))))
- if err != nil {
- t.Fatal(err)
- }
- c.s1.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet1,
- NIC: 1,
- }},
- )
-
- return c
-}
-
-func (c *testContext) cleanup() {
- c.linkEP0.Close()
- c.linkEP1.Close()
-}
-
-type routeArgs struct {
- src, dst *channel.Endpoint
- typ header.ICMPv6Type
- remoteLinkAddr tcpip.LinkAddress
-}
-
-func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) {
- t.Helper()
-
- pi, _ := args.src.ReadContext(context.Background())
-
- {
- views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()}
- size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size()
- vv := buffer.NewVectorisedView(size, views)
- args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), tcpip.PacketBuffer{
- Data: vv,
- })
- }
-
- if pi.Proto != ProtocolNumber {
- t.Errorf("unexpected protocol number %d", pi.Proto)
- return
- }
-
- if len(args.remoteLinkAddr) != 0 && args.remoteLinkAddr != pi.Route.RemoteLinkAddress {
- t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr)
- }
-
- ipv6 := header.IPv6(pi.Pkt.Header.View())
- transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader())
- if transProto != header.ICMPv6ProtocolNumber {
- t.Errorf("unexpected transport protocol number %d", transProto)
- return
- }
- icmpv6 := header.ICMPv6(ipv6.Payload())
- if got, want := icmpv6.Type(), args.typ; got != want {
- t.Errorf("got ICMPv6 type = %d, want = %d", got, want)
- return
- }
- if fn != nil {
- fn(t, icmpv6)
- }
-}
-
-func TestLinkResolution(t *testing.T) {
- c := newTestContext(t)
- defer c.cleanup()
-
- r, err := c.s0.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
- }
- defer r.Release()
-
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
- pkt.SetType(header.ICMPv6EchoRequest)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
- payload := tcpip.SlicePayload(hdr.View())
-
- // We can't send our payload directly over the route because that
- // doesn't provoke NDP discovery.
- var wq waiter.Queue
- ep, err := c.s0.NewEndpoint(header.ICMPv6ProtocolNumber, ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
- }
-
- for {
- _, resCh, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: 1, Addr: lladdr1}})
- if resCh != nil {
- if err != tcpip.ErrNoLinkAddress {
- t.Fatalf("ep.Write(_) = _, <non-nil>, %s, want = _, <non-nil>, tcpip.ErrNoLinkAddress", err)
- }
- for _, args := range []routeArgs{
- {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))},
- {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert},
- } {
- routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) {
- if got, want := tcpip.Address(icmpv6[8:][:16]), lladdr1; got != want {
- t.Errorf("%d: got target = %s, want = %s", icmpv6.Type(), got, want)
- }
- })
- }
- <-resCh
- continue
- }
- if err != nil {
- t.Fatalf("ep.Write(_) = _, _, %s", err)
- }
- break
- }
-
- for _, args := range []routeArgs{
- {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6EchoRequest},
- {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6EchoReply},
- } {
- routeICMPv6Packet(t, args, nil)
- }
-}
-
-func TestICMPChecksumValidationSimple(t *testing.T) {
- var tllData [header.NDPLinkLayerAddressSize]byte
- header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(linkAddr1),
- })
-
- types := []struct {
- name string
- typ header.ICMPv6Type
- size int
- extraData []byte
- statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
- }{
- {
- name: "DstUnreachable",
- typ: header.ICMPv6DstUnreachable,
- size: header.ICMPv6DstUnreachableMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.DstUnreachable
- },
- },
- {
- name: "PacketTooBig",
- typ: header.ICMPv6PacketTooBig,
- size: header.ICMPv6PacketTooBigMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.PacketTooBig
- },
- },
- {
- name: "TimeExceeded",
- typ: header.ICMPv6TimeExceeded,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.TimeExceeded
- },
- },
- {
- name: "ParamProblem",
- typ: header.ICMPv6ParamProblem,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.ParamProblem
- },
- },
- {
- name: "EchoRequest",
- typ: header.ICMPv6EchoRequest,
- size: header.ICMPv6EchoMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.EchoRequest
- },
- },
- {
- name: "EchoReply",
- typ: header.ICMPv6EchoReply,
- size: header.ICMPv6EchoMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.EchoReply
- },
- },
- {
- name: "RouterSolicit",
- typ: header.ICMPv6RouterSolicit,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterSolicit
- },
- },
- {
- name: "RouterAdvert",
- typ: header.ICMPv6RouterAdvert,
- size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterAdvert
- },
- },
- {
- name: "NeighborSolicit",
- typ: header.ICMPv6NeighborSolicit,
- size: header.ICMPv6NeighborSolicitMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborSolicit
- },
- },
- {
- name: "NeighborAdvert",
- typ: header.ICMPv6NeighborAdvert,
- size: header.ICMPv6NeighborAdvertMinimumSize,
- extraData: tllData[:],
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborAdvert
- },
- },
- {
- name: "RedirectMsg",
- typ: header.ICMPv6RedirectMsg,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RedirectMsg
- },
- },
- }
-
- for _, typ := range types {
- t.Run(typ.name, func(t *testing.T) {
- e := channel.New(10, 1280, linkAddr0)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- })
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
-
- if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
- }
- {
- subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet,
- NIC: 1,
- }},
- )
- }
-
- handleIPv6Payload := func(checksum bool) {
- extraDataLen := len(typ.extraData)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen)
- extraData := buffer.View(hdr.Prepend(extraDataLen))
- copy(extraData, typ.extraData)
- pkt := header.ICMPv6(hdr.Prepend(typ.size))
- pkt.SetType(typ.typ)
- if checksum {
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, extraData.ToVectorisedView()))
- }
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(typ.size + extraDataLen),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
- e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
- })
- }
-
- stats := s.Stats().ICMP.V6PacketsReceived
- invalid := stats.Invalid
- typStat := typ.statCounter(stats)
-
- // Initial stat counts should be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
- }
-
- // Without setting checksum, the incoming packet should
- // be invalid.
- handleIPv6Payload(false)
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- // Rx count of type typ.typ should not have increased.
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
- }
-
- // When checksum is set, it should be received.
- handleIPv6Payload(true)
- if got := typStat.Value(); got != 1 {
- t.Fatalf("got %s = %d, want = 1", typ.name, got)
- }
- // Invalid count should not have increased again.
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- })
- }
-}
-
-func TestICMPChecksumValidationWithPayload(t *testing.T) {
- const simpleBodySize = 64
- simpleBody := func(view buffer.View) {
- for i := 0; i < simpleBodySize; i++ {
- view[i] = uint8(i)
- }
- }
-
- const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize
- errorICMPBody := func(view buffer.View) {
- ip := header.IPv6(view)
- ip.Encode(&header.IPv6Fields{
- PayloadLength: simpleBodySize,
- NextHeader: 10,
- HopLimit: 20,
- SrcAddr: lladdr0,
- DstAddr: lladdr1,
- })
- simpleBody(view[header.IPv6MinimumSize:])
- }
-
- types := []struct {
- name string
- typ header.ICMPv6Type
- size int
- statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
- payloadSize int
- payload func(buffer.View)
- }{
- {
- "DstUnreachable",
- header.ICMPv6DstUnreachable,
- header.ICMPv6DstUnreachableMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.DstUnreachable
- },
- errorICMPBodySize,
- errorICMPBody,
- },
- {
- "PacketTooBig",
- header.ICMPv6PacketTooBig,
- header.ICMPv6PacketTooBigMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.PacketTooBig
- },
- errorICMPBodySize,
- errorICMPBody,
- },
- {
- "TimeExceeded",
- header.ICMPv6TimeExceeded,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.TimeExceeded
- },
- errorICMPBodySize,
- errorICMPBody,
- },
- {
- "ParamProblem",
- header.ICMPv6ParamProblem,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.ParamProblem
- },
- errorICMPBodySize,
- errorICMPBody,
- },
- {
- "EchoRequest",
- header.ICMPv6EchoRequest,
- header.ICMPv6EchoMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.EchoRequest
- },
- simpleBodySize,
- simpleBody,
- },
- {
- "EchoReply",
- header.ICMPv6EchoReply,
- header.ICMPv6EchoMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.EchoReply
- },
- simpleBodySize,
- simpleBody,
- },
- }
-
- for _, typ := range types {
- t.Run(typ.name, func(t *testing.T) {
- e := channel.New(10, 1280, linkAddr0)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- })
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
-
- if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
- }
- {
- subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet,
- NIC: 1,
- }},
- )
- }
-
- handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) {
- icmpSize := size + payloadSize
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
- pkt := header.ICMPv6(hdr.Prepend(icmpSize))
- pkt.SetType(typ)
- payloadFn(pkt.Payload())
-
- if checksum {
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
- }
-
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(icmpSize),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
- e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
- })
- }
-
- stats := s.Stats().ICMP.V6PacketsReceived
- invalid := stats.Invalid
- typStat := typ.statCounter(stats)
-
- // Initial stat counts should be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
- }
-
- // Without setting checksum, the incoming packet should
- // be invalid.
- handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false)
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- // Rx count of type typ.typ should not have increased.
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
- }
-
- // When checksum is set, it should be received.
- handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true)
- if got := typStat.Value(); got != 1 {
- t.Fatalf("got %s = %d, want = 1", typ.name, got)
- }
- // Invalid count should not have increased again.
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- })
- }
-}
-
-func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
- const simpleBodySize = 64
- simpleBody := func(view buffer.View) {
- for i := 0; i < simpleBodySize; i++ {
- view[i] = uint8(i)
- }
- }
-
- const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize
- errorICMPBody := func(view buffer.View) {
- ip := header.IPv6(view)
- ip.Encode(&header.IPv6Fields{
- PayloadLength: simpleBodySize,
- NextHeader: 10,
- HopLimit: 20,
- SrcAddr: lladdr0,
- DstAddr: lladdr1,
- })
- simpleBody(view[header.IPv6MinimumSize:])
- }
-
- types := []struct {
- name string
- typ header.ICMPv6Type
- size int
- statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
- payloadSize int
- payload func(buffer.View)
- }{
- {
- "DstUnreachable",
- header.ICMPv6DstUnreachable,
- header.ICMPv6DstUnreachableMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.DstUnreachable
- },
- errorICMPBodySize,
- errorICMPBody,
- },
- {
- "PacketTooBig",
- header.ICMPv6PacketTooBig,
- header.ICMPv6PacketTooBigMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.PacketTooBig
- },
- errorICMPBodySize,
- errorICMPBody,
- },
- {
- "TimeExceeded",
- header.ICMPv6TimeExceeded,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.TimeExceeded
- },
- errorICMPBodySize,
- errorICMPBody,
- },
- {
- "ParamProblem",
- header.ICMPv6ParamProblem,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.ParamProblem
- },
- errorICMPBodySize,
- errorICMPBody,
- },
- {
- "EchoRequest",
- header.ICMPv6EchoRequest,
- header.ICMPv6EchoMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.EchoRequest
- },
- simpleBodySize,
- simpleBody,
- },
- {
- "EchoReply",
- header.ICMPv6EchoReply,
- header.ICMPv6EchoMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.EchoReply
- },
- simpleBodySize,
- simpleBody,
- },
- }
-
- for _, typ := range types {
- t.Run(typ.name, func(t *testing.T) {
- e := channel.New(10, 1280, linkAddr0)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- })
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
-
- if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
- }
- {
- subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet,
- NIC: 1,
- }},
- )
- }
-
- handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + size)
- pkt := header.ICMPv6(hdr.Prepend(size))
- pkt.SetType(typ)
-
- payload := buffer.NewView(payloadSize)
- payloadFn(payload)
-
- if checksum {
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, payload.ToVectorisedView()))
- }
-
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(size + payloadSize),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
- e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
- Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}),
- })
- }
-
- stats := s.Stats().ICMP.V6PacketsReceived
- invalid := stats.Invalid
- typStat := typ.statCounter(stats)
-
- // Initial stat counts should be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
- }
-
- // Without setting checksum, the incoming packet should
- // be invalid.
- handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false)
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- // Rx count of type typ.typ should not have increased.
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
- }
-
- // When checksum is set, it should be received.
- handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true)
- if got := typStat.Value(); got != 1 {
- t.Fatalf("got %s = %d, want = 1", typ.name, got)
- }
- // Invalid count should not have increased again.
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- })
- }
-}
diff --git a/pkg/tcpip/network/ipv6/ipv6_state_autogen.go b/pkg/tcpip/network/ipv6/ipv6_state_autogen.go
new file mode 100755
index 000000000..40c67d440
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/ipv6_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package ipv6
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
deleted file mode 100644
index 1cbfa7278..000000000
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ /dev/null
@@ -1,270 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ipv6
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- // The least significant 3 bytes are the same as addr2 so both addr2 and
- // addr3 will have the same solicited-node address.
- addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02"
-)
-
-// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the
-// expected Neighbor Advertisement received count after receiving the packet.
-func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) {
- t.Helper()
-
- // Receive ICMP packet.
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
- pkt.SetType(header.ICMPv6NeighborAdvert)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{}))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: src,
- DstAddr: dst,
- })
-
- e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
- })
-
- stats := s.Stats().ICMP.V6PacketsReceived
-
- if got := stats.NeighborAdvert.Value(); got != want {
- t.Fatalf("got NeighborAdvert = %d, want = %d", got, want)
- }
-}
-
-// testReceiveUDP tests receiving a UDP packet from src to dst. want is the
-// expected UDP received count after receiving the packet.
-func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) {
- t.Helper()
-
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
- defer close(ch)
-
- ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
- defer ep.Close()
-
- if err := ep.Bind(tcpip.FullAddress{Addr: dst, Port: 80}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
- }
-
- // Receive UDP Packet.
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize)
- u := header.UDP(hdr.Prepend(header.UDPMinimumSize))
- u.Encode(&header.UDPFields{
- SrcPort: 5555,
- DstPort: 80,
- Length: header.UDPMinimumSize,
- })
-
- // UDP pseudo-header checksum.
- sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize)
-
- // UDP checksum
- sum = header.Checksum(header.UDP([]byte{}), sum)
- u.SetChecksum(^u.CalculateChecksum(sum))
-
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: 255,
- SrcAddr: src,
- DstAddr: dst,
- })
-
- e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
- })
-
- stat := s.Stats().UDP.PacketsReceived
-
- if got := stat.Value(); got != want {
- t.Fatalf("got UDPPacketsReceived = %d, want = %d", got, want)
- }
-}
-
-// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and
-// UDP packets destined to the IPv6 link-local all-nodes multicast address.
-func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
- tests := []struct {
- name string
- protocolFactory stack.TransportProtocol
- rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
- }{
- {"ICMP", icmp.NewProtocol6(), testReceiveICMP},
- {"UDP", udp.NewProtocol(), testReceiveUDP},
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
- })
- e := channel.New(10, 1280, linkAddr1)
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
-
- // Should receive a packet destined to the all-nodes
- // multicast address.
- test.rxf(t, s, e, addr1, header.IPv6AllNodesMulticastAddress, 1)
- })
- }
-}
-
-// TestReceiveOnSolicitedNodeAddr tests that IPv6 endpoints receive ICMP and UDP
-// packets destined to the IPv6 solicited-node address of an assigned IPv6
-// address.
-func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
- tests := []struct {
- name string
- protocolFactory stack.TransportProtocol
- rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
- }{
- {"ICMP", icmp.NewProtocol6(), testReceiveICMP},
- {"UDP", udp.NewProtocol(), testReceiveUDP},
- }
-
- snmc := header.SolicitedNodeAddr(addr2)
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
- })
- e := channel.New(10, 1280, linkAddr1)
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
-
- // Should not receive a packet destined to the solicited
- // node address of addr2/addr3 yet as we haven't added
- // those addresses.
- test.rxf(t, s, e, addr1, snmc, 0)
-
- if err := s.AddAddress(1, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr2, err)
- }
-
- // Should receive a packet destined to the solicited
- // node address of addr2/addr3 now that we have added
- // added addr2.
- test.rxf(t, s, e, addr1, snmc, 1)
-
- if err := s.AddAddress(1, ProtocolNumber, addr3); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr3, err)
- }
-
- // Should still receive a packet destined to the
- // solicited node address of addr2/addr3 now that we
- // have added addr3.
- test.rxf(t, s, e, addr1, snmc, 2)
-
- if err := s.RemoveAddress(1, addr2); err != nil {
- t.Fatalf("RemoveAddress(_, %s) = %s", addr2, err)
- }
-
- // Should still receive a packet destined to the
- // solicited node address of addr2/addr3 now that we
- // have removed addr2.
- test.rxf(t, s, e, addr1, snmc, 3)
-
- if err := s.RemoveAddress(1, addr3); err != nil {
- t.Fatalf("RemoveAddress(_, %s) = %s", addr3, err)
- }
-
- // Should not receive a packet destined to the solicited
- // node address of addr2/addr3 yet as both of them got
- // removed.
- test.rxf(t, s, e, addr1, snmc, 3)
- })
- }
-}
-
-// TestAddIpv6Address tests adding IPv6 addresses.
-func TestAddIpv6Address(t *testing.T) {
- tests := []struct {
- name string
- addr tcpip.Address
- }{
- // This test is in response to b/140943433.
- {
- "Nil",
- tcpip.Address([]byte(nil)),
- },
- {
- "ValidUnicast",
- addr1,
- },
- {
- "ValidLinkLocalUnicast",
- lladdr0,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- })
- if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
-
- if err := s.AddAddress(1, ProtocolNumber, test.addr); err != nil {
- t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err)
- }
-
- addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
- }
- if addr.Address != test.addr {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr.Address, test.addr)
- }
- })
- }
-}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
deleted file mode 100644
index c9395de52..000000000
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ /dev/null
@@ -1,613 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ipv6
-
-import (
- "strings"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
-)
-
-// setupStackAndEndpoint creates a stack with a single NIC with a link-local
-// address llladdr and an IPv6 endpoint to a remote with link-local address
-// rlladdr
-func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) {
- t.Helper()
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
- })
-
- if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
- if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err)
- }
-
- {
- subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr))))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet,
- NIC: 1,
- }},
- )
- }
-
- netProto := s.NetworkProtocolInstance(ProtocolNumber)
- if netProto == nil {
- t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
- }
-
- ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{rlladdr, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s)
- if err != nil {
- t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
- }
-
- return s, ep
-}
-
-// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a
-// valid NDP NS message with the Source Link Layer Address option results in a
-// new entry in the link address cache for the sender of the message.
-func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- optsBuf []byte
- expectedLinkAddr tcpip.LinkAddress
- }{
- {
- name: "Valid",
- optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7},
- expectedLinkAddr: "\x02\x03\x04\x05\x06\x07",
- },
- {
- name: "Too Small",
- optsBuf: []byte{1, 1, 2, 3, 4, 5, 6},
- },
- {
- name: "Invalid Length",
- optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- })
- e := channel.New(0, 1280, linkAddr0)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
- }
-
- ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
- pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
- pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
- ns.SetTargetAddress(lladdr0)
- opts := ns.Options()
- copy(opts, test.optsBuf)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
-
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
-
- // Invalid count should initially be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
-
- e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
- })
-
- linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
- if linkAddr != test.expectedLinkAddr {
- t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr)
- }
-
- if test.expectedLinkAddr != "" {
- if err != nil {
- t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err)
- }
- if c != nil {
- t.Errorf("got unexpected channel")
- }
-
- // Invalid count should not have increased.
- if got := invalid.Value(); got != 0 {
- t.Errorf("got invalid = %d, want = 0", got)
- }
- } else {
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock)
- }
- if c == nil {
- t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber)
- }
-
- // Invalid count should have increased.
- if got := invalid.Value(); got != 1 {
- t.Errorf("got invalid = %d, want = 1", got)
- }
- }
- })
- }
-}
-
-// TestNeighorAdvertisementWithTargetLinkLayerOption tests that receiving a
-// valid NDP NA message with the Target Link Layer Address option results in a
-// new entry in the link address cache for the target of the message.
-func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- optsBuf []byte
- expectedLinkAddr tcpip.LinkAddress
- }{
- {
- name: "Valid",
- optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7},
- expectedLinkAddr: "\x02\x03\x04\x05\x06\x07",
- },
- {
- name: "Too Small",
- optsBuf: []byte{2, 1, 2, 3, 4, 5, 6},
- },
- {
- name: "Invalid Length",
- optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- })
- e := channel.New(0, 1280, linkAddr0)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
- }
-
- ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
- pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
- pkt.SetType(header.ICMPv6NeighborAdvert)
- ns := header.NDPNeighborAdvert(pkt.NDPPayload())
- ns.SetTargetAddress(lladdr1)
- opts := ns.Options()
- copy(opts, test.optsBuf)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
-
- invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
-
- // Invalid count should initially be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
-
- e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
- })
-
- linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
- if linkAddr != test.expectedLinkAddr {
- t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr)
- }
-
- if test.expectedLinkAddr != "" {
- if err != nil {
- t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err)
- }
- if c != nil {
- t.Errorf("got unexpected channel")
- }
-
- // Invalid count should not have increased.
- if got := invalid.Value(); got != 0 {
- t.Errorf("got invalid = %d, want = 0", got)
- }
- } else {
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock)
- }
- if c == nil {
- t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber)
- }
-
- // Invalid count should have increased.
- if got := invalid.Value(); got != 1 {
- t.Errorf("got invalid = %d, want = 1", got)
- }
- }
- })
- }
-}
-
-// TestHopLimitValidation is a test that makes sure that NDP packets are only
-// received if their IP header's hop limit is set to 255.
-func TestHopLimitValidation(t *testing.T) {
- setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) {
- t.Helper()
-
- // Create a stack with the assigned link-local address lladdr0
- // and an endpoint to lladdr1.
- s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1)
-
- r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
- }
-
- return s, ep, r
- }
-
- handleIPv6Payload := func(hdr buffer.Prependable, hopLimit uint8, ep stack.NetworkEndpoint, r *stack.Route) {
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: hopLimit,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
- })
- ep.HandlePacket(r, tcpip.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
- })
- }
-
- var tllData [header.NDPLinkLayerAddressSize]byte
- header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(linkAddr1),
- })
-
- types := []struct {
- name string
- typ header.ICMPv6Type
- size int
- extraData []byte
- statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
- }{
- {
- name: "RouterSolicit",
- typ: header.ICMPv6RouterSolicit,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterSolicit
- },
- },
- {
- name: "RouterAdvert",
- typ: header.ICMPv6RouterAdvert,
- size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterAdvert
- },
- },
- {
- name: "NeighborSolicit",
- typ: header.ICMPv6NeighborSolicit,
- size: header.ICMPv6NeighborSolicitMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborSolicit
- },
- },
- {
- name: "NeighborAdvert",
- typ: header.ICMPv6NeighborAdvert,
- size: header.ICMPv6NeighborAdvertMinimumSize,
- extraData: tllData[:],
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborAdvert
- },
- },
- {
- name: "RedirectMsg",
- typ: header.ICMPv6RedirectMsg,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RedirectMsg
- },
- },
- }
-
- for _, typ := range types {
- t.Run(typ.name, func(t *testing.T) {
- s, ep, r := setup(t)
- defer r.Release()
-
- stats := s.Stats().ICMP.V6PacketsReceived
- invalid := stats.Invalid
- typStat := typ.statCounter(stats)
-
- extraDataLen := len(typ.extraData)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen)
- extraData := buffer.View(hdr.Prepend(extraDataLen))
- copy(extraData, typ.extraData)
- pkt := header.ICMPv6(hdr.Prepend(typ.size))
- pkt.SetType(typ.typ)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView()))
-
- // Invalid count should initially be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
-
- // Should not have received any ICMPv6 packets with
- // type = typ.typ.
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
- }
-
- // Receive the NDP packet with an invalid hop limit
- // value.
- handleIPv6Payload(hdr, header.NDPHopLimit-1, ep, &r)
-
- // Invalid count should have increased.
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
-
- // Rx count of NDP packet of type typ.typ should not
- // have increased.
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
- }
-
- // Receive the NDP packet with a valid hop limit value.
- handleIPv6Payload(hdr, header.NDPHopLimit, ep, &r)
-
- // Rx count of NDP packet of type typ.typ should have
- // increased.
- if got := typStat.Value(); got != 1 {
- t.Fatalf("got %s = %d, want = 1", typ.name, got)
- }
-
- // Invalid count should not have increased again.
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- })
- }
-}
-
-// TestRouterAdvertValidation tests that when the NIC is configured to handle
-// NDP Router Advertisement packets, it validates the Router Advertisement
-// properly before handling them.
-func TestRouterAdvertValidation(t *testing.T) {
- tests := []struct {
- name string
- src tcpip.Address
- hopLimit uint8
- code uint8
- ndpPayload []byte
- expectedSuccess bool
- }{
- {
- "OK",
- lladdr0,
- 255,
- 0,
- []byte{
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- },
- true,
- },
- {
- "NonLinkLocalSourceAddr",
- addr1,
- 255,
- 0,
- []byte{
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- },
- false,
- },
- {
- "HopLimitNot255",
- lladdr0,
- 254,
- 0,
- []byte{
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- },
- false,
- },
- {
- "NonZeroCode",
- lladdr0,
- 255,
- 1,
- []byte{
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- },
- false,
- },
- {
- "NDPPayloadTooSmall",
- lladdr0,
- 255,
- 0,
- []byte{
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0,
- },
- false,
- },
- {
- "OKWithOptions",
- lladdr0,
- 255,
- 0,
- []byte{
- // RA payload
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
-
- // Option #1 (TargetLinkLayerAddress)
- 2, 1, 0, 0, 0, 0, 0, 0,
-
- // Option #2 (unrecognized)
- 255, 1, 0, 0, 0, 0, 0, 0,
-
- // Option #3 (PrefixInformation)
- 3, 4, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- },
- true,
- },
- {
- "OptionWithZeroLength",
- lladdr0,
- 255,
- 0,
- []byte{
- // RA payload
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
-
- // Option #1 (TargetLinkLayerAddress)
- // Invalid as it has 0 length.
- 2, 0, 0, 0, 0, 0, 0, 0,
-
- // Option #2 (unrecognized)
- 255, 1, 0, 0, 0, 0, 0, 0,
-
- // Option #3 (PrefixInformation)
- 3, 4, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- 0, 0, 0, 0, 0, 0, 0, 0,
- },
- false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e := channel.New(10, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
-
- icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
- pkt := header.ICMPv6(hdr.Prepend(icmpSize))
- pkt.SetType(header.ICMPv6RouterAdvert)
- pkt.SetCode(test.code)
- copy(pkt.NDPPayload(), test.ndpPayload)
- payloadLength := hdr.UsedLength()
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: test.hopLimit,
- SrcAddr: test.src,
- DstAddr: header.IPv6AllNodesMulticastAddress,
- })
-
- stats := s.Stats().ICMP.V6PacketsReceived
- invalid := stats.Invalid
- rxRA := stats.RouterAdvert
-
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- if got := rxRA.Value(); got != 0 {
- t.Fatalf("got rxRA = %d, want = 0", got)
- }
-
- e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
- })
-
- if test.expectedSuccess {
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- if got := rxRA.Value(); got != 1 {
- t.Fatalf("got rxRA = %d, want = 1", got)
- }
-
- } else {
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- if got := rxRA.Value(); got != 0 {
- t.Fatalf("got rxRA = %d, want = 0", got)
- }
- }
- })
- }
-}
diff --git a/pkg/tcpip/packet_buffer.go b/pkg/tcpip/packet_buffer.go
index ab24372e7..ab24372e7 100644..100755
--- a/pkg/tcpip/packet_buffer.go
+++ b/pkg/tcpip/packet_buffer.go
diff --git a/pkg/tcpip/packet_buffer_state.go b/pkg/tcpip/packet_buffer_state.go
index ad3cc24fa..ad3cc24fa 100644..100755
--- a/pkg/tcpip/packet_buffer_state.go
+++ b/pkg/tcpip/packet_buffer_state.go
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
deleted file mode 100644
index 2bad05a2e..000000000
--- a/pkg/tcpip/ports/BUILD
+++ /dev/null
@@ -1,22 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "ports",
- srcs = ["ports.go"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/sync",
- "//pkg/tcpip",
- ],
-)
-
-go_test(
- name = "ports_test",
- srcs = ["ports_test.go"],
- library = ":ports",
- deps = [
- "//pkg/tcpip",
- ],
-)
diff --git a/pkg/tcpip/ports/ports_state_autogen.go b/pkg/tcpip/ports/ports_state_autogen.go
new file mode 100755
index 000000000..f0ee1bb11
--- /dev/null
+++ b/pkg/tcpip/ports/ports_state_autogen.go
@@ -0,0 +1,24 @@
+// automatically generated by stateify.
+
+package ports
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (x *Flags) beforeSave() {}
+func (x *Flags) save(m state.Map) {
+ x.beforeSave()
+ m.Save("MostRecent", &x.MostRecent)
+ m.Save("LoadBalanced", &x.LoadBalanced)
+}
+
+func (x *Flags) afterLoad() {}
+func (x *Flags) load(m state.Map) {
+ m.Load("MostRecent", &x.MostRecent)
+ m.Load("LoadBalanced", &x.LoadBalanced)
+}
+
+func init() {
+ state.Register("pkg/tcpip/ports.Flags", (*Flags)(nil), state.Fns{Save: (*Flags).save, Load: (*Flags).load})
+}
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
deleted file mode 100644
index d6969d050..000000000
--- a/pkg/tcpip/ports/ports_test.go
+++ /dev/null
@@ -1,402 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ports
-
-import (
- "math/rand"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
-)
-
-const (
- fakeTransNumber tcpip.TransportProtocolNumber = 1
- fakeNetworkNumber tcpip.NetworkProtocolNumber = 2
-
- fakeIPAddress = tcpip.Address("\x08\x08\x08\x08")
- fakeIPAddress1 = tcpip.Address("\x08\x08\x08\x09")
-)
-
-type portReserveTestAction struct {
- port uint16
- ip tcpip.Address
- want *tcpip.Error
- flags Flags
- release bool
- device tcpip.NICID
-}
-
-func TestPortReservation(t *testing.T) {
- for _, test := range []struct {
- tname string
- actions []portReserveTestAction
- }{
- {
- tname: "bind to ip",
- actions: []portReserveTestAction{
- {port: 80, ip: fakeIPAddress, want: nil},
- {port: 80, ip: fakeIPAddress1, want: nil},
- /* N.B. Order of tests matters! */
- {port: 80, ip: anyIPAddress, want: tcpip.ErrPortInUse},
- {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}},
- },
- },
- {
- tname: "bind to inaddr any",
- actions: []portReserveTestAction{
- {port: 22, ip: anyIPAddress, want: nil},
- {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
- /* release fakeIPAddress, but anyIPAddress is still inuse */
- {port: 22, ip: fakeIPAddress, release: true},
- {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
- {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}},
- /* Release port 22 from any IP address, then try to reserve fake IP address on 22 */
- {port: 22, ip: anyIPAddress, want: nil, release: true},
- {port: 22, ip: fakeIPAddress, want: nil},
- },
- }, {
- tname: "bind to zero port",
- actions: []portReserveTestAction{
- {port: 00, ip: fakeIPAddress, want: nil},
- {port: 00, ip: fakeIPAddress, want: nil},
- {port: 00, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- },
- }, {
- tname: "bind to ip with reuseport",
- actions: []portReserveTestAction{
- {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
-
- {port: 25, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
- {port: 25, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
-
- {port: 25, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- },
- }, {
- tname: "bind to inaddr any with reuseport",
- actions: []portReserveTestAction{
- {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
-
- {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
-
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, release: true, want: nil},
-
- {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true},
- {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
-
- {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true},
- {port: 24, ip: anyIPAddress, flags: Flags{}, want: nil},
- },
- }, {
- tname: "bind twice with device fails",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 3, want: nil},
- {port: 24, ip: fakeIPAddress, device: 3, want: tcpip.ErrPortInUse},
- },
- }, {
- tname: "bind to device",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 1, want: nil},
- {port: 24, ip: fakeIPAddress, device: 2, want: nil},
- },
- }, {
- tname: "bind to device and then without device",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
- },
- }, {
- tname: "bind without device",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, want: nil},
- {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
- },
- }, {
- tname: "bind with device",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, want: nil},
- {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 789, want: nil},
- {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
- },
- }, {
- tname: "bind with reuseport",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil},
- },
- }, {
- tname: "binding with reuseport and device",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse},
- },
- }, {
- tname: "mixing reuseport and not reuseport by binding to device",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 456, want: nil},
- {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 999, want: nil},
- },
- }, {
- tname: "can't bind to 0 after mixing reuseport and not reuseport",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 456, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
- },
- }, {
- tname: "bind and release",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil},
-
- // Release the bind to device 0 and try again.
- {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil, release: true},
- {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil},
- },
- }, {
- tname: "bind twice with reuseport once",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
- },
- }, {
- tname: "release an unreserved device",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil},
- // The below don't exist.
- {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil, release: true},
- {port: 9999, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true},
- // Release all.
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true},
- {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil, release: true},
- },
- }, {
- tname: "bind with reuseaddr",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{MostRecent: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: nil},
- },
- }, {
- tname: "bind twice with reuseaddr once",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse},
- },
- }, {
- tname: "bind with reuseaddr and reuseport",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
- },
- }, {
- tname: "bind with reuseaddr and reuseport, and then reuseaddr",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
- },
- }, {
- tname: "bind with reuseaddr and reuseport, and then reuseport",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse},
- },
- }, {
- tname: "bind with reuseaddr and reuseport twice, and then reuseaddr",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
- },
- }, {
- tname: "bind with reuseaddr and reuseport twice, and then reuseport",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- },
- }, {
- tname: "bind with reuseaddr, and then reuseaddr and reuseport",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
- },
- }, {
- tname: "bind with reuseport, and then reuseaddr and reuseport",
- actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse},
- },
- },
- } {
- t.Run(test.tname, func(t *testing.T) {
- pm := NewPortManager()
- net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber}
-
- for _, test := range test.actions {
- if test.release {
- pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device)
- continue
- }
- gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device)
- if err != test.want {
- t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d) = %v, want %v", test.ip, test.port, test.flags, test.device, err, test.want)
- }
- if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) {
- t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral)
- }
- }
- })
-
- }
-}
-
-func TestPickEphemeralPort(t *testing.T) {
- customErr := &tcpip.Error{}
- for _, test := range []struct {
- name string
- f func(port uint16) (bool, *tcpip.Error)
- wantErr *tcpip.Error
- wantPort uint16
- }{
- {
- name: "no-port-available",
- f: func(port uint16) (bool, *tcpip.Error) {
- return false, nil
- },
- wantErr: tcpip.ErrNoPortAvailable,
- },
- {
- name: "port-tester-error",
- f: func(port uint16) (bool, *tcpip.Error) {
- return false, customErr
- },
- wantErr: customErr,
- },
- {
- name: "only-port-16042-available",
- f: func(port uint16) (bool, *tcpip.Error) {
- if port == FirstEphemeral+42 {
- return true, nil
- }
- return false, nil
- },
- wantPort: FirstEphemeral + 42,
- },
- {
- name: "only-port-under-16000-available",
- f: func(port uint16) (bool, *tcpip.Error) {
- if port < FirstEphemeral {
- return true, nil
- }
- return false, nil
- },
- wantErr: tcpip.ErrNoPortAvailable,
- },
- } {
- t.Run(test.name, func(t *testing.T) {
- pm := NewPortManager()
- if port, err := pm.PickEphemeralPort(test.f); port != test.wantPort || err != test.wantErr {
- t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr)
- }
- })
- }
-}
-
-func TestPickEphemeralPortStable(t *testing.T) {
- customErr := &tcpip.Error{}
- for _, test := range []struct {
- name string
- f func(port uint16) (bool, *tcpip.Error)
- wantErr *tcpip.Error
- wantPort uint16
- }{
- {
- name: "no-port-available",
- f: func(port uint16) (bool, *tcpip.Error) {
- return false, nil
- },
- wantErr: tcpip.ErrNoPortAvailable,
- },
- {
- name: "port-tester-error",
- f: func(port uint16) (bool, *tcpip.Error) {
- return false, customErr
- },
- wantErr: customErr,
- },
- {
- name: "only-port-16042-available",
- f: func(port uint16) (bool, *tcpip.Error) {
- if port == FirstEphemeral+42 {
- return true, nil
- }
- return false, nil
- },
- wantPort: FirstEphemeral + 42,
- },
- {
- name: "only-port-under-16000-available",
- f: func(port uint16) (bool, *tcpip.Error) {
- if port < FirstEphemeral {
- return true, nil
- }
- return false, nil
- },
- wantErr: tcpip.ErrNoPortAvailable,
- },
- } {
- t.Run(test.name, func(t *testing.T) {
- pm := NewPortManager()
- portOffset := uint32(rand.Int31n(int32(numEphemeralPorts)))
- if port, err := pm.PickEphemeralPortStable(portOffset, test.f); port != test.wantPort || err != test.wantErr {
- t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr)
- }
- })
- }
-}
diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD
deleted file mode 100644
index cf0a5fefe..000000000
--- a/pkg/tcpip/sample/tun_tcp_connect/BUILD
+++ /dev/null
@@ -1,22 +0,0 @@
-load("//tools:defs.bzl", "go_binary")
-
-package(licenses = ["notice"])
-
-go_binary(
- name = "tun_tcp_connect",
- srcs = ["main.go"],
- visibility = ["//:sandbox"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/fdbased",
- "//pkg/tcpip/link/rawfile",
- "//pkg/tcpip/link/sniffer",
- "//pkg/tcpip/link/tun",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/tcp",
- "//pkg/waiter",
- ],
-)
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
deleted file mode 100644
index 0ab089208..000000000
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ /dev/null
@@ -1,225 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// +build linux
-
-// This sample creates a stack with TCP and IPv4 protocols on top of a TUN
-// device, and connects to a peer. Similar to "nc <address> <port>". While the
-// sample is running, attempts to connect to its IPv4 address will result in
-// a RST segment.
-//
-// As an example of how to run it, a TUN device can be created and enabled on
-// a linux host as follows (this only needs to be done once per boot):
-//
-// [sudo] ip tuntap add user <username> mode tun <device-name>
-// [sudo] ip link set <device-name> up
-// [sudo] ip addr add <ipv4-address>/<mask-length> dev <device-name>
-//
-// A concrete example:
-//
-// $ sudo ip tuntap add user wedsonaf mode tun tun0
-// $ sudo ip link set tun0 up
-// $ sudo ip addr add 192.168.1.1/24 dev tun0
-//
-// Then one can run tun_tcp_connect as such:
-//
-// $ ./tun/tun_tcp_connect tun0 192.168.1.2 0 192.168.1.1 1234
-//
-// This will attempt to connect to the linux host's stack. One can run nc in
-// listen mode to accept a connect from tun_tcp_connect and exchange data.
-package main
-
-import (
- "bufio"
- "fmt"
- "log"
- "math/rand"
- "net"
- "os"
- "strconv"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
- "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
- "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.dev/gvisor/pkg/tcpip/link/tun"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-// writer reads from standard input and writes to the endpoint until standard
-// input is closed. It signals that it's done by closing the provided channel.
-func writer(ch chan struct{}, ep tcpip.Endpoint) {
- defer func() {
- ep.Shutdown(tcpip.ShutdownWrite)
- close(ch)
- }()
-
- r := bufio.NewReader(os.Stdin)
- for {
- v := buffer.NewView(1024)
- n, err := r.Read(v)
- if err != nil {
- return
- }
-
- v.CapLength(n)
- for len(v) > 0 {
- n, _, err := ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
- if err != nil {
- fmt.Println("Write failed:", err)
- return
- }
-
- v.TrimFront(int(n))
- }
- }
-}
-
-func main() {
- if len(os.Args) != 6 {
- log.Fatal("Usage: ", os.Args[0], " <tun-device> <local-ipv4-address> <local-port> <remote-ipv4-address> <remote-port>")
- }
-
- tunName := os.Args[1]
- addrName := os.Args[2]
- portName := os.Args[3]
- remoteAddrName := os.Args[4]
- remotePortName := os.Args[5]
-
- rand.Seed(time.Now().UnixNano())
-
- addr := tcpip.Address(net.ParseIP(addrName).To4())
- remote := tcpip.FullAddress{
- NIC: 1,
- Addr: tcpip.Address(net.ParseIP(remoteAddrName).To4()),
- }
-
- var localPort uint16
- if v, err := strconv.Atoi(portName); err != nil {
- log.Fatalf("Unable to convert port %v: %v", portName, err)
- } else {
- localPort = uint16(v)
- }
-
- if v, err := strconv.Atoi(remotePortName); err != nil {
- log.Fatalf("Unable to convert port %v: %v", remotePortName, err)
- } else {
- remote.Port = uint16(v)
- }
-
- // Create the stack with ipv4 and tcp protocols, then add a tun-based
- // NIC and ipv4 address.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
- })
-
- mtu, err := rawfile.GetMTU(tunName)
- if err != nil {
- log.Fatal(err)
- }
-
- fd, err := tun.Open(tunName)
- if err != nil {
- log.Fatal(err)
- }
-
- linkEP, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu})
- if err != nil {
- log.Fatal(err)
- }
- if err := s.CreateNIC(1, sniffer.New(linkEP)); err != nil {
- log.Fatal(err)
- }
-
- if err := s.AddAddress(1, ipv4.ProtocolNumber, addr); err != nil {
- log.Fatal(err)
- }
-
- // Add default route.
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: 1,
- },
- })
-
- // Create TCP endpoint.
- var wq waiter.Queue
- ep, e := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if e != nil {
- log.Fatal(e)
- }
-
- // Bind if a port is specified.
- if localPort != 0 {
- if err := ep.Bind(tcpip.FullAddress{0, "", localPort}); err != nil {
- log.Fatal("Bind failed: ", err)
- }
- }
-
- // Issue connect request and wait for it to complete.
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- wq.EventRegister(&waitEntry, waiter.EventOut)
- terr := ep.Connect(remote)
- if terr == tcpip.ErrConnectStarted {
- fmt.Println("Connect is pending...")
- <-notifyCh
- terr = ep.GetSockOpt(tcpip.ErrorOption{})
- }
- wq.EventUnregister(&waitEntry)
-
- if terr != nil {
- log.Fatal("Unable to connect: ", terr)
- }
-
- fmt.Println("Connected")
-
- // Start the writer in its own goroutine.
- writerCompletedCh := make(chan struct{})
- go writer(writerCompletedCh, ep) // S/R-SAFE: sample code.
-
- // Read data and write to standard output until the peer closes the
- // connection from its side.
- wq.EventRegister(&waitEntry, waiter.EventIn)
- for {
- v, _, err := ep.Read(nil)
- if err != nil {
- if err == tcpip.ErrClosedForReceive {
- break
- }
-
- if err == tcpip.ErrWouldBlock {
- <-notifyCh
- continue
- }
-
- log.Fatal("Read() failed:", err)
- }
-
- os.Stdout.Write(v)
- }
- wq.EventUnregister(&waitEntry)
-
- // The reader has completed. Now wait for the writer as well.
- <-writerCompletedCh
-
- ep.Close()
-}
diff --git a/pkg/tcpip/sample/tun_tcp_echo/BUILD b/pkg/tcpip/sample/tun_tcp_echo/BUILD
deleted file mode 100644
index 43264b76d..000000000
--- a/pkg/tcpip/sample/tun_tcp_echo/BUILD
+++ /dev/null
@@ -1,21 +0,0 @@
-load("//tools:defs.bzl", "go_binary")
-
-package(licenses = ["notice"])
-
-go_binary(
- name = "tun_tcp_echo",
- srcs = ["main.go"],
- visibility = ["//:sandbox"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/link/fdbased",
- "//pkg/tcpip/link/rawfile",
- "//pkg/tcpip/link/tun",
- "//pkg/tcpip/network/arp",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/tcp",
- "//pkg/waiter",
- ],
-)
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
deleted file mode 100644
index 9e37cab18..000000000
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ /dev/null
@@ -1,203 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// +build linux
-
-// This sample creates a stack with TCP and IPv4 protocols on top of a TUN
-// device, and listens on a port. Data received by the server in the accepted
-// connections is echoed back to the clients.
-package main
-
-import (
- "flag"
- "log"
- "math/rand"
- "net"
- "os"
- "strconv"
- "strings"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
- "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
- "gvisor.dev/gvisor/pkg/tcpip/link/tun"
- "gvisor.dev/gvisor/pkg/tcpip/network/arp"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-var tap = flag.Bool("tap", false, "use tap istead of tun")
-var mac = flag.String("mac", "aa:00:01:01:01:01", "mac address to use in tap device")
-
-func echo(wq *waiter.Queue, ep tcpip.Endpoint) {
- defer ep.Close()
-
- // Create wait queue entry that notifies a channel.
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
-
- wq.EventRegister(&waitEntry, waiter.EventIn)
- defer wq.EventUnregister(&waitEntry)
-
- for {
- v, _, err := ep.Read(nil)
- if err != nil {
- if err == tcpip.ErrWouldBlock {
- <-notifyCh
- continue
- }
-
- return
- }
-
- ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
- }
-}
-
-func main() {
- flag.Parse()
- if len(flag.Args()) != 3 {
- log.Fatal("Usage: ", os.Args[0], " <tun-device> <local-address> <local-port>")
- }
-
- tunName := flag.Arg(0)
- addrName := flag.Arg(1)
- portName := flag.Arg(2)
-
- rand.Seed(time.Now().UnixNano())
-
- // Parse the mac address.
- maddr, err := net.ParseMAC(*mac)
- if err != nil {
- log.Fatalf("Bad MAC address: %v", *mac)
- }
-
- // Parse the IP address. Support both ipv4 and ipv6.
- parsedAddr := net.ParseIP(addrName)
- if parsedAddr == nil {
- log.Fatalf("Bad IP address: %v", addrName)
- }
-
- var addr tcpip.Address
- var proto tcpip.NetworkProtocolNumber
- if parsedAddr.To4() != nil {
- addr = tcpip.Address(parsedAddr.To4())
- proto = ipv4.ProtocolNumber
- } else if parsedAddr.To16() != nil {
- addr = tcpip.Address(parsedAddr.To16())
- proto = ipv6.ProtocolNumber
- } else {
- log.Fatalf("Unknown IP type: %v", addrName)
- }
-
- localPort, err := strconv.Atoi(portName)
- if err != nil {
- log.Fatalf("Unable to convert port %v: %v", portName, err)
- }
-
- // Create the stack with ip and tcp protocols, then add a tun-based
- // NIC and address.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
- })
-
- mtu, err := rawfile.GetMTU(tunName)
- if err != nil {
- log.Fatal(err)
- }
-
- var fd int
- if *tap {
- fd, err = tun.OpenTAP(tunName)
- } else {
- fd, err = tun.Open(tunName)
- }
- if err != nil {
- log.Fatal(err)
- }
-
- linkEP, err := fdbased.New(&fdbased.Options{
- FDs: []int{fd},
- MTU: mtu,
- EthernetHeader: *tap,
- Address: tcpip.LinkAddress(maddr),
- })
- if err != nil {
- log.Fatal(err)
- }
- if err := s.CreateNIC(1, linkEP); err != nil {
- log.Fatal(err)
- }
-
- if err := s.AddAddress(1, proto, addr); err != nil {
- log.Fatal(err)
- }
-
- if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
- log.Fatal(err)
- }
-
- subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr))))
- if err != nil {
- log.Fatal(err)
- }
-
- // Add default route.
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: subnet,
- NIC: 1,
- },
- })
-
- // Create TCP endpoint, bind it, then start listening.
- var wq waiter.Queue
- ep, e := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq)
- if e != nil {
- log.Fatal(e)
- }
-
- defer ep.Close()
-
- if err := ep.Bind(tcpip.FullAddress{0, "", uint16(localPort)}); err != nil {
- log.Fatal("Bind failed: ", err)
- }
-
- if err := ep.Listen(10); err != nil {
- log.Fatal("Listen failed: ", err)
- }
-
- // Wait for connections to appear.
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- wq.EventRegister(&waitEntry, waiter.EventIn)
- defer wq.EventUnregister(&waitEntry)
-
- for {
- n, wq, err := ep.Accept()
- if err != nil {
- if err == tcpip.ErrWouldBlock {
- <-notifyCh
- continue
- }
-
- log.Fatal("Accept() failed:", err)
- }
-
- go echo(wq, n) // S/R-SAFE: sample code.
- }
-}
diff --git a/pkg/tcpip/seqnum/BUILD b/pkg/tcpip/seqnum/BUILD
deleted file mode 100644
index 45f503845..000000000
--- a/pkg/tcpip/seqnum/BUILD
+++ /dev/null
@@ -1,9 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "seqnum",
- srcs = ["seqnum.go"],
- visibility = ["//visibility:public"],
-)
diff --git a/pkg/tcpip/seqnum/seqnum_state_autogen.go b/pkg/tcpip/seqnum/seqnum_state_autogen.go
new file mode 100755
index 000000000..23e79811d
--- /dev/null
+++ b/pkg/tcpip/seqnum/seqnum_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package seqnum
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
deleted file mode 100644
index 705cf01ee..000000000
--- a/pkg/tcpip/stack/BUILD
+++ /dev/null
@@ -1,92 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-package(licenses = ["notice"])
-
-go_template_instance(
- name = "linkaddrentry_list",
- out = "linkaddrentry_list.go",
- package = "stack",
- prefix = "linkAddrEntry",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*linkAddrEntry",
- "Linker": "*linkAddrEntry",
- },
-)
-
-go_library(
- name = "stack",
- srcs = [
- "icmp_rate_limit.go",
- "linkaddrcache.go",
- "linkaddrentry_list.go",
- "ndp.go",
- "nic.go",
- "registration.go",
- "route.go",
- "stack.go",
- "stack_global_state.go",
- "transport_demuxer.go",
- ],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/ilist",
- "//pkg/rand",
- "//pkg/sleep",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/hash/jenkins",
- "//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
- "//pkg/tcpip/ports",
- "//pkg/tcpip/seqnum",
- "//pkg/waiter",
- "@org_golang_x_time//rate:go_default_library",
- ],
-)
-
-go_test(
- name = "stack_x_test",
- size = "medium",
- srcs = [
- "ndp_test.go",
- "stack_test.go",
- "transport_demuxer_test.go",
- "transport_test.go",
- ],
- deps = [
- ":stack",
- "//pkg/rand",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/loopback",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/transport/icmp",
- "//pkg/tcpip/transport/udp",
- "//pkg/waiter",
- "@com_github_google_go-cmp//cmp:go_default_library",
- ],
-)
-
-go_test(
- name = "stack_test",
- size = "small",
- srcs = [
- "linkaddrcache_test.go",
- "nic_test.go",
- ],
- library = ":stack",
- deps = [
- "//pkg/sleep",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- ],
-)
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
deleted file mode 100644
index 1baa498d0..000000000
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ /dev/null
@@ -1,277 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package stack
-
-import (
- "fmt"
- "sync/atomic"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/sleep"
- "gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/tcpip"
-)
-
-type testaddr struct {
- addr tcpip.FullAddress
- linkAddr tcpip.LinkAddress
-}
-
-var testAddrs = func() []testaddr {
- var addrs []testaddr
- for i := 0; i < 4*linkAddrCacheSize; i++ {
- addr := fmt.Sprintf("Addr%06d", i)
- addrs = append(addrs, testaddr{
- addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)},
- linkAddr: tcpip.LinkAddress("Link" + addr),
- })
- }
- return addrs
-}()
-
-type testLinkAddressResolver struct {
- cache *linkAddrCache
- delay time.Duration
- onLinkAddressRequest func()
-}
-
-func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
- time.AfterFunc(r.delay, func() { r.fakeRequest(addr) })
- if f := r.onLinkAddressRequest; f != nil {
- f()
- }
- return nil
-}
-
-func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) {
- for _, ta := range testAddrs {
- if ta.addr.Addr == addr {
- r.cache.add(ta.addr, ta.linkAddr)
- break
- }
- }
-}
-
-func (*testLinkAddressResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
- if addr == "broadcast" {
- return "mac_broadcast", true
- }
- return "", false
-}
-
-func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
- return 1
-}
-
-func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) {
- w := sleep.Waker{}
- s := sleep.Sleeper{}
- s.AddWaker(&w, 123)
- defer s.Done()
-
- for {
- if got, _, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
- return got, err
- }
- s.Fetch(true)
- }
-}
-
-func TestCacheOverflow(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
- for i := len(testAddrs) - 1; i >= 0; i-- {
- e := testAddrs[i]
- c.add(e.addr, e.linkAddr)
- got, _, err := c.get(e.addr, nil, "", nil, nil)
- if err != nil {
- t.Errorf("insert %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
- }
- if got != e.linkAddr {
- t.Errorf("insert %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr)
- }
- }
- // Expect to find at least half of the most recent entries.
- for i := 0; i < linkAddrCacheSize/2; i++ {
- e := testAddrs[i]
- got, _, err := c.get(e.addr, nil, "", nil, nil)
- if err != nil {
- t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
- }
- if got != e.linkAddr {
- t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr)
- }
- }
- // The earliest entries should no longer be in the cache.
- for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- {
- e := testAddrs[i]
- if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
- t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err)
- }
- }
-}
-
-func TestCacheConcurrent(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
-
- var wg sync.WaitGroup
- for r := 0; r < 16; r++ {
- wg.Add(1)
- go func() {
- for _, e := range testAddrs {
- c.add(e.addr, e.linkAddr)
- c.get(e.addr, nil, "", nil, nil) // make work for gotsan
- }
- wg.Done()
- }()
- }
- wg.Wait()
-
- // All goroutines add in the same order and add more values than
- // can fit in the cache, so our eviction strategy requires that
- // the last entry be present and the first be missing.
- e := testAddrs[len(testAddrs)-1]
- got, _, err := c.get(e.addr, nil, "", nil, nil)
- if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
- }
- if got != e.linkAddr {
- t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
- }
-
- e = testAddrs[0]
- if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
- t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
- }
-}
-
-func TestCacheAgeLimit(t *testing.T) {
- c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3)
- e := testAddrs[0]
- c.add(e.addr, e.linkAddr)
- time.Sleep(50 * time.Millisecond)
- if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
- t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
- }
-}
-
-func TestCacheReplace(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
- e := testAddrs[0]
- l2 := e.linkAddr + "2"
- c.add(e.addr, e.linkAddr)
- got, _, err := c.get(e.addr, nil, "", nil, nil)
- if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
- }
- if got != e.linkAddr {
- t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
- }
-
- c.add(e.addr, l2)
- got, _, err = c.get(e.addr, nil, "", nil, nil)
- if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
- }
- if got != l2 {
- t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, l2)
- }
-}
-
-func TestCacheResolution(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1)
- linkRes := &testLinkAddressResolver{cache: c}
- for i, ta := range testAddrs {
- got, err := getBlocking(c, ta.addr, linkRes)
- if err != nil {
- t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err)
- }
- if got != ta.linkAddr {
- t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(ta.addr.Addr), got, ta.linkAddr)
- }
- }
-
- // Check that after resolved, address stays in the cache and never returns WouldBlock.
- for i := 0; i < 10; i++ {
- e := testAddrs[len(testAddrs)-1]
- got, _, err := c.get(e.addr, linkRes, "", nil, nil)
- if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
- }
- if got != e.linkAddr {
- t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
- }
- }
-}
-
-func TestCacheResolutionFailed(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5)
- linkRes := &testLinkAddressResolver{cache: c}
-
- var requestCount uint32
- linkRes.onLinkAddressRequest = func() {
- atomic.AddUint32(&requestCount, 1)
- }
-
- // First, sanity check that resolution is working...
- e := testAddrs[0]
- got, err := getBlocking(c, e.addr, linkRes)
- if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
- }
- if got != e.linkAddr {
- t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
- }
-
- before := atomic.LoadUint32(&requestCount)
-
- e.addr.Addr += "2"
- if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
- t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
- }
-
- if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want {
- t.Errorf("got link address request count = %d, want = %d", got, want)
- }
-}
-
-func TestCacheResolutionTimeout(t *testing.T) {
- resolverDelay := 500 * time.Millisecond
- expiration := resolverDelay / 10
- c := newLinkAddrCache(expiration, 1*time.Millisecond, 3)
- linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay}
-
- e := testAddrs[0]
- if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
- t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
- }
-}
-
-// TestStaticResolution checks that static link addresses are resolved immediately and don't
-// send resolution requests.
-func TestStaticResolution(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, time.Millisecond, 1)
- linkRes := &testLinkAddressResolver{cache: c, delay: time.Minute}
-
- addr := tcpip.Address("broadcast")
- want := tcpip.LinkAddress("mac_broadcast")
- got, _, err := c.get(tcpip.FullAddress{Addr: addr}, linkRes, "", nil, nil)
- if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(addr), string(got), err)
- }
- if got != want {
- t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want))
- }
-}
diff --git a/pkg/tcpip/stack/linkaddrentry_list.go b/pkg/tcpip/stack/linkaddrentry_list.go
new file mode 100755
index 000000000..61a45ddcb
--- /dev/null
+++ b/pkg/tcpip/stack/linkaddrentry_list.go
@@ -0,0 +1,173 @@
+package stack
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type linkAddrEntryElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (linkAddrEntryElementMapper) linkerFor(elem *linkAddrEntry) *linkAddrEntry { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type linkAddrEntryList struct {
+ head *linkAddrEntry
+ tail *linkAddrEntry
+}
+
+// Reset resets list l to the empty state.
+func (l *linkAddrEntryList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *linkAddrEntryList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *linkAddrEntryList) Front() *linkAddrEntry {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *linkAddrEntryList) Back() *linkAddrEntry {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *linkAddrEntryList) PushFront(e *linkAddrEntry) {
+ linkAddrEntryElementMapper{}.linkerFor(e).SetNext(l.head)
+ linkAddrEntryElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ linkAddrEntryElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *linkAddrEntryList) PushBack(e *linkAddrEntry) {
+ linkAddrEntryElementMapper{}.linkerFor(e).SetNext(nil)
+ linkAddrEntryElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ linkAddrEntryElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *linkAddrEntryList) PushBackList(m *linkAddrEntryList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ linkAddrEntryElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ linkAddrEntryElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *linkAddrEntryList) InsertAfter(b, e *linkAddrEntry) {
+ a := linkAddrEntryElementMapper{}.linkerFor(b).Next()
+ linkAddrEntryElementMapper{}.linkerFor(e).SetNext(a)
+ linkAddrEntryElementMapper{}.linkerFor(e).SetPrev(b)
+ linkAddrEntryElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ linkAddrEntryElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *linkAddrEntryList) InsertBefore(a, e *linkAddrEntry) {
+ b := linkAddrEntryElementMapper{}.linkerFor(a).Prev()
+ linkAddrEntryElementMapper{}.linkerFor(e).SetNext(a)
+ linkAddrEntryElementMapper{}.linkerFor(e).SetPrev(b)
+ linkAddrEntryElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ linkAddrEntryElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *linkAddrEntryList) Remove(e *linkAddrEntry) {
+ prev := linkAddrEntryElementMapper{}.linkerFor(e).Prev()
+ next := linkAddrEntryElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ linkAddrEntryElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ linkAddrEntryElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type linkAddrEntryEntry struct {
+ next *linkAddrEntry
+ prev *linkAddrEntry
+}
+
+// Next returns the entry that follows e in the list.
+func (e *linkAddrEntryEntry) Next() *linkAddrEntry {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *linkAddrEntryEntry) Prev() *linkAddrEntry {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *linkAddrEntryEntry) SetNext(elem *linkAddrEntry) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *linkAddrEntryEntry) SetPrev(elem *linkAddrEntry) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index 19bd05aa3..19bd05aa3 100644..100755
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
deleted file mode 100644
index f7b75b74e..000000000
--- a/pkg/tcpip/stack/ndp_test.go
+++ /dev/null
@@ -1,3612 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package stack_test
-
-import (
- "context"
- "encoding/binary"
- "fmt"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/rand"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
- linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07")
- linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
- linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09")
- defaultTimeout = 100 * time.Millisecond
- defaultAsyncEventTimeout = time.Second
-)
-
-var (
- llAddr1 = header.LinkLocalAddr(linkAddr1)
- llAddr2 = header.LinkLocalAddr(linkAddr2)
- llAddr3 = header.LinkLocalAddr(linkAddr3)
- llAddr4 = header.LinkLocalAddr(linkAddr4)
- dstAddr = tcpip.FullAddress{
- Addr: "\x0a\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
- Port: 25,
- }
-)
-
-func addrForSubnet(subnet tcpip.Subnet, linkAddr tcpip.LinkAddress) tcpip.AddressWithPrefix {
- if !header.IsValidUnicastEthernetAddress(linkAddr) {
- return tcpip.AddressWithPrefix{}
- }
-
- addrBytes := []byte(subnet.ID())
- header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:])
- return tcpip.AddressWithPrefix{
- Address: tcpip.Address(addrBytes),
- PrefixLen: 64,
- }
-}
-
-// prefixSubnetAddr returns a prefix (Address + Length), the prefix's equivalent
-// tcpip.Subnet, and an address where the lower half of the address is composed
-// of the EUI-64 of linkAddr if it is a valid unicast ethernet address.
-func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWithPrefix, tcpip.Subnet, tcpip.AddressWithPrefix) {
- prefixBytes := []byte{1, 2, 3, 4, 5, 6, 7, 8 + offset, 0, 0, 0, 0, 0, 0, 0, 0}
- prefix := tcpip.AddressWithPrefix{
- Address: tcpip.Address(prefixBytes),
- PrefixLen: 64,
- }
-
- subnet := prefix.Subnet()
-
- return prefix, subnet, addrForSubnet(subnet, linkAddr)
-}
-
-// ndpDADEvent is a set of parameters that was passed to
-// ndpDispatcher.OnDuplicateAddressDetectionStatus.
-type ndpDADEvent struct {
- nicID tcpip.NICID
- addr tcpip.Address
- resolved bool
- err *tcpip.Error
-}
-
-type ndpRouterEvent struct {
- nicID tcpip.NICID
- addr tcpip.Address
- // true if router was discovered, false if invalidated.
- discovered bool
-}
-
-type ndpPrefixEvent struct {
- nicID tcpip.NICID
- prefix tcpip.Subnet
- // true if prefix was discovered, false if invalidated.
- discovered bool
-}
-
-type ndpAutoGenAddrEventType int
-
-const (
- newAddr ndpAutoGenAddrEventType = iota
- deprecatedAddr
- invalidatedAddr
-)
-
-type ndpAutoGenAddrEvent struct {
- nicID tcpip.NICID
- addr tcpip.AddressWithPrefix
- eventType ndpAutoGenAddrEventType
-}
-
-type ndpRDNSS struct {
- addrs []tcpip.Address
- lifetime time.Duration
-}
-
-type ndpRDNSSEvent struct {
- nicID tcpip.NICID
- rdnss ndpRDNSS
-}
-
-type ndpDHCPv6Event struct {
- nicID tcpip.NICID
- configuration stack.DHCPv6ConfigurationFromNDPRA
-}
-
-var _ stack.NDPDispatcher = (*ndpDispatcher)(nil)
-
-// ndpDispatcher implements NDPDispatcher so tests can know when various NDP
-// related events happen for test purposes.
-type ndpDispatcher struct {
- dadC chan ndpDADEvent
- routerC chan ndpRouterEvent
- rememberRouter bool
- prefixC chan ndpPrefixEvent
- rememberPrefix bool
- autoGenAddrC chan ndpAutoGenAddrEvent
- rdnssC chan ndpRDNSSEvent
- dhcpv6ConfigurationC chan ndpDHCPv6Event
-}
-
-// Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus.
-func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) {
- if n.dadC != nil {
- n.dadC <- ndpDADEvent{
- nicID,
- addr,
- resolved,
- err,
- }
- }
-}
-
-// Implements stack.NDPDispatcher.OnDefaultRouterDiscovered.
-func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool {
- if c := n.routerC; c != nil {
- c <- ndpRouterEvent{
- nicID,
- addr,
- true,
- }
- }
-
- return n.rememberRouter
-}
-
-// Implements stack.NDPDispatcher.OnDefaultRouterInvalidated.
-func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) {
- if c := n.routerC; c != nil {
- c <- ndpRouterEvent{
- nicID,
- addr,
- false,
- }
- }
-}
-
-// Implements stack.NDPDispatcher.OnOnLinkPrefixDiscovered.
-func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool {
- if c := n.prefixC; c != nil {
- c <- ndpPrefixEvent{
- nicID,
- prefix,
- true,
- }
- }
-
- return n.rememberPrefix
-}
-
-// Implements stack.NDPDispatcher.OnOnLinkPrefixInvalidated.
-func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) {
- if c := n.prefixC; c != nil {
- c <- ndpPrefixEvent{
- nicID,
- prefix,
- false,
- }
- }
-}
-
-func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool {
- if c := n.autoGenAddrC; c != nil {
- c <- ndpAutoGenAddrEvent{
- nicID,
- addr,
- newAddr,
- }
- }
- return true
-}
-
-func (n *ndpDispatcher) OnAutoGenAddressDeprecated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) {
- if c := n.autoGenAddrC; c != nil {
- c <- ndpAutoGenAddrEvent{
- nicID,
- addr,
- deprecatedAddr,
- }
- }
-}
-
-func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) {
- if c := n.autoGenAddrC; c != nil {
- c <- ndpAutoGenAddrEvent{
- nicID,
- addr,
- invalidatedAddr,
- }
- }
-}
-
-// Implements stack.NDPDispatcher.OnRecursiveDNSServerOption.
-func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) {
- if c := n.rdnssC; c != nil {
- c <- ndpRDNSSEvent{
- nicID,
- ndpRDNSS{
- addrs,
- lifetime,
- },
- }
- }
-}
-
-// Implements stack.NDPDispatcher.OnDHCPv6Configuration.
-func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration stack.DHCPv6ConfigurationFromNDPRA) {
- if c := n.dhcpv6ConfigurationC; c != nil {
- c <- ndpDHCPv6Event{
- nicID,
- configuration,
- }
- }
-}
-
-// Check e to make sure that the event is for addr on nic with ID 1, and the
-// resolved flag set to resolved with the specified err.
-func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) string {
- return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, resolved: resolved, err: err}, e, cmp.AllowUnexported(e))
-}
-
-// TestDADDisabled tests that an address successfully resolves immediately
-// when DAD is not enabled (the default for an empty stack.Options).
-func TestDADDisabled(t *testing.T) {
- const nicID = 1
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 1),
- }
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPDisp: &ndpDisp,
- }
-
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(opts)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
- }
-
- // Should get the address immediately since we should not have performed
- // DAD on it.
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected DAD event")
- }
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("stack.GetMainNICAddress(%d, %d) err = %s", nicID, header.IPv6ProtocolNumber, err)
- }
- if addr.Address != addr1 {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1)
- }
-
- // We should not have sent any NDP NS messages.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 {
- t.Fatalf("got NeighborSolicit = %d, want = 0", got)
- }
-}
-
-// TestDADResolve tests that an address successfully resolves after performing
-// DAD for various values of DupAddrDetectTransmits and RetransmitTimer.
-// Included in the subtests is a test to make sure that an invalid
-// RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s.
-func TestDADResolve(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- dupAddrDetectTransmits uint8
- retransTimer time.Duration
- expectedRetransmitTimer time.Duration
- }{
- {"1:1s:1s", 1, time.Second, time.Second},
- {"2:1s:1s", 2, time.Second, time.Second},
- {"1:2s:2s", 1, 2 * time.Second, 2 * time.Second},
- // 0s is an invalid RetransmitTimer timer and will be fixed to
- // the default RetransmitTimer value of 1s.
- {"1:0s:1s", 1, 0, time.Second},
- }
-
- for _, test := range tests {
- test := test
-
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent),
- }
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPDisp: &ndpDisp,
- }
- opts.NDPConfigs.RetransmitTimer = test.retransTimer
- opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits
-
- e := channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- s := stack.New(opts)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
- }
-
- // Address should not be considered bound to the NIC yet
- // (DAD ongoing).
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
- }
-
- // Wait for the remaining time - some delta (500ms), to
- // make sure the address is still not resolved.
- const delta = 500 * time.Millisecond
- time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta)
- addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
- }
-
- // Wait for DAD to resolve.
- select {
- case <-time.After(2 * delta):
- // We should get a resolution event after 500ms
- // (delta) since we wait for 500ms less than the
- // expected resolution time above to make sure
- // that the address did not yet resolve. Waiting
- // for 1s (2x delta) without a resolution event
- // means something is wrong.
- t.Fatal("timed out waiting for DAD resolution")
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
- }
- }
- addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if addr.Address != addr1 {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1)
- }
-
- // Should not have sent any more NS messages.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) {
- t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits)
- }
-
- // Validate the sent Neighbor Solicitation messages.
- for i := uint8(0); i < test.dupAddrDetectTransmits; i++ {
- p, _ := e.ReadContext(context.Background())
-
- // Make sure its an IPv6 packet.
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
-
- // Make sure the right remote link address is used.
- snmc := header.SolicitedNodeAddr(addr1)
- if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want {
- t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
- }
-
- // Check NDP NS packet.
- //
- // As per RFC 4861 section 4.3, a possible option is the Source Link
- // Layer option, but this option MUST NOT be included when the source
- // address of the packet is the unspecified address.
- checker.IPv6(t, p.Pkt.Header.View().ToVectorisedView().First(),
- checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(snmc),
- checker.TTL(header.NDPHopLimit),
- checker.NDPNS(
- checker.NDPNSTargetAddress(addr1),
- checker.NDPNSOptions(nil),
- ))
- }
- })
- }
-}
-
-// TestDADFail tests to make sure that the DAD process fails if another node is
-// detected to be performing DAD on the same address (receive an NS message from
-// a node doing DAD for the same address), or if another node is detected to own
-// the address already (receive an NA message for the tentative address).
-func TestDADFail(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- makeBuf func(tgt tcpip.Address) buffer.Prependable
- getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
- }{
- {
- "RxSolicit",
- func(tgt tcpip.Address) buffer.Prependable {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
- pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
- ns.SetTargetAddress(tgt)
- snmc := header.SolicitedNodeAddr(tgt)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{}))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: 255,
- SrcAddr: header.IPv6Any,
- DstAddr: snmc,
- })
-
- return hdr
-
- },
- func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return s.NeighborSolicit
- },
- },
- {
- "RxAdvert",
- func(tgt tcpip.Address) buffer.Prependable {
- naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
- pkt := header.ICMPv6(hdr.Prepend(naSize))
- pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.NDPPayload())
- na.SetSolicitedFlag(true)
- na.SetOverrideFlag(true)
- na.SetTargetAddress(tgt)
- na.Options().Serialize(header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(linkAddr1),
- })
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, tgt, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: 255,
- SrcAddr: tgt,
- DstAddr: header.IPv6AllNodesMulticastAddress,
- })
-
- return hdr
-
- },
- func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return s.NeighborAdvert
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 1),
- }
- ndpConfigs := stack.DefaultNDPConfigurations()
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
- }
- opts.NDPConfigs.RetransmitTimer = time.Second * 2
-
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(opts)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
- }
-
- // Address should not be considered bound to the NIC yet
- // (DAD ongoing).
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
- }
-
- // Receive a packet to simulate multiple nodes owning or
- // attempting to own the same address.
- hdr := test.makeBuf(addr1)
- e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
- })
-
- stat := test.getStat(s.Stats().ICMP.V6PacketsReceived)
- if got := stat.Value(); got != 1 {
- t.Fatalf("got stat = %d, want = 1", got)
- }
-
- // Wait for DAD to fail and make sure the address did
- // not get resolved.
- select {
- case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
- // If we don't get a failure event after the
- // expected resolution time + extra 1s buffer,
- // something is wrong.
- t.Fatal("timed out waiting for DAD failure")
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
- }
- }
- addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
- }
- })
- }
-}
-
-func TestDADStop(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- stopFn func(t *testing.T, s *stack.Stack)
- }{
- // Tests to make sure that DAD stops when an address is removed.
- {
- name: "Remove address",
- stopFn: func(t *testing.T, s *stack.Stack) {
- if err := s.RemoveAddress(nicID, addr1); err != nil {
- t.Fatalf("RemoveAddress(%d, %s): %s", nicID, addr1, err)
- }
- },
- },
-
- // Tests to make sure that DAD stops when the NIC is disabled.
- {
- name: "Disable NIC",
- stopFn: func(t *testing.T, s *stack.Stack) {
- if err := s.DisableNIC(nicID); err != nil {
- t.Fatalf("DisableNIC(%d): %s", nicID, err)
- }
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 1),
- }
- ndpConfigs := stack.NDPConfigurations{
- RetransmitTimer: time.Second,
- DupAddrDetectTransmits: 2,
- }
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPDisp: &ndpDisp,
- NDPConfigs: ndpConfigs,
- }
-
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(opts)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
-
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err)
- }
-
- // Address should not be considered bound to the NIC yet (DAD ongoing).
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
- }
-
- test.stopFn(t, s)
-
- // Wait for DAD to fail (since the address was removed during DAD).
- select {
- case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
- // If we don't get a failure event after the expected resolution
- // time + extra 1s buffer, something is wrong.
- t.Fatal("timed out waiting for DAD failure")
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
- }
- }
- addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
- }
-
- // Should not have sent more than 1 NS message.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 {
- t.Errorf("got NeighborSolicit = %d, want <= 1", got)
- }
- })
- }
-}
-
-// TestSetNDPConfigurationFailsForBadNICID tests to make sure we get an error if
-// we attempt to update NDP configurations using an invalid NICID.
-func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- })
-
- // No NIC with ID 1 yet.
- if got := s.SetNDPConfigurations(1, stack.NDPConfigurations{}); got != tcpip.ErrUnknownNICID {
- t.Fatalf("got s.SetNDPConfigurations = %v, want = %s", got, tcpip.ErrUnknownNICID)
- }
-}
-
-// TestSetNDPConfigurations tests that we can update and use per-interface NDP
-// configurations without affecting the default NDP configurations or other
-// interfaces' configurations.
-func TestSetNDPConfigurations(t *testing.T) {
- const nicID1 = 1
- const nicID2 = 2
- const nicID3 = 3
-
- tests := []struct {
- name string
- dupAddrDetectTransmits uint8
- retransmitTimer time.Duration
- expectedRetransmitTimer time.Duration
- }{
- {
- "OK",
- 1,
- time.Second,
- time.Second,
- },
- {
- "Invalid Retransmit Timer",
- 1,
- 0,
- time.Second,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPDisp: &ndpDisp,
- })
-
- expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) {
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatalf("expected DAD event for %s", addr)
- }
- }
-
- // This NIC(1)'s NDP configurations will be updated to
- // be different from the default.
- if err := s.CreateNIC(nicID1, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err)
- }
-
- // Created before updating NIC(1)'s NDP configurations
- // but updating NIC(1)'s NDP configurations should not
- // affect other existing NICs.
- if err := s.CreateNIC(nicID2, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err)
- }
-
- // Update the NDP configurations on NIC(1) to use DAD.
- configs := stack.NDPConfigurations{
- DupAddrDetectTransmits: test.dupAddrDetectTransmits,
- RetransmitTimer: test.retransmitTimer,
- }
- if err := s.SetNDPConfigurations(nicID1, configs); err != nil {
- t.Fatalf("got SetNDPConfigurations(%d, _) = %s", nicID1, err)
- }
-
- // Created after updating NIC(1)'s NDP configurations
- // but the stack's default NDP configurations should not
- // have been updated.
- if err := s.CreateNIC(nicID3, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID3, err)
- }
-
- // Add addresses for each NIC.
- if err := s.AddAddress(nicID1, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addr1, err)
- }
- if err := s.AddAddress(nicID2, header.IPv6ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addr2, err)
- }
- expectDADEvent(nicID2, addr2)
- if err := s.AddAddress(nicID3, header.IPv6ProtocolNumber, addr3); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addr3, err)
- }
- expectDADEvent(nicID3, addr3)
-
- // Address should not be considered bound to NIC(1) yet
- // (DAD ongoing).
- addr, err := s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want)
- }
-
- // Should get the address on NIC(2) and NIC(3)
- // immediately since we should not have performed DAD on
- // it as the stack was configured to not do DAD by
- // default and we only updated the NDP configurations on
- // NIC(1).
- addr, err = s.GetMainNICAddress(nicID2, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID2, header.IPv6ProtocolNumber, err)
- }
- if addr.Address != addr2 {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID2, header.IPv6ProtocolNumber, addr, addr2)
- }
- addr, err = s.GetMainNICAddress(nicID3, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID3, header.IPv6ProtocolNumber, err)
- }
- if addr.Address != addr3 {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID3, header.IPv6ProtocolNumber, addr, addr3)
- }
-
- // Sleep until right (500ms before) before resolution to
- // make sure the address didn't resolve on NIC(1) yet.
- const delta = 500 * time.Millisecond
- time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta)
- addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want)
- }
-
- // Wait for DAD to resolve.
- select {
- case <-time.After(2 * delta):
- // We should get a resolution event after 500ms
- // (delta) since we wait for 500ms less than the
- // expected resolution time above to make sure
- // that the address did not yet resolve. Waiting
- // for 1s (2x delta) without a resolution event
- // means something is wrong.
- t.Fatal("timed out waiting for DAD resolution")
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID1, addr1, true, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
- }
- }
- addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err)
- }
- if addr.Address != addr1 {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID1, header.IPv6ProtocolNumber, addr, addr1)
- }
- })
- }
-}
-
-// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options
-// and DHCPv6 configurations specified.
-func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer {
- icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length())
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
- pkt := header.ICMPv6(hdr.Prepend(icmpSize))
- pkt.SetType(header.ICMPv6RouterAdvert)
- pkt.SetCode(0)
- raPayload := pkt.NDPPayload()
- ra := header.NDPRouterAdvert(raPayload)
- // Populate the Router Lifetime.
- binary.BigEndian.PutUint16(raPayload[2:], rl)
- // Populate the Managed Address flag field.
- if managedAddress {
- // The Managed Addresses flag field is the 7th bit of byte #1 (0-indexing)
- // of the RA payload.
- raPayload[1] |= (1 << 7)
- }
- // Populate the Other Configurations flag field.
- if otherConfigurations {
- // The Other Configurations flag field is the 6th bit of byte #1
- // (0-indexing) of the RA payload.
- raPayload[1] |= (1 << 6)
- }
- opts := ra.Options()
- opts.Serialize(optSer)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
- payloadLength := hdr.UsedLength()
- iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- iph.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(icmp.ProtocolNumber6),
- HopLimit: header.NDPHopLimit,
- SrcAddr: ip,
- DstAddr: header.IPv6AllNodesMulticastAddress,
- })
-
- return tcpip.PacketBuffer{Data: hdr.View().ToVectorisedView()}
-}
-
-// raBufWithOpts returns a valid NDP Router Advertisement with options.
-//
-// Note, raBufWithOpts does not populate any of the RA fields other than the
-// Router Lifetime.
-func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer {
- return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer)
-}
-
-// raBufWithDHCPv6 returns a valid NDP Router Advertisement with DHCPv6 related
-// fields set.
-//
-// Note, raBufWithDHCPv6 does not populate any of the RA fields other than the
-// DHCPv6 related ones.
-func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) tcpip.PacketBuffer {
- return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{})
-}
-
-// raBuf returns a valid NDP Router Advertisement.
-//
-// Note, raBuf does not populate any of the RA fields other than the
-// Router Lifetime.
-func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer {
- return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{})
-}
-
-// raBufWithPI returns a valid NDP Router Advertisement with a single Prefix
-// Information option.
-//
-// Note, raBufWithPI does not populate any of the RA fields other than the
-// Router Lifetime.
-func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) tcpip.PacketBuffer {
- flags := uint8(0)
- if onLink {
- // The OnLink flag is the 7th bit in the flags byte.
- flags |= 1 << 7
- }
- if auto {
- // The Address Auto-Configuration flag is the 6th bit in the
- // flags byte.
- flags |= 1 << 6
- }
-
- // A valid header.NDPPrefixInformation must be 30 bytes.
- buf := [30]byte{}
- // The first byte in a header.NDPPrefixInformation is the Prefix Length
- // field.
- buf[0] = uint8(prefix.PrefixLen)
- // The 2nd byte within a header.NDPPrefixInformation is the Flags field.
- buf[1] = flags
- // The Valid Lifetime field starts after the 2nd byte within a
- // header.NDPPrefixInformation.
- binary.BigEndian.PutUint32(buf[2:], vl)
- // The Preferred Lifetime field starts after the 6th byte within a
- // header.NDPPrefixInformation.
- binary.BigEndian.PutUint32(buf[6:], pl)
- // The Prefix Address field starts after the 14th byte within a
- // header.NDPPrefixInformation.
- copy(buf[14:], prefix.Address)
- return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{
- header.NDPPrefixInformation(buf[:]),
- })
-}
-
-// TestNoRouterDiscovery tests that router discovery will not be performed if
-// configured not to.
-func TestNoRouterDiscovery(t *testing.T) {
- // Being configured to discover routers means handle and
- // discover are set to true and forwarding is set to false.
- // This tests all possible combinations of the configurations,
- // except for the configuration where handle = true, discover =
- // true and forwarding = false (the required configuration to do
- // router discovery) - that will done in other tests.
- for i := 0; i < 7; i++ {
- handle := i&1 != 0
- discover := i&2 != 0
- forwarding := i&4 == 0
-
- t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) {
- t.Parallel()
-
- ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: handle,
- DiscoverDefaultRouters: discover,
- },
- NDPDisp: &ndpDisp,
- })
- s.SetForwarding(forwarding)
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Rx an RA with non-zero lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
- select {
- case <-ndpDisp.routerC:
- t.Fatal("unexpectedly discovered a router when configured not to")
- default:
- }
- })
- }
-}
-
-// Check e to make sure that the event is for addr on nic with ID 1, and the
-// discovered flag set to discovered.
-func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string {
- return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e))
-}
-
-// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not
-// remember a discovered router when the dispatcher asks it not to.
-func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
- t.Parallel()
-
- ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: true,
- },
- NDPDisp: &ndpDisp,
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Receive an RA for a router we should not remember.
- const lifetimeSeconds = 1
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, lifetimeSeconds))
- select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, llAddr2, true); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected router discovery event")
- }
-
- // Wait for the invalidation time plus some buffer to make sure we do
- // not actually receive any invalidation events as we should not have
- // remembered the router in the first place.
- select {
- case <-ndpDisp.routerC:
- t.Fatal("should not have received any router events")
- case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
- }
-}
-
-func TestRouterDiscovery(t *testing.T) {
- t.Parallel()
-
- ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- rememberRouter: true,
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: true,
- },
- NDPDisp: &ndpDisp,
- })
-
- expectRouterEvent := func(addr tcpip.Address, discovered bool) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, addr, discovered); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected router discovery event")
- }
- }
-
- expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, addr, false); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(timeout):
- t.Fatal("timed out waiting for router discovery event")
- }
- }
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Rx an RA from lladdr2 with zero lifetime. It should not be
- // remembered.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
- select {
- case <-ndpDisp.routerC:
- t.Fatal("unexpectedly discovered a router with 0 lifetime")
- default:
- }
-
- // Rx an RA from lladdr2 with a huge lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
- expectRouterEvent(llAddr2, true)
-
- // Rx an RA from another router (lladdr3) with non-zero lifetime.
- const l3LifetimeSeconds = 6
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds))
- expectRouterEvent(llAddr3, true)
-
- // Rx an RA from lladdr2 with lesser lifetime.
- const l2LifetimeSeconds = 2
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds))
- select {
- case <-ndpDisp.routerC:
- t.Fatal("Should not receive a router event when updating lifetimes for known routers")
- default:
- }
-
- // Wait for lladdr2's router invalidation timer to fire. The lifetime
- // of the router should have been updated to the most recent (smaller)
- // lifetime.
- //
- // Wait for the normal lifetime plus an extra bit for the
- // router to get invalidated. If we don't get an invalidation
- // event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncEventTimeout)
-
- // Rx an RA from lladdr2 with huge lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
- expectRouterEvent(llAddr2, true)
-
- // Rx an RA from lladdr2 with zero lifetime. It should be invalidated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
- expectRouterEvent(llAddr2, false)
-
- // Wait for lladdr3's router invalidation timer to fire. The lifetime
- // of the router should have been updated to the most recent (smaller)
- // lifetime.
- //
- // Wait for the normal lifetime plus an extra bit for the
- // router to get invalidated. If we don't get an invalidation
- // event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncEventTimeout)
-}
-
-// TestRouterDiscoveryMaxRouters tests that only
-// stack.MaxDiscoveredDefaultRouters discovered routers are remembered.
-func TestRouterDiscoveryMaxRouters(t *testing.T) {
- t.Parallel()
-
- ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- rememberRouter: true,
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: true,
- },
- NDPDisp: &ndpDisp,
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Receive an RA from 2 more than the max number of discovered routers.
- for i := 1; i <= stack.MaxDiscoveredDefaultRouters+2; i++ {
- linkAddr := []byte{2, 2, 3, 4, 5, 0}
- linkAddr[5] = byte(i)
- llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr))
-
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5))
-
- if i <= stack.MaxDiscoveredDefaultRouters {
- select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, llAddr, true); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected router discovery event")
- }
-
- } else {
- select {
- case <-ndpDisp.routerC:
- t.Fatal("should not have discovered a new router after we already discovered the max number of routers")
- default:
- }
- }
- }
-}
-
-// TestNoPrefixDiscovery tests that prefix discovery will not be performed if
-// configured not to.
-func TestNoPrefixDiscovery(t *testing.T) {
- prefix := tcpip.AddressWithPrefix{
- Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"),
- PrefixLen: 64,
- }
-
- // Being configured to discover prefixes means handle and
- // discover are set to true and forwarding is set to false.
- // This tests all possible combinations of the configurations,
- // except for the configuration where handle = true, discover =
- // true and forwarding = false (the required configuration to do
- // prefix discovery) - that will done in other tests.
- for i := 0; i < 7; i++ {
- handle := i&1 != 0
- discover := i&2 != 0
- forwarding := i&4 == 0
-
- t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) {
- t.Parallel()
-
- ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: handle,
- DiscoverOnLinkPrefixes: discover,
- },
- NDPDisp: &ndpDisp,
- })
- s.SetForwarding(forwarding)
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Rx an RA with prefix with non-zero lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0))
-
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("unexpectedly discovered a prefix when configured not to")
- default:
- }
- })
- }
-}
-
-// Check e to make sure that the event is for prefix on nic with ID 1, and the
-// discovered flag set to discovered.
-func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) string {
- return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e))
-}
-
-// TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not
-// remember a discovered on-link prefix when the dispatcher asks it not to.
-func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
- t.Parallel()
-
- prefix, subnet, _ := prefixSubnetAddr(0, "")
-
- ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: false,
- DiscoverOnLinkPrefixes: true,
- },
- NDPDisp: &ndpDisp,
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Receive an RA with prefix that we should not remember.
- const lifetimeSeconds = 1
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetimeSeconds, 0))
- select {
- case e := <-ndpDisp.prefixC:
- if diff := checkPrefixEvent(e, subnet, true); diff != "" {
- t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected prefix discovery event")
- }
-
- // Wait for the invalidation time plus some buffer to make sure we do
- // not actually receive any invalidation events as we should not have
- // remembered the prefix in the first place.
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("should not have received any prefix events")
- case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
- }
-}
-
-func TestPrefixDiscovery(t *testing.T) {
- t.Parallel()
-
- prefix1, subnet1, _ := prefixSubnetAddr(0, "")
- prefix2, subnet2, _ := prefixSubnetAddr(1, "")
- prefix3, subnet3, _ := prefixSubnetAddr(2, "")
-
- ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 1),
- rememberPrefix: true,
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverOnLinkPrefixes: true,
- },
- NDPDisp: &ndpDisp,
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.prefixC:
- if diff := checkPrefixEvent(e, prefix, discovered); diff != "" {
- t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected prefix discovery event")
- }
- }
-
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with zero valid lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0))
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("unexpectedly discovered a prefix with 0 lifetime")
- default:
- }
-
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with non-zero lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0))
- expectPrefixEvent(subnet1, true)
-
- // Receive an RA with prefix2 in a PI.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0))
- expectPrefixEvent(subnet2, true)
-
- // Receive an RA with prefix3 in a PI.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0))
- expectPrefixEvent(subnet3, true)
-
- // Receive an RA with prefix1 in a PI with lifetime = 0.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0))
- expectPrefixEvent(subnet1, false)
-
- // Receive an RA with prefix2 in a PI with lesser lifetime.
- lifetime := uint32(2)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0))
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("unexpectedly received prefix event when updating lifetime")
- default:
- }
-
- // Wait for prefix2's most recent invalidation timer plus some buffer to
- // expire.
- select {
- case e := <-ndpDisp.prefixC:
- if diff := checkPrefixEvent(e, subnet2, false); diff != "" {
- t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncEventTimeout):
- t.Fatal("timed out waiting for prefix discovery event")
- }
-
- // Receive RA to invalidate prefix3.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0))
- expectPrefixEvent(subnet3, false)
-}
-
-func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
- // Update the infinite lifetime value to a smaller value so we can test
- // that when we receive a PI with such a lifetime value, we do not
- // invalidate the prefix.
- const testInfiniteLifetimeSeconds = 2
- const testInfiniteLifetime = testInfiniteLifetimeSeconds * time.Second
- saved := header.NDPInfiniteLifetime
- header.NDPInfiniteLifetime = testInfiniteLifetime
- defer func() {
- header.NDPInfiniteLifetime = saved
- }()
-
- prefix := tcpip.AddressWithPrefix{
- Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"),
- PrefixLen: 64,
- }
- subnet := prefix.Subnet()
-
- ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 1),
- rememberPrefix: true,
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverOnLinkPrefixes: true,
- },
- NDPDisp: &ndpDisp,
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.prefixC:
- if diff := checkPrefixEvent(e, prefix, discovered); diff != "" {
- t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected prefix discovery event")
- }
- }
-
- // Receive an RA with prefix in an NDP Prefix Information option (PI)
- // with infinite valid lifetime which should not get invalidated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0))
- expectPrefixEvent(subnet, true)
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
- case <-time.After(testInfiniteLifetime + defaultTimeout):
- }
-
- // Receive an RA with finite lifetime.
- // The prefix should get invalidated after 1s.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0))
- select {
- case e := <-ndpDisp.prefixC:
- if diff := checkPrefixEvent(e, subnet, false); diff != "" {
- t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(testInfiniteLifetime):
- t.Fatal("timed out waiting for prefix discovery event")
- }
-
- // Receive an RA with finite lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0))
- expectPrefixEvent(subnet, true)
-
- // Receive an RA with prefix with an infinite lifetime.
- // The prefix should not be invalidated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0))
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
- case <-time.After(testInfiniteLifetime + defaultTimeout):
- }
-
- // Receive an RA with a prefix with a lifetime value greater than the
- // set infinite lifetime value.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0))
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
- case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultTimeout):
- }
-
- // Receive an RA with 0 lifetime.
- // The prefix should get invalidated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 0, 0))
- expectPrefixEvent(subnet, false)
-}
-
-// TestPrefixDiscoveryMaxRouters tests that only
-// stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered.
-func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
- t.Parallel()
-
- ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3),
- rememberPrefix: true,
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: false,
- DiscoverOnLinkPrefixes: true,
- },
- NDPDisp: &ndpDisp,
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- optSer := make(header.NDPOptionsSerializer, stack.MaxDiscoveredOnLinkPrefixes+2)
- prefixes := [stack.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{}
-
- // Receive an RA with 2 more than the max number of discovered on-link
- // prefixes.
- for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ {
- prefixAddr := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0}
- prefixAddr[7] = byte(i)
- prefix := tcpip.AddressWithPrefix{
- Address: tcpip.Address(prefixAddr[:]),
- PrefixLen: 64,
- }
- prefixes[i] = prefix.Subnet()
- buf := [30]byte{}
- buf[0] = uint8(prefix.PrefixLen)
- buf[1] = 128
- binary.BigEndian.PutUint32(buf[2:], 10)
- copy(buf[14:], prefix.Address)
-
- optSer[i] = header.NDPPrefixInformation(buf[:])
- }
-
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer))
- for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ {
- if i < stack.MaxDiscoveredOnLinkPrefixes {
- select {
- case e := <-ndpDisp.prefixC:
- if diff := checkPrefixEvent(e, prefixes[i], true); diff != "" {
- t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected prefix discovery event")
- }
- } else {
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("should not have discovered a new prefix after we already discovered the max number of prefixes")
- default:
- }
- }
- }
-}
-
-// Checks to see if list contains an IPv6 address, item.
-func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool {
- protocolAddress := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: item,
- }
-
- for _, i := range list {
- if i == protocolAddress {
- return true
- }
- }
-
- return false
-}
-
-// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to.
-func TestNoAutoGenAddr(t *testing.T) {
- prefix, _, _ := prefixSubnetAddr(0, "")
-
- // Being configured to auto-generate addresses means handle and
- // autogen are set to true and forwarding is set to false.
- // This tests all possible combinations of the configurations,
- // except for the configuration where handle = true, autogen =
- // true and forwarding = false (the required configuration to do
- // SLAAC) - that will done in other tests.
- for i := 0; i < 7; i++ {
- handle := i&1 != 0
- autogen := i&2 != 0
- forwarding := i&4 == 0
-
- t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) {
- t.Parallel()
-
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: handle,
- AutoGenGlobalAddresses: autogen,
- },
- NDPDisp: &ndpDisp,
- })
- s.SetForwarding(forwarding)
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Rx an RA with prefix with non-zero lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0))
-
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly auto-generated an address when configured not to")
- default:
- }
- })
- }
-}
-
-// Check e to make sure that the event is for addr on nic with ID 1, and the
-// event type is set to eventType.
-func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) string {
- return cmp.Diff(ndpAutoGenAddrEvent{nicID: 1, addr: addr, eventType: eventType}, e, cmp.AllowUnexported(e))
-}
-
-// TestAutoGenAddr tests that an address is properly generated and invalidated
-// when configured to do so.
-func TestAutoGenAddr(t *testing.T) {
- const newMinVL = 2
- newMinVLDuration := newMinVL * time.Second
- saved := stack.MinPrefixInformationValidLifetimeForUpdate
- defer func() {
- stack.MinPrefixInformationValidLifetimeForUpdate = saved
- }()
- stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
-
- prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
- prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
-
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
-
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with zero valid lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly auto-generated an address with 0 lifetime")
- default:
- }
-
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with non-zero lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
- expectAutoGenAddrEvent(addr1, newAddr)
- if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
- t.Fatalf("Should have %s in the list of addresses", addr1)
- }
-
- // Receive an RA with prefix2 in an NDP Prefix Information option (PI)
- // with preferred lifetime > valid lifetime
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime")
- default:
- }
-
- // Receive an RA with prefix2 in a PI.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
- expectAutoGenAddrEvent(addr2, newAddr)
- if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
- t.Fatalf("Should have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) {
- t.Fatalf("Should have %s in the list of addresses", addr2)
- }
-
- // Refresh valid lifetime for addr of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix")
- default:
- }
-
- // Wait for addr of prefix1 to be invalidated.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(newMinVLDuration + defaultAsyncEventTimeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
- if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
- t.Fatalf("Should not have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) {
- t.Fatalf("Should have %s in the list of addresses", addr2)
- }
-}
-
-// stackAndNdpDispatcherWithDefaultRoute returns an ndpDispatcher,
-// channel.Endpoint and stack.Stack.
-//
-// stack.Stack will have a default route through the router (llAddr3) installed
-// and a static link-address (linkAddr3) added to the link address cache for the
-// router.
-func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) {
- t.Helper()
- ndpDisp := &ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: ndpDisp,
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv6EmptySubnet,
- Gateway: llAddr3,
- NIC: nicID,
- }})
- s.AddLinkAddress(nicID, llAddr3, linkAddr3)
- return ndpDisp, e, s
-}
-
-// addrForNewConnectionTo returns the local address used when creating a new
-// connection to addr.
-func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address {
- t.Helper()
-
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
- defer close(ch)
- ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
- }
- defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
- if err := ep.Connect(addr); err != nil {
- t.Fatalf("ep.Connect(%+v): %s", addr, err)
- }
- got, err := ep.GetLocalAddress()
- if err != nil {
- t.Fatalf("ep.GetLocalAddress(): %s", err)
- }
- return got.Addr
-}
-
-// addrForNewConnection returns the local address used when creating a new
-// connection.
-func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address {
- t.Helper()
-
- return addrForNewConnectionTo(t, s, dstAddr)
-}
-
-// addrForNewConnectionWithAddr returns the local address used when creating a
-// new connection with a specific local address.
-func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address {
- t.Helper()
-
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
- defer close(ch)
- ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
- }
- defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
- if err := ep.Bind(addr); err != nil {
- t.Fatalf("ep.Bind(%+v): %s", addr, err)
- }
- if err := ep.Connect(dstAddr); err != nil {
- t.Fatalf("ep.Connect(%+v): %s", dstAddr, err)
- }
- got, err := ep.GetLocalAddress()
- if err != nil {
- t.Fatalf("ep.GetLocalAddress(): %s", err)
- }
- return got.Addr
-}
-
-// TestAutoGenAddrDeprecateFromPI tests deprecating a SLAAC address when
-// receiving a PI with 0 preferred lifetime.
-func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
- const nicID = 1
-
- prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
- prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
-
- ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
-
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
-
- expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
- t.Helper()
-
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if got != addr {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
- }
-
- if got := addrForNewConnection(t, s); got != addr.Address {
- t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
- }
- }
-
- // Receive PI for prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
- expectAutoGenAddrEvent(addr1, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should have %s in the list of addresses", addr1)
- }
- expectPrimaryAddr(addr1)
-
- // Deprecate addr for prefix1 immedaitely.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
- expectAutoGenAddrEvent(addr1, deprecatedAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should have %s in the list of addresses", addr1)
- }
- // addr should still be the primary endpoint as there are no other addresses.
- expectPrimaryAddr(addr1)
-
- // Refresh lifetimes of addr generated from prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr1)
-
- // Receive PI for prefix2.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
- expectAutoGenAddrEvent(addr2, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- expectPrimaryAddr(addr2)
-
- // Deprecate addr for prefix2 immedaitely.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
- expectAutoGenAddrEvent(addr2, deprecatedAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- // addr1 should be the primary endpoint now since addr2 is deprecated but
- // addr1 is not.
- expectPrimaryAddr(addr1)
- // addr2 is deprecated but if explicitly requested, it should be used.
- fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID}
- if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address)
- }
-
- // Another PI w/ 0 preferred lifetime should not result in a deprecation
- // event.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr1)
- if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address)
- }
-
- // Refresh lifetimes of addr generated from prefix2.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr2)
-}
-
-// TestAutoGenAddrTimerDeprecation tests that an address is properly deprecated
-// when its preferred lifetime expires.
-func TestAutoGenAddrTimerDeprecation(t *testing.T) {
- const nicID = 1
- const newMinVL = 2
- newMinVLDuration := newMinVL * time.Second
- saved := stack.MinPrefixInformationValidLifetimeForUpdate
- defer func() {
- stack.MinPrefixInformationValidLifetimeForUpdate = saved
- }()
- stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
-
- prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
- prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
-
- ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
-
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
-
- expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(timeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
- }
-
- expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
- t.Helper()
-
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if got != addr {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
- }
-
- if got := addrForNewConnection(t, s); got != addr.Address {
- t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
- }
- }
-
- // Receive PI for prefix2.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
- expectAutoGenAddrEvent(addr2, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- expectPrimaryAddr(addr2)
-
- // Receive a PI for prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90))
- expectAutoGenAddrEvent(addr1, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- expectPrimaryAddr(addr1)
-
- // Refresh lifetime for addr of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr1)
-
- // Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should not have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- // addr2 should be the primary endpoint now since addr1 is deprecated but
- // addr2 is not.
- expectPrimaryAddr(addr2)
- // addr1 is deprecated but if explicitly requested, it should be used.
- fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID}
- if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address)
- }
-
- // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make
- // sure we do not get a deprecation event again.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr2)
- if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address)
- }
-
- // Refresh lifetimes for addr of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- // addr1 is the primary endpoint again since it is non-deprecated now.
- expectPrimaryAddr(addr1)
-
- // Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should not have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- // addr2 should be the primary endpoint now since it is not deprecated.
- expectPrimaryAddr(addr2)
- if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address)
- }
-
- // Wait for addr of prefix1 to be invalidated.
- expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncEventTimeout)
- if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should not have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- expectPrimaryAddr(addr2)
-
- // Refresh both lifetimes for addr of prefix2 to the same value.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
-
- // Wait for a deprecation then invalidation events, or just an invalidation
- // event. We need to cover both cases but cannot deterministically hit both
- // cases because the deprecation and invalidation handlers could be handled in
- // either deprecation then invalidation, or invalidation then deprecation
- // (which should be cancelled by the invalidation handler).
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" {
- // If we get a deprecation event first, we should get an invalidation
- // event almost immediately after.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(defaultAsyncEventTimeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
- } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" {
- // If we get an invalidation event first, we should not get a deprecation
- // event after.
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- case <-time.After(defaultTimeout):
- }
- } else {
- t.Fatalf("got unexpected auto-generated event")
- }
-
- case <-time.After(newMinVLDuration + defaultAsyncEventTimeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
- if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should not have %s in the list of addresses", addr1)
- }
- if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should not have %s in the list of addresses", addr2)
- }
- // Should not have any primary endpoints.
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if want := (tcpip.AddressWithPrefix{}); got != want {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want)
- }
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
- defer close(ch)
- ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
- }
- defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
-
- if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute {
- t.Errorf("got ep.Connect(%+v) = %v, want = %s", dstAddr, err, tcpip.ErrNoRoute)
- }
-}
-
-// Tests transitioning a SLAAC address's valid lifetime between finite and
-// infinite values.
-func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
- const infiniteVLSeconds = 2
- const minVLSeconds = 1
- savedIL := header.NDPInfiniteLifetime
- savedMinVL := stack.MinPrefixInformationValidLifetimeForUpdate
- defer func() {
- stack.MinPrefixInformationValidLifetimeForUpdate = savedMinVL
- header.NDPInfiniteLifetime = savedIL
- }()
- stack.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second
- header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second
-
- prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
-
- tests := []struct {
- name string
- infiniteVL uint32
- }{
- {
- name: "EqualToInfiniteVL",
- infiniteVL: infiniteVLSeconds,
- },
- // Our implementation supports changing header.NDPInfiniteLifetime for tests
- // such that a packet can be received where the lifetime field has a value
- // greater than header.NDPInfiniteLifetime. Because of this, we test to make
- // sure that receiving a value greater than header.NDPInfiniteLifetime is
- // handled the same as when receiving a value equal to
- // header.NDPInfiniteLifetime.
- {
- name: "MoreThanInfiniteVL",
- infiniteVL: infiniteVLSeconds + 1,
- },
- }
-
- // This Run will not return until the parallel tests finish.
- //
- // We need this because we need to do some teardown work after the
- // parallel tests complete.
- //
- // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
- // more details.
- t.Run("group", func(t *testing.T) {
- for _, test := range tests {
- test := test
-
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Receive an RA with finite prefix.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0))
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
-
- default:
- t.Fatal("expected addr auto gen event")
- }
-
- // Receive an new RA with prefix with infinite VL.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0))
-
- // Receive a new RA with prefix with finite VL.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0))
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
-
- case <-time.After(minVLSeconds*time.Second + defaultAsyncEventTimeout):
- t.Fatal("timeout waiting for addr auto gen event")
- }
- })
- }
- })
-}
-
-// TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an
-// auto-generated address only gets updated when required to, as specified in
-// RFC 4862 section 5.5.3.e.
-func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
- const infiniteVL = 4294967295
- const newMinVL = 4
- saved := stack.MinPrefixInformationValidLifetimeForUpdate
- defer func() {
- stack.MinPrefixInformationValidLifetimeForUpdate = saved
- }()
- stack.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second
-
- prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
-
- tests := []struct {
- name string
- ovl uint32
- nvl uint32
- evl uint32
- }{
- // Should update the VL to the minimum VL for updating if the
- // new VL is less than newMinVL but was originally greater than
- // it.
- {
- "LargeVLToVLLessThanMinVLForUpdate",
- 9999,
- 1,
- newMinVL,
- },
- {
- "LargeVLTo0",
- 9999,
- 0,
- newMinVL,
- },
- {
- "InfiniteVLToVLLessThanMinVLForUpdate",
- infiniteVL,
- 1,
- newMinVL,
- },
- {
- "InfiniteVLTo0",
- infiniteVL,
- 0,
- newMinVL,
- },
-
- // Should not update VL if original VL was less than newMinVL
- // and the new VL is also less than newMinVL.
- {
- "ShouldNotUpdateWhenBothOldAndNewAreLessThanMinVLForUpdate",
- newMinVL - 1,
- newMinVL - 3,
- newMinVL - 1,
- },
-
- // Should take the new VL if the new VL is greater than the
- // remaining time or is greater than newMinVL.
- {
- "MorethanMinVLToLesserButStillMoreThanMinVLForUpdate",
- newMinVL + 5,
- newMinVL + 3,
- newMinVL + 3,
- },
- {
- "SmallVLToGreaterVLButStillLessThanMinVLForUpdate",
- newMinVL - 3,
- newMinVL - 1,
- newMinVL - 1,
- },
- {
- "SmallVLToGreaterVLThatIsMoreThaMinVLForUpdate",
- newMinVL - 3,
- newMinVL + 1,
- newMinVL + 1,
- },
- }
-
- const delta = 500 * time.Millisecond
-
- // This Run will not return until the parallel tests finish.
- //
- // We need this because we need to do some teardown work after the
- // parallel tests complete.
- //
- // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
- // more details.
- t.Run("group", func(t *testing.T) {
- for _, test := range tests {
- test := test
-
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10),
- }
- e := channel.New(10, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Receive an RA with prefix with initial VL,
- // test.ovl.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0))
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
-
- // Receive an new RA with prefix with new VL,
- // test.nvl.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0))
-
- //
- // Validate that the VL for the address got set
- // to test.evl.
- //
-
- // Make sure we do not get any invalidation
- // events until atleast 500ms (delta) before
- // test.evl.
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly received an auto gen addr event")
- case <-time.After(time.Duration(test.evl)*time.Second - delta):
- }
-
- // Wait for another second (2x delta), but now
- // we expect the invalidation event.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
-
- case <-time.After(2 * delta):
- t.Fatal("timeout waiting for addr auto gen event")
- }
- })
- }
- })
-}
-
-// TestAutoGenAddrRemoval tests that when auto-generated addresses are removed
-// by the user, its resources will be cleaned up and an invalidation event will
-// be sent to the integrator.
-func TestAutoGenAddrRemoval(t *testing.T) {
- t.Parallel()
-
- prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
-
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
-
- // Receive a PI to auto-generate an address.
- const lifetimeSeconds = 1
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0))
- expectAutoGenAddrEvent(addr, newAddr)
-
- // Removing the address should result in an invalidation event
- // immediately.
- if err := s.RemoveAddress(1, addr.Address); err != nil {
- t.Fatalf("RemoveAddress(_, %s) = %s", addr.Address, err)
- }
- expectAutoGenAddrEvent(addr, invalidatedAddr)
-
- // Wait for the original valid lifetime to make sure the original timer
- // got stopped/cleaned up.
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly received an auto gen addr event")
- case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
- }
-}
-
-// TestAutoGenAddrAfterRemoval tests adding a SLAAC address that was previously
-// assigned to the NIC but is in the permanentExpired state.
-func TestAutoGenAddrAfterRemoval(t *testing.T) {
- t.Parallel()
-
- const nicID = 1
-
- prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
- prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
- ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
-
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
-
- expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
- t.Helper()
-
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if got != addr {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
- }
-
- if got := addrForNewConnection(t, s); got != addr.Address {
- t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
- }
- }
-
- // Receive a PI to auto-generate addr1 with a large valid and preferred
- // lifetime.
- const largeLifetimeSeconds = 999
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
- expectAutoGenAddrEvent(addr1, newAddr)
- expectPrimaryAddr(addr1)
-
- // Add addr2 as a static address.
- protoAddr2 := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: addr2,
- }
- if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d, %s) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err)
- }
- // addr2 should be more preferred now since it is at the front of the primary
- // list.
- expectPrimaryAddr(addr2)
-
- // Get a route using addr2 to increment its reference count then remove it
- // to leave it in the permanentExpired state.
- r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false)
- if err != nil {
- t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err)
- }
- defer r.Release()
- if err := s.RemoveAddress(nicID, addr2.Address); err != nil {
- t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err)
- }
- // addr1 should be preferred again since addr2 is in the expired state.
- expectPrimaryAddr(addr1)
-
- // Receive a PI to auto-generate addr2 as valid and preferred.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
- expectAutoGenAddrEvent(addr2, newAddr)
- // addr2 should be more preferred now that it is closer to the front of the
- // primary list and not deprecated.
- expectPrimaryAddr(addr2)
-
- // Removing the address should result in an invalidation event immediately.
- // It should still be in the permanentExpired state because r is still held.
- //
- // We remove addr2 here to make sure addr2 was marked as a SLAAC address
- // (it was previously marked as a static address).
- if err := s.RemoveAddress(1, addr2.Address); err != nil {
- t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
- }
- expectAutoGenAddrEvent(addr2, invalidatedAddr)
- // addr1 should be more preferred since addr2 is in the expired state.
- expectPrimaryAddr(addr1)
-
- // Receive a PI to auto-generate addr2 as valid and deprecated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0))
- expectAutoGenAddrEvent(addr2, newAddr)
- // addr1 should still be more preferred since addr2 is deprecated, even though
- // it is closer to the front of the primary list.
- expectPrimaryAddr(addr1)
-
- // Receive a PI to refresh addr2's preferred lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto gen addr event")
- default:
- }
- // addr2 should be more preferred now that it is not deprecated.
- expectPrimaryAddr(addr2)
-
- if err := s.RemoveAddress(1, addr2.Address); err != nil {
- t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
- }
- expectAutoGenAddrEvent(addr2, invalidatedAddr)
- expectPrimaryAddr(addr1)
-}
-
-// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that
-// is already assigned to the NIC, the static address remains.
-func TestAutoGenAddrStaticConflict(t *testing.T) {
- t.Parallel()
-
- prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
-
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Add the address as a static address before SLAAC tries to add it.
- if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err)
- }
- if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
- t.Fatalf("Should have %s in the list of addresses", addr1)
- }
-
- // Receive a PI where the generated address will be the same as the one
- // that we already have assigned statically.
- const lifetimeSeconds = 1
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly received an auto gen addr event for an address we already have statically")
- default:
- }
- if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
- t.Fatalf("Should have %s in the list of addresses", addr1)
- }
-
- // Should not get an invalidation event after the PI's invalidation
- // time.
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly received an auto gen addr event")
- case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
- }
- if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
- t.Fatalf("Should have %s in the list of addresses", addr1)
- }
-}
-
-// TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use
-// opaque interface identifiers when configured to do so.
-func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
- t.Parallel()
-
- const nicID = 1
- const nicName = "nic1"
- var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
- secretKey := secretKeyBuf[:]
- n, err := rand.Read(secretKey)
- if err != nil {
- t.Fatalf("rand.Read(_): %s", err)
- }
- if n != header.OpaqueIIDSecretKeyMinBytes {
- t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes)
- }
-
- prefix1, subnet1, _ := prefixSubnetAddr(0, linkAddr1)
- prefix2, subnet2, _ := prefixSubnetAddr(1, linkAddr1)
- // addr1 and addr2 are the addresses that are expected to be generated when
- // stack.Stack is configured to generate opaque interface identifiers as
- // defined by RFC 7217.
- addrBytes := []byte(subnet1.ID())
- addr1 := tcpip.AddressWithPrefix{
- Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet1, nicName, 0, secretKey)),
- PrefixLen: 64,
- }
- addrBytes = []byte(subnet2.ID())
- addr2 := tcpip.AddressWithPrefix{
- Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet2, nicName, 0, secretKey)),
- PrefixLen: 64,
- }
-
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: func(_ tcpip.NICID, nicName string) string {
- return nicName
- },
- SecretKey: secretKey,
- },
- })
- opts := stack.NICOptions{Name: nicName}
- if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
- t.Fatalf("CreateNICWithOptions(%d, _, %+v, _) = %s", nicID, opts, err)
- }
-
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
-
- // Receive an RA with prefix1 in a PI.
- const validLifetimeSecondPrefix1 = 1
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, validLifetimeSecondPrefix1, 0))
- expectAutoGenAddrEvent(addr1, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should have %s in the list of addresses", addr1)
- }
-
- // Receive an RA with prefix2 in a PI with a large valid lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
- expectAutoGenAddrEvent(addr2, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
-
- // Wait for addr of prefix1 to be invalidated.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncEventTimeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
- if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should not have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
-}
-
-// TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event
-// to the integrator when an RA is received with the NDP Recursive DNS Server
-// option with at least one valid address.
-func TestNDPRecursiveDNSServerDispatch(t *testing.T) {
- t.Parallel()
-
- tests := []struct {
- name string
- opt header.NDPRecursiveDNSServer
- expected *ndpRDNSS
- }{
- {
- "Unspecified",
- header.NDPRecursiveDNSServer([]byte{
- 0, 0,
- 0, 0, 0, 2,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- }),
- nil,
- },
- {
- "Multicast",
- header.NDPRecursiveDNSServer([]byte{
- 0, 0,
- 0, 0, 0, 2,
- 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
- }),
- nil,
- },
- {
- "OptionTooSmall",
- header.NDPRecursiveDNSServer([]byte{
- 0, 0,
- 0, 0, 0, 2,
- 1, 2, 3, 4, 5, 6, 7, 8,
- }),
- nil,
- },
- {
- "0Addresses",
- header.NDPRecursiveDNSServer([]byte{
- 0, 0,
- 0, 0, 0, 2,
- }),
- nil,
- },
- {
- "Valid1Address",
- header.NDPRecursiveDNSServer([]byte{
- 0, 0,
- 0, 0, 0, 2,
- 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1,
- }),
- &ndpRDNSS{
- []tcpip.Address{
- "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01",
- },
- 2 * time.Second,
- },
- },
- {
- "Valid2Addresses",
- header.NDPRecursiveDNSServer([]byte{
- 0, 0,
- 0, 0, 0, 1,
- 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1,
- 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2,
- }),
- &ndpRDNSS{
- []tcpip.Address{
- "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01",
- "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02",
- },
- time.Second,
- },
- },
- {
- "Valid3Addresses",
- header.NDPRecursiveDNSServer([]byte{
- 0, 0,
- 0, 0, 0, 0,
- 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1,
- 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2,
- 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 3,
- }),
- &ndpRDNSS{
- []tcpip.Address{
- "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01",
- "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02",
- "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x03",
- },
- 0,
- },
- },
- }
-
- for _, test := range tests {
- test := test
-
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
- ndpDisp := ndpDispatcher{
- // We do not expect more than a single RDNSS
- // event at any time for this test.
- rdnssC: make(chan ndpRDNSSEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- },
- NDPDisp: &ndpDisp,
- })
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, header.NDPOptionsSerializer{test.opt}))
-
- if test.expected != nil {
- select {
- case e := <-ndpDisp.rdnssC:
- if e.nicID != 1 {
- t.Errorf("got rdnss nicID = %d, want = 1", e.nicID)
- }
- if diff := cmp.Diff(e.rdnss.addrs, test.expected.addrs); diff != "" {
- t.Errorf("rdnss addrs mismatch (-want +got):\n%s", diff)
- }
- if e.rdnss.lifetime != test.expected.lifetime {
- t.Errorf("got rdnss lifetime = %s, want = %s", e.rdnss.lifetime, test.expected.lifetime)
- }
- default:
- t.Fatal("expected an RDNSS option event")
- }
- }
-
- // Should have no more RDNSS options.
- select {
- case e := <-ndpDisp.rdnssC:
- t.Fatalf("unexpectedly got a new RDNSS option event: %+v", e)
- default:
- }
- })
- }
-}
-
-// TestCleanupNDPState tests that all discovered routers and prefixes, and
-// auto-generated addresses are invalidated when a NIC becomes a router.
-func TestCleanupNDPState(t *testing.T) {
- t.Parallel()
-
- const (
- lifetimeSeconds = 5
- maxRouterAndPrefixEvents = 4
- nicID1 = 1
- nicID2 = 2
- )
-
- prefix1, subnet1, e1Addr1 := prefixSubnetAddr(0, linkAddr1)
- prefix2, subnet2, e1Addr2 := prefixSubnetAddr(1, linkAddr1)
- e2Addr1 := addrForSubnet(subnet1, linkAddr2)
- e2Addr2 := addrForSubnet(subnet2, linkAddr2)
- llAddrWithPrefix1 := tcpip.AddressWithPrefix{
- Address: llAddr1,
- PrefixLen: 64,
- }
- llAddrWithPrefix2 := tcpip.AddressWithPrefix{
- Address: llAddr2,
- PrefixLen: 64,
- }
-
- tests := []struct {
- name string
- cleanupFn func(t *testing.T, s *stack.Stack)
- keepAutoGenLinkLocal bool
- maxAutoGenAddrEvents int
- }{
- // A NIC should still keep its auto-generated link-local address when
- // becoming a router.
- {
- name: "Forwarding Enable",
- cleanupFn: func(t *testing.T, s *stack.Stack) {
- t.Helper()
- s.SetForwarding(true)
- },
- keepAutoGenLinkLocal: true,
- maxAutoGenAddrEvents: 4,
- },
-
- // A NIC should cleanup all NDP state when it is disabled.
- {
- name: "NIC Disable",
- cleanupFn: func(t *testing.T, s *stack.Stack) {
- t.Helper()
-
- if err := s.DisableNIC(nicID1); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
- }
- if err := s.DisableNIC(nicID2); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
- }
- },
- keepAutoGenLinkLocal: false,
- maxAutoGenAddrEvents: 6,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, maxRouterAndPrefixEvents),
- rememberRouter: true,
- prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents),
- rememberPrefix: true,
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents),
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- AutoGenIPv6LinkLocal: true,
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: true,
- DiscoverOnLinkPrefixes: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- })
-
- expectRouterEvent := func() (bool, ndpRouterEvent) {
- select {
- case e := <-ndpDisp.routerC:
- return true, e
- default:
- }
-
- return false, ndpRouterEvent{}
- }
-
- expectPrefixEvent := func() (bool, ndpPrefixEvent) {
- select {
- case e := <-ndpDisp.prefixC:
- return true, e
- default:
- }
-
- return false, ndpPrefixEvent{}
- }
-
- expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) {
- select {
- case e := <-ndpDisp.autoGenAddrC:
- return true, e
- default:
- }
-
- return false, ndpAutoGenAddrEvent{}
- }
-
- e1 := channel.New(0, 1280, linkAddr1)
- if err := s.CreateNIC(nicID1, e1); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err)
- }
- // We have other tests that make sure we receive the *correct* events
- // on normal discovery of routers/prefixes, and auto-generated
- // addresses. Here we just make sure we get an event and let other tests
- // handle the correctness check.
- expectAutoGenAddrEvent()
-
- e2 := channel.New(0, 1280, linkAddr2)
- if err := s.CreateNIC(nicID2, e2); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err)
- }
- expectAutoGenAddrEvent()
-
- // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and
- // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from
- // llAddr4) to discover multiple routers and prefixes, and auto-gen
- // multiple addresses.
-
- e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1)
- }
-
- e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1)
- }
-
- e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2)
- }
-
- e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2)
- }
-
- // We should have the auto-generated addresses added.
- nicinfo := s.NICInfo()
- nic1Addrs := nicinfo[nicID1].ProtocolAddresses
- nic2Addrs := nicinfo[nicID2].ProtocolAddresses
- if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
- }
- if !containsV6Addr(nic1Addrs, e1Addr1) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
- }
- if !containsV6Addr(nic1Addrs, e1Addr2) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
- }
- if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
- }
- if !containsV6Addr(nic2Addrs, e2Addr1) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
- }
- if !containsV6Addr(nic2Addrs, e2Addr2) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
- }
-
- // We can't proceed any further if we already failed the test (missing
- // some discovery/auto-generated address events or addresses).
- if t.Failed() {
- t.FailNow()
- }
-
- test.cleanupFn(t, s)
-
- // Collect invalidation events after having NDP state cleaned up.
- gotRouterEvents := make(map[ndpRouterEvent]int)
- for i := 0; i < maxRouterAndPrefixEvents; i++ {
- ok, e := expectRouterEvent()
- if !ok {
- t.Errorf("expected %d router events after becoming a router; got = %d", maxRouterAndPrefixEvents, i)
- break
- }
- gotRouterEvents[e]++
- }
- gotPrefixEvents := make(map[ndpPrefixEvent]int)
- for i := 0; i < maxRouterAndPrefixEvents; i++ {
- ok, e := expectPrefixEvent()
- if !ok {
- t.Errorf("expected %d prefix events after becoming a router; got = %d", maxRouterAndPrefixEvents, i)
- break
- }
- gotPrefixEvents[e]++
- }
- gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int)
- for i := 0; i < test.maxAutoGenAddrEvents; i++ {
- ok, e := expectAutoGenAddrEvent()
- if !ok {
- t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", test.maxAutoGenAddrEvents, i)
- break
- }
- gotAutoGenAddrEvents[e]++
- }
-
- // No need to proceed any further if we already failed the test (missing
- // some invalidation events).
- if t.Failed() {
- t.FailNow()
- }
-
- expectedRouterEvents := map[ndpRouterEvent]int{
- {nicID: nicID1, addr: llAddr3, discovered: false}: 1,
- {nicID: nicID1, addr: llAddr4, discovered: false}: 1,
- {nicID: nicID2, addr: llAddr3, discovered: false}: 1,
- {nicID: nicID2, addr: llAddr4, discovered: false}: 1,
- }
- if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" {
- t.Errorf("router events mismatch (-want +got):\n%s", diff)
- }
- expectedPrefixEvents := map[ndpPrefixEvent]int{
- {nicID: nicID1, prefix: subnet1, discovered: false}: 1,
- {nicID: nicID1, prefix: subnet2, discovered: false}: 1,
- {nicID: nicID2, prefix: subnet1, discovered: false}: 1,
- {nicID: nicID2, prefix: subnet2, discovered: false}: 1,
- }
- if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" {
- t.Errorf("prefix events mismatch (-want +got):\n%s", diff)
- }
- expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{
- {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1,
- {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1,
- {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1,
- {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1,
- }
-
- if !test.keepAutoGenLinkLocal {
- expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID1, addr: llAddrWithPrefix1, eventType: invalidatedAddr}] = 1
- expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID2, addr: llAddrWithPrefix2, eventType: invalidatedAddr}] = 1
- }
-
- if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" {
- t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff)
- }
-
- // Make sure the auto-generated addresses got removed.
- nicinfo = s.NICInfo()
- nic1Addrs = nicinfo[nicID1].ProtocolAddresses
- nic2Addrs = nicinfo[nicID2].ProtocolAddresses
- if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal {
- if test.keepAutoGenLinkLocal {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
- } else {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
- }
- }
- if containsV6Addr(nic1Addrs, e1Addr1) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
- }
- if containsV6Addr(nic1Addrs, e1Addr2) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
- }
- if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal {
- if test.keepAutoGenLinkLocal {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
- } else {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
- }
- }
- if containsV6Addr(nic2Addrs, e2Addr1) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
- }
- if containsV6Addr(nic2Addrs, e2Addr2) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
- }
-
- // Should not get any more events (invalidation timers should have been
- // cancelled when the NDP state was cleaned up).
- time.Sleep(lifetimeSeconds*time.Second + defaultTimeout)
- select {
- case <-ndpDisp.routerC:
- t.Error("unexpected router event")
- default:
- }
- select {
- case <-ndpDisp.prefixC:
- t.Error("unexpected prefix event")
- default:
- }
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Error("unexpected auto-generated address event")
- default:
- }
- })
- }
-}
-
-// TestDHCPv6ConfigurationFromNDPDA tests that the NDPDispatcher is properly
-// informed when new information about what configurations are available via
-// DHCPv6 is learned.
-func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
- const nicID = 1
-
- ndpDisp := ndpDispatcher{
- dhcpv6ConfigurationC: make(chan ndpDHCPv6Event, 1),
- rememberRouter: true,
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- },
- NDPDisp: &ndpDisp,
- })
-
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- expectDHCPv6Event := func(configuration stack.DHCPv6ConfigurationFromNDPRA) {
- t.Helper()
- select {
- case e := <-ndpDisp.dhcpv6ConfigurationC:
- if diff := cmp.Diff(ndpDHCPv6Event{nicID: nicID, configuration: configuration}, e, cmp.AllowUnexported(e)); diff != "" {
- t.Errorf("dhcpv6 event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected DHCPv6 configuration event")
- }
- }
-
- expectNoDHCPv6Event := func() {
- t.Helper()
- select {
- case <-ndpDisp.dhcpv6ConfigurationC:
- t.Fatal("unexpected DHCPv6 configuration event")
- default:
- }
- }
-
- // The initial DHCPv6 configuration should be stack.DHCPv6NoConfiguration.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
- expectNoDHCPv6Event()
-
- // Receive an RA that updates the DHCPv6 configuration to Other
- // Configurations.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
- expectDHCPv6Event(stack.DHCPv6OtherConfigurations)
- // Receiving the same update again should not result in an event to the
- // NDPDispatcher.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
- expectNoDHCPv6Event()
-
- // Receive an RA that updates the DHCPv6 configuration to Managed Address.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
- expectDHCPv6Event(stack.DHCPv6ManagedAddress)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
- expectNoDHCPv6Event()
-
- // Receive an RA that updates the DHCPv6 configuration to none.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
- expectDHCPv6Event(stack.DHCPv6NoConfiguration)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
- expectNoDHCPv6Event()
-
- // Receive an RA that updates the DHCPv6 configuration to Managed Address.
- //
- // Note, when the M flag is set, the O flag is redundant.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
- expectDHCPv6Event(stack.DHCPv6ManagedAddress)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
- expectNoDHCPv6Event()
- // Even though the DHCPv6 flags are different, the effective configuration is
- // the same so we should not receive a new event.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
- expectNoDHCPv6Event()
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
- expectNoDHCPv6Event()
-
- // Receive an RA that updates the DHCPv6 configuration to Other
- // Configurations.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
- expectDHCPv6Event(stack.DHCPv6OtherConfigurations)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
- expectNoDHCPv6Event()
-}
-
-// TestRouterSolicitation tests the initial Router Solicitations that are sent
-// when a NIC newly becomes enabled.
-func TestRouterSolicitation(t *testing.T) {
- t.Parallel()
-
- tests := []struct {
- name string
- maxRtrSolicit uint8
- rtrSolicitInt time.Duration
- effectiveRtrSolicitInt time.Duration
- maxRtrSolicitDelay time.Duration
- effectiveMaxRtrSolicitDelay time.Duration
- }{
- {
- name: "Single RS with delay",
- maxRtrSolicit: 1,
- rtrSolicitInt: time.Second,
- effectiveRtrSolicitInt: time.Second,
- maxRtrSolicitDelay: time.Second,
- effectiveMaxRtrSolicitDelay: time.Second,
- },
- {
- name: "Two RS with delay",
- maxRtrSolicit: 2,
- rtrSolicitInt: time.Second,
- effectiveRtrSolicitInt: time.Second,
- maxRtrSolicitDelay: 500 * time.Millisecond,
- effectiveMaxRtrSolicitDelay: 500 * time.Millisecond,
- },
- {
- name: "Single RS without delay",
- maxRtrSolicit: 1,
- rtrSolicitInt: time.Second,
- effectiveRtrSolicitInt: time.Second,
- maxRtrSolicitDelay: 0,
- effectiveMaxRtrSolicitDelay: 0,
- },
- {
- name: "Two RS without delay and invalid zero interval",
- maxRtrSolicit: 2,
- rtrSolicitInt: 0,
- effectiveRtrSolicitInt: 4 * time.Second,
- maxRtrSolicitDelay: 0,
- effectiveMaxRtrSolicitDelay: 0,
- },
- {
- name: "Three RS without delay",
- maxRtrSolicit: 3,
- rtrSolicitInt: 500 * time.Millisecond,
- effectiveRtrSolicitInt: 500 * time.Millisecond,
- maxRtrSolicitDelay: 0,
- effectiveMaxRtrSolicitDelay: 0,
- },
- {
- name: "Two RS with invalid negative delay",
- maxRtrSolicit: 2,
- rtrSolicitInt: time.Second,
- effectiveRtrSolicitInt: time.Second,
- maxRtrSolicitDelay: -3 * time.Second,
- effectiveMaxRtrSolicitDelay: time.Second,
- },
- }
-
- // This Run will not return until the parallel tests finish.
- //
- // We need this because we need to do some teardown work after the
- // parallel tests complete.
- //
- // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
- // more details.
- t.Run("group", func(t *testing.T) {
- for _, test := range tests {
- test := test
-
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
- e := channel.New(int(test.maxRtrSolicit), 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- waitForPkt := func(timeout time.Duration) {
- t.Helper()
- ctx, _ := context.WithTimeout(context.Background(), timeout)
- p, ok := e.ReadContext(ctx)
- if !ok {
- t.Fatal("timed out waiting for packet")
- return
- }
-
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
-
- // Make sure the right remote link address is used.
- if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want {
- t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
- }
-
- checker.IPv6(t,
- p.Pkt.Header.View(),
- checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS(),
- )
- }
- waitForNothing := func(timeout time.Duration) {
- t.Helper()
- ctx, _ := context.WithTimeout(context.Background(), timeout)
- if _, ok := e.ReadContext(ctx); ok {
- t.Fatal("unexpectedly got a packet")
- }
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- MaxRtrSolicitations: test.maxRtrSolicit,
- RtrSolicitationInterval: test.rtrSolicitInt,
- MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
- },
- })
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Make sure each RS got sent at the right
- // times.
- remaining := test.maxRtrSolicit
- if remaining > 0 {
- waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout)
- remaining--
- }
- for ; remaining > 0; remaining-- {
- waitForNothing(test.effectiveRtrSolicitInt - defaultTimeout)
- waitForPkt(defaultAsyncEventTimeout)
- }
-
- // Make sure no more RS.
- if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
- waitForNothing(test.effectiveRtrSolicitInt + defaultTimeout)
- } else {
- waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultTimeout)
- }
-
- // Make sure the counter got properly
- // incremented.
- if got, want := s.Stats().ICMP.V6PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
- t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
- }
- })
- }
- })
-}
-
-func TestStopStartSolicitingRouters(t *testing.T) {
- t.Parallel()
-
- const nicID = 1
- const interval = 500 * time.Millisecond
- const delay = time.Second
- const maxRtrSolicitations = 3
-
- tests := []struct {
- name string
- startFn func(t *testing.T, s *stack.Stack)
- stopFn func(t *testing.T, s *stack.Stack)
- }{
- // Tests that when forwarding is enabled or disabled, router solicitations
- // are stopped or started, respectively.
- {
- name: "Forwarding enabled and disabled",
- startFn: func(t *testing.T, s *stack.Stack) {
- t.Helper()
- s.SetForwarding(false)
- },
- stopFn: func(t *testing.T, s *stack.Stack) {
- t.Helper()
- s.SetForwarding(true)
- },
- },
-
- // Tests that when a NIC is enabled or disabled, router solicitations
- // are started or stopped, respectively.
- {
- name: "NIC disabled and enabled",
- startFn: func(t *testing.T, s *stack.Stack) {
- t.Helper()
-
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
- }
- },
- stopFn: func(t *testing.T, s *stack.Stack) {
- t.Helper()
-
- if err := s.DisableNIC(nicID); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
- }
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e := channel.New(maxRtrSolicitations, 1280, linkAddr1)
- waitForPkt := func(timeout time.Duration) {
- t.Helper()
-
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- p, ok := e.ReadContext(ctx)
- if !ok {
- t.Fatal("timed out waiting for packet")
- return
- }
-
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
- checker.IPv6(t, p.Pkt.Header.View(),
- checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS())
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- MaxRtrSolicitations: maxRtrSolicitations,
- RtrSolicitationInterval: interval,
- MaxRtrSolicitationDelay: delay,
- },
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- // Stop soliciting routers.
- test.stopFn(t, s)
- ctx, cancel := context.WithTimeout(context.Background(), delay+defaultTimeout)
- defer cancel()
- if _, ok := e.ReadContext(ctx); ok {
- // A single RS may have been sent before forwarding was enabled.
- ctx, cancel := context.WithTimeout(context.Background(), interval+defaultTimeout)
- defer cancel()
- if _, ok = e.ReadContext(ctx); ok {
- t.Fatal("should not have sent more than one RS message")
- }
- }
-
- // Stopping router solicitations after it has already been stopped should
- // do nothing.
- test.stopFn(t, s)
- ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout)
- defer cancel()
- if _, ok := e.ReadContext(ctx); ok {
- t.Fatal("unexpectedly got a packet after router solicitation has been stopepd")
- }
-
- // Start soliciting routers.
- test.startFn(t, s)
- waitForPkt(delay + defaultAsyncEventTimeout)
- waitForPkt(interval + defaultAsyncEventTimeout)
- waitForPkt(interval + defaultAsyncEventTimeout)
- ctx, cancel = context.WithTimeout(context.Background(), interval+defaultTimeout)
- defer cancel()
- if _, ok := e.ReadContext(ctx); ok {
- t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
- }
-
- // Starting router solicitations after it has already completed should do
- // nothing.
- test.startFn(t, s)
- ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout)
- defer cancel()
- if _, ok := e.ReadContext(ctx); ok {
- t.Fatal("unexpectedly got a packet after finishing router solicitations")
- }
- })
- }
-}
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
deleted file mode 100644
index edaee3b86..000000000
--- a/pkg/tcpip/stack/nic_test.go
+++ /dev/null
@@ -1,62 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package stack
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
-)
-
-func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
- // When the NIC is disabled, the only field that matters is the stats field.
- // This test is limited to stats counter checks.
- nic := NIC{
- stats: makeNICStats(),
- }
-
- if got := nic.stats.DisabledRx.Packets.Value(); got != 0 {
- t.Errorf("got DisabledRx.Packets = %d, want = 0", got)
- }
- if got := nic.stats.DisabledRx.Bytes.Value(); got != 0 {
- t.Errorf("got DisabledRx.Bytes = %d, want = 0", got)
- }
- if got := nic.stats.Rx.Packets.Value(); got != 0 {
- t.Errorf("got Rx.Packets = %d, want = 0", got)
- }
- if got := nic.stats.Rx.Bytes.Value(); got != 0 {
- t.Errorf("got Rx.Bytes = %d, want = 0", got)
- }
-
- if t.Failed() {
- t.FailNow()
- }
-
- nic.DeliverNetworkPacket(nil, "", "", 0, tcpip.PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()})
-
- if got := nic.stats.DisabledRx.Packets.Value(); got != 1 {
- t.Errorf("got DisabledRx.Packets = %d, want = 1", got)
- }
- if got := nic.stats.DisabledRx.Bytes.Value(); got != 4 {
- t.Errorf("got DisabledRx.Bytes = %d, want = 4", got)
- }
- if got := nic.stats.Rx.Packets.Value(); got != 0 {
- t.Errorf("got Rx.Packets = %d, want = 0", got)
- }
- if got := nic.stats.Rx.Bytes.Value(); got != 0 {
- t.Errorf("got Rx.Bytes = %d, want = 0", got)
- }
-}
diff --git a/pkg/tcpip/stack/stack_state_autogen.go b/pkg/tcpip/stack/stack_state_autogen.go
new file mode 100755
index 000000000..f66a8dd75
--- /dev/null
+++ b/pkg/tcpip/stack/stack_state_autogen.go
@@ -0,0 +1,131 @@
+// automatically generated by stateify.
+
+package stack
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (x *linkAddrEntryList) beforeSave() {}
+func (x *linkAddrEntryList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *linkAddrEntryList) afterLoad() {}
+func (x *linkAddrEntryList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *linkAddrEntryEntry) beforeSave() {}
+func (x *linkAddrEntryEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *linkAddrEntryEntry) afterLoad() {}
+func (x *linkAddrEntryEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func (x *TransportEndpointID) beforeSave() {}
+func (x *TransportEndpointID) save(m state.Map) {
+ x.beforeSave()
+ m.Save("LocalPort", &x.LocalPort)
+ m.Save("LocalAddress", &x.LocalAddress)
+ m.Save("RemotePort", &x.RemotePort)
+ m.Save("RemoteAddress", &x.RemoteAddress)
+}
+
+func (x *TransportEndpointID) afterLoad() {}
+func (x *TransportEndpointID) load(m state.Map) {
+ m.Load("LocalPort", &x.LocalPort)
+ m.Load("LocalAddress", &x.LocalAddress)
+ m.Load("RemotePort", &x.RemotePort)
+ m.Load("RemoteAddress", &x.RemoteAddress)
+}
+
+func (x *GSOType) save(m state.Map) {
+ m.SaveValue("", (int)(*x))
+}
+
+func (x *GSOType) load(m state.Map) {
+ m.LoadValue("", new(int), func(y interface{}) { *x = (GSOType)(y.(int)) })
+}
+
+func (x *GSO) beforeSave() {}
+func (x *GSO) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Type", &x.Type)
+ m.Save("NeedsCsum", &x.NeedsCsum)
+ m.Save("CsumOffset", &x.CsumOffset)
+ m.Save("MSS", &x.MSS)
+ m.Save("L3HdrLen", &x.L3HdrLen)
+ m.Save("MaxSize", &x.MaxSize)
+}
+
+func (x *GSO) afterLoad() {}
+func (x *GSO) load(m state.Map) {
+ m.Load("Type", &x.Type)
+ m.Load("NeedsCsum", &x.NeedsCsum)
+ m.Load("CsumOffset", &x.CsumOffset)
+ m.Load("MSS", &x.MSS)
+ m.Load("L3HdrLen", &x.L3HdrLen)
+ m.Load("MaxSize", &x.MaxSize)
+}
+
+func (x *TransportEndpointInfo) beforeSave() {}
+func (x *TransportEndpointInfo) save(m state.Map) {
+ x.beforeSave()
+ m.Save("NetProto", &x.NetProto)
+ m.Save("TransProto", &x.TransProto)
+ m.Save("ID", &x.ID)
+ m.Save("BindNICID", &x.BindNICID)
+ m.Save("BindAddr", &x.BindAddr)
+ m.Save("RegisterNICID", &x.RegisterNICID)
+}
+
+func (x *TransportEndpointInfo) afterLoad() {}
+func (x *TransportEndpointInfo) load(m state.Map) {
+ m.Load("NetProto", &x.NetProto)
+ m.Load("TransProto", &x.TransProto)
+ m.Load("ID", &x.ID)
+ m.Load("BindNICID", &x.BindNICID)
+ m.Load("BindAddr", &x.BindAddr)
+ m.Load("RegisterNICID", &x.RegisterNICID)
+}
+
+func (x *multiPortEndpoint) beforeSave() {}
+func (x *multiPortEndpoint) save(m state.Map) {
+ x.beforeSave()
+ m.Save("demux", &x.demux)
+ m.Save("netProto", &x.netProto)
+ m.Save("transProto", &x.transProto)
+ m.Save("endpointsArr", &x.endpointsArr)
+ m.Save("endpointsMap", &x.endpointsMap)
+ m.Save("reuse", &x.reuse)
+}
+
+func (x *multiPortEndpoint) afterLoad() {}
+func (x *multiPortEndpoint) load(m state.Map) {
+ m.Load("demux", &x.demux)
+ m.Load("netProto", &x.netProto)
+ m.Load("transProto", &x.transProto)
+ m.Load("endpointsArr", &x.endpointsArr)
+ m.Load("endpointsMap", &x.endpointsMap)
+ m.Load("reuse", &x.reuse)
+}
+
+func init() {
+ state.Register("pkg/tcpip/stack.linkAddrEntryList", (*linkAddrEntryList)(nil), state.Fns{Save: (*linkAddrEntryList).save, Load: (*linkAddrEntryList).load})
+ state.Register("pkg/tcpip/stack.linkAddrEntryEntry", (*linkAddrEntryEntry)(nil), state.Fns{Save: (*linkAddrEntryEntry).save, Load: (*linkAddrEntryEntry).load})
+ state.Register("pkg/tcpip/stack.TransportEndpointID", (*TransportEndpointID)(nil), state.Fns{Save: (*TransportEndpointID).save, Load: (*TransportEndpointID).load})
+ state.Register("pkg/tcpip/stack.GSOType", (*GSOType)(nil), state.Fns{Save: (*GSOType).save, Load: (*GSOType).load})
+ state.Register("pkg/tcpip/stack.GSO", (*GSO)(nil), state.Fns{Save: (*GSO).save, Load: (*GSO).load})
+ state.Register("pkg/tcpip/stack.TransportEndpointInfo", (*TransportEndpointInfo)(nil), state.Fns{Save: (*TransportEndpointInfo).save, Load: (*TransportEndpointInfo).load})
+ state.Register("pkg/tcpip/stack.multiPortEndpoint", (*multiPortEndpoint)(nil), state.Fns{Save: (*multiPortEndpoint).save, Load: (*multiPortEndpoint).load})
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
deleted file mode 100644
index eb6f7d1fc..000000000
--- a/pkg/tcpip/stack/stack_test.go
+++ /dev/null
@@ -1,3089 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package stack_test contains tests for the stack. It is in its own package so
-// that the tests can also validate that all definitions needed to implement
-// transport and network protocols are properly exported by the stack package.
-package stack_test
-
-import (
- "bytes"
- "fmt"
- "math"
- "sort"
- "strings"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/rand"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
-)
-
-const (
- fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
- fakeNetHeaderLen = 12
- fakeDefaultPrefixLen = 8
-
- // fakeControlProtocol is used for control packets that represent
- // destination port unreachable.
- fakeControlProtocol tcpip.TransportProtocolNumber = 2
-
- // defaultMTU is the MTU, in bytes, used throughout the tests, except
- // where another value is explicitly used. It is chosen to match the MTU
- // of loopback interfaces on linux systems.
- defaultMTU = 65536
-)
-
-// fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and
-// received packets; the counts of all endpoints are aggregated in the protocol
-// descriptor.
-//
-// Headers of this protocol are fakeNetHeaderLen bytes, but we currently only
-// use the first three: destination address, source address, and transport
-// protocol. They're all one byte fields to simplify parsing.
-type fakeNetworkEndpoint struct {
- nicID tcpip.NICID
- id stack.NetworkEndpointID
- prefixLen int
- proto *fakeNetworkProtocol
- dispatcher stack.TransportDispatcher
- ep stack.LinkEndpoint
-}
-
-func (f *fakeNetworkEndpoint) MTU() uint32 {
- return f.ep.MTU() - uint32(f.MaxHeaderLength())
-}
-
-func (f *fakeNetworkEndpoint) NICID() tcpip.NICID {
- return f.nicID
-}
-
-func (f *fakeNetworkEndpoint) PrefixLen() int {
- return f.prefixLen
-}
-
-func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
- return 123
-}
-
-func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID {
- return &f.id
-}
-
-func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
- // Increment the received packet count in the protocol descriptor.
- f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++
-
- // Consume the network header.
- b := pkt.Data.First()
- pkt.Data.TrimFront(fakeNetHeaderLen)
-
- // Handle control packets.
- if b[2] == uint8(fakeControlProtocol) {
- nb := pkt.Data.First()
- if len(nb) < fakeNetHeaderLen {
- return
- }
-
- pkt.Data.TrimFront(fakeNetHeaderLen)
- f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt)
- return
- }
-
- // Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt)
-}
-
-func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
- return f.ep.MaxHeaderLength() + fakeNetHeaderLen
-}
-
-func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
- return 0
-}
-
-func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
- return f.ep.Capabilities()
-}
-
-func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error {
- // Increment the sent packet count in the protocol descriptor.
- f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
-
- // Add the protocol's header to the packet and send it to the link
- // endpoint.
- b := pkt.Header.Prepend(fakeNetHeaderLen)
- b[0] = r.RemoteAddress[0]
- b[1] = f.id.LocalAddress[0]
- b[2] = byte(params.Protocol)
-
- if r.Loop&stack.PacketLoop != 0 {
- views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
- views[0] = pkt.Header.View()
- views = append(views, pkt.Data.Views()...)
- f.HandlePacket(r, tcpip.PacketBuffer{
- Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
- })
- }
- if r.Loop&stack.PacketOut == 0 {
- return nil
- }
-
- return f.ep.WritePacket(r, gso, fakeNetNumber, pkt)
-}
-
-// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
- panic("not implemented")
-}
-
-func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error {
- return tcpip.ErrNotSupported
-}
-
-func (*fakeNetworkEndpoint) Close() {}
-
-type fakeNetGoodOption bool
-
-type fakeNetBadOption bool
-
-type fakeNetInvalidValueOption int
-
-type fakeNetOptions struct {
- good bool
-}
-
-// fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the
-// number of packets sent and received via endpoints of this protocol. The index
-// where packets are added is given by the packet's destination address MOD 10.
-type fakeNetworkProtocol struct {
- packetCount [10]int
- sendPacketCount [10]int
- opts fakeNetOptions
-}
-
-func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
- return fakeNetNumber
-}
-
-func (f *fakeNetworkProtocol) MinimumPacketSize() int {
- return fakeNetHeaderLen
-}
-
-func (f *fakeNetworkProtocol) DefaultPrefixLen() int {
- return fakeDefaultPrefixLen
-}
-
-func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int {
- return f.packetCount[int(intfAddr)%len(f.packetCount)]
-}
-
-func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
- return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
-}
-
-func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
- return &fakeNetworkEndpoint{
- nicID: nicID,
- id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
- prefixLen: addrWithPrefix.PrefixLen,
- proto: f,
- dispatcher: dispatcher,
- ep: ep,
- }, nil
-}
-
-func (f *fakeNetworkProtocol) SetOption(option interface{}) *tcpip.Error {
- switch v := option.(type) {
- case fakeNetGoodOption:
- f.opts.good = bool(v)
- return nil
- case fakeNetInvalidValueOption:
- return tcpip.ErrInvalidOptionValue
- default:
- return tcpip.ErrUnknownProtocolOption
- }
-}
-
-func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error {
- switch v := option.(type) {
- case *fakeNetGoodOption:
- *v = fakeNetGoodOption(f.opts.good)
- return nil
- default:
- return tcpip.ErrUnknownProtocolOption
- }
-}
-
-func fakeNetFactory() stack.NetworkProtocol {
- return &fakeNetworkProtocol{}
-}
-
-func TestNetworkReceive(t *testing.T) {
- // Create a stack with the fake network protocol, one nic, and two
- // addresses attached to it: 1 & 2.
- ep := channel.New(10, defaultMTU, "")
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
- }
-
- if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress failed:", err)
- }
-
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
- buf := buffer.NewView(30)
-
- // Make sure packet with wrong address is not delivered.
- buf[0] = 3
- ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
- if fakeNet.packetCount[1] != 0 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
- }
- if fakeNet.packetCount[2] != 0 {
- t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0)
- }
-
- // Make sure packet is delivered to first endpoint.
- buf[0] = 1
- ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
- if fakeNet.packetCount[2] != 0 {
- t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0)
- }
-
- // Make sure packet is delivered to second endpoint.
- buf[0] = 2
- ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
- if fakeNet.packetCount[2] != 1 {
- t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1)
- }
-
- // Make sure packet is not delivered if protocol number is wrong.
- ep.InjectInbound(fakeNetNumber-1, tcpip.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
- if fakeNet.packetCount[2] != 1 {
- t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1)
- }
-
- // Make sure packet that is too small is dropped.
- buf.CapLength(2)
- ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
- if fakeNet.packetCount[2] != 1 {
- t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1)
- }
-}
-
-func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Error {
- r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- return err
- }
- defer r.Release()
- return send(r, payload)
-}
-
-func send(r stack.Route, payload buffer.View) *tcpip.Error {
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
- Header: hdr,
- Data: payload.ToVectorisedView(),
- })
-}
-
-func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) {
- t.Helper()
- ep.Drain()
- if err := sendTo(s, addr, payload); err != nil {
- t.Error("sendTo failed:", err)
- }
- if got, want := ep.Drain(), 1; got != want {
- t.Errorf("sendTo packet count: got = %d, want %d", got, want)
- }
-}
-
-func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) {
- t.Helper()
- ep.Drain()
- if err := send(r, payload); err != nil {
- t.Error("send failed:", err)
- }
- if got, want := ep.Drain(), 1; got != want {
- t.Errorf("send packet count: got = %d, want %d", got, want)
- }
-}
-
-func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
- t.Helper()
- if gotErr := send(r, payload); gotErr != wantErr {
- t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr)
- }
-}
-
-func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
- t.Helper()
- if gotErr := sendTo(s, addr, payload); gotErr != wantErr {
- t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr)
- }
-}
-
-func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) {
- t.Helper()
- // testRecvInternal injects one packet, and we expect to receive it.
- want := fakeNet.PacketCount(localAddrByte) + 1
- testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want)
-}
-
-func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) {
- t.Helper()
- // testRecvInternal injects one packet, and we do NOT expect to receive it.
- want := fakeNet.PacketCount(localAddrByte)
- testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want)
-}
-
-func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) {
- t.Helper()
- ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
- if got := fakeNet.PacketCount(localAddrByte); got != want {
- t.Errorf("receive packet count: got = %d, want %d", got, want)
- }
-}
-
-func TestNetworkSend(t *testing.T) {
- // Create a stack with the fake network protocol, one nic, and one
- // address: 1. The route table sends all packets through the only
- // existing nic.
- ep := channel.New(10, defaultMTU, "")
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("NewNIC failed:", err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
- }
-
- // Make sure that the link-layer endpoint received the outbound packet.
- testSendTo(t, s, "\x03", ep, nil)
-}
-
-func TestNetworkSendMultiRoute(t *testing.T) {
- // Create a stack with the fake network protocol, two nics, and two
- // addresses per nic, the first nic has odd address, the second one has
- // even addresses.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
-
- ep1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep1); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
- }
-
- if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil {
- t.Fatal("AddAddress failed:", err)
- }
-
- ep2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, ep2); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress failed:", err)
- }
-
- if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil {
- t.Fatal("AddAddress failed:", err)
- }
-
- // Set a route table that sends all packets with odd destination
- // addresses through the first NIC, and all even destination address
- // through the second one.
- {
- subnet0, err := tcpip.NewSubnet("\x00", "\x01")
- if err != nil {
- t.Fatal(err)
- }
- subnet1, err := tcpip.NewSubnet("\x01", "\x01")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{
- {Destination: subnet1, Gateway: "\x00", NIC: 1},
- {Destination: subnet0, Gateway: "\x00", NIC: 2},
- })
- }
-
- // Send a packet to an odd destination.
- testSendTo(t, s, "\x05", ep1, nil)
-
- // Send a packet to an even destination.
- testSendTo(t, s, "\x06", ep2, nil)
-}
-
-func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) {
- r, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatal("FindRoute failed:", err)
- }
-
- defer r.Release()
-
- if r.LocalAddress != expectedSrcAddr {
- t.Fatalf("Bad source address: expected %v, got %v", expectedSrcAddr, r.LocalAddress)
- }
-
- if r.RemoteAddress != dstAddr {
- t.Fatalf("Bad destination address: expected %v, got %v", dstAddr, r.RemoteAddress)
- }
-}
-
-func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) {
- _, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
- if err != tcpip.ErrNoRoute {
- t.Fatalf("FindRoute returned unexpected error, got = %v, want = %s", err, tcpip.ErrNoRoute)
- }
-}
-
-func TestDisableUnknownNIC(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
-
- if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID {
- t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID)
- }
-}
-
-func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) {
- const nicID = 1
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
-
- e := loopback.New()
- nicOpts := stack.NICOptions{Disabled: true}
- if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
- t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err)
- }
-
- checkNIC := func(enabled bool) {
- t.Helper()
-
- allNICInfo := s.NICInfo()
- nicInfo, ok := allNICInfo[nicID]
- if !ok {
- t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo)
- } else if nicInfo.Flags.Running != enabled {
- t.Errorf("got nicInfo.Flags.Running = %t, want = %t", nicInfo.Flags.Running, enabled)
- }
-
- if got := s.CheckNIC(nicID); got != enabled {
- t.Errorf("got s.CheckNIC(%d) = %t, want = %t", nicID, got, enabled)
- }
- }
-
- // NIC should initially report itself as disabled.
- checkNIC(false)
-
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
- }
- checkNIC(true)
-
- // If the NIC is not reporting a correct enabled status, we cannot trust the
- // next check so end the test here.
- if t.Failed() {
- t.FailNow()
- }
-
- if err := s.DisableNIC(nicID); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
- }
- checkNIC(false)
-}
-
-func TestRoutesWithDisabledNIC(t *testing.T) {
- const unspecifiedNIC = 0
- const nicID1 = 1
- const nicID2 = 2
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
-
- ep1 := channel.New(0, defaultMTU, "")
- if err := s.CreateNIC(nicID1, ep1); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
- }
-
- addr1 := tcpip.Address("\x01")
- if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
- }
-
- ep2 := channel.New(0, defaultMTU, "")
- if err := s.CreateNIC(nicID2, ep2); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
- }
-
- addr2 := tcpip.Address("\x02")
- if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
- }
-
- // Set a route table that sends all packets with odd destination
- // addresses through the first NIC, and all even destination address
- // through the second one.
- {
- subnet0, err := tcpip.NewSubnet("\x00", "\x01")
- if err != nil {
- t.Fatal(err)
- }
- subnet1, err := tcpip.NewSubnet("\x01", "\x01")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{
- {Destination: subnet1, Gateway: "\x00", NIC: nicID1},
- {Destination: subnet0, Gateway: "\x00", NIC: nicID2},
- })
- }
-
- // Test routes to odd address.
- testRoute(t, s, unspecifiedNIC, "", "\x05", addr1)
- testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1)
- testRoute(t, s, nicID1, addr1, "\x05", addr1)
-
- // Test routes to even address.
- testRoute(t, s, unspecifiedNIC, "", "\x06", addr2)
- testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2)
- testRoute(t, s, nicID2, addr2, "\x06", addr2)
-
- // Disabling NIC1 should result in no routes to odd addresses. Routes to even
- // addresses should continue to be available as NIC2 is still enabled.
- if err := s.DisableNIC(nicID1); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
- }
- nic1Dst := tcpip.Address("\x05")
- testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
- testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
- testNoRoute(t, s, nicID1, addr1, nic1Dst)
- nic2Dst := tcpip.Address("\x06")
- testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2)
- testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2)
- testRoute(t, s, nicID2, addr2, nic2Dst, addr2)
-
- // Disabling NIC2 should result in no routes to even addresses. No route
- // should be available to any address as routes to odd addresses were made
- // unavailable by disabling NIC1 above.
- if err := s.DisableNIC(nicID2); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
- }
- testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
- testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
- testNoRoute(t, s, nicID1, addr1, nic1Dst)
- testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
- testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
- testNoRoute(t, s, nicID2, addr2, nic2Dst)
-
- // Enabling NIC1 should make routes to odd addresses available again. Routes
- // to even addresses should continue to be unavailable as NIC2 is still
- // disabled.
- if err := s.EnableNIC(nicID1); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID1, err)
- }
- testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1)
- testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1)
- testRoute(t, s, nicID1, addr1, nic1Dst, addr1)
- testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
- testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
- testNoRoute(t, s, nicID2, addr2, nic2Dst)
-}
-
-func TestRouteWritePacketWithDisabledNIC(t *testing.T) {
- const unspecifiedNIC = 0
- const nicID1 = 1
- const nicID2 = 2
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
-
- ep1 := channel.New(1, defaultMTU, "")
- if err := s.CreateNIC(nicID1, ep1); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
- }
-
- addr1 := tcpip.Address("\x01")
- if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
- }
-
- ep2 := channel.New(1, defaultMTU, "")
- if err := s.CreateNIC(nicID2, ep2); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
- }
-
- addr2 := tcpip.Address("\x02")
- if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
- }
-
- // Set a route table that sends all packets with odd destination
- // addresses through the first NIC, and all even destination address
- // through the second one.
- {
- subnet0, err := tcpip.NewSubnet("\x00", "\x01")
- if err != nil {
- t.Fatal(err)
- }
- subnet1, err := tcpip.NewSubnet("\x01", "\x01")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{
- {Destination: subnet1, Gateway: "\x00", NIC: nicID1},
- {Destination: subnet0, Gateway: "\x00", NIC: nicID2},
- })
- }
-
- nic1Dst := tcpip.Address("\x05")
- r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err)
- }
- defer r1.Release()
-
- nic2Dst := tcpip.Address("\x06")
- r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err)
- }
- defer r2.Release()
-
- // If we failed to get routes r1 or r2, we cannot proceed with the test.
- if t.Failed() {
- t.FailNow()
- }
-
- buf := buffer.View([]byte{1})
- testSend(t, r1, ep1, buf)
- testSend(t, r2, ep2, buf)
-
- // Writes with Routes that use the disabled NIC1 should fail.
- if err := s.DisableNIC(nicID1); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
- }
- testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
- testSend(t, r2, ep2, buf)
-
- // Writes with Routes that use the disabled NIC2 should fail.
- if err := s.DisableNIC(nicID2); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
- }
- testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
- testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
-
- // Writes with Routes that use the re-enabled NIC1 should succeed.
- // TODO(b/147015577): Should we instead completely invalidate all Routes that
- // were bound to a disabled NIC at some point?
- if err := s.EnableNIC(nicID1); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID1, err)
- }
- testSend(t, r1, ep1, buf)
- testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
-}
-
-func TestRoutes(t *testing.T) {
- // Create a stack with the fake network protocol, two nics, and two
- // addresses per nic, the first nic has odd address, the second one has
- // even addresses.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
-
- ep1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep1); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
- }
-
- if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil {
- t.Fatal("AddAddress failed:", err)
- }
-
- ep2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, ep2); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress failed:", err)
- }
-
- if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil {
- t.Fatal("AddAddress failed:", err)
- }
-
- // Set a route table that sends all packets with odd destination
- // addresses through the first NIC, and all even destination address
- // through the second one.
- {
- subnet0, err := tcpip.NewSubnet("\x00", "\x01")
- if err != nil {
- t.Fatal(err)
- }
- subnet1, err := tcpip.NewSubnet("\x01", "\x01")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{
- {Destination: subnet1, Gateway: "\x00", NIC: 1},
- {Destination: subnet0, Gateway: "\x00", NIC: 2},
- })
- }
-
- // Test routes to odd address.
- testRoute(t, s, 0, "", "\x05", "\x01")
- testRoute(t, s, 0, "\x01", "\x05", "\x01")
- testRoute(t, s, 1, "\x01", "\x05", "\x01")
- testRoute(t, s, 0, "\x03", "\x05", "\x03")
- testRoute(t, s, 1, "\x03", "\x05", "\x03")
-
- // Test routes to even address.
- testRoute(t, s, 0, "", "\x06", "\x02")
- testRoute(t, s, 0, "\x02", "\x06", "\x02")
- testRoute(t, s, 2, "\x02", "\x06", "\x02")
- testRoute(t, s, 0, "\x04", "\x06", "\x04")
- testRoute(t, s, 2, "\x04", "\x06", "\x04")
-
- // Try to send to odd numbered address from even numbered ones, then
- // vice-versa.
- testNoRoute(t, s, 0, "\x02", "\x05")
- testNoRoute(t, s, 2, "\x02", "\x05")
- testNoRoute(t, s, 0, "\x04", "\x05")
- testNoRoute(t, s, 2, "\x04", "\x05")
-
- testNoRoute(t, s, 0, "\x01", "\x06")
- testNoRoute(t, s, 1, "\x01", "\x06")
- testNoRoute(t, s, 0, "\x03", "\x06")
- testNoRoute(t, s, 1, "\x03", "\x06")
-}
-
-func TestAddressRemoval(t *testing.T) {
- const localAddrByte byte = 0x01
- localAddr := tcpip.Address([]byte{localAddrByte})
- remoteAddr := tcpip.Address("\x02")
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
-
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
- }
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
- buf := buffer.NewView(30)
-
- // Send and receive packets, and verify they are received.
- buf[0] = localAddrByte
- testRecv(t, fakeNet, localAddrByte, ep, buf)
- testSendTo(t, s, remoteAddr, ep, nil)
-
- // Remove the address, then check that send/receive doesn't work anymore.
- if err := s.RemoveAddress(1, localAddr); err != nil {
- t.Fatal("RemoveAddress failed:", err)
- }
- testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
- testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
-
- // Check that removing the same address fails.
- if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress {
- t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress)
- }
-}
-
-func TestAddressRemovalWithRouteHeld(t *testing.T) {
- const localAddrByte byte = 0x01
- localAddr := tcpip.Address([]byte{localAddrByte})
- remoteAddr := tcpip.Address("\x02")
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
-
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
- buf := buffer.NewView(30)
-
- if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
- }
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatal("FindRoute failed:", err)
- }
-
- // Send and receive packets, and verify they are received.
- buf[0] = localAddrByte
- testRecv(t, fakeNet, localAddrByte, ep, buf)
- testSend(t, r, ep, nil)
- testSendTo(t, s, remoteAddr, ep, nil)
-
- // Remove the address, then check that send/receive doesn't work anymore.
- if err := s.RemoveAddress(1, localAddr); err != nil {
- t.Fatal("RemoveAddress failed:", err)
- }
- testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
- testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState)
- testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
-
- // Check that removing the same address fails.
- if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress {
- t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress)
- }
-}
-
-func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.Address) {
- t.Helper()
- info, ok := s.NICInfo()[nicID]
- if !ok {
- t.Fatalf("NICInfo() failed to find nicID=%d", nicID)
- }
- if len(addr) == 0 {
- // No address given, verify that there is no address assigned to the NIC.
- for _, a := range info.ProtocolAddresses {
- if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) {
- t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{}))
- }
- }
- return
- }
- // Address given, verify the address is assigned to the NIC and no other
- // address is.
- found := false
- for _, a := range info.ProtocolAddresses {
- if a.Protocol == fakeNetNumber {
- if a.AddressWithPrefix.Address == addr {
- found = true
- } else {
- t.Errorf("verify address: got = %s, want = %s", a.AddressWithPrefix.Address, addr)
- }
- }
- }
- if !found {
- t.Errorf("verify address: couldn't find %s on the NIC", addr)
- }
-}
-
-func TestEndpointExpiration(t *testing.T) {
- const (
- localAddrByte byte = 0x01
- remoteAddr tcpip.Address = "\x03"
- noAddr tcpip.Address = ""
- nicID tcpip.NICID = 1
- )
- localAddr := tcpip.Address([]byte{localAddrByte})
-
- for _, promiscuous := range []bool{true, false} {
- for _, spoofing := range []bool{true, false} {
- t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
-
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
- buf := buffer.NewView(30)
- buf[0] = localAddrByte
-
- if promiscuous {
- if err := s.SetPromiscuousMode(nicID, true); err != nil {
- t.Fatal("SetPromiscuousMode failed:", err)
- }
- }
-
- if spoofing {
- if err := s.SetSpoofing(nicID, true); err != nil {
- t.Fatal("SetSpoofing failed:", err)
- }
- }
-
- // 1. No Address yet, send should only work for spoofing, receive for
- // promiscuous mode.
- //-----------------------
- verifyAddress(t, s, nicID, noAddr)
- if promiscuous {
- testRecv(t, fakeNet, localAddrByte, ep, buf)
- } else {
- testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
- }
- if spoofing {
- // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
- // testSendTo(t, s, remoteAddr, ep, nil)
- } else {
- testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
- }
-
- // 2. Add Address, everything should work.
- //-----------------------
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
- }
- verifyAddress(t, s, nicID, localAddr)
- testRecv(t, fakeNet, localAddrByte, ep, buf)
- testSendTo(t, s, remoteAddr, ep, nil)
-
- // 3. Remove the address, send should only work for spoofing, receive
- // for promiscuous mode.
- //-----------------------
- if err := s.RemoveAddress(nicID, localAddr); err != nil {
- t.Fatal("RemoveAddress failed:", err)
- }
- verifyAddress(t, s, nicID, noAddr)
- if promiscuous {
- testRecv(t, fakeNet, localAddrByte, ep, buf)
- } else {
- testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
- }
- if spoofing {
- // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
- // testSendTo(t, s, remoteAddr, ep, nil)
- } else {
- testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
- }
-
- // 4. Add Address back, everything should work again.
- //-----------------------
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
- }
- verifyAddress(t, s, nicID, localAddr)
- testRecv(t, fakeNet, localAddrByte, ep, buf)
- testSendTo(t, s, remoteAddr, ep, nil)
-
- // 5. Take a reference to the endpoint by getting a route. Verify that
- // we can still send/receive, including sending using the route.
- //-----------------------
- r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatal("FindRoute failed:", err)
- }
- testRecv(t, fakeNet, localAddrByte, ep, buf)
- testSendTo(t, s, remoteAddr, ep, nil)
- testSend(t, r, ep, nil)
-
- // 6. Remove the address. Send should only work for spoofing, receive
- // for promiscuous mode.
- //-----------------------
- if err := s.RemoveAddress(nicID, localAddr); err != nil {
- t.Fatal("RemoveAddress failed:", err)
- }
- verifyAddress(t, s, nicID, noAddr)
- if promiscuous {
- testRecv(t, fakeNet, localAddrByte, ep, buf)
- } else {
- testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
- }
- if spoofing {
- testSend(t, r, ep, nil)
- testSendTo(t, s, remoteAddr, ep, nil)
- } else {
- testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState)
- testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
- }
-
- // 7. Add Address back, everything should work again.
- //-----------------------
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
- }
- verifyAddress(t, s, nicID, localAddr)
- testRecv(t, fakeNet, localAddrByte, ep, buf)
- testSendTo(t, s, remoteAddr, ep, nil)
- testSend(t, r, ep, nil)
-
- // 8. Remove the route, sendTo/recv should still work.
- //-----------------------
- r.Release()
- verifyAddress(t, s, nicID, localAddr)
- testRecv(t, fakeNet, localAddrByte, ep, buf)
- testSendTo(t, s, remoteAddr, ep, nil)
-
- // 9. Remove the address. Send should only work for spoofing, receive
- // for promiscuous mode.
- //-----------------------
- if err := s.RemoveAddress(nicID, localAddr); err != nil {
- t.Fatal("RemoveAddress failed:", err)
- }
- verifyAddress(t, s, nicID, noAddr)
- if promiscuous {
- testRecv(t, fakeNet, localAddrByte, ep, buf)
- } else {
- testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
- }
- if spoofing {
- // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
- // testSendTo(t, s, remoteAddr, ep, nil)
- } else {
- testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
- }
- })
- }
- }
-}
-
-func TestPromiscuousMode(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
-
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
- buf := buffer.NewView(30)
-
- // Write a packet, and check that it doesn't get delivered as we don't
- // have a matching endpoint.
- const localAddrByte byte = 0x01
- buf[0] = localAddrByte
- testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
-
- // Set promiscuous mode, then check that packet is delivered.
- if err := s.SetPromiscuousMode(1, true); err != nil {
- t.Fatal("SetPromiscuousMode failed:", err)
- }
- testRecv(t, fakeNet, localAddrByte, ep, buf)
-
- // Check that we can't get a route as there is no local address.
- _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */)
- if err != tcpip.ErrNoRoute {
- t.Fatalf("FindRoute returned unexpected error: got = %v, want = %s", err, tcpip.ErrNoRoute)
- }
-
- // Set promiscuous mode to false, then check that packet can't be
- // delivered anymore.
- if err := s.SetPromiscuousMode(1, false); err != nil {
- t.Fatal("SetPromiscuousMode failed:", err)
- }
- testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
-}
-
-func TestSpoofingWithAddress(t *testing.T) {
- localAddr := tcpip.Address("\x01")
- nonExistentLocalAddr := tcpip.Address("\x02")
- dstAddr := tcpip.Address("\x03")
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
-
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- // With address spoofing disabled, FindRoute does not permit an address
- // that was not added to the NIC to be used as the source.
- r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
- if err == nil {
- t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
- }
-
- // With address spoofing enabled, FindRoute permits any address to be used
- // as the source.
- if err := s.SetSpoofing(1, true); err != nil {
- t.Fatal("SetSpoofing failed:", err)
- }
- r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatal("FindRoute failed:", err)
- }
- if r.LocalAddress != nonExistentLocalAddr {
- t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr)
- }
- if r.RemoteAddress != dstAddr {
- t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr)
- }
- // Sending a packet works.
- testSendTo(t, s, dstAddr, ep, nil)
- testSend(t, r, ep, nil)
-
- // FindRoute should also work with a local address that exists on the NIC.
- r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatal("FindRoute failed:", err)
- }
- if r.LocalAddress != localAddr {
- t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr)
- }
- if r.RemoteAddress != dstAddr {
- t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr)
- }
- // Sending a packet using the route works.
- testSend(t, r, ep, nil)
-}
-
-func TestSpoofingNoAddress(t *testing.T) {
- nonExistentLocalAddr := tcpip.Address("\x01")
- dstAddr := tcpip.Address("\x02")
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
-
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- // With address spoofing disabled, FindRoute does not permit an address
- // that was not added to the NIC to be used as the source.
- r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
- if err == nil {
- t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
- }
- // Sending a packet fails.
- testFailingSendTo(t, s, dstAddr, ep, nil, tcpip.ErrNoRoute)
-
- // With address spoofing enabled, FindRoute permits any address to be used
- // as the source.
- if err := s.SetSpoofing(1, true); err != nil {
- t.Fatal("SetSpoofing failed:", err)
- }
- r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatal("FindRoute failed:", err)
- }
- if r.LocalAddress != nonExistentLocalAddr {
- t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr)
- }
- if r.RemoteAddress != dstAddr {
- t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr)
- }
- // Sending a packet works.
- // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
- // testSendTo(t, s, remoteAddr, ep, nil)
-}
-
-func verifyRoute(gotRoute, wantRoute stack.Route) error {
- if gotRoute.LocalAddress != wantRoute.LocalAddress {
- return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress)
- }
- if gotRoute.RemoteAddress != wantRoute.RemoteAddress {
- return fmt.Errorf("bad remote address: got %s, want = %s", gotRoute.RemoteAddress, wantRoute.RemoteAddress)
- }
- if gotRoute.RemoteLinkAddress != wantRoute.RemoteLinkAddress {
- return fmt.Errorf("bad remote link address: got %s, want = %s", gotRoute.RemoteLinkAddress, wantRoute.RemoteLinkAddress)
- }
- if gotRoute.NextHop != wantRoute.NextHop {
- return fmt.Errorf("bad next-hop address: got %s, want = %s", gotRoute.NextHop, wantRoute.NextHop)
- }
- return nil
-}
-
-func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
-
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
- s.SetRouteTable([]tcpip.Route{})
-
- // If there is no endpoint, it won't work.
- if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable {
- t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
- }
-
- protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}}
- if err := s.AddProtocolAddress(1, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %s) failed: %s", protoAddr, err)
- }
- r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
- }
- if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil {
- t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
- }
-
- // If the NIC doesn't exist, it won't work.
- if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable {
- t.Fatalf("got FindRoute(2, %s, %s, %d) = %s want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
- }
-}
-
-func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
- defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0}
- // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1.
- nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24}
- nic1Gateway := tcpip.Address("\xc0\xa8\x01\x01")
- // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1.
- nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24}
- nic2Gateway := tcpip.Address("\x0a\x0a\x0a\x01")
-
- // Create a new stack with two NICs.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatalf("CreateNIC failed: %s", err)
- }
- if err := s.CreateNIC(2, ep); err != nil {
- t.Fatalf("CreateNIC failed: %s", err)
- }
- nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr}
- if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %s) failed: %s", nic1ProtoAddr, err)
- }
-
- nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr}
- if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil {
- t.Fatalf("AddAddress(2, %s) failed: %s", nic2ProtoAddr, err)
- }
-
- // Set the initial route table.
- rt := []tcpip.Route{
- {Destination: nic1Addr.Subnet(), NIC: 1},
- {Destination: nic2Addr.Subnet(), NIC: 2},
- {Destination: defaultAddr.Subnet(), Gateway: nic2Gateway, NIC: 2},
- {Destination: defaultAddr.Subnet(), Gateway: nic1Gateway, NIC: 1},
- }
- s.SetRouteTable(rt)
-
- // When an interface is given, the route for a broadcast goes through it.
- r, err := s.FindRoute(1, nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
- }
- if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
- t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
- }
-
- // When an interface is not given, it consults the route table.
- // 1. Case: Using the default route.
- r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
- }
- if err := verifyRoute(r, stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
- t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err)
- }
-
- // 2. Case: Having an explicit route for broadcast will select that one.
- rt = append(
- []tcpip.Route{
- {Destination: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}.Subnet(), NIC: 1},
- },
- rt...,
- )
- s.SetRouteTable(rt)
- r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
- }
- if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
- t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err)
- }
-}
-
-func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
- for _, tc := range []struct {
- name string
- routeNeeded bool
- address tcpip.Address
- }{
- // IPv4 multicast address range: 224.0.0.0 - 239.255.255.255
- // <=> 0xe0.0x00.0x00.0x00 - 0xef.0xff.0xff.0xff
- {"IPv4 Multicast 1", false, "\xe0\x00\x00\x00"},
- {"IPv4 Multicast 2", false, "\xef\xff\xff\xff"},
- {"IPv4 Unicast 1", true, "\xdf\xff\xff\xff"},
- {"IPv4 Unicast 2", true, "\xf0\x00\x00\x00"},
- {"IPv4 Unicast 3", true, "\x00\x00\x00\x00"},
-
- // IPv6 multicast address is 0xff[8] + flags[4] + scope[4] + groupId[112]
- {"IPv6 Multicast 1", false, "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Multicast 2", false, "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Multicast 3", false, "\xff\x0f\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"},
-
- // IPv6 link-local address starts with fe80::/10.
- {"IPv6 Unicast Link-Local 1", false, "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Unicast Link-Local 2", false, "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"},
- {"IPv6 Unicast Link-Local 3", false, "\xfe\x80\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff"},
- {"IPv6 Unicast Link-Local 4", false, "\xfe\xbf\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Unicast Link-Local 5", false, "\xfe\xbf\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"},
-
- // IPv6 addresses that are neither multicast nor link-local.
- {"IPv6 Unicast Not Link-Local 1", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Unicast Not Link-Local 2", true, "\xf0\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"},
- {"IPv6 Unicast Not Link-local 3", true, "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Unicast Not Link-Local 4", true, "\xfe\xc0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Unicast Not Link-Local 5", true, "\xfe\xdf\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Unicast Not Link-Local 6", true, "\xfd\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- {"IPv6 Unicast Not Link-Local 7", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
- } {
- t.Run(tc.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
-
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- s.SetRouteTable([]tcpip.Route{})
-
- var anyAddr tcpip.Address
- if len(tc.address) == header.IPv4AddressSize {
- anyAddr = header.IPv4Any
- } else {
- anyAddr = header.IPv6Any
- }
-
- want := tcpip.ErrNetworkUnreachable
- if tc.routeNeeded {
- want = tcpip.ErrNoRoute
- }
-
- // If there is no endpoint, it won't work.
- if _, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); err != want {
- t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, want)
- }
-
- if err := s.AddAddress(1, fakeNetNumber, anyAddr); err != nil {
- t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, anyAddr, err)
- }
-
- if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded {
- // Route table is empty but we need a route, this should cause an error.
- if err != tcpip.ErrNoRoute {
- t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, tcpip.ErrNoRoute)
- }
- } else {
- if err != nil {
- t.Fatalf("FindRoute(1, %v, %v, %v) failed: %v", anyAddr, tc.address, fakeNetNumber, err)
- }
- if r.LocalAddress != anyAddr {
- t.Errorf("Bad local address: got %v, want = %v", r.LocalAddress, anyAddr)
- }
- if r.RemoteAddress != tc.address {
- t.Errorf("Bad remote address: got %v, want = %v", r.RemoteAddress, tc.address)
- }
- }
- // If the NIC doesn't exist, it won't work.
- if _, err := s.FindRoute(2, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); err != want {
- t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", anyAddr, tc.address, fakeNetNumber, err, want)
- }
- })
- }
-}
-
-// Add a range of addresses, then check that a packet is delivered.
-func TestAddressRangeAcceptsMatchingPacket(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
-
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
- buf := buffer.NewView(30)
-
- const localAddrByte byte = 0x01
- buf[0] = localAddrByte
- subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0"))
- if err != nil {
- t.Fatal("NewSubnet failed:", err)
- }
- if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil {
- t.Fatal("AddAddressRange failed:", err)
- }
-
- testRecv(t, fakeNet, localAddrByte, ep, buf)
-}
-
-func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, subnet tcpip.Subnet, rangeExists bool) {
- t.Helper()
-
- // Loop over all addresses and check them.
- numOfAddresses := 1 << uint(8-subnet.Prefix())
- if numOfAddresses < 1 || numOfAddresses > 255 {
- t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet)
- }
-
- addrBytes := []byte(subnet.ID())
- for i := 0; i < numOfAddresses; i++ {
- addr := tcpip.Address(addrBytes)
- wantNicID := nicID
- // The subnet and broadcast addresses are skipped.
- if !rangeExists || addr == subnet.ID() || addr == subnet.Broadcast() {
- wantNicID = 0
- }
- if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, addr); gotNicID != wantNicID {
- t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, addr, gotNicID, wantNicID)
- }
- addrBytes[0]++
- }
-
- // Trying the next address should always fail since it is outside the range.
- if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addrBytes)); gotNicID != 0 {
- t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addrBytes), gotNicID, 0)
- }
-}
-
-// Set a range of addresses, then remove it again, and check at each step that
-// CheckLocalAddress returns the correct NIC for each address or zero if not
-// existent.
-func TestCheckLocalAddressForSubnet(t *testing.T) {
- const nicID tcpip.NICID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
-
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID}})
- }
-
- subnet, err := tcpip.NewSubnet(tcpip.Address("\xa0"), tcpip.AddressMask("\xf0"))
- if err != nil {
- t.Fatal("NewSubnet failed:", err)
- }
-
- testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */)
-
- if err := s.AddAddressRange(nicID, fakeNetNumber, subnet); err != nil {
- t.Fatal("AddAddressRange failed:", err)
- }
-
- testNicForAddressRange(t, nicID, s, subnet, true /* rangeExists */)
-
- if err := s.RemoveAddressRange(nicID, subnet); err != nil {
- t.Fatal("RemoveAddressRange failed:", err)
- }
-
- testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */)
-}
-
-// Set a range of addresses, then send a packet to a destination outside the
-// range and then check it doesn't get delivered.
-func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
-
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
- buf := buffer.NewView(30)
-
- const localAddrByte byte = 0x01
- buf[0] = localAddrByte
- subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0"))
- if err != nil {
- t.Fatal("NewSubnet failed:", err)
- }
- if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil {
- t.Fatal("AddAddressRange failed:", err)
- }
- testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
-}
-
-func TestNetworkOptions(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- TransportProtocols: []stack.TransportProtocol{},
- })
-
- // Try an unsupported network protocol.
- if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol {
- t.Fatalf("SetNetworkProtocolOption(fakeNet2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err)
- }
-
- testCases := []struct {
- option interface{}
- wantErr *tcpip.Error
- verifier func(t *testing.T, p stack.NetworkProtocol)
- }{
- {fakeNetGoodOption(true), nil, func(t *testing.T, p stack.NetworkProtocol) {
- t.Helper()
- fakeNet := p.(*fakeNetworkProtocol)
- if fakeNet.opts.good != true {
- t.Fatalf("fakeNet.opts.good = false, want = true")
- }
- var v fakeNetGoodOption
- if err := s.NetworkProtocolOption(fakeNetNumber, &v); err != nil {
- t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) = %v, want = nil, where v is option %T", v, err)
- }
- if v != true {
- t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) returned v = %v, want = true", v)
- }
- }},
- {fakeNetBadOption(true), tcpip.ErrUnknownProtocolOption, nil},
- {fakeNetInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil},
- }
- for _, tc := range testCases {
- if got := s.SetNetworkProtocolOption(fakeNetNumber, tc.option); got != tc.wantErr {
- t.Errorf("s.SetNetworkProtocolOption(fakeNet, %v) = %v, want = %v", tc.option, got, tc.wantErr)
- }
- if tc.verifier != nil {
- tc.verifier(t, s.NetworkProtocolInstance(fakeNetNumber))
- }
- }
-}
-
-func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.Subnet) bool {
- ranges, ok := s.NICAddressRanges()[id]
- if !ok {
- return false
- }
- for _, r := range ranges {
- if r == addrRange {
- return true
- }
- }
- return false
-}
-
-func TestAddresRangeAddRemove(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- addr := tcpip.Address("\x01\x01\x01\x01")
- mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr)))
- addrRange, err := tcpip.NewSubnet(addr, mask)
- if err != nil {
- t.Fatal("NewSubnet failed:", err)
- }
-
- if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want {
- t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want)
- }
-
- if err := s.AddAddressRange(1, fakeNetNumber, addrRange); err != nil {
- t.Fatal("AddAddressRange failed:", err)
- }
-
- if got, want := stackContainsAddressRange(s, 1, addrRange), true; got != want {
- t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want)
- }
-
- if err := s.RemoveAddressRange(1, addrRange); err != nil {
- t.Fatal("RemoveAddressRange failed:", err)
- }
-
- if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want {
- t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want)
- }
-}
-
-func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
- for _, addrLen := range []int{4, 16} {
- t.Run(fmt.Sprintf("addrLen=%d", addrLen), func(t *testing.T) {
- for canBe := 0; canBe < 3; canBe++ {
- t.Run(fmt.Sprintf("canBe=%d", canBe), func(t *testing.T) {
- for never := 0; never < 3; never++ {
- t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
- // Insert <canBe> primary and <never> never-primary addresses.
- // Each one will add a network endpoint to the NIC.
- primaryAddrAdded := make(map[tcpip.AddressWithPrefix]struct{})
- for i := 0; i < canBe+never; i++ {
- var behavior stack.PrimaryEndpointBehavior
- if i < canBe {
- behavior = stack.CanBePrimaryEndpoint
- } else {
- behavior = stack.NeverPrimaryEndpoint
- }
- // Add an address and in case of a primary one include a
- // prefixLen.
- address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen))
- if behavior == stack.CanBePrimaryEndpoint {
- protocolAddress := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: address,
- PrefixLen: addrLen * 8,
- },
- }
- if err := s.AddProtocolAddressWithOptions(1, protocolAddress, behavior); err != nil {
- t.Fatal("AddProtocolAddressWithOptions failed:", err)
- }
- // Remember the address/prefix.
- primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{}
- } else {
- if err := s.AddAddressWithOptions(1, fakeNetNumber, address, behavior); err != nil {
- t.Fatal("AddAddressWithOptions failed:", err)
- }
- }
- }
- // Check that GetMainNICAddress returns an address if at least
- // one primary address was added. In that case make sure the
- // address/prefixLen matches what we added.
- gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber)
- if err != nil {
- t.Fatal("GetMainNICAddress failed:", err)
- }
- if len(primaryAddrAdded) == 0 {
- // No primary addresses present.
- if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr {
- t.Fatalf("GetMainNICAddress: got addr = %s, want = %s", gotAddr, wantAddr)
- }
- } else {
- // At least one primary address was added, verify the returned
- // address is in the list of primary addresses we added.
- if _, ok := primaryAddrAdded[gotAddr]; !ok {
- t.Fatalf("GetMainNICAddress: got = %s, want any in {%v}", gotAddr, primaryAddrAdded)
- }
- }
- })
- }
- })
- }
- })
- }
-}
-
-func TestGetMainNICAddressAddRemove(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- for _, tc := range []struct {
- name string
- address tcpip.Address
- prefixLen int
- }{
- {"IPv4", "\x01\x01\x01\x01", 24},
- {"IPv6", "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", 116},
- } {
- t.Run(tc.name, func(t *testing.T) {
- protocolAddress := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tc.address,
- PrefixLen: tc.prefixLen,
- },
- }
- if err := s.AddProtocolAddress(1, protocolAddress); err != nil {
- t.Fatal("AddProtocolAddress failed:", err)
- }
-
- // Check that we get the right initial address and prefix length.
- gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber)
- if err != nil {
- t.Fatal("GetMainNICAddress failed:", err)
- }
- if wantAddr := protocolAddress.AddressWithPrefix; gotAddr != wantAddr {
- t.Fatalf("got s.GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr)
- }
-
- if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil {
- t.Fatal("RemoveAddress failed:", err)
- }
-
- // Check that we get no address after removal.
- gotAddr, err = s.GetMainNICAddress(1, fakeNetNumber)
- if err != nil {
- t.Fatal("GetMainNICAddress failed:", err)
- }
- if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr {
- t.Fatalf("got GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr)
- }
- })
- }
-}
-
-// Simple network address generator. Good for 255 addresses.
-type addressGenerator struct{ cnt byte }
-
-func (g *addressGenerator) next(addrLen int) tcpip.Address {
- g.cnt++
- return tcpip.Address(bytes.Repeat([]byte{g.cnt}, addrLen))
-}
-
-func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.ProtocolAddress) {
- t.Helper()
-
- if len(gotAddresses) != len(expectedAddresses) {
- t.Fatalf("got len(addresses) = %d, want = %d", len(gotAddresses), len(expectedAddresses))
- }
-
- sort.Slice(gotAddresses, func(i, j int) bool {
- return gotAddresses[i].AddressWithPrefix.Address < gotAddresses[j].AddressWithPrefix.Address
- })
- sort.Slice(expectedAddresses, func(i, j int) bool {
- return expectedAddresses[i].AddressWithPrefix.Address < expectedAddresses[j].AddressWithPrefix.Address
- })
-
- for i, gotAddr := range gotAddresses {
- expectedAddr := expectedAddresses[i]
- if gotAddr != expectedAddr {
- t.Errorf("got address = %+v, wanted = %+v", gotAddr, expectedAddr)
- }
- }
-}
-
-func TestAddAddress(t *testing.T) {
- const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- var addrGen addressGenerator
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2)
- for _, addrLen := range []int{4, 16} {
- address := addrGen.next(addrLen)
- if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil {
- t.Fatalf("AddAddress(address=%s) failed: %s", address, err)
- }
- expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen},
- })
- }
-
- gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
-}
-
-func TestAddProtocolAddress(t *testing.T) {
- const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- var addrGen addressGenerator
- addrLenRange := []int{4, 16}
- prefixLenRange := []int{8, 13, 20, 32}
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange))
- for _, addrLen := range addrLenRange {
- for _, prefixLen := range prefixLenRange {
- protocolAddress := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addrGen.next(addrLen),
- PrefixLen: prefixLen,
- },
- }
- if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil {
- t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err)
- }
- expectedAddresses = append(expectedAddresses, protocolAddress)
- }
- }
-
- gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
-}
-
-func TestAddAddressWithOptions(t *testing.T) {
- const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- addrLenRange := []int{4, 16}
- behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint}
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange))
- var addrGen addressGenerator
- for _, addrLen := range addrLenRange {
- for _, behavior := range behaviorRange {
- address := addrGen.next(addrLen)
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil {
- t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err)
- }
- expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen},
- })
- }
- }
-
- gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
-}
-
-func TestAddProtocolAddressWithOptions(t *testing.T) {
- const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- addrLenRange := []int{4, 16}
- prefixLenRange := []int{8, 13, 20, 32}
- behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint}
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)*len(behaviorRange))
- var addrGen addressGenerator
- for _, addrLen := range addrLenRange {
- for _, prefixLen := range prefixLenRange {
- for _, behavior := range behaviorRange {
- protocolAddress := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addrGen.next(addrLen),
- PrefixLen: prefixLen,
- },
- }
- if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil {
- t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err)
- }
- expectedAddresses = append(expectedAddresses, protocolAddress)
- }
- }
- }
-
- gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
-}
-
-func TestCreateNICWithOptions(t *testing.T) {
- type callArgsAndExpect struct {
- nicID tcpip.NICID
- opts stack.NICOptions
- err *tcpip.Error
- }
-
- tests := []struct {
- desc string
- calls []callArgsAndExpect
- }{
- {
- desc: "DuplicateNICID",
- calls: []callArgsAndExpect{
- {
- nicID: tcpip.NICID(1),
- opts: stack.NICOptions{Name: "eth1"},
- err: nil,
- },
- {
- nicID: tcpip.NICID(1),
- opts: stack.NICOptions{Name: "eth2"},
- err: tcpip.ErrDuplicateNICID,
- },
- },
- },
- {
- desc: "DuplicateName",
- calls: []callArgsAndExpect{
- {
- nicID: tcpip.NICID(1),
- opts: stack.NICOptions{Name: "lo"},
- err: nil,
- },
- {
- nicID: tcpip.NICID(2),
- opts: stack.NICOptions{Name: "lo"},
- err: tcpip.ErrDuplicateNICID,
- },
- },
- },
- {
- desc: "Unnamed",
- calls: []callArgsAndExpect{
- {
- nicID: tcpip.NICID(1),
- opts: stack.NICOptions{},
- err: nil,
- },
- {
- nicID: tcpip.NICID(2),
- opts: stack.NICOptions{},
- err: nil,
- },
- },
- },
- {
- desc: "UnnamedDuplicateNICID",
- calls: []callArgsAndExpect{
- {
- nicID: tcpip.NICID(1),
- opts: stack.NICOptions{},
- err: nil,
- },
- {
- nicID: tcpip.NICID(1),
- opts: stack.NICOptions{},
- err: tcpip.ErrDuplicateNICID,
- },
- },
- },
- }
- for _, test := range tests {
- t.Run(test.desc, func(t *testing.T) {
- s := stack.New(stack.Options{})
- ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"))
- for _, call := range test.calls {
- if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want {
- t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want)
- }
- }
- })
- }
-}
-
-func TestNICStats(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
- ep1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep1); err != nil {
- t.Fatal("CreateNIC failed: ", err)
- }
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
- }
- // Route all packets for address \x01 to NIC 1.
- {
- subnet, err := tcpip.NewSubnet("\x01", "\xff")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- // Send a packet to address 1.
- buf := buffer.NewView(30)
- ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
- if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want {
- t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.NICInfo()[1].Stats.Rx.Bytes.Value(), uint64(len(buf)); got != want {
- t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want)
- }
-
- payload := buffer.NewView(10)
- // Write a packet out via the address for NIC 1
- if err := sendTo(s, "\x01", payload); err != nil {
- t.Fatal("sendTo failed: ", err)
- }
- want := uint64(ep1.Drain())
- if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want {
- t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want)
- }
-
- if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want {
- t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
- }
-}
-
-func TestNICForwarding(t *testing.T) {
- // Create a stack with the fake network protocol, two NICs, each with
- // an address.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
- s.SetForwarding(true)
-
- ep1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep1); err != nil {
- t.Fatal("CreateNIC #1 failed:", err)
- }
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress #1 failed:", err)
- }
-
- ep2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, ep2); err != nil {
- t.Fatal("CreateNIC #2 failed:", err)
- }
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress #2 failed:", err)
- }
-
- // Route all packets to address 3 to NIC 2.
- {
- subnet, err := tcpip.NewSubnet("\x03", "\xff")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 2}})
- }
-
- // Send a packet to address 3.
- buf := buffer.NewView(30)
- buf[0] = 3
- ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
-
- if _, ok := ep2.Read(); !ok {
- t.Fatal("Packet not forwarded")
- }
-
- // Test that forwarding increments Tx stats correctly.
- if got, want := s.NICInfo()[2].Stats.Tx.Packets.Value(), uint64(1); got != want {
- t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want)
- }
-
- if got, want := s.NICInfo()[2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want {
- t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
- }
-}
-
-// TestNICContextPreservation tests that you can read out via stack.NICInfo the
-// Context data you pass via NICContext.Context in stack.CreateNICWithOptions.
-func TestNICContextPreservation(t *testing.T) {
- var ctx *int
- tests := []struct {
- name string
- opts stack.NICOptions
- want stack.NICContext
- }{
- {
- "context_set",
- stack.NICOptions{Context: ctx},
- ctx,
- },
- {
- "context_not_set",
- stack.NICOptions{},
- nil,
- },
- }
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{})
- id := tcpip.NICID(1)
- ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"))
- if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil {
- t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err)
- }
- nicinfos := s.NICInfo()
- nicinfo, ok := nicinfos[id]
- if !ok {
- t.Fatalf("got nicinfos[%d] = _, %t, want _, true; nicinfos = %+v", id, ok, nicinfos)
- }
- if got, want := nicinfo.Context == test.want, true; got != want {
- t.Fatal("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want)
- }
- })
- }
-}
-
-// TestNICAutoGenLinkLocalAddr tests the auto-generation of IPv6 link-local
-// addresses.
-func TestNICAutoGenLinkLocalAddr(t *testing.T) {
- const nicID = 1
-
- var secretKey [header.OpaqueIIDSecretKeyMinBytes]byte
- n, err := rand.Read(secretKey[:])
- if err != nil {
- t.Fatalf("rand.Read(_): %s", err)
- }
- if n != header.OpaqueIIDSecretKeyMinBytes {
- t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", header.OpaqueIIDSecretKeyMinBytes, n)
- }
-
- nicNameFunc := func(_ tcpip.NICID, name string) string {
- return name
- }
-
- tests := []struct {
- name string
- nicName string
- autoGen bool
- linkAddr tcpip.LinkAddress
- iidOpts stack.OpaqueInterfaceIdentifierOptions
- shouldGen bool
- expectedAddr tcpip.Address
- }{
- {
- name: "Disabled",
- nicName: "nic1",
- autoGen: false,
- linkAddr: linkAddr1,
- shouldGen: false,
- },
- {
- name: "Disabled without OIID options",
- nicName: "nic1",
- autoGen: false,
- linkAddr: linkAddr1,
- iidOpts: stack.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: nicNameFunc,
- SecretKey: secretKey[:],
- },
- shouldGen: false,
- },
-
- // Tests for EUI64 based addresses.
- {
- name: "EUI64 Enabled",
- autoGen: true,
- linkAddr: linkAddr1,
- shouldGen: true,
- expectedAddr: header.LinkLocalAddr(linkAddr1),
- },
- {
- name: "EUI64 Empty MAC",
- autoGen: true,
- shouldGen: false,
- },
- {
- name: "EUI64 Invalid MAC",
- autoGen: true,
- linkAddr: "\x01\x02\x03",
- shouldGen: false,
- },
- {
- name: "EUI64 Multicast MAC",
- autoGen: true,
- linkAddr: "\x01\x02\x03\x04\x05\x06",
- shouldGen: false,
- },
- {
- name: "EUI64 Unspecified MAC",
- autoGen: true,
- linkAddr: "\x00\x00\x00\x00\x00\x00",
- shouldGen: false,
- },
-
- // Tests for Opaque IID based addresses.
- {
- name: "OIID Enabled",
- nicName: "nic1",
- autoGen: true,
- linkAddr: linkAddr1,
- iidOpts: stack.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: nicNameFunc,
- SecretKey: secretKey[:],
- },
- shouldGen: true,
- expectedAddr: header.LinkLocalAddrWithOpaqueIID("nic1", 0, secretKey[:]),
- },
- // These are all cases where we would not have generated a
- // link-local address if opaque IIDs were disabled.
- {
- name: "OIID Empty MAC and empty nicName",
- autoGen: true,
- iidOpts: stack.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: nicNameFunc,
- SecretKey: secretKey[:1],
- },
- shouldGen: true,
- expectedAddr: header.LinkLocalAddrWithOpaqueIID("", 0, secretKey[:1]),
- },
- {
- name: "OIID Invalid MAC",
- nicName: "test",
- autoGen: true,
- linkAddr: "\x01\x02\x03",
- iidOpts: stack.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: nicNameFunc,
- SecretKey: secretKey[:2],
- },
- shouldGen: true,
- expectedAddr: header.LinkLocalAddrWithOpaqueIID("test", 0, secretKey[:2]),
- },
- {
- name: "OIID Multicast MAC",
- nicName: "test2",
- autoGen: true,
- linkAddr: "\x01\x02\x03\x04\x05\x06",
- iidOpts: stack.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: nicNameFunc,
- SecretKey: secretKey[:3],
- },
- shouldGen: true,
- expectedAddr: header.LinkLocalAddrWithOpaqueIID("test2", 0, secretKey[:3]),
- },
- {
- name: "OIID Unspecified MAC and nil SecretKey",
- nicName: "test3",
- autoGen: true,
- linkAddr: "\x00\x00\x00\x00\x00\x00",
- iidOpts: stack.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: nicNameFunc,
- },
- shouldGen: true,
- expectedAddr: header.LinkLocalAddrWithOpaqueIID("test3", 0, nil),
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- AutoGenIPv6LinkLocal: test.autoGen,
- NDPDisp: &ndpDisp,
- OpaqueIIDOpts: test.iidOpts,
- }
-
- e := channel.New(0, 1280, test.linkAddr)
- s := stack.New(opts)
- nicOpts := stack.NICOptions{Name: test.nicName, Disabled: true}
- if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
- t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err)
- }
-
- // A new disabled NIC should not have any address, even if auto generation
- // was enabled.
- allStackAddrs := s.AllAddresses()
- allNICAddrs, ok := allStackAddrs[nicID]
- if !ok {
- t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
- }
- if l := len(allNICAddrs); l != 0 {
- t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
- }
-
- // Enabling the NIC should attempt auto-generation of a link-local
- // address.
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
- }
-
- var expectedMainAddr tcpip.AddressWithPrefix
- if test.shouldGen {
- expectedMainAddr = tcpip.AddressWithPrefix{
- Address: test.expectedAddr,
- PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen,
- }
-
- // Should have auto-generated an address and resolved immediately (DAD
- // is disabled).
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, expectedMainAddr, newAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- } else {
- // Should not have auto-generated an address.
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly auto-generated an address")
- default:
- }
- }
-
- gotMainAddr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
- }
- if gotMainAddr != expectedMainAddr {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", gotMainAddr, expectedMainAddr)
- }
- })
- }
-}
-
-// TestNoLinkLocalAutoGenForLoopbackNIC tests that IPv6 link-local addresses are
-// not auto-generated for loopback NICs.
-func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) {
- const nicID = 1
- const nicName = "nicName"
-
- tests := []struct {
- name string
- opaqueIIDOpts stack.OpaqueInterfaceIdentifierOptions
- }{
- {
- name: "IID From MAC",
- opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{},
- },
- {
- name: "Opaque IID",
- opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: func(_ tcpip.NICID, nicName string) string {
- return nicName
- },
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- AutoGenIPv6LinkLocal: true,
- OpaqueIIDOpts: test.opaqueIIDOpts,
- }
-
- e := loopback.New()
- s := stack.New(opts)
- nicOpts := stack.NICOptions{Name: nicName}
- if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
- t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err)
- }
-
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Errorf("got stack.GetMainNICAddress(%d, _) = %s, want = %s", nicID, addr, want)
- }
- })
- }
-}
-
-// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6
-// link-local addresses will only be assigned after the DAD process resolves.
-func TestNICAutoGenAddrDoesDAD(t *testing.T) {
- const nicID = 1
-
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent),
- }
- ndpConfigs := stack.DefaultNDPConfigurations()
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: ndpConfigs,
- AutoGenIPv6LinkLocal: true,
- NDPDisp: &ndpDisp,
- }
-
- e := channel.New(int(ndpConfigs.DupAddrDetectTransmits), 1280, linkAddr1)
- s := stack.New(opts)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- // Address should not be considered bound to the
- // NIC yet (DAD ongoing).
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
- }
-
- linkLocalAddr := header.LinkLocalAddr(linkAddr1)
-
- // Wait for DAD to resolve.
- select {
- case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
- // We should get a resolution event after 1s (default time to
- // resolve as per default NDP configurations). Waiting for that
- // resolution time + an extra 1s without a resolution event
- // means something is wrong.
- t.Fatal("timed out waiting for DAD resolution")
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, linkLocalAddr, true, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
- }
- }
- addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
- }
-}
-
-// TestNewPEB tests that a new PrimaryEndpointBehavior value (peb) is respected
-// when an address's kind gets "promoted" to permanent from permanentExpired.
-func TestNewPEBOnPromotionToPermanent(t *testing.T) {
- pebs := []stack.PrimaryEndpointBehavior{
- stack.NeverPrimaryEndpoint,
- stack.CanBePrimaryEndpoint,
- stack.FirstPrimaryEndpoint,
- }
-
- for _, pi := range pebs {
- for _, ps := range pebs {
- t.Run(fmt.Sprintf("%d-to-%d", pi, ps), func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
- ep1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep1); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- // Add a permanent address with initial
- // PrimaryEndpointBehavior (peb), pi. If pi is
- // NeverPrimaryEndpoint, the address should not
- // be returned by a call to GetMainNICAddress;
- // else, it should.
- if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", pi); err != nil {
- t.Fatal("AddAddressWithOptions failed:", err)
- }
- addr, err := s.GetMainNICAddress(1, fakeNetNumber)
- if err != nil {
- t.Fatal("s.GetMainNICAddress failed:", err)
- }
- if pi == stack.NeverPrimaryEndpoint {
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want)
-
- }
- } else if addr.Address != "\x01" {
- t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatalf("NewSubnet failed:", err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- // Take a route through the address so its ref
- // count gets incremented and does not actually
- // get deleted when RemoveAddress is called
- // below. This is because we want to test that a
- // new peb is respected when an address gets
- // "promoted" to permanent from a
- // permanentExpired kind.
- r, err := s.FindRoute(1, "\x01", "\x02", fakeNetNumber, false)
- if err != nil {
- t.Fatal("FindRoute failed:", err)
- }
- defer r.Release()
- if err := s.RemoveAddress(1, "\x01"); err != nil {
- t.Fatalf("RemoveAddress failed:", err)
- }
-
- //
- // At this point, the address should still be
- // known by the NIC, but have its
- // kind = permanentExpired.
- //
-
- // Add some other address with peb set to
- // FirstPrimaryEndpoint.
- if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x03", stack.FirstPrimaryEndpoint); err != nil {
- t.Fatal("AddAddressWithOptions failed:", err)
-
- }
-
- // Add back the address we removed earlier and
- // make sure the new peb was respected.
- // (The address should just be promoted now).
- if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", ps); err != nil {
- t.Fatal("AddAddressWithOptions failed:", err)
- }
- var primaryAddrs []tcpip.Address
- for _, pa := range s.NICInfo()[1].ProtocolAddresses {
- primaryAddrs = append(primaryAddrs, pa.AddressWithPrefix.Address)
- }
- var expectedList []tcpip.Address
- switch ps {
- case stack.FirstPrimaryEndpoint:
- expectedList = []tcpip.Address{
- "\x01",
- "\x03",
- }
- case stack.CanBePrimaryEndpoint:
- expectedList = []tcpip.Address{
- "\x03",
- "\x01",
- }
- case stack.NeverPrimaryEndpoint:
- expectedList = []tcpip.Address{
- "\x03",
- }
- }
- if !cmp.Equal(primaryAddrs, expectedList) {
- t.Fatalf("got NIC's primary addresses = %v, want = %v", primaryAddrs, expectedList)
- }
-
- // Once we remove the other address, if the new
- // peb, ps, was NeverPrimaryEndpoint, no address
- // should be returned by a call to
- // GetMainNICAddress; else, our original address
- // should be returned.
- if err := s.RemoveAddress(1, "\x03"); err != nil {
- t.Fatalf("RemoveAddress failed:", err)
- }
- addr, err = s.GetMainNICAddress(1, fakeNetNumber)
- if err != nil {
- t.Fatal("s.GetMainNICAddress failed:", err)
- }
- if ps == stack.NeverPrimaryEndpoint {
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want)
-
- }
- } else {
- if addr.Address != "\x01" {
- t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address)
- }
- }
- })
- }
- }
-}
-
-func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
- const (
- linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- nicID = 1
- )
-
- // Rule 3 is not tested here, and is instead tested by NDP's AutoGenAddr test.
- tests := []struct {
- name string
- nicAddrs []tcpip.Address
- connectAddr tcpip.Address
- expectedLocalAddr tcpip.Address
- }{
- // Test Rule 1 of RFC 6724 section 5.
- {
- name: "Same Global most preferred (last address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: globalAddr1,
- expectedLocalAddr: globalAddr1,
- },
- {
- name: "Same Global most preferred (first address)",
- nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
- connectAddr: globalAddr1,
- expectedLocalAddr: globalAddr1,
- },
- {
- name: "Same Link Local most preferred (last address)",
- nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1},
- connectAddr: linkLocalAddr1,
- expectedLocalAddr: linkLocalAddr1,
- },
- {
- name: "Same Link Local most preferred (first address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: linkLocalAddr1,
- expectedLocalAddr: linkLocalAddr1,
- },
- {
- name: "Same Unique Local most preferred (last address)",
- nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1},
- connectAddr: uniqueLocalAddr1,
- expectedLocalAddr: uniqueLocalAddr1,
- },
- {
- name: "Same Unique Local most preferred (first address)",
- nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
- connectAddr: uniqueLocalAddr1,
- expectedLocalAddr: uniqueLocalAddr1,
- },
-
- // Test Rule 2 of RFC 6724 section 5.
- {
- name: "Global most preferred (last address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: globalAddr2,
- expectedLocalAddr: globalAddr1,
- },
- {
- name: "Global most preferred (first address)",
- nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
- connectAddr: globalAddr2,
- expectedLocalAddr: globalAddr1,
- },
- {
- name: "Link Local most preferred (last address)",
- nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1},
- connectAddr: linkLocalAddr2,
- expectedLocalAddr: linkLocalAddr1,
- },
- {
- name: "Link Local most preferred (first address)",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
- connectAddr: linkLocalAddr2,
- expectedLocalAddr: linkLocalAddr1,
- },
- {
- name: "Unique Local most preferred (last address)",
- nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1},
- connectAddr: uniqueLocalAddr2,
- expectedLocalAddr: uniqueLocalAddr1,
- },
- {
- name: "Unique Local most preferred (first address)",
- nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
- connectAddr: uniqueLocalAddr2,
- expectedLocalAddr: uniqueLocalAddr1,
- },
-
- // Test returning the endpoint that is closest to the front when
- // candidate addresses are "equal" from the perspective of RFC 6724
- // section 5.
- {
- name: "Unique Local for Global",
- nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, uniqueLocalAddr2},
- connectAddr: globalAddr2,
- expectedLocalAddr: uniqueLocalAddr1,
- },
- {
- name: "Link Local for Global",
- nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2},
- connectAddr: globalAddr2,
- expectedLocalAddr: linkLocalAddr1,
- },
- {
- name: "Link Local for Unique Local",
- nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2},
- connectAddr: uniqueLocalAddr2,
- expectedLocalAddr: linkLocalAddr1,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
- })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- s.SetRouteTable([]tcpip.Route{{
- Destination: header.IPv6EmptySubnet,
- Gateway: llAddr3,
- NIC: nicID,
- }})
- s.AddLinkAddress(nicID, llAddr3, linkAddr3)
-
- for _, a := range test.nicAddrs {
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil {
- t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err)
- }
- }
-
- if t.Failed() {
- t.FailNow()
- }
-
- if got := addrForNewConnectionTo(t, s, tcpip.FullAddress{Addr: test.connectAddr, NIC: nicID, Port: 1234}); got != test.expectedLocalAddr {
- t.Errorf("got local address = %s, want = %s", got, test.expectedLocalAddr)
- }
- })
- }
-}
-
-func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) {
- const nicID = 1
-
- e := loopback.New()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- })
- nicOpts := stack.NICOptions{Disabled: true}
- if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
- t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
- }
-
- allStackAddrs := s.AllAddresses()
- allNICAddrs, ok := allStackAddrs[nicID]
- if !ok {
- t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
- }
- if l := len(allNICAddrs); l != 0 {
- t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
- }
-
- // Enabling the NIC should add the IPv4 broadcast address.
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
- }
- allStackAddrs = s.AllAddresses()
- allNICAddrs, ok = allStackAddrs[nicID]
- if !ok {
- t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
- }
- if l := len(allNICAddrs); l != 1 {
- t.Fatalf("got len(allNICAddrs) = %d, want = 1", l)
- }
- want := tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: header.IPv4Broadcast,
- PrefixLen: 32,
- },
- }
- if allNICAddrs[0] != want {
- t.Fatalf("got allNICAddrs[0] = %+v, want = %+v", allNICAddrs[0], want)
- }
-
- // Disabling the NIC should remove the IPv4 broadcast address.
- if err := s.DisableNIC(nicID); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
- }
- allStackAddrs = s.AllAddresses()
- allNICAddrs, ok = allStackAddrs[nicID]
- if !ok {
- t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
- }
- if l := len(allNICAddrs); l != 0 {
- t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
- }
-}
-
-func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) {
- const nicID = 1
-
- e := loopback.New()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- })
- nicOpts := stack.NICOptions{Disabled: true}
- if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
- t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
- }
-
- // Should not be in the IPv6 all-nodes multicast group yet because the NIC has
- // not been enabled yet.
- isInGroup, err := s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
- if err != nil {
- t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
- }
- if isInGroup {
- t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress)
- }
-
- // The all-nodes multicast group should be joined when the NIC is enabled.
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
- }
- isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
- if err != nil {
- t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
- }
- if !isInGroup {
- t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, header.IPv6AllNodesMulticastAddress)
- }
-
- // The all-nodes multicast group should be left when the NIC is disabled.
- if err := s.DisableNIC(nicID); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
- }
- isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
- if err != nil {
- t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
- }
- if isInGroup {
- t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress)
- }
-}
-
-// TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC
-// was disabled have DAD performed on them when the NIC is enabled.
-func TestDoDADWhenNICEnabled(t *testing.T) {
- t.Parallel()
-
- const dadTransmits = 1
- const retransmitTimer = time.Second
- const nicID = 1
-
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent),
- }
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- DupAddrDetectTransmits: dadTransmits,
- RetransmitTimer: retransmitTimer,
- },
- NDPDisp: &ndpDisp,
- }
-
- e := channel.New(dadTransmits, 1280, linkAddr1)
- s := stack.New(opts)
- nicOpts := stack.NICOptions{Disabled: true}
- if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
- t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
- }
-
- addr := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: llAddr1,
- PrefixLen: 128,
- },
- }
- if err := s.AddProtocolAddress(nicID, addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err)
- }
-
- // Address should be in the list of all addresses.
- if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
- t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
- }
-
- // Address should be tentative so it should not be a main address.
- got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); got != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want)
- }
-
- // Enabling the NIC should start DAD for the address.
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
- }
- if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
- t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
- }
-
- // Address should not be considered bound to the NIC yet (DAD ongoing).
- got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); got != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want)
- }
-
- // Wait for DAD to resolve.
- select {
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
- t.Fatal("timed out waiting for DAD resolution")
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, true, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
- }
- }
- if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
- t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
- }
- got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if got != addr.AddressWithPrefix {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix)
- }
-
- // Enabling the NIC again should be a no-op.
- if err := s.EnableNIC(nicID); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
- }
- if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
- t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
- }
- got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if got != addr.AddressWithPrefix {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix)
- }
-}
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
deleted file mode 100644
index 5e9237de9..000000000
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ /dev/null
@@ -1,348 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package stack_test
-
-import (
- "math"
- "math/rand"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
-
- stackAddr = "\x0a\x00\x00\x01"
- stackPort = 1234
- testPort = 4096
-)
-
-type testContext struct {
- t *testing.T
- linkEps map[tcpip.NICID]*channel.Endpoint
- s *stack.Stack
-
- ep tcpip.Endpoint
- wq waiter.Queue
-}
-
-func (c *testContext) cleanup() {
- if c.ep != nil {
- c.ep.Close()
- }
-}
-
-func (c *testContext) createV6Endpoint(v6only bool) {
- var err *tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
- }
-}
-
-// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs.
-func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
- linkEps := make(map[tcpip.NICID]*channel.Endpoint)
- for _, linkEpID := range linkEpIDs {
- channelEp := channel.New(256, mtu, "")
- if err := s.CreateNIC(linkEpID, channelEp); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
- linkEps[linkEpID] = channelEp
-
- if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress IPv4 failed: %v", err)
- }
-
- if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, stackV6Addr); err != nil {
- t.Fatalf("AddAddress IPv6 failed: %v", err)
- }
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: 1,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: 1,
- },
- })
-
- return &testContext{
- t: t,
- s: s,
- linkEps: linkEps,
- }
-}
-
-type headers struct {
- srcPort uint16
- dstPort uint16
-}
-
-func newPayload() []byte {
- b := make([]byte, 30+rand.Intn(100))
- for i := range b {
- b[i] = byte(rand.Intn(256))
- }
- return b
-}
-
-func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
- // Allocate a buffer for data and headers.
- buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
- copy(buf[len(buf)-len(payload):], payload)
-
- // Initialize the IP header.
- ip := header.IPv6(buf)
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: 65,
- SrcAddr: testV6Addr,
- DstAddr: stackV6Addr,
- })
-
- // Initialize the UDP header.
- u := header.UDP(buf[header.IPv6MinimumSize:])
- u.Encode(&header.UDPFields{
- SrcPort: h.srcPort,
- DstPort: h.dstPort,
- Length: uint16(header.UDPMinimumSize + len(payload)),
- })
-
- // Calculate the UDP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u)))
-
- // Calculate the UDP checksum and set it.
- xsum = header.Checksum(payload, xsum)
- u.SetChecksum(^u.CalculateChecksum(xsum))
-
- // Inject packet.
- c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
-}
-
-func TestTransportDemuxerRegister(t *testing.T) {
- for _, test := range []struct {
- name string
- proto tcpip.NetworkProtocolNumber
- want *tcpip.Error
- }{
- {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol},
- {"success", ipv4.ProtocolNumber, nil},
- } {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
- if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want {
- t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want)
- }
- })
- }
-}
-
-// TestReuseBindToDevice injects varied packets on input devices and checks that
-// the distribution of packets received matches expectations.
-func TestDistribution(t *testing.T) {
- type endpointSockopts struct {
- reuse int
- bindToDevice tcpip.NICID
- }
- for _, test := range []struct {
- name string
- // endpoints will received the inject packets.
- endpoints []endpointSockopts
- // wantedDistribution is the wanted ratio of packets received on each
- // endpoint for each NIC on which packets are injected.
- wantedDistributions map[tcpip.NICID][]float64
- }{
- {
- "BindPortReuse",
- // 5 endpoints that all have reuse set.
- []endpointSockopts{
- {1, 0},
- {1, 0},
- {1, 0},
- {1, 0},
- {1, 0},
- },
- map[tcpip.NICID][]float64{
- // Injected packets on dev0 get distributed evenly.
- 1: {0.2, 0.2, 0.2, 0.2, 0.2},
- },
- },
- {
- "BindToDevice",
- // 3 endpoints with various bindings.
- []endpointSockopts{
- {0, 1},
- {0, 2},
- {0, 3},
- },
- map[tcpip.NICID][]float64{
- // Injected packets on dev0 go only to the endpoint bound to dev0.
- 1: {1, 0, 0},
- // Injected packets on dev1 go only to the endpoint bound to dev1.
- 2: {0, 1, 0},
- // Injected packets on dev2 go only to the endpoint bound to dev2.
- 3: {0, 0, 1},
- },
- },
- {
- "ReuseAndBindToDevice",
- // 6 endpoints with various bindings.
- []endpointSockopts{
- {1, 1},
- {1, 1},
- {1, 2},
- {1, 2},
- {1, 2},
- {1, 0},
- },
- map[tcpip.NICID][]float64{
- // Injected packets on dev0 get distributed among endpoints bound to
- // dev0.
- 1: {0.5, 0.5, 0, 0, 0, 0},
- // Injected packets on dev1 get distributed among endpoints bound to
- // dev1 or unbound.
- 2: {0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
- // Injected packets on dev999 go only to the unbound.
- 1000: {0, 0, 0, 0, 0, 1},
- },
- },
- } {
- t.Run(test.name, func(t *testing.T) {
- for device, wantedDistribution := range test.wantedDistributions {
- t.Run(string(device), func(t *testing.T) {
- var devices []tcpip.NICID
- for d := range test.wantedDistributions {
- devices = append(devices, d)
- }
- c := newDualTestContextMultiNIC(t, defaultMTU, devices)
- defer c.cleanup()
-
- c.createV6Endpoint(false)
-
- eps := make(map[tcpip.Endpoint]int)
-
- pollChannel := make(chan tcpip.Endpoint)
- for i, endpoint := range test.endpoints {
- // Try to receive the data.
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
- defer close(ch)
-
- var err *tcpip.Error
- ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
- eps[ep] = i
-
- go func(ep tcpip.Endpoint) {
- for range ch {
- pollChannel <- ep
- }
- }(ep)
-
- defer ep.Close()
- reusePortOption := tcpip.ReusePortOption(endpoint.reuse)
- if err := ep.SetSockOpt(reusePortOption); err != nil {
- c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err)
- }
- bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
- if err := ep.SetSockOpt(bindToDeviceOption); err != nil {
- c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err)
- }
- if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil {
- t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err)
- }
- }
-
- npackets := 100000
- nports := 10000
- if got, want := len(test.endpoints), len(wantedDistribution); got != want {
- t.Fatalf("got len(test.endpoints) = %d, want %d", got, want)
- }
- ports := make(map[uint16]tcpip.Endpoint)
- stats := make(map[tcpip.Endpoint]int)
- for i := 0; i < npackets; i++ {
- // Send a packet.
- port := uint16(i % nports)
- payload := newPayload()
- c.sendV6Packet(payload,
- &headers{
- srcPort: testPort + port,
- dstPort: stackPort},
- device)
-
- var addr tcpip.FullAddress
- ep := <-pollChannel
- _, _, err := ep.Read(&addr)
- if err != nil {
- c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err)
- }
- stats[ep]++
- if i < nports {
- ports[uint16(i)] = ep
- } else {
- // Check that all packets from one client are handled by the same
- // socket.
- if want, got := ports[port], ep; want != got {
- t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got])
- }
- }
- }
-
- // Check that a packet distribution is as expected.
- for ep, i := range eps {
- wantedRatio := wantedDistribution[i]
- wantedRecv := wantedRatio * float64(npackets)
- actualRecv := stats[ep]
- actualRatio := float64(stats[ep]) / float64(npackets)
- // The deviation is less than 10%.
- if math.Abs(actualRatio-wantedRatio) > 0.05 {
- t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets)
- }
- }
- })
- }
- })
- }
-}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
deleted file mode 100644
index 869c69a6d..000000000
--- a/pkg/tcpip/stack/transport_test.go
+++ /dev/null
@@ -1,637 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package stack_test
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- fakeTransNumber tcpip.TransportProtocolNumber = 1
- fakeTransHeaderLen = 3
-)
-
-// fakeTransportEndpoint is a transport-layer protocol endpoint. It counts
-// received packets; the counts of all endpoints are aggregated in the protocol
-// descriptor.
-//
-// Headers of this protocol are fakeTransHeaderLen bytes, but we currently don't
-// use it.
-type fakeTransportEndpoint struct {
- stack.TransportEndpointInfo
- stack *stack.Stack
- proto *fakeTransportProtocol
- peerAddr tcpip.Address
- route stack.Route
- uniqueID uint64
-
- // acceptQueue is non-nil iff bound.
- acceptQueue []fakeTransportEndpoint
-}
-
-func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo {
- return &f.TransportEndpointInfo
-}
-
-func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats {
- return nil
-}
-
-func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
- return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
-}
-
-func (f *fakeTransportEndpoint) Close() {
- f.route.Release()
-}
-
-func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
- return mask
-}
-
-func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- return buffer.View{}, tcpip.ControlMessages{}, nil
-}
-
-func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
- if len(f.route.RemoteAddress) == 0 {
- return 0, nil, tcpip.ErrNoRoute
- }
-
- hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()))
- v, err := p.FullPayload()
- if err != nil {
- return 0, nil, err
- }
- if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
- Header: hdr,
- Data: buffer.View(v).ToVectorisedView(),
- }); err != nil {
- return 0, nil, err
- }
-
- return int64(len(v)), nil, nil
-}
-
-func (f *fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- return 0, tcpip.ControlMessages{}, nil
-}
-
-// SetSockOpt sets a socket option. Currently not supported.
-func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error {
- return tcpip.ErrInvalidEndpointState
-}
-
-// SetSockOptBool sets a socket option. Currently not supported.
-func (*fakeTransportEndpoint) SetSockOptBool(tcpip.SockOptBool, bool) *tcpip.Error {
- return tcpip.ErrInvalidEndpointState
-}
-
-// SetSockOptInt sets a socket option. Currently not supported.
-func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error {
- return tcpip.ErrInvalidEndpointState
-}
-
-// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
-func (*fakeTransportEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- return false, tcpip.ErrUnknownProtocolOption
-}
-
-// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
- return -1, tcpip.ErrUnknownProtocolOption
-}
-
-// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch opt.(type) {
- case tcpip.ErrorOption:
- return nil
- }
- return tcpip.ErrInvalidEndpointState
-}
-
-// Disconnect implements tcpip.Endpoint.Disconnect.
-func (*fakeTransportEndpoint) Disconnect() *tcpip.Error {
- return tcpip.ErrNotSupported
-}
-
-func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
- f.peerAddr = addr.Addr
-
- // Find the route.
- r, err := f.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- return tcpip.ErrNoRoute
- }
- defer r.Release()
-
- // Try to register so that we can start receiving packets.
- f.ID.RemoteAddress = addr.Addr
- err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, false /* reuse */, 0 /* bindToDevice */)
- if err != nil {
- return err
- }
-
- f.route = r.Clone()
-
- return nil
-}
-
-func (f *fakeTransportEndpoint) UniqueID() uint64 {
- return f.uniqueID
-}
-
-func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error {
- return nil
-}
-
-func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) *tcpip.Error {
- return nil
-}
-
-func (*fakeTransportEndpoint) Reset() {
-}
-
-func (*fakeTransportEndpoint) Listen(int) *tcpip.Error {
- return nil
-}
-
-func (f *fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
- if len(f.acceptQueue) == 0 {
- return nil, nil, nil
- }
- a := f.acceptQueue[0]
- f.acceptQueue = f.acceptQueue[1:]
- return &a, nil, nil
-}
-
-func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
- if err := f.stack.RegisterTransportEndpoint(
- a.NIC,
- []tcpip.NetworkProtocolNumber{fakeNetNumber},
- fakeTransNumber,
- stack.TransportEndpointID{LocalAddress: a.Addr},
- f,
- false, /* reuse */
- 0, /* bindtoDevice */
- ); err != nil {
- return err
- }
- f.acceptQueue = []fakeTransportEndpoint{}
- return nil
-}
-
-func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
- return tcpip.FullAddress{}, nil
-}
-
-func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
- return tcpip.FullAddress{}, nil
-}
-
-func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ tcpip.PacketBuffer) {
- // Increment the number of received packets.
- f.proto.packetCount++
- if f.acceptQueue != nil {
- f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{
- stack: f.stack,
- TransportEndpointInfo: stack.TransportEndpointInfo{
- ID: f.ID,
- NetProto: f.NetProto,
- },
- proto: f.proto,
- peerAddr: r.RemoteAddress,
- route: r.Clone(),
- })
- }
-}
-
-func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, tcpip.PacketBuffer) {
- // Increment the number of received control packets.
- f.proto.controlCount++
-}
-
-func (f *fakeTransportEndpoint) State() uint32 {
- return 0
-}
-
-func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {}
-
-func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) {
- return iptables.IPTables{}, nil
-}
-
-func (f *fakeTransportEndpoint) Resume(*stack.Stack) {}
-
-func (f *fakeTransportEndpoint) Wait() {}
-
-type fakeTransportGoodOption bool
-
-type fakeTransportBadOption bool
-
-type fakeTransportInvalidValueOption int
-
-type fakeTransportProtocolOptions struct {
- good bool
-}
-
-// fakeTransportProtocol is a transport-layer protocol descriptor. It
-// aggregates the number of packets received via endpoints of this protocol.
-type fakeTransportProtocol struct {
- packetCount int
- controlCount int
- opts fakeTransportProtocolOptions
-}
-
-func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber {
- return fakeTransNumber
-}
-
-func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil
-}
-
-func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return nil, tcpip.ErrUnknownProtocol
-}
-
-func (*fakeTransportProtocol) MinimumPacketSize() int {
- return fakeTransHeaderLen
-}
-
-func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcpip.Error) {
- return 0, 0, nil
-}
-
-func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool {
- return true
-}
-
-func (f *fakeTransportProtocol) SetOption(option interface{}) *tcpip.Error {
- switch v := option.(type) {
- case fakeTransportGoodOption:
- f.opts.good = bool(v)
- return nil
- case fakeTransportInvalidValueOption:
- return tcpip.ErrInvalidOptionValue
- default:
- return tcpip.ErrUnknownProtocolOption
- }
-}
-
-func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error {
- switch v := option.(type) {
- case *fakeTransportGoodOption:
- *v = fakeTransportGoodOption(f.opts.good)
- return nil
- default:
- return tcpip.ErrUnknownProtocolOption
- }
-}
-
-func fakeTransFactory() stack.TransportProtocol {
- return &fakeTransportProtocol{}
-}
-
-func TestTransportReceive(t *testing.T) {
- linkEP := channel.New(10, defaultMTU, "")
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
- })
- if err := s.CreateNIC(1, linkEP); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
- // Create endpoint and connect to remote address.
- wq := waiter.Queue{}
- ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil {
- t.Fatalf("Connect failed: %v", err)
- }
-
- fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol)
-
- // Create buffer that will hold the packet.
- buf := buffer.NewView(30)
-
- // Make sure packet with wrong protocol is not delivered.
- buf[0] = 1
- buf[2] = 0
- linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
- if fakeTrans.packetCount != 0 {
- t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0)
- }
-
- // Make sure packet from the wrong source is not delivered.
- buf[0] = 1
- buf[1] = 3
- buf[2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
- if fakeTrans.packetCount != 0 {
- t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0)
- }
-
- // Make sure packet is delivered.
- buf[0] = 1
- buf[1] = 2
- buf[2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
- if fakeTrans.packetCount != 1 {
- t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1)
- }
-}
-
-func TestTransportControlReceive(t *testing.T) {
- linkEP := channel.New(10, defaultMTU, "")
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
- })
- if err := s.CreateNIC(1, linkEP); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
- // Create endpoint and connect to remote address.
- wq := waiter.Queue{}
- ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil {
- t.Fatalf("Connect failed: %v", err)
- }
-
- fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol)
-
- // Create buffer that will hold the control packet.
- buf := buffer.NewView(2*fakeNetHeaderLen + 30)
-
- // Outer packet contains the control protocol number.
- buf[0] = 1
- buf[1] = 0xfe
- buf[2] = uint8(fakeControlProtocol)
-
- // Make sure packet with wrong protocol is not delivered.
- buf[fakeNetHeaderLen+0] = 0
- buf[fakeNetHeaderLen+1] = 1
- buf[fakeNetHeaderLen+2] = 0
- linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
- if fakeTrans.controlCount != 0 {
- t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0)
- }
-
- // Make sure packet from the wrong source is not delivered.
- buf[fakeNetHeaderLen+0] = 3
- buf[fakeNetHeaderLen+1] = 1
- buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
- if fakeTrans.controlCount != 0 {
- t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0)
- }
-
- // Make sure packet is delivered.
- buf[fakeNetHeaderLen+0] = 2
- buf[fakeNetHeaderLen+1] = 1
- buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
- if fakeTrans.controlCount != 1 {
- t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1)
- }
-}
-
-func TestTransportSend(t *testing.T) {
- linkEP := channel.New(10, defaultMTU, "")
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
- })
- if err := s.CreateNIC(1, linkEP); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- // Create endpoint and bind it.
- wq := waiter.Queue{}
- ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil {
- t.Fatalf("Connect failed: %v", err)
- }
-
- // Create buffer that will hold the payload.
- view := buffer.NewView(30)
- _, _, err = ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{})
- if err != nil {
- t.Fatalf("write failed: %v", err)
- }
-
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
- if fakeNet.sendPacketCount[2] != 1 {
- t.Errorf("sendPacketCount = %d, want %d", fakeNet.sendPacketCount[2], 1)
- }
-}
-
-func TestTransportOptions(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
- })
-
- // Try an unsupported transport protocol.
- if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol {
- t.Fatalf("SetTransportProtocolOption(fakeTrans2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err)
- }
-
- testCases := []struct {
- option interface{}
- wantErr *tcpip.Error
- verifier func(t *testing.T, p stack.TransportProtocol)
- }{
- {fakeTransportGoodOption(true), nil, func(t *testing.T, p stack.TransportProtocol) {
- t.Helper()
- fakeTrans := p.(*fakeTransportProtocol)
- if fakeTrans.opts.good != true {
- t.Fatalf("fakeTrans.opts.good = false, want = true")
- }
- var v fakeTransportGoodOption
- if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil {
- t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) = %v, want = nil, where v is option %T", v, err)
- }
- if v != true {
- t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) returned v = %v, want = true", v)
- }
-
- }},
- {fakeTransportBadOption(true), tcpip.ErrUnknownProtocolOption, nil},
- {fakeTransportInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil},
- }
- for _, tc := range testCases {
- if got := s.SetTransportProtocolOption(fakeTransNumber, tc.option); got != tc.wantErr {
- t.Errorf("s.SetTransportProtocolOption(fakeTrans, %v) = %v, want = %v", tc.option, got, tc.wantErr)
- }
- if tc.verifier != nil {
- tc.verifier(t, s.TransportProtocolInstance(fakeTransNumber))
- }
- }
-}
-
-func TestTransportForwarding(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
- })
- s.SetForwarding(true)
-
- // TODO(b/123449044): Change this to a channel NIC.
- ep1 := loopback.New()
- if err := s.CreateNIC(1, ep1); err != nil {
- t.Fatalf("CreateNIC #1 failed: %v", err)
- }
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress #1 failed: %v", err)
- }
-
- ep2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, ep2); err != nil {
- t.Fatalf("CreateNIC #2 failed: %v", err)
- }
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatalf("AddAddress #2 failed: %v", err)
- }
-
- // Route all packets to address 3 to NIC 2 and all packets to address
- // 1 to NIC 1.
- {
- subnet0, err := tcpip.NewSubnet("\x03", "\xff")
- if err != nil {
- t.Fatal(err)
- }
- subnet1, err := tcpip.NewSubnet("\x01", "\xff")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{
- {Destination: subnet0, Gateway: "\x00", NIC: 2},
- {Destination: subnet1, Gateway: "\x00", NIC: 1},
- })
- }
-
- wq := waiter.Queue{}
- ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if err := ep.Bind(tcpip.FullAddress{Addr: "\x01", NIC: 1}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Send a packet to address 1 from address 3.
- req := buffer.NewView(30)
- req[0] = 1
- req[1] = 3
- req[2] = byte(fakeTransNumber)
- ep2.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
- Data: req.ToVectorisedView(),
- })
-
- aep, _, err := ep.Accept()
- if err != nil || aep == nil {
- t.Fatalf("Accept failed: %v, %v", aep, err)
- }
-
- resp := buffer.NewView(30)
- if _, _, err := aep.Write(tcpip.SlicePayload(resp), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- p, ok := ep2.Read()
- if !ok {
- t.Fatal("Response packet not forwarded")
- }
-
- if dst := p.Pkt.Header.View()[0]; dst != 3 {
- t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst)
- }
- if src := p.Pkt.Header.View()[1]; src != 1 {
- t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src)
- }
-}
diff --git a/pkg/tcpip/tcpip_state_autogen.go b/pkg/tcpip/tcpip_state_autogen.go
new file mode 100755
index 000000000..6753503f0
--- /dev/null
+++ b/pkg/tcpip/tcpip_state_autogen.go
@@ -0,0 +1,108 @@
+// automatically generated by stateify.
+
+// +build go1.9
+// +build !go1.15
+
+package tcpip
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func (x *PacketBuffer) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Data", &x.Data)
+ m.Save("DataOffset", &x.DataOffset)
+ m.Save("DataSize", &x.DataSize)
+ m.Save("Header", &x.Header)
+ m.Save("LinkHeader", &x.LinkHeader)
+ m.Save("NetworkHeader", &x.NetworkHeader)
+ m.Save("TransportHeader", &x.TransportHeader)
+}
+
+func (x *PacketBuffer) afterLoad() {}
+func (x *PacketBuffer) load(m state.Map) {
+ m.Load("Data", &x.Data)
+ m.Load("DataOffset", &x.DataOffset)
+ m.Load("DataSize", &x.DataSize)
+ m.Load("Header", &x.Header)
+ m.Load("LinkHeader", &x.LinkHeader)
+ m.Load("NetworkHeader", &x.NetworkHeader)
+ m.Load("TransportHeader", &x.TransportHeader)
+}
+
+func (x *FullAddress) beforeSave() {}
+func (x *FullAddress) save(m state.Map) {
+ x.beforeSave()
+ m.Save("NIC", &x.NIC)
+ m.Save("Addr", &x.Addr)
+ m.Save("Port", &x.Port)
+}
+
+func (x *FullAddress) afterLoad() {}
+func (x *FullAddress) load(m state.Map) {
+ m.Load("NIC", &x.NIC)
+ m.Load("Addr", &x.Addr)
+ m.Load("Port", &x.Port)
+}
+
+func (x *ControlMessages) beforeSave() {}
+func (x *ControlMessages) save(m state.Map) {
+ x.beforeSave()
+ m.Save("HasTimestamp", &x.HasTimestamp)
+ m.Save("Timestamp", &x.Timestamp)
+ m.Save("HasInq", &x.HasInq)
+ m.Save("Inq", &x.Inq)
+ m.Save("HasTOS", &x.HasTOS)
+ m.Save("TOS", &x.TOS)
+ m.Save("HasTClass", &x.HasTClass)
+ m.Save("TClass", &x.TClass)
+ m.Save("HasIPPacketInfo", &x.HasIPPacketInfo)
+ m.Save("PacketInfo", &x.PacketInfo)
+}
+
+func (x *ControlMessages) afterLoad() {}
+func (x *ControlMessages) load(m state.Map) {
+ m.Load("HasTimestamp", &x.HasTimestamp)
+ m.Load("Timestamp", &x.Timestamp)
+ m.Load("HasInq", &x.HasInq)
+ m.Load("Inq", &x.Inq)
+ m.Load("HasTOS", &x.HasTOS)
+ m.Load("TOS", &x.TOS)
+ m.Load("HasTClass", &x.HasTClass)
+ m.Load("TClass", &x.TClass)
+ m.Load("HasIPPacketInfo", &x.HasIPPacketInfo)
+ m.Load("PacketInfo", &x.PacketInfo)
+}
+
+func (x *IPPacketInfo) beforeSave() {}
+func (x *IPPacketInfo) save(m state.Map) {
+ x.beforeSave()
+ m.Save("NIC", &x.NIC)
+ m.Save("LocalAddr", &x.LocalAddr)
+ m.Save("DestinationAddr", &x.DestinationAddr)
+}
+
+func (x *IPPacketInfo) afterLoad() {}
+func (x *IPPacketInfo) load(m state.Map) {
+ m.Load("NIC", &x.NIC)
+ m.Load("LocalAddr", &x.LocalAddr)
+ m.Load("DestinationAddr", &x.DestinationAddr)
+}
+
+func (x *StdClock) beforeSave() {}
+func (x *StdClock) save(m state.Map) {
+ x.beforeSave()
+}
+
+func (x *StdClock) afterLoad() {}
+func (x *StdClock) load(m state.Map) {
+}
+
+func init() {
+ state.Register("pkg/tcpip.PacketBuffer", (*PacketBuffer)(nil), state.Fns{Save: (*PacketBuffer).save, Load: (*PacketBuffer).load})
+ state.Register("pkg/tcpip.FullAddress", (*FullAddress)(nil), state.Fns{Save: (*FullAddress).save, Load: (*FullAddress).load})
+ state.Register("pkg/tcpip.ControlMessages", (*ControlMessages)(nil), state.Fns{Save: (*ControlMessages).save, Load: (*ControlMessages).load})
+ state.Register("pkg/tcpip.IPPacketInfo", (*IPPacketInfo)(nil), state.Fns{Save: (*IPPacketInfo).save, Load: (*IPPacketInfo).load})
+ state.Register("pkg/tcpip.StdClock", (*StdClock)(nil), state.Fns{Save: (*StdClock).save, Load: (*StdClock).load})
+}
diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go
deleted file mode 100644
index 8c0aacffa..000000000
--- a/pkg/tcpip/tcpip_test.go
+++ /dev/null
@@ -1,228 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcpip
-
-import (
- "fmt"
- "net"
- "strings"
- "testing"
-)
-
-func TestSubnetContains(t *testing.T) {
- tests := []struct {
- s Address
- m AddressMask
- a Address
- want bool
- }{
- {"\xa0", "\xf0", "\x90", false},
- {"\xa0", "\xf0", "\xa0", true},
- {"\xa0", "\xf0", "\xa5", true},
- {"\xa0", "\xf0", "\xaf", true},
- {"\xa0", "\xf0", "\xb0", false},
- {"\xa0", "\xf0", "", false},
- {"\xa0", "\xf0", "\xa0\x00", false},
- {"\xc2\x80", "\xff\xf0", "\xc2\x80", true},
- {"\xc2\x80", "\xff\xf0", "\xc2\x00", false},
- {"\xc2\x00", "\xff\xf0", "\xc2\x00", true},
- {"\xc2\x00", "\xff\xf0", "\xc2\x80", false},
- }
- for _, tt := range tests {
- s, err := NewSubnet(tt.s, tt.m)
- if err != nil {
- t.Errorf("NewSubnet(%v, %v) = %v", tt.s, tt.m, err)
- continue
- }
- if got := s.Contains(tt.a); got != tt.want {
- t.Errorf("Subnet(%v).Contains(%v) = %v, want %v", s, tt.a, got, tt.want)
- }
- }
-}
-
-func TestSubnetBits(t *testing.T) {
- tests := []struct {
- a AddressMask
- want1 int
- want0 int
- }{
- {"\x00", 0, 8},
- {"\x00\x00", 0, 16},
- {"\x36", 0, 8},
- {"\x5c", 0, 8},
- {"\x5c\x5c", 0, 16},
- {"\x5c\x36", 0, 16},
- {"\x36\x5c", 0, 16},
- {"\x36\x36", 0, 16},
- {"\xff", 8, 0},
- {"\xff\xff", 16, 0},
- }
- for _, tt := range tests {
- s := &Subnet{mask: tt.a}
- got1, got0 := s.Bits()
- if got1 != tt.want1 || got0 != tt.want0 {
- t.Errorf("Subnet{mask: %x}.Bits() = %d, %d, want %d, %d", tt.a, got1, got0, tt.want1, tt.want0)
- }
- }
-}
-
-func TestSubnetPrefix(t *testing.T) {
- tests := []struct {
- a AddressMask
- want int
- }{
- {"\x00", 0},
- {"\x00\x00", 0},
- {"\x36", 0},
- {"\x86", 1},
- {"\xc5", 2},
- {"\xff\x00", 8},
- {"\xff\x36", 8},
- {"\xff\x8c", 9},
- {"\xff\xc8", 10},
- {"\xff", 8},
- {"\xff\xff", 16},
- }
- for _, tt := range tests {
- s := &Subnet{mask: tt.a}
- got := s.Prefix()
- if got != tt.want {
- t.Errorf("Subnet{mask: %x}.Bits() = %d want %d", tt.a, got, tt.want)
- }
- }
-}
-
-func TestSubnetCreation(t *testing.T) {
- tests := []struct {
- a Address
- m AddressMask
- want error
- }{
- {"\xa0", "\xf0", nil},
- {"\xa0\xa0", "\xf0", errSubnetLengthMismatch},
- {"\xaa", "\xf0", errSubnetAddressMasked},
- {"", "", nil},
- }
- for _, tt := range tests {
- if _, err := NewSubnet(tt.a, tt.m); err != tt.want {
- t.Errorf("NewSubnet(%v, %v) = %v, want %v", tt.a, tt.m, err, tt.want)
- }
- }
-}
-
-func TestAddressString(t *testing.T) {
- for _, want := range []string{
- // Taken from stdlib.
- "2001:db8::123:12:1",
- "2001:db8::1",
- "2001:db8:0:1:0:1:0:1",
- "2001:db8:1:0:1:0:1:0",
- "2001::1:0:0:1",
- "2001:db8:0:0:1::",
- "2001:db8::1:0:0:1",
- "2001:db8::a:b:c:d",
-
- // Leading zeros.
- "::1",
- // Trailing zeros.
- "8::",
- // No zeros.
- "1:1:1:1:1:1:1:1",
- // Longer sequence is after other zeros, but not at the end.
- "1:0:0:1::1",
- // Longer sequence is at the beginning, shorter sequence is at
- // the end.
- "::1:1:1:0:0",
- // Longer sequence is not at the beginning, shorter sequence is
- // at the end.
- "1::1:1:0:0",
- // Longer sequence is at the beginning, shorter sequence is not
- // at the end.
- "::1:1:0:0:1",
- // Neither sequence is at an end, longer is after shorter.
- "1:0:0:1::1",
- // Shorter sequence is at the beginning, longer sequence is not
- // at the end.
- "0:0:1:1::1",
- // Shorter sequence is at the beginning, longer sequence is at
- // the end.
- "0:0:1:1:1::",
- // Short sequences at both ends, longer one in the middle.
- "0:1:1::1:1:0",
- // Short sequences at both ends, longer one in the middle.
- "0:1::1:0:0",
- // Short sequences at both ends, longer one in the middle.
- "0:0:1::1:0",
- // Longer sequence surrounded by shorter sequences, but none at
- // the end.
- "1:0:1::1:0:1",
- } {
- addr := Address(net.ParseIP(want))
- if got := addr.String(); got != want {
- t.Errorf("Address(%x).String() = '%s', want = '%s'", addr, got, want)
- }
- }
-}
-
-func TestStatsString(t *testing.T) {
- got := fmt.Sprintf("%+v", Stats{}.FillIn())
-
- matchers := []string{
- // Print root-level stats correctly.
- "UnknownProtocolRcvdPackets:0",
- // Print protocol-specific stats correctly.
- "TCP:{ActiveConnectionOpenings:0",
- }
-
- for _, m := range matchers {
- if !strings.Contains(got, m) {
- t.Errorf("string.Contains(got, %q) = false", m)
- }
- }
- if t.Failed() {
- t.Logf(`got = fmt.Sprintf("%%+v", Stats{}.FillIn()) = %q`, got)
- }
-}
-
-func TestAddressWithPrefixSubnet(t *testing.T) {
- tests := []struct {
- addr Address
- prefixLen int
- subnetAddr Address
- subnetMask AddressMask
- }{
- {"\xaa\x55\x33\x42", -1, "\x00\x00\x00\x00", "\x00\x00\x00\x00"},
- {"\xaa\x55\x33\x42", 0, "\x00\x00\x00\x00", "\x00\x00\x00\x00"},
- {"\xaa\x55\x33\x42", 1, "\x80\x00\x00\x00", "\x80\x00\x00\x00"},
- {"\xaa\x55\x33\x42", 7, "\xaa\x00\x00\x00", "\xfe\x00\x00\x00"},
- {"\xaa\x55\x33\x42", 8, "\xaa\x00\x00\x00", "\xff\x00\x00\x00"},
- {"\xaa\x55\x33\x42", 24, "\xaa\x55\x33\x00", "\xff\xff\xff\x00"},
- {"\xaa\x55\x33\x42", 31, "\xaa\x55\x33\x42", "\xff\xff\xff\xfe"},
- {"\xaa\x55\x33\x42", 32, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"},
- {"\xaa\x55\x33\x42", 33, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"},
- }
- for _, tt := range tests {
- ap := AddressWithPrefix{Address: tt.addr, PrefixLen: tt.prefixLen}
- gotSubnet := ap.Subnet()
- wantSubnet, err := NewSubnet(tt.subnetAddr, tt.subnetMask)
- if err != nil {
- t.Error("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err)
- continue
- }
- if gotSubnet != wantSubnet {
- t.Errorf("got subnet = %q, want = %q", gotSubnet, wantSubnet)
- }
- }
-}
diff --git a/pkg/tcpip/time.s b/pkg/tcpip/time.s
deleted file mode 100644
index fb37360ac..000000000
--- a/pkg/tcpip/time.s
+++ /dev/null
@@ -1,15 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Empty assembly file so empty func definitions work.
diff --git a/pkg/tcpip/timer.go b/pkg/tcpip/timer.go
index f5f01f32f..f5f01f32f 100644..100755
--- a/pkg/tcpip/timer.go
+++ b/pkg/tcpip/timer.go
diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go
deleted file mode 100644
index 2d20f7ef3..000000000
--- a/pkg/tcpip/timer_test.go
+++ /dev/null
@@ -1,236 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcpip_test
-
-import (
- "sync"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
-)
-
-const (
- shortDuration = 1 * time.Nanosecond
- middleDuration = 100 * time.Millisecond
- longDuration = 1 * time.Second
-)
-
-func TestCancellableTimerFire(t *testing.T) {
- t.Parallel()
-
- ch := make(chan struct{})
- var lock sync.Mutex
-
- timer := tcpip.MakeCancellableTimer(&lock, func() {
- ch <- struct{}{}
- })
- timer.Reset(shortDuration)
-
- // Wait for timer to fire.
- select {
- case <-ch:
- case <-time.After(middleDuration):
- t.Fatal("timed out waiting for timer to fire")
- }
-
- // The timer should have fired only once.
- select {
- case <-ch:
- t.Fatal("no other timers should have fired")
- case <-time.After(middleDuration):
- }
-}
-
-func TestCancellableTimerResetFromLongDuration(t *testing.T) {
- t.Parallel()
-
- ch := make(chan struct{})
- var lock sync.Mutex
-
- timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(middleDuration)
-
- lock.Lock()
- timer.StopLocked()
- lock.Unlock()
-
- timer.Reset(shortDuration)
-
- // Wait for timer to fire.
- select {
- case <-ch:
- case <-time.After(middleDuration):
- t.Fatal("timed out waiting for timer to fire")
- }
-
- // The timer should have fired only once.
- select {
- case <-ch:
- t.Fatal("no other timers should have fired")
- case <-time.After(middleDuration):
- }
-}
-
-func TestCancellableTimerResetFromShortDuration(t *testing.T) {
- t.Parallel()
-
- ch := make(chan struct{})
- var lock sync.Mutex
-
- lock.Lock()
- timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
- timer.StopLocked()
- lock.Unlock()
-
- // Wait for timer to fire if it wasn't correctly stopped.
- select {
- case <-ch:
- t.Fatal("timer fired after being stopped")
- case <-time.After(middleDuration):
- }
-
- timer.Reset(shortDuration)
-
- // Wait for timer to fire.
- select {
- case <-ch:
- case <-time.After(middleDuration):
- t.Fatal("timed out waiting for timer to fire")
- }
-
- // The timer should have fired only once.
- select {
- case <-ch:
- t.Fatal("no other timers should have fired")
- case <-time.After(middleDuration):
- }
-}
-
-func TestCancellableTimerImmediatelyStop(t *testing.T) {
- t.Parallel()
-
- ch := make(chan struct{})
- var lock sync.Mutex
-
- for i := 0; i < 1000; i++ {
- lock.Lock()
- timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
- timer.StopLocked()
- lock.Unlock()
- }
-
- // Wait for timer to fire if it wasn't correctly stopped.
- select {
- case <-ch:
- t.Fatal("timer fired after being stopped")
- case <-time.After(middleDuration):
- }
-}
-
-func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) {
- t.Parallel()
-
- ch := make(chan struct{})
- var lock sync.Mutex
-
- lock.Lock()
- timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
- timer.StopLocked()
- lock.Unlock()
-
- for i := 0; i < 10; i++ {
- timer.Reset(middleDuration)
-
- lock.Lock()
- // Sleep until the timer fires and gets blocked trying to take the lock.
- time.Sleep(middleDuration * 2)
- timer.StopLocked()
- lock.Unlock()
- }
-
- // Wait for double the duration so timers that weren't correctly stopped can
- // fire.
- select {
- case <-ch:
- t.Fatal("timer fired after being stopped")
- case <-time.After(middleDuration * 2):
- }
-}
-
-func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
- t.Parallel()
-
- ch := make(chan struct{})
- var lock sync.Mutex
-
- lock.Lock()
- timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
- for i := 0; i < 10; i++ {
- // Sleep until the timer fires and gets blocked trying to take the lock.
- time.Sleep(middleDuration)
- timer.StopLocked()
- timer.Reset(shortDuration)
- }
- lock.Unlock()
-
- // Wait for double the duration for the last timer to fire.
- select {
- case <-ch:
- case <-time.After(middleDuration):
- t.Fatal("timed out waiting for timer to fire")
- }
-
- // The timer should have fired only once.
- select {
- case <-ch:
- t.Fatal("no other timers should have fired")
- case <-time.After(middleDuration):
- }
-}
-
-func TestManyCancellableTimerResetUnderLock(t *testing.T) {
- t.Parallel()
-
- ch := make(chan struct{})
- var lock sync.Mutex
-
- lock.Lock()
- timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
- for i := 0; i < 10; i++ {
- timer.StopLocked()
- timer.Reset(shortDuration)
- }
- lock.Unlock()
-
- // Wait for double the duration for the last timer to fire.
- select {
- case <-ch:
- case <-time.After(middleDuration):
- t.Fatal("timed out waiting for timer to fire")
- }
-
- // The timer should have fired only once.
- select {
- case <-ch:
- t.Fatal("no other timers should have fired")
- case <-time.After(middleDuration):
- }
-}
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
deleted file mode 100644
index ac18ec5b1..000000000
--- a/pkg/tcpip/transport/icmp/BUILD
+++ /dev/null
@@ -1,40 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-package(licenses = ["notice"])
-
-go_template_instance(
- name = "icmp_packet_list",
- out = "icmp_packet_list.go",
- package = "icmp",
- prefix = "icmpPacket",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*icmpPacket",
- "Linker": "*icmpPacket",
- },
-)
-
-go_library(
- name = "icmp",
- srcs = [
- "endpoint.go",
- "endpoint_state.go",
- "icmp_packet_list.go",
- "protocol.go",
- ],
- imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/sleep",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/raw",
- "//pkg/tcpip/transport/tcp",
- "//pkg/waiter",
- ],
-)
diff --git a/pkg/tcpip/transport/icmp/icmp_packet_list.go b/pkg/tcpip/transport/icmp/icmp_packet_list.go
new file mode 100755
index 000000000..1b35e5b4a
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/icmp_packet_list.go
@@ -0,0 +1,173 @@
+package icmp
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type icmpPacketElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (icmpPacketElementMapper) linkerFor(elem *icmpPacket) *icmpPacket { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type icmpPacketList struct {
+ head *icmpPacket
+ tail *icmpPacket
+}
+
+// Reset resets list l to the empty state.
+func (l *icmpPacketList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *icmpPacketList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *icmpPacketList) Front() *icmpPacket {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *icmpPacketList) Back() *icmpPacket {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *icmpPacketList) PushFront(e *icmpPacket) {
+ icmpPacketElementMapper{}.linkerFor(e).SetNext(l.head)
+ icmpPacketElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ icmpPacketElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *icmpPacketList) PushBack(e *icmpPacket) {
+ icmpPacketElementMapper{}.linkerFor(e).SetNext(nil)
+ icmpPacketElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *icmpPacketList) PushBackList(m *icmpPacketList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ icmpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *icmpPacketList) InsertAfter(b, e *icmpPacket) {
+ a := icmpPacketElementMapper{}.linkerFor(b).Next()
+ icmpPacketElementMapper{}.linkerFor(e).SetNext(a)
+ icmpPacketElementMapper{}.linkerFor(e).SetPrev(b)
+ icmpPacketElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ icmpPacketElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *icmpPacketList) InsertBefore(a, e *icmpPacket) {
+ b := icmpPacketElementMapper{}.linkerFor(a).Prev()
+ icmpPacketElementMapper{}.linkerFor(e).SetNext(a)
+ icmpPacketElementMapper{}.linkerFor(e).SetPrev(b)
+ icmpPacketElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ icmpPacketElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *icmpPacketList) Remove(e *icmpPacket) {
+ prev := icmpPacketElementMapper{}.linkerFor(e).Prev()
+ next := icmpPacketElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ icmpPacketElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ icmpPacketElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type icmpPacketEntry struct {
+ next *icmpPacket
+ prev *icmpPacket
+}
+
+// Next returns the entry that follows e in the list.
+func (e *icmpPacketEntry) Next() *icmpPacket {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *icmpPacketEntry) Prev() *icmpPacket {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *icmpPacketEntry) SetNext(elem *icmpPacket) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *icmpPacketEntry) SetPrev(elem *icmpPacket) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/icmp/icmp_state_autogen.go b/pkg/tcpip/transport/icmp/icmp_state_autogen.go
new file mode 100755
index 000000000..b856a4b89
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/icmp_state_autogen.go
@@ -0,0 +1,92 @@
+// automatically generated by stateify.
+
+package icmp
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+func (x *icmpPacket) beforeSave() {}
+func (x *icmpPacket) save(m state.Map) {
+ x.beforeSave()
+ var data buffer.VectorisedView = x.saveData()
+ m.SaveValue("data", data)
+ m.Save("icmpPacketEntry", &x.icmpPacketEntry)
+ m.Save("senderAddress", &x.senderAddress)
+ m.Save("timestamp", &x.timestamp)
+}
+
+func (x *icmpPacket) afterLoad() {}
+func (x *icmpPacket) load(m state.Map) {
+ m.Load("icmpPacketEntry", &x.icmpPacketEntry)
+ m.Load("senderAddress", &x.senderAddress)
+ m.Load("timestamp", &x.timestamp)
+ m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) })
+}
+
+func (x *endpoint) save(m state.Map) {
+ x.beforeSave()
+ var rcvBufSizeMax int = x.saveRcvBufSizeMax()
+ m.SaveValue("rcvBufSizeMax", rcvBufSizeMax)
+ m.Save("TransportEndpointInfo", &x.TransportEndpointInfo)
+ m.Save("waiterQueue", &x.waiterQueue)
+ m.Save("uniqueID", &x.uniqueID)
+ m.Save("rcvReady", &x.rcvReady)
+ m.Save("rcvList", &x.rcvList)
+ m.Save("rcvBufSize", &x.rcvBufSize)
+ m.Save("rcvClosed", &x.rcvClosed)
+ m.Save("sndBufSize", &x.sndBufSize)
+ m.Save("shutdownFlags", &x.shutdownFlags)
+ m.Save("state", &x.state)
+ m.Save("ttl", &x.ttl)
+}
+
+func (x *endpoint) load(m state.Map) {
+ m.Load("TransportEndpointInfo", &x.TransportEndpointInfo)
+ m.Load("waiterQueue", &x.waiterQueue)
+ m.Load("uniqueID", &x.uniqueID)
+ m.Load("rcvReady", &x.rcvReady)
+ m.Load("rcvList", &x.rcvList)
+ m.Load("rcvBufSize", &x.rcvBufSize)
+ m.Load("rcvClosed", &x.rcvClosed)
+ m.Load("sndBufSize", &x.sndBufSize)
+ m.Load("shutdownFlags", &x.shutdownFlags)
+ m.Load("state", &x.state)
+ m.Load("ttl", &x.ttl)
+ m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) })
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *icmpPacketList) beforeSave() {}
+func (x *icmpPacketList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *icmpPacketList) afterLoad() {}
+func (x *icmpPacketList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *icmpPacketEntry) beforeSave() {}
+func (x *icmpPacketEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *icmpPacketEntry) afterLoad() {}
+func (x *icmpPacketEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func init() {
+ state.Register("pkg/tcpip/transport/icmp.icmpPacket", (*icmpPacket)(nil), state.Fns{Save: (*icmpPacket).save, Load: (*icmpPacket).load})
+ state.Register("pkg/tcpip/transport/icmp.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load})
+ state.Register("pkg/tcpip/transport/icmp.icmpPacketList", (*icmpPacketList)(nil), state.Fns{Save: (*icmpPacketList).save, Load: (*icmpPacketList).load})
+ state.Register("pkg/tcpip/transport/icmp.icmpPacketEntry", (*icmpPacketEntry)(nil), state.Fns{Save: (*icmpPacketEntry).save, Load: (*icmpPacketEntry).load})
+}
diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD
deleted file mode 100644
index d22de6b26..000000000
--- a/pkg/tcpip/transport/packet/BUILD
+++ /dev/null
@@ -1,38 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-package(licenses = ["notice"])
-
-go_template_instance(
- name = "packet_list",
- out = "packet_list.go",
- package = "packet",
- prefix = "packet",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*packet",
- "Linker": "*packet",
- },
-)
-
-go_library(
- name = "packet",
- srcs = [
- "endpoint.go",
- "endpoint_state.go",
- "packet_list.go",
- ],
- imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/log",
- "//pkg/sleep",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
- "//pkg/tcpip/stack",
- "//pkg/waiter",
- ],
-)
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index fc5bc69fa..fc5bc69fa 100644..100755
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
index 9b88f17e4..9b88f17e4 100644..100755
--- a/pkg/tcpip/transport/packet/endpoint_state.go
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
diff --git a/pkg/tcpip/transport/packet/packet_list.go b/pkg/tcpip/transport/packet/packet_list.go
new file mode 100755
index 000000000..0da0dfcb6
--- /dev/null
+++ b/pkg/tcpip/transport/packet/packet_list.go
@@ -0,0 +1,173 @@
+package packet
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type packetElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (packetElementMapper) linkerFor(elem *packet) *packet { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type packetList struct {
+ head *packet
+ tail *packet
+}
+
+// Reset resets list l to the empty state.
+func (l *packetList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *packetList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *packetList) Front() *packet {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *packetList) Back() *packet {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *packetList) PushFront(e *packet) {
+ packetElementMapper{}.linkerFor(e).SetNext(l.head)
+ packetElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ packetElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *packetList) PushBack(e *packet) {
+ packetElementMapper{}.linkerFor(e).SetNext(nil)
+ packetElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ packetElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *packetList) PushBackList(m *packetList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ packetElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ packetElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *packetList) InsertAfter(b, e *packet) {
+ a := packetElementMapper{}.linkerFor(b).Next()
+ packetElementMapper{}.linkerFor(e).SetNext(a)
+ packetElementMapper{}.linkerFor(e).SetPrev(b)
+ packetElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ packetElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *packetList) InsertBefore(a, e *packet) {
+ b := packetElementMapper{}.linkerFor(a).Prev()
+ packetElementMapper{}.linkerFor(e).SetNext(a)
+ packetElementMapper{}.linkerFor(e).SetPrev(b)
+ packetElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ packetElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *packetList) Remove(e *packet) {
+ prev := packetElementMapper{}.linkerFor(e).Prev()
+ next := packetElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ packetElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ packetElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type packetEntry struct {
+ next *packet
+ prev *packet
+}
+
+// Next returns the entry that follows e in the list.
+func (e *packetEntry) Next() *packet {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *packetEntry) Prev() *packet {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *packetEntry) SetNext(elem *packet) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *packetEntry) SetPrev(elem *packet) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/packet/packet_state_autogen.go b/pkg/tcpip/transport/packet/packet_state_autogen.go
new file mode 100755
index 000000000..1d90a383f
--- /dev/null
+++ b/pkg/tcpip/transport/packet/packet_state_autogen.go
@@ -0,0 +1,88 @@
+// automatically generated by stateify.
+
+package packet
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+func (x *packet) beforeSave() {}
+func (x *packet) save(m state.Map) {
+ x.beforeSave()
+ var data buffer.VectorisedView = x.saveData()
+ m.SaveValue("data", data)
+ m.Save("packetEntry", &x.packetEntry)
+ m.Save("timestampNS", &x.timestampNS)
+ m.Save("senderAddr", &x.senderAddr)
+}
+
+func (x *packet) afterLoad() {}
+func (x *packet) load(m state.Map) {
+ m.Load("packetEntry", &x.packetEntry)
+ m.Load("timestampNS", &x.timestampNS)
+ m.Load("senderAddr", &x.senderAddr)
+ m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) })
+}
+
+func (x *endpoint) save(m state.Map) {
+ x.beforeSave()
+ var rcvBufSizeMax int = x.saveRcvBufSizeMax()
+ m.SaveValue("rcvBufSizeMax", rcvBufSizeMax)
+ m.Save("TransportEndpointInfo", &x.TransportEndpointInfo)
+ m.Save("netProto", &x.netProto)
+ m.Save("waiterQueue", &x.waiterQueue)
+ m.Save("cooked", &x.cooked)
+ m.Save("rcvList", &x.rcvList)
+ m.Save("rcvBufSize", &x.rcvBufSize)
+ m.Save("rcvClosed", &x.rcvClosed)
+ m.Save("sndBufSize", &x.sndBufSize)
+ m.Save("closed", &x.closed)
+}
+
+func (x *endpoint) load(m state.Map) {
+ m.Load("TransportEndpointInfo", &x.TransportEndpointInfo)
+ m.Load("netProto", &x.netProto)
+ m.Load("waiterQueue", &x.waiterQueue)
+ m.Load("cooked", &x.cooked)
+ m.Load("rcvList", &x.rcvList)
+ m.Load("rcvBufSize", &x.rcvBufSize)
+ m.Load("rcvClosed", &x.rcvClosed)
+ m.Load("sndBufSize", &x.sndBufSize)
+ m.Load("closed", &x.closed)
+ m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) })
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *packetList) beforeSave() {}
+func (x *packetList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *packetList) afterLoad() {}
+func (x *packetList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *packetEntry) beforeSave() {}
+func (x *packetEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *packetEntry) afterLoad() {}
+func (x *packetEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func init() {
+ state.Register("pkg/tcpip/transport/packet.packet", (*packet)(nil), state.Fns{Save: (*packet).save, Load: (*packet).load})
+ state.Register("pkg/tcpip/transport/packet.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load})
+ state.Register("pkg/tcpip/transport/packet.packetList", (*packetList)(nil), state.Fns{Save: (*packetList).save, Load: (*packetList).load})
+ state.Register("pkg/tcpip/transport/packet.packetEntry", (*packetEntry)(nil), state.Fns{Save: (*packetEntry).save, Load: (*packetEntry).load})
+}
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
deleted file mode 100644
index c9baf4600..000000000
--- a/pkg/tcpip/transport/raw/BUILD
+++ /dev/null
@@ -1,40 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-package(licenses = ["notice"])
-
-go_template_instance(
- name = "raw_packet_list",
- out = "raw_packet_list.go",
- package = "raw",
- prefix = "rawPacket",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*rawPacket",
- "Linker": "*rawPacket",
- },
-)
-
-go_library(
- name = "raw",
- srcs = [
- "endpoint.go",
- "endpoint_state.go",
- "protocol.go",
- "raw_packet_list.go",
- ],
- imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/log",
- "//pkg/sleep",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/packet",
- "//pkg/waiter",
- ],
-)
diff --git a/pkg/tcpip/transport/raw/raw_packet_list.go b/pkg/tcpip/transport/raw/raw_packet_list.go
new file mode 100755
index 000000000..12edb4334
--- /dev/null
+++ b/pkg/tcpip/transport/raw/raw_packet_list.go
@@ -0,0 +1,173 @@
+package raw
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type rawPacketElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (rawPacketElementMapper) linkerFor(elem *rawPacket) *rawPacket { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type rawPacketList struct {
+ head *rawPacket
+ tail *rawPacket
+}
+
+// Reset resets list l to the empty state.
+func (l *rawPacketList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *rawPacketList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *rawPacketList) Front() *rawPacket {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *rawPacketList) Back() *rawPacket {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *rawPacketList) PushFront(e *rawPacket) {
+ rawPacketElementMapper{}.linkerFor(e).SetNext(l.head)
+ rawPacketElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ rawPacketElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *rawPacketList) PushBack(e *rawPacket) {
+ rawPacketElementMapper{}.linkerFor(e).SetNext(nil)
+ rawPacketElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ rawPacketElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *rawPacketList) PushBackList(m *rawPacketList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ rawPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ rawPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *rawPacketList) InsertAfter(b, e *rawPacket) {
+ a := rawPacketElementMapper{}.linkerFor(b).Next()
+ rawPacketElementMapper{}.linkerFor(e).SetNext(a)
+ rawPacketElementMapper{}.linkerFor(e).SetPrev(b)
+ rawPacketElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ rawPacketElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *rawPacketList) InsertBefore(a, e *rawPacket) {
+ b := rawPacketElementMapper{}.linkerFor(a).Prev()
+ rawPacketElementMapper{}.linkerFor(e).SetNext(a)
+ rawPacketElementMapper{}.linkerFor(e).SetPrev(b)
+ rawPacketElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ rawPacketElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *rawPacketList) Remove(e *rawPacket) {
+ prev := rawPacketElementMapper{}.linkerFor(e).Prev()
+ next := rawPacketElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ rawPacketElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ rawPacketElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type rawPacketEntry struct {
+ next *rawPacket
+ prev *rawPacket
+}
+
+// Next returns the entry that follows e in the list.
+func (e *rawPacketEntry) Next() *rawPacket {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *rawPacketEntry) Prev() *rawPacket {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *rawPacketEntry) SetNext(elem *rawPacket) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *rawPacketEntry) SetPrev(elem *rawPacket) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/raw/raw_state_autogen.go b/pkg/tcpip/transport/raw/raw_state_autogen.go
new file mode 100755
index 000000000..41b72cf93
--- /dev/null
+++ b/pkg/tcpip/transport/raw/raw_state_autogen.go
@@ -0,0 +1,90 @@
+// automatically generated by stateify.
+
+package raw
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+func (x *rawPacket) beforeSave() {}
+func (x *rawPacket) save(m state.Map) {
+ x.beforeSave()
+ var data buffer.VectorisedView = x.saveData()
+ m.SaveValue("data", data)
+ m.Save("rawPacketEntry", &x.rawPacketEntry)
+ m.Save("timestampNS", &x.timestampNS)
+ m.Save("senderAddr", &x.senderAddr)
+}
+
+func (x *rawPacket) afterLoad() {}
+func (x *rawPacket) load(m state.Map) {
+ m.Load("rawPacketEntry", &x.rawPacketEntry)
+ m.Load("timestampNS", &x.timestampNS)
+ m.Load("senderAddr", &x.senderAddr)
+ m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) })
+}
+
+func (x *endpoint) save(m state.Map) {
+ x.beforeSave()
+ var rcvBufSizeMax int = x.saveRcvBufSizeMax()
+ m.SaveValue("rcvBufSizeMax", rcvBufSizeMax)
+ m.Save("TransportEndpointInfo", &x.TransportEndpointInfo)
+ m.Save("waiterQueue", &x.waiterQueue)
+ m.Save("associated", &x.associated)
+ m.Save("rcvList", &x.rcvList)
+ m.Save("rcvBufSize", &x.rcvBufSize)
+ m.Save("rcvClosed", &x.rcvClosed)
+ m.Save("sndBufSize", &x.sndBufSize)
+ m.Save("closed", &x.closed)
+ m.Save("connected", &x.connected)
+ m.Save("bound", &x.bound)
+}
+
+func (x *endpoint) load(m state.Map) {
+ m.Load("TransportEndpointInfo", &x.TransportEndpointInfo)
+ m.Load("waiterQueue", &x.waiterQueue)
+ m.Load("associated", &x.associated)
+ m.Load("rcvList", &x.rcvList)
+ m.Load("rcvBufSize", &x.rcvBufSize)
+ m.Load("rcvClosed", &x.rcvClosed)
+ m.Load("sndBufSize", &x.sndBufSize)
+ m.Load("closed", &x.closed)
+ m.Load("connected", &x.connected)
+ m.Load("bound", &x.bound)
+ m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) })
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *rawPacketList) beforeSave() {}
+func (x *rawPacketList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *rawPacketList) afterLoad() {}
+func (x *rawPacketList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *rawPacketEntry) beforeSave() {}
+func (x *rawPacketEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *rawPacketEntry) afterLoad() {}
+func (x *rawPacketEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func init() {
+ state.Register("pkg/tcpip/transport/raw.rawPacket", (*rawPacket)(nil), state.Fns{Save: (*rawPacket).save, Load: (*rawPacket).load})
+ state.Register("pkg/tcpip/transport/raw.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load})
+ state.Register("pkg/tcpip/transport/raw.rawPacketList", (*rawPacketList)(nil), state.Fns{Save: (*rawPacketList).save, Load: (*rawPacketList).load})
+ state.Register("pkg/tcpip/transport/raw.rawPacketEntry", (*rawPacketEntry)(nil), state.Fns{Save: (*rawPacketEntry).save, Load: (*rawPacketEntry).load})
+}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
deleted file mode 100644
index 272e8f570..000000000
--- a/pkg/tcpip/transport/tcp/BUILD
+++ /dev/null
@@ -1,109 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-package(licenses = ["notice"])
-
-go_template_instance(
- name = "tcp_segment_list",
- out = "tcp_segment_list.go",
- package = "tcp",
- prefix = "segment",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*segment",
- "Linker": "*segment",
- },
-)
-
-go_template_instance(
- name = "tcp_endpoint_list",
- out = "tcp_endpoint_list.go",
- package = "tcp",
- prefix = "endpoint",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*endpoint",
- "Linker": "*endpoint",
- },
-)
-
-go_library(
- name = "tcp",
- srcs = [
- "accept.go",
- "connect.go",
- "cubic.go",
- "cubic_state.go",
- "dispatcher.go",
- "endpoint.go",
- "endpoint_state.go",
- "forwarder.go",
- "protocol.go",
- "rcv.go",
- "rcv_state.go",
- "reno.go",
- "sack.go",
- "sack_scoreboard.go",
- "segment.go",
- "segment_heap.go",
- "segment_queue.go",
- "segment_state.go",
- "snd.go",
- "snd_state.go",
- "tcp_endpoint_list.go",
- "tcp_segment_list.go",
- "timer.go",
- ],
- imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/log",
- "//pkg/rand",
- "//pkg/sleep",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/hash/jenkins",
- "//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
- "//pkg/tcpip/ports",
- "//pkg/tcpip/seqnum",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/raw",
- "//pkg/tmutex",
- "//pkg/waiter",
- "@com_github_google_btree//:go_default_library",
- ],
-)
-
-go_test(
- name = "tcp_test",
- size = "medium",
- srcs = [
- "dual_stack_test.go",
- "sack_scoreboard_test.go",
- "tcp_noracedetector_test.go",
- "tcp_sack_test.go",
- "tcp_test.go",
- "tcp_timestamp_test.go",
- ],
- # FIXME(b/68809571)
- tags = ["flaky"],
- deps = [
- ":tcp",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/loopback",
- "//pkg/tcpip/link/sniffer",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/ports",
- "//pkg/tcpip/seqnum",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/tcp/testing/context",
- "//pkg/waiter",
- ],
-)
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index e18012ac0..e18012ac0 100644..100755
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
deleted file mode 100644
index 4f361b226..000000000
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ /dev/null
@@ -1,652 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp_test
-
-import (
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-func TestV4MappedConnectOnV6Only(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(true)
-
- // Start connection attempt, it must fail.
- err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort})
- if err != tcpip.ErrNoRoute {
- t.Fatalf("Unexpected return value from Connect: %v", err)
- }
-}
-
-func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) {
- // Start connection attempt.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.EventOut)
- defer c.WQ.EventUnregister(&we)
-
- err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort})
- if err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %v", err)
- }
-
- // Receive SYN packet.
- b := c.GetPacket()
- synCheckers := append(checkers, checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ))
- checker.IPv4(t, b, synCheckers...)
-
- tcp := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcp.SequenceNumber())
-
- iss := seqnum.Value(789)
- c.SendPacket(nil, &context.Headers{
- SrcPort: tcp.DestinationPort(),
- DstPort: tcp.SourcePort(),
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Receive ACK packet.
- ackCheckers := append(checkers, checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+1),
- ))
- checker.IPv4(t, c.GetPacket(), ackCheckers...)
-
- // Wait for connection to be established.
- select {
- case <-ch:
- err = c.EP.GetSockOpt(tcpip.ErrorOption{})
- if err != nil {
- t.Fatalf("Unexpected error when connecting: %v", err)
- }
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for connection")
- }
-}
-
-func TestV4MappedConnect(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Test the connection request.
- testV4Connect(t, c)
-}
-
-func TestV4ConnectWhenBoundToWildcard(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test the connection request.
- testV4Connect(t, c)
-}
-
-func TestV4ConnectWhenBoundToV4MappedWildcard(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind to v4 mapped wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test the connection request.
- testV4Connect(t, c)
-}
-
-func TestV4ConnectWhenBoundToV4Mapped(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind to v4 mapped address.
- if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test the connection request.
- testV4Connect(t, c)
-}
-
-func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) {
- // Start connection attempt to IPv6 address.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.EventOut)
- defer c.WQ.EventUnregister(&we)
-
- err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort})
- if err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %v", err)
- }
-
- // Receive SYN packet.
- b := c.GetV6Packet()
- synCheckers := append(checkers, checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ))
- checker.IPv6(t, b, synCheckers...)
-
- tcp := header.TCP(header.IPv6(b).Payload())
- c.IRS = seqnum.Value(tcp.SequenceNumber())
-
- iss := seqnum.Value(789)
- c.SendV6Packet(nil, &context.Headers{
- SrcPort: tcp.DestinationPort(),
- DstPort: tcp.SourcePort(),
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Receive ACK packet.
- ackCheckers := append(checkers, checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+1),
- ))
- checker.IPv6(t, c.GetV6Packet(), ackCheckers...)
-
- // Wait for connection to be established.
- select {
- case <-ch:
- err = c.EP.GetSockOpt(tcpip.ErrorOption{})
- if err != nil {
- t.Fatalf("Unexpected error when connecting: %v", err)
- }
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for connection")
- }
-}
-
-func TestV6Connect(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Test the connection request.
- testV6Connect(t, c)
-}
-
-func TestV6ConnectV6Only(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(true)
-
- // Test the connection request.
- testV6Connect(t, c)
-}
-
-func TestV6ConnectWhenBoundToWildcard(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test the connection request.
- testV6Connect(t, c)
-}
-
-func TestV6ConnectWhenBoundToLocalAddress(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind to local address.
- if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV6Addr}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test the connection request.
- testV6Connect(t, c)
-}
-
-func TestV4RefuseOnV6Only(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(true)
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Start listening.
- if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-
- // Send a SYN request.
- irs := seqnum.Value(789)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
-
- // Receive the RST reply.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.AckNum(uint32(irs)+1),
- ),
- )
-}
-
-func TestV6RefuseOnBoundToV4Mapped(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind and listen.
- if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-
- // Send a SYN request.
- irs := seqnum.Value(789)
- c.SendV6Packet(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
-
- // Receive the RST reply.
- checker.IPv6(t, c.GetV6Packet(),
- checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.AckNum(uint32(irs)+1),
- ),
- )
-}
-
-func testV4Accept(t *testing.T, c *context.Context) {
- c.SetGSOEnabled(true)
- defer c.SetGSOEnabled(false)
-
- // Start listening.
- if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-
- // Send a SYN request.
- irs := seqnum.Value(789)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcp := header.TCP(header.IPv4(b).Payload())
- iss := seqnum.Value(tcp.SequenceNumber())
- checker.IPv4(t, b,
- checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.AckNum(uint32(irs)+1),
- ),
- )
-
- // Send ACK.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.EventIn)
- defer c.WQ.EventUnregister(&we)
-
- nep, _, err := c.EP.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
- select {
- case <-ch:
- nep, _, err = c.EP.Accept()
- if err != nil {
- t.Fatalf("Accept failed: %v", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Make sure we get the same error when calling the original ep and the
- // new one. This validates that v4-mapped endpoints are still able to
- // query the V6Only flag, whereas pure v4 endpoints are not.
- _, expected := c.EP.GetSockOptBool(tcpip.V6OnlyOption)
- if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != expected {
- t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected)
- }
-
- // Check the peer address.
- addr, err := nep.GetRemoteAddress()
- if err != nil {
- t.Fatalf("GetRemoteAddress failed failed: %v", err)
- }
-
- if addr.Addr != context.TestAddr {
- t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestAddr)
- }
-
- data := "Don't panic"
- nep.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
- b = c.GetPacket()
- tcp = header.TCP(header.IPv4(b).Payload())
- if string(tcp.Payload()) != data {
- t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data)
- }
-}
-
-func TestV4AcceptOnV6(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test acceptance.
- testV4Accept(t, c)
-}
-
-func TestV4AcceptOnBoundToV4MappedWildcard(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind to v4 mapped wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test acceptance.
- testV4Accept(t, c)
-}
-
-func TestV4AcceptOnBoundToV4Mapped(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind and listen.
- if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test acceptance.
- testV4Accept(t, c)
-}
-
-func TestV6AcceptOnV6(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- // Bind and listen.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-
- // Send a SYN request.
- irs := seqnum.Value(789)
- c.SendV6Packet(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetV6Packet()
- tcp := header.TCP(header.IPv6(b).Payload())
- iss := seqnum.Value(tcp.SequenceNumber())
- checker.IPv6(t, b,
- checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.AckNum(uint32(irs)+1),
- ),
- )
-
- // Send ACK.
- c.SendV6Packet(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.EventIn)
- defer c.WQ.EventUnregister(&we)
-
- nep, _, err := c.EP.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
- select {
- case <-ch:
- nep, _, err = c.EP.Accept()
- if err != nil {
- t.Fatalf("Accept failed: %v", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Make sure we can still query the v6 only status of the new endpoint,
- // that is, that it is in fact a v6 socket.
- if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil {
- t.Fatalf("GetSockOpt failed failed: %v", err)
- }
-
- // Check the peer address.
- addr, err := nep.GetRemoteAddress()
- if err != nil {
- t.Fatalf("GetRemoteAddress failed failed: %v", err)
- }
-
- if addr.Addr != context.TestV6Addr {
- t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestV6Addr)
- }
-}
-
-func TestV4AcceptOnV4(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err *tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test acceptance.
- testV4Accept(t, c)
-}
-
-func testV4ListenClose(t *testing.T, c *context.Context) {
- // Set the SynRcvd threshold to zero to force a syn cookie based accept
- // to happen.
- saved := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = saved
- }()
- tcp.SynRcvdCountThreshold = 0
- const n = uint16(32)
-
- // Start listening.
- if err := c.EP.Listen(int(tcp.SynRcvdCountThreshold + 1)); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-
- irs := seqnum.Value(789)
- for i := uint16(0); i < n; i++ {
- // Send a SYN request.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort + i,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
- }
-
- // Each of these ACK's will cause a syn-cookie based connection to be
- // accepted and delivered to the listening endpoint.
- for i := uint16(0); i < n; i++ {
- b := c.GetPacket()
- tcp := header.TCP(header.IPv4(b).Payload())
- iss := seqnum.Value(tcp.SequenceNumber())
- // Send ACK.
- c.SendPacket(nil, &context.Headers{
- SrcPort: tcp.DestinationPort(),
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
- }
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.EventIn)
- defer c.WQ.EventUnregister(&we)
- nep, _, err := c.EP.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
- select {
- case <-ch:
- nep, _, err = c.EP.Accept()
- if err != nil {
- t.Fatalf("Accept failed: %v", err)
- }
-
- case <-time.After(10 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
- nep.Close()
- c.EP.Close()
-}
-
-func TestV4ListenCloseOnV4(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err *tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test acceptance.
- testV4ListenClose(t, c)
-}
diff --git a/pkg/tcpip/transport/tcp/rcv_state.go b/pkg/tcpip/transport/tcp/rcv_state.go
index 2bf21a2e7..2bf21a2e7 100644..100755
--- a/pkg/tcpip/transport/tcp/rcv_state.go
+++ b/pkg/tcpip/transport/tcp/rcv_state.go
diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard_test.go b/pkg/tcpip/transport/tcp/sack_scoreboard_test.go
deleted file mode 100644
index b4e5ba0df..000000000
--- a/pkg/tcpip/transport/tcp/sack_scoreboard_test.go
+++ /dev/null
@@ -1,249 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp_test
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
-)
-
-const smss = 1500
-
-func initScoreboard(blocks []header.SACKBlock, iss seqnum.Value) *tcp.SACKScoreboard {
- s := tcp.NewSACKScoreboard(smss, iss)
- for _, blk := range blocks {
- s.Insert(blk)
- }
- return s
-}
-
-func TestSACKScoreboardIsSACKED(t *testing.T) {
- type blockTest struct {
- block header.SACKBlock
- sacked bool
- }
- testCases := []struct {
- comment string
- scoreboardBlocks []header.SACKBlock
- blockTests []blockTest
- iss seqnum.Value
- }{
- {
- "Test holes and unsacked SACK blocks in SACKed ranges and insertion of overlapping SACK blocks",
- []header.SACKBlock{{10, 20}, {10, 30}, {30, 40}, {41, 50}, {5, 10}, {1, 50}, {111, 120}, {101, 110}, {52, 120}},
- []blockTest{
- {header.SACKBlock{15, 21}, true},
- {header.SACKBlock{200, 201}, false},
- {header.SACKBlock{50, 51}, false},
- {header.SACKBlock{53, 120}, true},
- },
- 0,
- },
- {
- "Test disjoint SACKBlocks",
- []header.SACKBlock{{2288624809, 2288810057}, {2288811477, 2288838565}},
- []blockTest{
- {header.SACKBlock{2288624809, 2288810057}, true},
- {header.SACKBlock{2288811477, 2288838565}, true},
- {header.SACKBlock{2288810057, 2288811477}, false},
- },
- 2288624809,
- },
- {
- "Test sequence number wrap around",
- []header.SACKBlock{{4294254144, 225652}, {5340409, 5350509}},
- []blockTest{
- {header.SACKBlock{4294254144, 4294254145}, true},
- {header.SACKBlock{4294254143, 4294254144}, false},
- {header.SACKBlock{4294254144, 1}, true},
- {header.SACKBlock{225652, 5350509}, false},
- {header.SACKBlock{5340409, 5350509}, true},
- {header.SACKBlock{5350509, 5350609}, false},
- },
- 4294254144,
- },
- {
- "Test disjoint SACKBlocks out of order",
- []header.SACKBlock{{827450276, 827454536}, {827426028, 827428868}},
- []blockTest{
- {header.SACKBlock{827426028, 827428867}, true},
- {header.SACKBlock{827450168, 827450275}, false},
- },
- 827426000,
- },
- }
- for _, tc := range testCases {
- sb := initScoreboard(tc.scoreboardBlocks, tc.iss)
- for _, blkTest := range tc.blockTests {
- if want, got := blkTest.sacked, sb.IsSACKED(blkTest.block); got != want {
- t.Errorf("%s: s.IsSACKED(%v) = %v, want %v", tc.comment, blkTest.block, got, want)
- }
- }
- }
-}
-
-func TestSACKScoreboardIsRangeLost(t *testing.T) {
- s := tcp.NewSACKScoreboard(10, 0)
- s.Insert(header.SACKBlock{1, 25})
- s.Insert(header.SACKBlock{25, 50})
- s.Insert(header.SACKBlock{51, 100})
- s.Insert(header.SACKBlock{111, 120})
- s.Insert(header.SACKBlock{101, 110})
- s.Insert(header.SACKBlock{121, 141})
- s.Insert(header.SACKBlock{145, 146})
- s.Insert(header.SACKBlock{147, 148})
- s.Insert(header.SACKBlock{149, 150})
- s.Insert(header.SACKBlock{153, 154})
- s.Insert(header.SACKBlock{155, 156})
- testCases := []struct {
- block header.SACKBlock
- lost bool
- }{
- // Block not covered by SACK block and has more than
- // nDupAckThreshold discontiguous SACK blocks after it as well
- // as (nDupAckThreshold -1) * 10 (smss) bytes that have been
- // SACKED above the sequence number covered by this block.
- {block: header.SACKBlock{0, 1}, lost: true},
-
- // These blocks have all been SACKed and should not be
- // considered lost.
- {block: header.SACKBlock{1, 2}, lost: false},
- {block: header.SACKBlock{25, 26}, lost: false},
- {block: header.SACKBlock{1, 45}, lost: false},
-
- // Same as the first case above.
- {block: header.SACKBlock{50, 51}, lost: true},
-
- // This block has been SACKed and should not be considered lost.
- {block: header.SACKBlock{119, 120}, lost: false},
-
- // This one should return true because there are >
- // (nDupAckThreshold - 1) * 10 (smss) bytes that have been
- // sacked above this sequence number.
- {block: header.SACKBlock{120, 121}, lost: true},
-
- // This block has been SACKed and should not be considered lost.
- {block: header.SACKBlock{125, 126}, lost: false},
-
- // This block has not been SACKed and there are nDupAckThreshold
- // number of SACKed blocks after it.
- {block: header.SACKBlock{141, 145}, lost: true},
-
- // This block has not been SACKed and there are less than
- // nDupAckThreshold SACKed sequences after it.
- {block: header.SACKBlock{151, 152}, lost: false},
- }
- for _, tc := range testCases {
- if want, got := tc.lost, s.IsRangeLost(tc.block); got != want {
- t.Errorf("s.IsRangeLost(%v) = %v, want %v", tc.block, got, want)
- }
- }
-}
-
-func TestSACKScoreboardIsLost(t *testing.T) {
- s := tcp.NewSACKScoreboard(10, 0)
- s.Insert(header.SACKBlock{1, 25})
- s.Insert(header.SACKBlock{25, 50})
- s.Insert(header.SACKBlock{51, 100})
- s.Insert(header.SACKBlock{111, 120})
- s.Insert(header.SACKBlock{101, 110})
- s.Insert(header.SACKBlock{121, 141})
- s.Insert(header.SACKBlock{121, 141})
- s.Insert(header.SACKBlock{145, 146})
- s.Insert(header.SACKBlock{147, 148})
- s.Insert(header.SACKBlock{149, 150})
- s.Insert(header.SACKBlock{153, 154})
- s.Insert(header.SACKBlock{155, 156})
- testCases := []struct {
- seq seqnum.Value
- lost bool
- }{
- // Sequence number not covered by SACK block and has more than
- // nDupAckThreshold discontiguous SACK blocks after it as well
- // as (nDupAckThreshold -1) * 10 (smss) bytes that have been
- // SACKED above the sequence number.
- {seq: 0, lost: true},
-
- // These sequence numbers have all been SACKed and should not be
- // considered lost.
- {seq: 1, lost: false},
- {seq: 25, lost: false},
- {seq: 45, lost: false},
-
- // Same as first case above.
- {seq: 50, lost: true},
-
- // This block has been SACKed and should not be considered lost.
- {seq: 119, lost: false},
-
- // This one should return true because there are >
- // (nDupAckThreshold - 1) * 10 (smss) bytes that have been
- // sacked above this sequence number.
- {seq: 120, lost: true},
-
- // This sequence number has been SACKed and should not be
- // considered lost.
- {seq: 125, lost: false},
-
- // This sequence number has not been SACKed and there are
- // nDupAckThreshold number of SACKed blocks after it.
- {seq: 141, lost: true},
-
- // This sequence number has not been SACKed and there are less
- // than nDupAckThreshold SACKed sequences after it.
- {seq: 151, lost: false},
- }
- for _, tc := range testCases {
- if want, got := tc.lost, s.IsLost(tc.seq); got != want {
- t.Errorf("s.IsLost(%v) = %v, want %v", tc.seq, got, want)
- }
- }
-}
-
-func TestSACKScoreboardDelete(t *testing.T) {
- blocks := []header.SACKBlock{{4294254144, 225652}, {5340409, 5350509}}
- s := initScoreboard(blocks, 4294254143)
- s.Delete(5340408)
- if s.Empty() {
- t.Fatalf("s.Empty() = true, want false")
- }
- if got, want := s.Sacked(), blocks[1].Start.Size(blocks[1].End); got != want {
- t.Fatalf("incorrect sacked bytes in scoreboard got: %v, want: %v", got, want)
- }
- s.Delete(5340410)
- if s.Empty() {
- t.Fatal("s.Empty() = true, want false")
- }
- newSB := header.SACKBlock{5340410, 5350509}
- if !s.IsSACKED(newSB) {
- t.Fatalf("s.IsSACKED(%v) = false, want true, scoreboard: %v", newSB, s)
- }
- s.Delete(5350509)
- lastOctet := header.SACKBlock{5350508, 5350509}
- if s.IsSACKED(lastOctet) {
- t.Fatalf("s.IsSACKED(%v) = false, want true", lastOctet)
- }
-
- s.Delete(5350510)
- if !s.Empty() {
- t.Fatal("s.Empty() = false, want true")
- }
- if got, want := s.Sacked(), seqnum.Size(0); got != want {
- t.Fatalf("incorrect sacked bytes in scoreboard got: %v, want: %v", got, want)
- }
-}
diff --git a/pkg/tcpip/transport/tcp/tcp_endpoint_list.go b/pkg/tcpip/transport/tcp/tcp_endpoint_list.go
new file mode 100755
index 000000000..42e190060
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_endpoint_list.go
@@ -0,0 +1,173 @@
+package tcp
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type endpointElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (endpointElementMapper) linkerFor(elem *endpoint) *endpoint { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type endpointList struct {
+ head *endpoint
+ tail *endpoint
+}
+
+// Reset resets list l to the empty state.
+func (l *endpointList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *endpointList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *endpointList) Front() *endpoint {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *endpointList) Back() *endpoint {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *endpointList) PushFront(e *endpoint) {
+ endpointElementMapper{}.linkerFor(e).SetNext(l.head)
+ endpointElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ endpointElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *endpointList) PushBack(e *endpoint) {
+ endpointElementMapper{}.linkerFor(e).SetNext(nil)
+ endpointElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ endpointElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *endpointList) PushBackList(m *endpointList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ endpointElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ endpointElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *endpointList) InsertAfter(b, e *endpoint) {
+ a := endpointElementMapper{}.linkerFor(b).Next()
+ endpointElementMapper{}.linkerFor(e).SetNext(a)
+ endpointElementMapper{}.linkerFor(e).SetPrev(b)
+ endpointElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ endpointElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *endpointList) InsertBefore(a, e *endpoint) {
+ b := endpointElementMapper{}.linkerFor(a).Prev()
+ endpointElementMapper{}.linkerFor(e).SetNext(a)
+ endpointElementMapper{}.linkerFor(e).SetPrev(b)
+ endpointElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ endpointElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *endpointList) Remove(e *endpoint) {
+ prev := endpointElementMapper{}.linkerFor(e).Prev()
+ next := endpointElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ endpointElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ endpointElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type endpointEntry struct {
+ next *endpoint
+ prev *endpoint
+}
+
+// Next returns the entry that follows e in the list.
+func (e *endpointEntry) Next() *endpoint {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *endpointEntry) Prev() *endpoint {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *endpointEntry) SetNext(elem *endpoint) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *endpointEntry) SetPrev(elem *endpoint) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
deleted file mode 100644
index 782d7b42c..000000000
--- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
+++ /dev/null
@@ -1,527 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-// These tests are flaky when run under the go race detector due to some
-// iterations taking long enough that the retransmit timer can kick in causing
-// the congestion window measurements to fail due to extra packets etc.
-//
-// +build !race
-
-package tcp_test
-
-import (
- "fmt"
- "math"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
-)
-
-func TestFastRecovery(t *testing.T) {
- maxPayload := 32
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- const iterations = 7
- data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write all the data in one shot. Packets will only be written at the
- // MTU size though.
- if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- // Do slow start for a few iterations.
- expected := tcp.InitialCwnd
- bytesRead := 0
- for i := 0; i < iterations; i++ {
- expected = tcp.InitialCwnd << uint(i)
- if i > 0 {
- // Acknowledge all the data received so far if not on
- // first iteration.
- c.SendAck(790, bytesRead)
- }
-
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
- }
-
- // Send 3 duplicate acks. This should force an immediate retransmit of
- // the pending packet and put the sender into fast recovery.
- rtxOffset := bytesRead - maxPayload*expected
- for i := 0; i < 3; i++ {
- c.SendAck(790, rtxOffset)
- }
-
- // Receive the retransmitted packet.
- c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
-
- if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
- }
-
- if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want)
- }
-
- if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.FastRecovery.Value = %v, want = %v", got, want)
- }
-
- // Now send 7 mode duplicate acks. Each of these should cause a window
- // inflation by 1 and cause the sender to send an extra packet.
- for i := 0; i < 7; i++ {
- c.SendAck(790, rtxOffset)
- }
-
- recover := bytesRead
-
- // Ensure no new packets arrive.
- c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.",
- 50*time.Millisecond)
-
- // Acknowledge half of the pending data.
- rtxOffset = bytesRead - expected*maxPayload/2
- c.SendAck(790, rtxOffset)
-
- // Receive the retransmit due to partial ack.
- c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
-
- if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want {
- t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
- }
-
- if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want {
- t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want)
- }
-
- // Receive the 10 extra packets that should have been released due to
- // the congestion window inflation in recovery.
- for i := 0; i < 10; i++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // A partial ACK during recovery should reduce congestion window by the
- // number acked. Since we had "expected" packets outstanding before sending
- // partial ack and we acked expected/2 , the cwnd and outstanding should
- // be expected/2 + 10 (7 dupAcks + 3 for the original 3 dupacks that triggered
- // fast recovery). Which means the sender should not send any more packets
- // till we ack this one.
- c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.",
- 50*time.Millisecond)
-
- // Acknowledge all pending data to recover point.
- c.SendAck(790, recover)
-
- // At this point, the cwnd should reset to expected/2 and there are 10
- // packets outstanding.
- //
- // NOTE: Technically netstack is incorrect in that we adjust the cwnd on
- // the same segment that takes us out of recovery. But because of that
- // the actual cwnd at exit of recovery will be expected/2 + 1 as we
- // acked a cwnd worth of packets which will increase the cwnd further by
- // 1 in congestion avoidance.
- //
- // Now in the first iteration since there are 10 packets outstanding.
- // We would expect to get expected/2 +1 - 10 packets. But subsequent
- // iterations will send us expected/2 + 1 + 1 (per iteration).
- expected = expected/2 + 1 - 10
- for i := 0; i < iterations; i++ {
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd.", expected), 50*time.Millisecond)
-
- // Acknowledge all the data received so far.
- c.SendAck(790, bytesRead)
-
- // In cogestion avoidance, the packets trains increase by 1 in
- // each iteration.
- if i == 0 {
- // After the first iteration we expect to get the full
- // congestion window worth of packets in every
- // iteration.
- expected += 10
- }
- expected++
- }
-}
-
-func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
- maxPayload := 32
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- const iterations = 7
- data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write all the data in one shot. Packets will only be written at the
- // MTU size though.
- if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- expected := tcp.InitialCwnd
- bytesRead := 0
- for i := 0; i < iterations; i++ {
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
-
- // Acknowledge all the data received so far.
- c.SendAck(790, bytesRead)
-
- // Double the number of expected packets for the next iteration.
- expected *= 2
- }
-}
-
-func TestCongestionAvoidance(t *testing.T) {
- maxPayload := 32
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- const iterations = 7
- data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write all the data in one shot. Packets will only be written at the
- // MTU size though.
- if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- // Do slow start for a few iterations.
- expected := tcp.InitialCwnd
- bytesRead := 0
- for i := 0; i < iterations; i++ {
- expected = tcp.InitialCwnd << uint(i)
- if i > 0 {
- // Acknowledge all the data received so far if not on
- // first iteration.
- c.SendAck(790, bytesRead)
- }
-
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd (slow start phase).", 50*time.Millisecond)
- }
-
- // Don't acknowledge the first packet of the last packet train. Let's
- // wait for them to time out, which will trigger a restart of slow
- // start, and initialization of ssthresh to cwnd/2.
- rtxOffset := bytesRead - maxPayload*expected
- c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
-
- // Acknowledge all the data received so far.
- c.SendAck(790, bytesRead)
-
- // This part is tricky: when the timeout happened, we had "expected"
- // packets pending, cwnd reset to 1, and ssthresh set to expected/2.
- // By acknowledging "expected" packets, the slow-start part will
- // increase cwnd to expected/2 (which "consumes" expected/2-1 of the
- // acknowledgements), then the congestion avoidance part will consume
- // an extra expected/2 acks to take cwnd to expected/2 + 1. One ack
- // remains in the "ack count" (which will cause cwnd to be incremented
- // once it reaches cwnd acks).
- //
- // So we're straight into congestion avoidance with cwnd set to
- // expected/2 + 1.
- //
- // Check that packets trains of cwnd packets are sent, and that cwnd is
- // incremented by 1 after we acknowledge each packet.
- expected = expected/2 + 1
- for i := 0; i < iterations; i++ {
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd (congestion avoidance phase).", 50*time.Millisecond)
-
- // Acknowledge all the data received so far.
- c.SendAck(790, bytesRead)
-
- // In cogestion avoidance, the packets trains increase by 1 in
- // each iteration.
- expected++
- }
-}
-
-// cubicCwnd returns an estimate of a cubic window given the
-// originalCwnd, wMax, last congestion event time and sRTT.
-func cubicCwnd(origCwnd int, wMax int, congEventTime time.Time, sRTT time.Duration) int {
- cwnd := float64(origCwnd)
- // We wait 50ms between each iteration so sRTT as computed by cubic
- // should be close to 50ms.
- elapsed := (time.Since(congEventTime) + sRTT).Seconds()
- k := math.Cbrt(float64(wMax) * 0.3 / 0.7)
- wtRTT := 0.4*math.Pow(elapsed-k, 3) + float64(wMax)
- cwnd += (wtRTT - cwnd) / cwnd
- return int(cwnd)
-}
-
-func TestCubicCongestionAvoidance(t *testing.T) {
- maxPayload := 32
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- enableCUBIC(t, c)
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- const iterations = 7
- data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
-
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write all the data in one shot. Packets will only be written at the
- // MTU size though.
- if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- // Do slow start for a few iterations.
- expected := tcp.InitialCwnd
- bytesRead := 0
- for i := 0; i < iterations; i++ {
- expected = tcp.InitialCwnd << uint(i)
- if i > 0 {
- // Acknowledge all the data received so far if not on
- // first iteration.
- c.SendAck(790, bytesRead)
- }
-
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd (during slow-start phase).", 50*time.Millisecond)
- }
-
- // Don't acknowledge the first packet of the last packet train. Let's
- // wait for them to time out, which will trigger a restart of slow
- // start, and initialization of ssthresh to cwnd * 0.7.
- rtxOffset := bytesRead - maxPayload*expected
- c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
-
- // Acknowledge all pending data.
- c.SendAck(790, bytesRead)
-
- // Store away the time we sent the ACK and assuming a 200ms RTO
- // we estimate that the sender will have an RTO 200ms from now
- // and go back into slow start.
- packetDropTime := time.Now().Add(200 * time.Millisecond)
-
- // This part is tricky: when the timeout happened, we had "expected"
- // packets pending, cwnd reset to 1, and ssthresh set to expected * 0.7.
- // By acknowledging "expected" packets, the slow-start part will
- // increase cwnd to expected/2 essentially putting the connection
- // straight into congestion avoidance.
- wMax := expected
- // Lower expected as per cubic spec after a congestion event.
- expected = int(float64(expected) * 0.7)
- cwnd := expected
- for i := 0; i < iterations; i++ {
- // Cubic grows window independent of ACKs. Cubic Window growth
- // is a function of time elapsed since last congestion event.
- // As a result the congestion window does not grow
- // deterministically in response to ACKs.
- //
- // We need to roughly estimate what the cwnd of the sender is
- // based on when we sent the dupacks.
- cwnd := cubicCwnd(cwnd, wMax, packetDropTime, 50*time.Millisecond)
-
- packetsExpected := cwnd
- for j := 0; j < packetsExpected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
- t.Logf("expected packets received, next trying to receive any extra packets that may come")
-
- // If our estimate was correct there should be no more pending packets.
- // We attempt to read a packet a few times with a short sleep in between
- // to ensure that we don't see the sender send any unexpected packets.
- unexpectedPackets := 0
- for {
- gotPacket := c.ReceiveNonBlockingAndCheckPacket(data, bytesRead, maxPayload)
- if !gotPacket {
- break
- }
- bytesRead += maxPayload
- unexpectedPackets++
- time.Sleep(1 * time.Millisecond)
- }
- if unexpectedPackets != 0 {
- t.Fatalf("received %d unexpected packets for iteration %d", unexpectedPackets, i)
- }
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd(congestion avoidance)", 5*time.Millisecond)
-
- // Acknowledge all the data received so far.
- c.SendAck(790, bytesRead)
- }
-}
-
-func TestRetransmit(t *testing.T) {
- maxPayload := 32
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- const iterations = 7
- data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write all the data in two shots. Packets will only be written at the
- // MTU size though.
- half := data[:len(data)/2]
- if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
- half = data[len(data)/2:]
- if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- // Do slow start for a few iterations.
- expected := tcp.InitialCwnd
- bytesRead := 0
- for i := 0; i < iterations; i++ {
- expected = tcp.InitialCwnd << uint(i)
- if i > 0 {
- // Acknowledge all the data received so far if not on
- // first iteration.
- c.SendAck(790, bytesRead)
- }
-
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
- }
-
- // Wait for a timeout and retransmit.
- rtxOffset := bytesRead - maxPayload*expected
- c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
-
- if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.Timeouts.Value = %v, want = %v", got, want)
- }
-
- if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
- }
-
- if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want {
- t.Errorf("got EP SendErrors.Timeouts.Value = %v, want = %v", got, want)
- }
-
- if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want {
- t.Errorf("got EP stats SendErrors.Retransmits.Value = %v, want = %v", got, want)
- }
-
- if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want)
- }
-
- // Acknowledge half of the pending data.
- rtxOffset = bytesRead - expected*maxPayload/2
- c.SendAck(790, rtxOffset)
-
- // Receive the remaining data, making sure that acknowledged data is not
- // retransmitted.
- for offset := rtxOffset; offset < len(data); offset += maxPayload {
- c.ReceiveAndCheckPacket(data, offset, maxPayload)
- c.SendAck(790, offset+maxPayload)
- }
-
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
-}
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
deleted file mode 100644
index afea124ec..000000000
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ /dev/null
@@ -1,572 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp_test
-
-import (
- "fmt"
- "log"
- "reflect"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
-)
-
-// createConnectedWithSACKPermittedOption creates and connects c.ep with the
-// SACKPermitted option enabled if the stack in the context has the SACK support
-// enabled.
-func createConnectedWithSACKPermittedOption(c *context.Context) *context.RawEndpoint {
- return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()})
-}
-
-// createConnectedWithSACKAndTS creates and connects c.ep with the SACK & TS
-// option enabled if the stack in the context has SACK and TS enabled.
-func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint {
- return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true})
-}
-
-func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) {
- t.Helper()
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil {
- t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%v) = %v", enable, err)
- }
-}
-
-// TestSackPermittedConnect establishes a connection with the SACK option
-// enabled.
-func TestSackPermittedConnect(t *testing.T) {
- for _, sackEnabled := range []bool{false, true} {
- t.Run(fmt.Sprintf("stack.sackEnabled: %v", sackEnabled), func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- setStackSACKPermitted(t, c, sackEnabled)
- rep := createConnectedWithSACKPermittedOption(c)
- data := []byte{1, 2, 3}
-
- rep.SendPacket(data, nil)
- savedSeqNum := rep.NextSeqNum
- rep.VerifyACKNoSACK()
-
- // Make an out of order packet and send it.
- rep.NextSeqNum += 3
- sackBlocks := []header.SACKBlock{
- {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))},
- }
- rep.SendPacket(data, nil)
-
- // Restore the saved sequence number so that the
- // VerifyXXX calls use the right sequence number for
- // checking ACK numbers.
- rep.NextSeqNum = savedSeqNum
- if sackEnabled {
- rep.VerifyACKHasSACK(sackBlocks)
- } else {
- rep.VerifyACKNoSACK()
- }
-
- // Send the missing segment.
- rep.SendPacket(data, nil)
- // The ACK should contain the cumulative ACK for all 9
- // bytes sent and no SACK blocks.
- rep.NextSeqNum += 3
- // Check that no SACK block is returned in the ACK.
- rep.VerifyACKNoSACK()
- })
- }
-}
-
-// TestSackDisabledConnect establishes a connection with the SACK option
-// disabled and verifies that no SACKs are sent for out of order segments.
-func TestSackDisabledConnect(t *testing.T) {
- for _, sackEnabled := range []bool{false, true} {
- t.Run(fmt.Sprintf("sackEnabled: %v", sackEnabled), func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- setStackSACKPermitted(t, c, sackEnabled)
-
- rep := c.CreateConnectedWithOptions(header.TCPSynOptions{})
-
- data := []byte{1, 2, 3}
-
- rep.SendPacket(data, nil)
- savedSeqNum := rep.NextSeqNum
- rep.VerifyACKNoSACK()
-
- // Make an out of order packet and send it.
- rep.NextSeqNum += 3
- rep.SendPacket(data, nil)
-
- // The ACK should contain the older sequence number and
- // no SACK blocks.
- rep.NextSeqNum = savedSeqNum
- rep.VerifyACKNoSACK()
-
- // Send the missing segment.
- rep.SendPacket(data, nil)
- // The ACK should contain the cumulative ACK for all 9
- // bytes sent and no SACK blocks.
- rep.NextSeqNum += 3
- // Check that no SACK block is returned in the ACK.
- rep.VerifyACKNoSACK()
- })
- }
-}
-
-// TestSackPermittedAccept accepts and establishes a connection with the
-// SACKPermitted option enabled if the connection request specifies the
-// SACKPermitted option. In case of SYN cookies SACK should be disabled as we
-// don't encode the SACK information in the cookie.
-func TestSackPermittedAccept(t *testing.T) {
- type testCase struct {
- cookieEnabled bool
- sackPermitted bool
- wndScale int
- wndSize uint16
- }
-
- testCases := []testCase{
- // When cookie is used window scaling is disabled.
- {true, false, -1, 0xffff}, // When cookie is used window scaling is disabled.
- {false, true, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
- }
- savedSynCountThreshold := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = savedSynCountThreshold
- }()
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) {
- if tc.cookieEnabled {
- tcp.SynRcvdCountThreshold = 0
- } else {
- tcp.SynRcvdCountThreshold = savedSynCountThreshold
- }
- for _, sackEnabled := range []bool{false, true} {
- t.Run(fmt.Sprintf("test stack.sackEnabled: %v", sackEnabled), func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- setStackSACKPermitted(t, c, sackEnabled)
-
- rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted})
- // Now verify no SACK blocks are
- // received when sack is disabled.
- data := []byte{1, 2, 3}
- rep.SendPacket(data, nil)
- rep.VerifyACKNoSACK()
-
- savedSeqNum := rep.NextSeqNum
-
- // Make an out of order packet and send
- // it.
- rep.NextSeqNum += 3
- sackBlocks := []header.SACKBlock{
- {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))},
- }
- rep.SendPacket(data, nil)
-
- // The ACK should contain the older
- // sequence number.
- rep.NextSeqNum = savedSeqNum
- if sackEnabled && tc.sackPermitted {
- rep.VerifyACKHasSACK(sackBlocks)
- } else {
- rep.VerifyACKNoSACK()
- }
-
- // Send the missing segment.
- rep.SendPacket(data, nil)
- // The ACK should contain the cumulative
- // ACK for all 9 bytes sent and no SACK
- // blocks.
- rep.NextSeqNum += 3
- // Check that no SACK block is returned
- // in the ACK.
- rep.VerifyACKNoSACK()
- })
- }
- })
- }
-}
-
-// TestSackDisabledAccept accepts and establishes a connection with
-// the SACKPermitted option disabled and verifies that no SACKs are
-// sent for out of order packets.
-func TestSackDisabledAccept(t *testing.T) {
- type testCase struct {
- cookieEnabled bool
- wndScale int
- wndSize uint16
- }
-
- testCases := []testCase{
- // When cookie is used window scaling is disabled.
- {true, -1, 0xffff}, // When cookie is used window scaling is disabled.
- {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
- }
- savedSynCountThreshold := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = savedSynCountThreshold
- }()
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) {
- if tc.cookieEnabled {
- tcp.SynRcvdCountThreshold = 0
- } else {
- tcp.SynRcvdCountThreshold = savedSynCountThreshold
- }
- for _, sackEnabled := range []bool{false, true} {
- t.Run(fmt.Sprintf("test: sackEnabled: %v", sackEnabled), func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- setStackSACKPermitted(t, c, sackEnabled)
-
- rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
-
- // Now verify no SACK blocks are
- // received when sack is disabled.
- data := []byte{1, 2, 3}
- rep.SendPacket(data, nil)
- rep.VerifyACKNoSACK()
- savedSeqNum := rep.NextSeqNum
-
- // Make an out of order packet and send
- // it.
- rep.NextSeqNum += 3
- rep.SendPacket(data, nil)
-
- // The ACK should contain the older
- // sequence number and no SACK blocks.
- rep.NextSeqNum = savedSeqNum
- rep.VerifyACKNoSACK()
-
- // Send the missing segment.
- rep.SendPacket(data, nil)
- // The ACK should contain the cumulative
- // ACK for all 9 bytes sent and no SACK
- // blocks.
- rep.NextSeqNum += 3
- // Check that no SACK block is returned
- // in the ACK.
- rep.VerifyACKNoSACK()
- })
- }
- })
- }
-}
-
-func TestUpdateSACKBlocks(t *testing.T) {
- testCases := []struct {
- segStart seqnum.Value
- segEnd seqnum.Value
- rcvNxt seqnum.Value
- sackBlocks []header.SACKBlock
- updated []header.SACKBlock
- }{
- // Trivial cases where current SACK block list is empty and we
- // have an out of order delivery.
- {10, 11, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 11}}},
- {10, 12, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 12}}},
- {10, 20, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 20}}},
-
- // Cases where current SACK block list is not empty and we have
- // an out of order delivery. Tests that the updated SACK block
- // list has the first block as the one that contains the new
- // SACK block representing the segment that was just delivered.
- {10, 11, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 11}, {12, 20}}},
- {24, 30, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{24, 30}, {12, 20}}},
- {24, 30, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}}},
-
- // Ensure that we only retain header.MaxSACKBlocks and drop the
- // oldest one if adding a new block exceeds
- // header.MaxSACKBlocks.
- {24, 30, 9,
- []header.SACKBlock{{12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}, {72, 80}},
- []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}},
-
- // Cases where segment extends an existing SACK block.
- {10, 12, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 20}}},
- {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}},
- {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}},
- {15, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 22}}},
- {15, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 25}}},
- {11, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{11, 25}}},
- {10, 12, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 20}, {32, 40}}},
- {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}},
- {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}},
- {15, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 22}, {32, 40}}},
- {15, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 25}, {32, 40}}},
- {11, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{11, 25}, {32, 40}}},
-
- // Cases where segment contains rcvNxt.
- {10, 20, 15, []header.SACKBlock{{20, 30}, {40, 50}}, []header.SACKBlock{{40, 50}}},
- }
-
- for _, tc := range testCases {
- var sack tcp.SACKInfo
- copy(sack.Blocks[:], tc.sackBlocks)
- sack.NumBlocks = len(tc.sackBlocks)
- tcp.UpdateSACKBlocks(&sack, tc.segStart, tc.segEnd, tc.rcvNxt)
- if got, want := sack.Blocks[:sack.NumBlocks], tc.updated; !reflect.DeepEqual(got, want) {
- t.Errorf("UpdateSACKBlocks(%v, %v, %v, %v), got: %v, want: %v", tc.sackBlocks, tc.segStart, tc.segEnd, tc.rcvNxt, got, want)
- }
-
- }
-}
-
-func TestTrimSackBlockList(t *testing.T) {
- testCases := []struct {
- rcvNxt seqnum.Value
- sackBlocks []header.SACKBlock
- trimmed []header.SACKBlock
- }{
- // Simple cases where we trim whole entries.
- {2, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}},
- {21, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{22, 30}, {32, 40}}},
- {31, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{32, 40}}},
- {40, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}},
- // Cases where we need to update a block.
- {12, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{12, 20}, {22, 30}, {32, 40}}},
- {23, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{23, 30}, {32, 40}}},
- {33, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{33, 40}}},
- {41, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}},
- }
- for _, tc := range testCases {
- var sack tcp.SACKInfo
- copy(sack.Blocks[:], tc.sackBlocks)
- sack.NumBlocks = len(tc.sackBlocks)
- tcp.TrimSACKBlockList(&sack, tc.rcvNxt)
- if got, want := sack.Blocks[:sack.NumBlocks], tc.trimmed; !reflect.DeepEqual(got, want) {
- t.Errorf("TrimSackBlockList(%v, %v), got: %v, want: %v", tc.sackBlocks, tc.rcvNxt, got, want)
- }
- }
-}
-
-func TestSACKRecovery(t *testing.T) {
- const maxPayload = 10
- // See: tcp.makeOptions for why tsOptionSize is set to 12 here.
- const tsOptionSize = 12
- // Enabling SACK means the payload size is reduced to account
- // for the extra space required for the TCP options.
- //
- // We increase the MTU by 40 bytes to account for SACK and Timestamp
- // options.
- const maxTCPOptionSize = 40
-
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload))
- defer c.Cleanup()
-
- c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) {
- // We use log.Printf instead of t.Logf here because this probe
- // can fire even when the test function has finished. This is
- // because closing the endpoint in cleanup() does not mean the
- // actual worker loop terminates immediately as it still has to
- // do a full TCP shutdown. But this test can finish running
- // before the shutdown is done. Using t.Logf in such a case
- // causes the test to panic due to logging after test finished.
- log.Printf("state: %+v\n", s)
- })
- setStackSACKPermitted(t, c, true)
- createConnectedWithSACKAndTS(c)
-
- const iterations = 7
- data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
- for i := range data {
- data[i] = byte(i)
- }
-
- // Write all the data in one shot. Packets will only be written at the
- // MTU size though.
- if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- // Do slow start for a few iterations.
- expected := tcp.InitialCwnd
- bytesRead := 0
- for i := 0; i < iterations; i++ {
- expected = tcp.InitialCwnd << uint(i)
- if i > 0 {
- // Acknowledge all the data received so far if not on
- // first iteration.
- c.SendAck(790, bytesRead)
- }
-
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
- bytesRead += maxPayload
- }
-
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
- }
-
- // Send 3 duplicate acks. This should force an immediate retransmit of
- // the pending packet and put the sender into fast recovery.
- rtxOffset := bytesRead - maxPayload*expected
- start := c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1)
- end := start.Add(10)
- for i := 0; i < 3; i++ {
- c.SendAckWithSACK(790, rtxOffset, []header.SACKBlock{{start, end}})
- end = end.Add(10)
- }
-
- // Receive the retransmitted packet.
- c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize)
-
- tcpStats := c.Stack().Stats().TCP
- stats := []struct {
- stat *tcpip.StatCounter
- name string
- want uint64
- }{
- {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1},
- {tcpStats.Retransmits, "stats.TCP.Retransmits", 1},
- {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
- {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0},
- }
- for _, s := range stats {
- if got, want := s.stat.Value(), s.want; got != want {
- t.Errorf("got %s.Value() = %v, want = %v", s.name, got, want)
- }
- }
-
- // Now send 7 mode duplicate ACKs. In SACK TCP dupAcks do not cause
- // window inflation and sending of packets is completely handled by the
- // SACK Recovery algorithm. We should see no packets being released, as
- // the cwnd at this point after entering recovery should be half of the
- // outstanding number of packets in flight.
- for i := 0; i < 7; i++ {
- c.SendAckWithSACK(790, rtxOffset, []header.SACKBlock{{start, end}})
- end = end.Add(10)
- }
-
- recover := bytesRead
-
- // Ensure no new packets arrive.
- c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.",
- 50*time.Millisecond)
-
- // Acknowledge half of the pending data. This along with the 10 sacked
- // segments above should reduce the outstanding below the current
- // congestion window allowing the sender to transmit data.
- rtxOffset = bytesRead - expected*maxPayload/2
-
- // Now send a partial ACK w/ a SACK block that indicates that the next 3
- // segments are lost and we have received 6 segments after the lost
- // segments. This should cause the sender to immediately transmit all 3
- // segments in response to this ACK unlike in FastRecovery where only 1
- // segment is retransmitted per ACK.
- start = c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1)
- end = start.Add(60)
- c.SendAckWithSACK(790, rtxOffset, []header.SACKBlock{{start, end}})
-
- // At this point, we acked expected/2 packets and we SACKED 6 packets and
- // 3 segments were considered lost due to the SACK block we sent.
- //
- // So total packets outstanding can be calculated as follows after 7
- // iterations of slow start -> 10/20/40/80/160/320/640. So expected
- // should be 640 at start, then we went to recover at which point the
- // cwnd should be set to 320 + 3 (for the 3 dupAcks which have left the
- // network).
- // Outstanding at this point after acking half the window
- // (320 packets) will be:
- // outstanding = 640-320-6(due to SACK block)-3 = 311
- //
- // The last 3 is due to the fact that the first 3 packets after
- // rtxOffset will be considered lost due to the SACK blocks sent.
- // Receive the retransmit due to partial ack.
-
- c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize)
- // Receive the 2 extra packets that should have been retransmitted as
- // those should be considered lost and immediately retransmitted based
- // on the SACK information in the previous ACK sent above.
- for i := 0; i < 2; i++ {
- c.ReceiveAndCheckPacketWithOptions(data, rtxOffset+maxPayload*(i+1), maxPayload, tsOptionSize)
- }
-
- // Now we should get 9 more new unsent packets as the cwnd is 323 and
- // outstanding is 311.
- for i := 0; i < 9; i++ {
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
- bytesRead += maxPayload
- }
-
- // In SACK recovery only the first segment is fast retransmitted when
- // entering recovery.
- if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
- }
-
- if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want {
- t.Errorf("got EP stats SendErrors.FastRetransmit = %v, want = %v", got, want)
- }
-
- if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want {
- t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
- }
-
- if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want {
- t.Errorf("got EP stats Stats.SendErrors.Retransmits = %v, want = %v", got, want)
- }
-
- c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond)
-
- // Acknowledge all pending data to recover point.
- c.SendAck(790, recover)
-
- // At this point, the cwnd should reset to expected/2 and there are 9
- // packets outstanding.
- //
- // Now in the first iteration since there are 9 packets outstanding.
- // We would expect to get expected/2 - 9 packets. But subsequent
- // iterations will send us expected/2 + 1 (per iteration).
- expected = expected/2 - 9
- for i := 0; i < iterations; i++ {
- // Read all packets expected on this iteration. Don't
- // acknowledge any of them just yet, so that we can measure the
- // congestion window.
- for j := 0; j < expected; j++ {
- c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
- bytesRead += maxPayload
- }
- // Check we don't receive any more packets on this iteration.
- // The timeout can't be too high or we'll trigger a timeout.
- c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd and iteration: %d.", expected, i), 50*time.Millisecond)
-
- // Acknowledge all the data received so far.
- c.SendAck(790, bytesRead)
-
- // In cogestion avoidance, the packets trains increase by 1 in
- // each iteration.
- if i == 0 {
- // After the first iteration we expect to get the full
- // congestion window worth of packets in every
- // iteration.
- expected += 9
- }
- expected++
- }
-}
diff --git a/pkg/tcpip/transport/tcp/tcp_segment_list.go b/pkg/tcpip/transport/tcp/tcp_segment_list.go
new file mode 100755
index 000000000..029f98a11
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_segment_list.go
@@ -0,0 +1,173 @@
+package tcp
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type segmentElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (segmentElementMapper) linkerFor(elem *segment) *segment { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type segmentList struct {
+ head *segment
+ tail *segment
+}
+
+// Reset resets list l to the empty state.
+func (l *segmentList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *segmentList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *segmentList) Front() *segment {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *segmentList) Back() *segment {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *segmentList) PushFront(e *segment) {
+ segmentElementMapper{}.linkerFor(e).SetNext(l.head)
+ segmentElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ segmentElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *segmentList) PushBack(e *segment) {
+ segmentElementMapper{}.linkerFor(e).SetNext(nil)
+ segmentElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ segmentElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *segmentList) PushBackList(m *segmentList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ segmentElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ segmentElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *segmentList) InsertAfter(b, e *segment) {
+ a := segmentElementMapper{}.linkerFor(b).Next()
+ segmentElementMapper{}.linkerFor(e).SetNext(a)
+ segmentElementMapper{}.linkerFor(e).SetPrev(b)
+ segmentElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ segmentElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *segmentList) InsertBefore(a, e *segment) {
+ b := segmentElementMapper{}.linkerFor(a).Prev()
+ segmentElementMapper{}.linkerFor(e).SetNext(a)
+ segmentElementMapper{}.linkerFor(e).SetPrev(b)
+ segmentElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ segmentElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *segmentList) Remove(e *segment) {
+ prev := segmentElementMapper{}.linkerFor(e).Prev()
+ next := segmentElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ segmentElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ segmentElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type segmentEntry struct {
+ next *segment
+ prev *segment
+}
+
+// Next returns the entry that follows e in the list.
+func (e *segmentEntry) Next() *segment {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *segmentEntry) Prev() *segment {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *segmentEntry) SetNext(elem *segment) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *segmentEntry) SetPrev(elem *segment) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
new file mode 100755
index 000000000..9c1514e62
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
@@ -0,0 +1,531 @@
+// automatically generated by stateify.
+
+package tcp
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+func (x *cubicState) beforeSave() {}
+func (x *cubicState) save(m state.Map) {
+ x.beforeSave()
+ var t unixTime = x.saveT()
+ m.SaveValue("t", t)
+ m.Save("wLastMax", &x.wLastMax)
+ m.Save("wMax", &x.wMax)
+ m.Save("numCongestionEvents", &x.numCongestionEvents)
+ m.Save("c", &x.c)
+ m.Save("k", &x.k)
+ m.Save("beta", &x.beta)
+ m.Save("wC", &x.wC)
+ m.Save("wEst", &x.wEst)
+ m.Save("s", &x.s)
+}
+
+func (x *cubicState) afterLoad() {}
+func (x *cubicState) load(m state.Map) {
+ m.Load("wLastMax", &x.wLastMax)
+ m.Load("wMax", &x.wMax)
+ m.Load("numCongestionEvents", &x.numCongestionEvents)
+ m.Load("c", &x.c)
+ m.Load("k", &x.k)
+ m.Load("beta", &x.beta)
+ m.Load("wC", &x.wC)
+ m.Load("wEst", &x.wEst)
+ m.Load("s", &x.s)
+ m.LoadValue("t", new(unixTime), func(y interface{}) { x.loadT(y.(unixTime)) })
+}
+
+func (x *SACKInfo) beforeSave() {}
+func (x *SACKInfo) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Blocks", &x.Blocks)
+ m.Save("NumBlocks", &x.NumBlocks)
+}
+
+func (x *SACKInfo) afterLoad() {}
+func (x *SACKInfo) load(m state.Map) {
+ m.Load("Blocks", &x.Blocks)
+ m.Load("NumBlocks", &x.NumBlocks)
+}
+
+func (x *rcvBufAutoTuneParams) beforeSave() {}
+func (x *rcvBufAutoTuneParams) save(m state.Map) {
+ x.beforeSave()
+ var measureTime unixTime = x.saveMeasureTime()
+ m.SaveValue("measureTime", measureTime)
+ var rttMeasureTime unixTime = x.saveRttMeasureTime()
+ m.SaveValue("rttMeasureTime", rttMeasureTime)
+ m.Save("copied", &x.copied)
+ m.Save("prevCopied", &x.prevCopied)
+ m.Save("rtt", &x.rtt)
+ m.Save("rttMeasureSeqNumber", &x.rttMeasureSeqNumber)
+ m.Save("disabled", &x.disabled)
+}
+
+func (x *rcvBufAutoTuneParams) afterLoad() {}
+func (x *rcvBufAutoTuneParams) load(m state.Map) {
+ m.Load("copied", &x.copied)
+ m.Load("prevCopied", &x.prevCopied)
+ m.Load("rtt", &x.rtt)
+ m.Load("rttMeasureSeqNumber", &x.rttMeasureSeqNumber)
+ m.Load("disabled", &x.disabled)
+ m.LoadValue("measureTime", new(unixTime), func(y interface{}) { x.loadMeasureTime(y.(unixTime)) })
+ m.LoadValue("rttMeasureTime", new(unixTime), func(y interface{}) { x.loadRttMeasureTime(y.(unixTime)) })
+}
+
+func (x *EndpointInfo) beforeSave() {}
+func (x *EndpointInfo) save(m state.Map) {
+ x.beforeSave()
+ var HardError string = x.saveHardError()
+ m.SaveValue("HardError", HardError)
+ m.Save("TransportEndpointInfo", &x.TransportEndpointInfo)
+}
+
+func (x *EndpointInfo) afterLoad() {}
+func (x *EndpointInfo) load(m state.Map) {
+ m.Load("TransportEndpointInfo", &x.TransportEndpointInfo)
+ m.LoadValue("HardError", new(string), func(y interface{}) { x.loadHardError(y.(string)) })
+}
+
+func (x *endpoint) save(m state.Map) {
+ x.beforeSave()
+ var lastError string = x.saveLastError()
+ m.SaveValue("lastError", lastError)
+ var state EndpointState = x.saveState()
+ m.SaveValue("state", state)
+ var acceptedChan []*endpoint = x.saveAcceptedChan()
+ m.SaveValue("acceptedChan", acceptedChan)
+ m.Save("EndpointInfo", &x.EndpointInfo)
+ m.Save("waiterQueue", &x.waiterQueue)
+ m.Save("uniqueID", &x.uniqueID)
+ m.Save("rcvList", &x.rcvList)
+ m.Save("rcvClosed", &x.rcvClosed)
+ m.Save("rcvBufSize", &x.rcvBufSize)
+ m.Save("rcvBufUsed", &x.rcvBufUsed)
+ m.Save("rcvAutoParams", &x.rcvAutoParams)
+ m.Save("zeroWindow", &x.zeroWindow)
+ m.Save("isRegistered", &x.isRegistered)
+ m.Save("ttl", &x.ttl)
+ m.Save("v6only", &x.v6only)
+ m.Save("isConnectNotified", &x.isConnectNotified)
+ m.Save("broadcast", &x.broadcast)
+ m.Save("boundBindToDevice", &x.boundBindToDevice)
+ m.Save("boundPortFlags", &x.boundPortFlags)
+ m.Save("workerRunning", &x.workerRunning)
+ m.Save("workerCleanup", &x.workerCleanup)
+ m.Save("sendTSOk", &x.sendTSOk)
+ m.Save("recentTS", &x.recentTS)
+ m.Save("tsOffset", &x.tsOffset)
+ m.Save("shutdownFlags", &x.shutdownFlags)
+ m.Save("sackPermitted", &x.sackPermitted)
+ m.Save("sack", &x.sack)
+ m.Save("reusePort", &x.reusePort)
+ m.Save("bindToDevice", &x.bindToDevice)
+ m.Save("delay", &x.delay)
+ m.Save("cork", &x.cork)
+ m.Save("scoreboard", &x.scoreboard)
+ m.Save("reuseAddr", &x.reuseAddr)
+ m.Save("slowAck", &x.slowAck)
+ m.Save("segmentQueue", &x.segmentQueue)
+ m.Save("synRcvdCount", &x.synRcvdCount)
+ m.Save("userMSS", &x.userMSS)
+ m.Save("sndBufSize", &x.sndBufSize)
+ m.Save("sndBufUsed", &x.sndBufUsed)
+ m.Save("sndClosed", &x.sndClosed)
+ m.Save("sndBufInQueue", &x.sndBufInQueue)
+ m.Save("sndQueue", &x.sndQueue)
+ m.Save("cc", &x.cc)
+ m.Save("packetTooBigCount", &x.packetTooBigCount)
+ m.Save("sndMTU", &x.sndMTU)
+ m.Save("keepalive", &x.keepalive)
+ m.Save("userTimeout", &x.userTimeout)
+ m.Save("deferAccept", &x.deferAccept)
+ m.Save("rcv", &x.rcv)
+ m.Save("snd", &x.snd)
+ m.Save("connectingAddress", &x.connectingAddress)
+ m.Save("amss", &x.amss)
+ m.Save("sendTOS", &x.sendTOS)
+ m.Save("gso", &x.gso)
+ m.Save("tcpLingerTimeout", &x.tcpLingerTimeout)
+ m.Save("closed", &x.closed)
+}
+
+func (x *endpoint) load(m state.Map) {
+ m.Load("EndpointInfo", &x.EndpointInfo)
+ m.LoadWait("waiterQueue", &x.waiterQueue)
+ m.Load("uniqueID", &x.uniqueID)
+ m.LoadWait("rcvList", &x.rcvList)
+ m.Load("rcvClosed", &x.rcvClosed)
+ m.Load("rcvBufSize", &x.rcvBufSize)
+ m.Load("rcvBufUsed", &x.rcvBufUsed)
+ m.Load("rcvAutoParams", &x.rcvAutoParams)
+ m.Load("zeroWindow", &x.zeroWindow)
+ m.Load("isRegistered", &x.isRegistered)
+ m.Load("ttl", &x.ttl)
+ m.Load("v6only", &x.v6only)
+ m.Load("isConnectNotified", &x.isConnectNotified)
+ m.Load("broadcast", &x.broadcast)
+ m.Load("boundBindToDevice", &x.boundBindToDevice)
+ m.Load("boundPortFlags", &x.boundPortFlags)
+ m.Load("workerRunning", &x.workerRunning)
+ m.Load("workerCleanup", &x.workerCleanup)
+ m.Load("sendTSOk", &x.sendTSOk)
+ m.Load("recentTS", &x.recentTS)
+ m.Load("tsOffset", &x.tsOffset)
+ m.Load("shutdownFlags", &x.shutdownFlags)
+ m.Load("sackPermitted", &x.sackPermitted)
+ m.Load("sack", &x.sack)
+ m.Load("reusePort", &x.reusePort)
+ m.Load("bindToDevice", &x.bindToDevice)
+ m.Load("delay", &x.delay)
+ m.Load("cork", &x.cork)
+ m.Load("scoreboard", &x.scoreboard)
+ m.Load("reuseAddr", &x.reuseAddr)
+ m.Load("slowAck", &x.slowAck)
+ m.LoadWait("segmentQueue", &x.segmentQueue)
+ m.Load("synRcvdCount", &x.synRcvdCount)
+ m.Load("userMSS", &x.userMSS)
+ m.Load("sndBufSize", &x.sndBufSize)
+ m.Load("sndBufUsed", &x.sndBufUsed)
+ m.Load("sndClosed", &x.sndClosed)
+ m.Load("sndBufInQueue", &x.sndBufInQueue)
+ m.LoadWait("sndQueue", &x.sndQueue)
+ m.Load("cc", &x.cc)
+ m.Load("packetTooBigCount", &x.packetTooBigCount)
+ m.Load("sndMTU", &x.sndMTU)
+ m.Load("keepalive", &x.keepalive)
+ m.Load("userTimeout", &x.userTimeout)
+ m.Load("deferAccept", &x.deferAccept)
+ m.LoadWait("rcv", &x.rcv)
+ m.LoadWait("snd", &x.snd)
+ m.Load("connectingAddress", &x.connectingAddress)
+ m.Load("amss", &x.amss)
+ m.Load("sendTOS", &x.sendTOS)
+ m.Load("gso", &x.gso)
+ m.Load("tcpLingerTimeout", &x.tcpLingerTimeout)
+ m.Load("closed", &x.closed)
+ m.LoadValue("lastError", new(string), func(y interface{}) { x.loadLastError(y.(string)) })
+ m.LoadValue("state", new(EndpointState), func(y interface{}) { x.loadState(y.(EndpointState)) })
+ m.LoadValue("acceptedChan", new([]*endpoint), func(y interface{}) { x.loadAcceptedChan(y.([]*endpoint)) })
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *keepalive) beforeSave() {}
+func (x *keepalive) save(m state.Map) {
+ x.beforeSave()
+ m.Save("enabled", &x.enabled)
+ m.Save("idle", &x.idle)
+ m.Save("interval", &x.interval)
+ m.Save("count", &x.count)
+ m.Save("unacked", &x.unacked)
+}
+
+func (x *keepalive) afterLoad() {}
+func (x *keepalive) load(m state.Map) {
+ m.Load("enabled", &x.enabled)
+ m.Load("idle", &x.idle)
+ m.Load("interval", &x.interval)
+ m.Load("count", &x.count)
+ m.Load("unacked", &x.unacked)
+}
+
+func (x *receiver) beforeSave() {}
+func (x *receiver) save(m state.Map) {
+ x.beforeSave()
+ var lastRcvdAckTime unixTime = x.saveLastRcvdAckTime()
+ m.SaveValue("lastRcvdAckTime", lastRcvdAckTime)
+ m.Save("ep", &x.ep)
+ m.Save("rcvNxt", &x.rcvNxt)
+ m.Save("rcvAcc", &x.rcvAcc)
+ m.Save("rcvWnd", &x.rcvWnd)
+ m.Save("rcvWndScale", &x.rcvWndScale)
+ m.Save("closed", &x.closed)
+ m.Save("pendingRcvdSegments", &x.pendingRcvdSegments)
+ m.Save("pendingBufUsed", &x.pendingBufUsed)
+ m.Save("pendingBufSize", &x.pendingBufSize)
+}
+
+func (x *receiver) afterLoad() {}
+func (x *receiver) load(m state.Map) {
+ m.Load("ep", &x.ep)
+ m.Load("rcvNxt", &x.rcvNxt)
+ m.Load("rcvAcc", &x.rcvAcc)
+ m.Load("rcvWnd", &x.rcvWnd)
+ m.Load("rcvWndScale", &x.rcvWndScale)
+ m.Load("closed", &x.closed)
+ m.Load("pendingRcvdSegments", &x.pendingRcvdSegments)
+ m.Load("pendingBufUsed", &x.pendingBufUsed)
+ m.Load("pendingBufSize", &x.pendingBufSize)
+ m.LoadValue("lastRcvdAckTime", new(unixTime), func(y interface{}) { x.loadLastRcvdAckTime(y.(unixTime)) })
+}
+
+func (x *renoState) beforeSave() {}
+func (x *renoState) save(m state.Map) {
+ x.beforeSave()
+ m.Save("s", &x.s)
+}
+
+func (x *renoState) afterLoad() {}
+func (x *renoState) load(m state.Map) {
+ m.Load("s", &x.s)
+}
+
+func (x *SACKScoreboard) beforeSave() {}
+func (x *SACKScoreboard) save(m state.Map) {
+ x.beforeSave()
+ m.Save("smss", &x.smss)
+ m.Save("maxSACKED", &x.maxSACKED)
+}
+
+func (x *SACKScoreboard) afterLoad() {}
+func (x *SACKScoreboard) load(m state.Map) {
+ m.Load("smss", &x.smss)
+ m.Load("maxSACKED", &x.maxSACKED)
+}
+
+func (x *segment) beforeSave() {}
+func (x *segment) save(m state.Map) {
+ x.beforeSave()
+ var data buffer.VectorisedView = x.saveData()
+ m.SaveValue("data", data)
+ var options []byte = x.saveOptions()
+ m.SaveValue("options", options)
+ var rcvdTime unixTime = x.saveRcvdTime()
+ m.SaveValue("rcvdTime", rcvdTime)
+ var xmitTime unixTime = x.saveXmitTime()
+ m.SaveValue("xmitTime", xmitTime)
+ m.Save("segmentEntry", &x.segmentEntry)
+ m.Save("refCnt", &x.refCnt)
+ m.Save("viewToDeliver", &x.viewToDeliver)
+ m.Save("sequenceNumber", &x.sequenceNumber)
+ m.Save("ackNumber", &x.ackNumber)
+ m.Save("flags", &x.flags)
+ m.Save("window", &x.window)
+ m.Save("csum", &x.csum)
+ m.Save("csumValid", &x.csumValid)
+ m.Save("parsedOptions", &x.parsedOptions)
+ m.Save("hasNewSACKInfo", &x.hasNewSACKInfo)
+}
+
+func (x *segment) afterLoad() {}
+func (x *segment) load(m state.Map) {
+ m.Load("segmentEntry", &x.segmentEntry)
+ m.Load("refCnt", &x.refCnt)
+ m.Load("viewToDeliver", &x.viewToDeliver)
+ m.Load("sequenceNumber", &x.sequenceNumber)
+ m.Load("ackNumber", &x.ackNumber)
+ m.Load("flags", &x.flags)
+ m.Load("window", &x.window)
+ m.Load("csum", &x.csum)
+ m.Load("csumValid", &x.csumValid)
+ m.Load("parsedOptions", &x.parsedOptions)
+ m.Load("hasNewSACKInfo", &x.hasNewSACKInfo)
+ m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) })
+ m.LoadValue("options", new([]byte), func(y interface{}) { x.loadOptions(y.([]byte)) })
+ m.LoadValue("rcvdTime", new(unixTime), func(y interface{}) { x.loadRcvdTime(y.(unixTime)) })
+ m.LoadValue("xmitTime", new(unixTime), func(y interface{}) { x.loadXmitTime(y.(unixTime)) })
+}
+
+func (x *segmentQueue) beforeSave() {}
+func (x *segmentQueue) save(m state.Map) {
+ x.beforeSave()
+ m.Save("list", &x.list)
+ m.Save("limit", &x.limit)
+ m.Save("used", &x.used)
+}
+
+func (x *segmentQueue) afterLoad() {}
+func (x *segmentQueue) load(m state.Map) {
+ m.LoadWait("list", &x.list)
+ m.Load("limit", &x.limit)
+ m.Load("used", &x.used)
+}
+
+func (x *sender) beforeSave() {}
+func (x *sender) save(m state.Map) {
+ x.beforeSave()
+ var lastSendTime unixTime = x.saveLastSendTime()
+ m.SaveValue("lastSendTime", lastSendTime)
+ var rttMeasureTime unixTime = x.saveRttMeasureTime()
+ m.SaveValue("rttMeasureTime", rttMeasureTime)
+ var firstRetransmittedSegXmitTime unixTime = x.saveFirstRetransmittedSegXmitTime()
+ m.SaveValue("firstRetransmittedSegXmitTime", firstRetransmittedSegXmitTime)
+ m.Save("ep", &x.ep)
+ m.Save("dupAckCount", &x.dupAckCount)
+ m.Save("fr", &x.fr)
+ m.Save("sndCwnd", &x.sndCwnd)
+ m.Save("sndSsthresh", &x.sndSsthresh)
+ m.Save("sndCAAckCount", &x.sndCAAckCount)
+ m.Save("outstanding", &x.outstanding)
+ m.Save("sndWnd", &x.sndWnd)
+ m.Save("sndUna", &x.sndUna)
+ m.Save("sndNxt", &x.sndNxt)
+ m.Save("sndNxtList", &x.sndNxtList)
+ m.Save("rttMeasureSeqNum", &x.rttMeasureSeqNum)
+ m.Save("closed", &x.closed)
+ m.Save("writeNext", &x.writeNext)
+ m.Save("writeList", &x.writeList)
+ m.Save("rtt", &x.rtt)
+ m.Save("rto", &x.rto)
+ m.Save("maxPayloadSize", &x.maxPayloadSize)
+ m.Save("gso", &x.gso)
+ m.Save("sndWndScale", &x.sndWndScale)
+ m.Save("maxSentAck", &x.maxSentAck)
+ m.Save("state", &x.state)
+ m.Save("cc", &x.cc)
+}
+
+func (x *sender) load(m state.Map) {
+ m.Load("ep", &x.ep)
+ m.Load("dupAckCount", &x.dupAckCount)
+ m.Load("fr", &x.fr)
+ m.Load("sndCwnd", &x.sndCwnd)
+ m.Load("sndSsthresh", &x.sndSsthresh)
+ m.Load("sndCAAckCount", &x.sndCAAckCount)
+ m.Load("outstanding", &x.outstanding)
+ m.Load("sndWnd", &x.sndWnd)
+ m.Load("sndUna", &x.sndUna)
+ m.Load("sndNxt", &x.sndNxt)
+ m.Load("sndNxtList", &x.sndNxtList)
+ m.Load("rttMeasureSeqNum", &x.rttMeasureSeqNum)
+ m.Load("closed", &x.closed)
+ m.Load("writeNext", &x.writeNext)
+ m.Load("writeList", &x.writeList)
+ m.Load("rtt", &x.rtt)
+ m.Load("rto", &x.rto)
+ m.Load("maxPayloadSize", &x.maxPayloadSize)
+ m.Load("gso", &x.gso)
+ m.Load("sndWndScale", &x.sndWndScale)
+ m.Load("maxSentAck", &x.maxSentAck)
+ m.Load("state", &x.state)
+ m.Load("cc", &x.cc)
+ m.LoadValue("lastSendTime", new(unixTime), func(y interface{}) { x.loadLastSendTime(y.(unixTime)) })
+ m.LoadValue("rttMeasureTime", new(unixTime), func(y interface{}) { x.loadRttMeasureTime(y.(unixTime)) })
+ m.LoadValue("firstRetransmittedSegXmitTime", new(unixTime), func(y interface{}) { x.loadFirstRetransmittedSegXmitTime(y.(unixTime)) })
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *rtt) beforeSave() {}
+func (x *rtt) save(m state.Map) {
+ x.beforeSave()
+ m.Save("srtt", &x.srtt)
+ m.Save("rttvar", &x.rttvar)
+ m.Save("srttInited", &x.srttInited)
+}
+
+func (x *rtt) afterLoad() {}
+func (x *rtt) load(m state.Map) {
+ m.Load("srtt", &x.srtt)
+ m.Load("rttvar", &x.rttvar)
+ m.Load("srttInited", &x.srttInited)
+}
+
+func (x *fastRecovery) beforeSave() {}
+func (x *fastRecovery) save(m state.Map) {
+ x.beforeSave()
+ m.Save("active", &x.active)
+ m.Save("first", &x.first)
+ m.Save("last", &x.last)
+ m.Save("maxCwnd", &x.maxCwnd)
+ m.Save("highRxt", &x.highRxt)
+ m.Save("rescueRxt", &x.rescueRxt)
+}
+
+func (x *fastRecovery) afterLoad() {}
+func (x *fastRecovery) load(m state.Map) {
+ m.Load("active", &x.active)
+ m.Load("first", &x.first)
+ m.Load("last", &x.last)
+ m.Load("maxCwnd", &x.maxCwnd)
+ m.Load("highRxt", &x.highRxt)
+ m.Load("rescueRxt", &x.rescueRxt)
+}
+
+func (x *unixTime) beforeSave() {}
+func (x *unixTime) save(m state.Map) {
+ x.beforeSave()
+ m.Save("second", &x.second)
+ m.Save("nano", &x.nano)
+}
+
+func (x *unixTime) afterLoad() {}
+func (x *unixTime) load(m state.Map) {
+ m.Load("second", &x.second)
+ m.Load("nano", &x.nano)
+}
+
+func (x *endpointList) beforeSave() {}
+func (x *endpointList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *endpointList) afterLoad() {}
+func (x *endpointList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *endpointEntry) beforeSave() {}
+func (x *endpointEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *endpointEntry) afterLoad() {}
+func (x *endpointEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func (x *segmentList) beforeSave() {}
+func (x *segmentList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *segmentList) afterLoad() {}
+func (x *segmentList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *segmentEntry) beforeSave() {}
+func (x *segmentEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *segmentEntry) afterLoad() {}
+func (x *segmentEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func init() {
+ state.Register("pkg/tcpip/transport/tcp.cubicState", (*cubicState)(nil), state.Fns{Save: (*cubicState).save, Load: (*cubicState).load})
+ state.Register("pkg/tcpip/transport/tcp.SACKInfo", (*SACKInfo)(nil), state.Fns{Save: (*SACKInfo).save, Load: (*SACKInfo).load})
+ state.Register("pkg/tcpip/transport/tcp.rcvBufAutoTuneParams", (*rcvBufAutoTuneParams)(nil), state.Fns{Save: (*rcvBufAutoTuneParams).save, Load: (*rcvBufAutoTuneParams).load})
+ state.Register("pkg/tcpip/transport/tcp.EndpointInfo", (*EndpointInfo)(nil), state.Fns{Save: (*EndpointInfo).save, Load: (*EndpointInfo).load})
+ state.Register("pkg/tcpip/transport/tcp.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load})
+ state.Register("pkg/tcpip/transport/tcp.keepalive", (*keepalive)(nil), state.Fns{Save: (*keepalive).save, Load: (*keepalive).load})
+ state.Register("pkg/tcpip/transport/tcp.receiver", (*receiver)(nil), state.Fns{Save: (*receiver).save, Load: (*receiver).load})
+ state.Register("pkg/tcpip/transport/tcp.renoState", (*renoState)(nil), state.Fns{Save: (*renoState).save, Load: (*renoState).load})
+ state.Register("pkg/tcpip/transport/tcp.SACKScoreboard", (*SACKScoreboard)(nil), state.Fns{Save: (*SACKScoreboard).save, Load: (*SACKScoreboard).load})
+ state.Register("pkg/tcpip/transport/tcp.segment", (*segment)(nil), state.Fns{Save: (*segment).save, Load: (*segment).load})
+ state.Register("pkg/tcpip/transport/tcp.segmentQueue", (*segmentQueue)(nil), state.Fns{Save: (*segmentQueue).save, Load: (*segmentQueue).load})
+ state.Register("pkg/tcpip/transport/tcp.sender", (*sender)(nil), state.Fns{Save: (*sender).save, Load: (*sender).load})
+ state.Register("pkg/tcpip/transport/tcp.rtt", (*rtt)(nil), state.Fns{Save: (*rtt).save, Load: (*rtt).load})
+ state.Register("pkg/tcpip/transport/tcp.fastRecovery", (*fastRecovery)(nil), state.Fns{Save: (*fastRecovery).save, Load: (*fastRecovery).load})
+ state.Register("pkg/tcpip/transport/tcp.unixTime", (*unixTime)(nil), state.Fns{Save: (*unixTime).save, Load: (*unixTime).load})
+ state.Register("pkg/tcpip/transport/tcp.endpointList", (*endpointList)(nil), state.Fns{Save: (*endpointList).save, Load: (*endpointList).load})
+ state.Register("pkg/tcpip/transport/tcp.endpointEntry", (*endpointEntry)(nil), state.Fns{Save: (*endpointEntry).save, Load: (*endpointEntry).load})
+ state.Register("pkg/tcpip/transport/tcp.segmentList", (*segmentList)(nil), state.Fns{Save: (*segmentList).save, Load: (*segmentList).load})
+ state.Register("pkg/tcpip/transport/tcp.segmentEntry", (*segmentEntry)(nil), state.Fns{Save: (*segmentEntry).save, Load: (*segmentEntry).load})
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
deleted file mode 100644
index cc118c993..000000000
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ /dev/null
@@ -1,6969 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp_test
-
-import (
- "bytes"
- "fmt"
- "math"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/ports"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- // defaultMTU is the MTU, in bytes, used throughout the tests, except
- // where another value is explicitly used. It is chosen to match the MTU
- // of loopback interfaces on linux systems.
- defaultMTU = 65535
-
- // defaultIPv4MSS is the MSS sent by the network stack in SYN/SYN-ACK for an
- // IPv4 endpoint when the MTU is set to defaultMTU in the test.
- defaultIPv4MSS = defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize
-)
-
-func TestGiveUpConnect(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- var wq waiter.Queue
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- // Register for notification, then start connection attempt.
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- wq.EventRegister(&waitEntry, waiter.EventOut)
- defer wq.EventUnregister(&waitEntry)
-
- if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
- }
-
- // Close the connection, wait for completion.
- ep.Close()
-
- // Wait for ep to become writable.
- <-notifyCh
- if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted {
- t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %v, want = %v", err, tcpip.ErrAborted)
- }
-
- // Call Connect again to retreive the handshake failure status
- // and stats updates.
- if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted {
- t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrAborted)
- }
-
- if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 {
- t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = 1", got)
- }
-
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
- }
-}
-
-func TestConnectIncrementActiveConnection(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- stats := c.Stack().Stats()
- want := stats.TCP.ActiveConnectionOpenings.Value() + 1
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want {
- t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want)
- }
-}
-
-func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- stats := c.Stack().Stats()
- want := stats.TCP.FailedConnectionAttempts.Value()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want)
- }
- if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got EP stats.FailedConnectionAttempts = %v, want = %v", got, want)
- }
-}
-
-func TestActiveFailedConnectionAttemptIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- stats := c.Stack().Stats()
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
- c.EP = ep
- want := stats.TCP.FailedConnectionAttempts.Value() + 1
-
- if err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrNoRoute {
- t.Errorf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute)
- }
-
- if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want)
- }
- if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got EP stats FailedConnectionAttempts = %v, want = %v", got, want)
- }
-}
-
-func TestTCPSegmentsSentIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- stats := c.Stack().Stats()
- // SYN and ACK
- want := stats.TCP.SegmentsSent.Value() + 2
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- if got := stats.TCP.SegmentsSent.Value(); got != want {
- t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want)
- }
- if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want {
- t.Errorf("got EP stats SegmentsSent.Value() = %v, want = %v", got, want)
- }
-}
-
-func TestTCPResetsSentIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- stats := c.Stack().Stats()
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
- want := stats.TCP.SegmentsSent.Value() + 1
-
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-
- // Send a SYN request.
- iss := seqnum.Value(789)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- ackHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- // If the AckNum is not the increment of the last sequence number, a RST
- // segment is sent back in response.
- AckNum: c.IRS + 2,
- }
-
- // Send ACK.
- c.SendPacket(nil, ackHeaders)
-
- c.GetPacket()
- if got := stats.TCP.ResetsSent.Value(); got != want {
- t.Errorf("got stats.TCP.ResetsSent.Value() = %v, want = %v", got, want)
- }
-}
-
-// TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates
-// a RST if an ACK is received on the listening socket for which there is no
-// active handshake in progress and we are not using SYN cookies.
-func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set TCPLingerTimeout to 5 seconds so that sockets are marked closed
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send a SYN request.
- iss := seqnum.Value(789)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- ackHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 1,
- }
-
- // Send ACK.
- c.SendPacket(nil, ackHeaders)
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept()
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Lower stackwide TIME_WAIT timeout so that the reservations
- // are released instantly on Close.
- tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil {
- t.Fatalf("e.stack.SetTransportProtocolOption(%d, %s) = %s", tcp.ProtocolNumber, tcpTW, err)
- }
-
- c.EP.Close()
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
- finHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss + 1,
- AckNum: c.IRS + 2,
- }
-
- c.SendPacket(nil, finHeaders)
-
- // Get the ACK to the FIN we just sent.
- c.GetPacket()
-
- // Since an active close was done we need to wait for a little more than
- // tcpLingerTimeout for the port reservations to be released and the
- // socket to move to a CLOSED state.
- time.Sleep(20 * time.Millisecond)
-
- // Now resend the same ACK, this ACK should generate a RST as there
- // should be no endpoint in SYN-RCVD state and we are not using
- // syn-cookies yet. The reason we send the same ACK is we need a valid
- // cookie(IRS) generated by the netstack without which the ACK will be
- // rejected.
- c.SendPacket(nil, ackHeaders)
-
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(0),
- checker.TCPFlags(header.TCPFlagRst)))
-}
-
-func TestTCPResetsReceivedIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- stats := c.Stack().Stats()
- want := stats.TCP.ResetsReceived.Value() + 1
- iss := seqnum.Value(789)
- rcvWnd := seqnum.Size(30000)
- c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
-
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- SeqNum: iss.Add(1),
- AckNum: c.IRS.Add(1),
- RcvWnd: rcvWnd,
- Flags: header.TCPFlagRst,
- })
-
- if got := stats.TCP.ResetsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want)
- }
-}
-
-func TestTCPResetsDoNotGenerateResets(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- stats := c.Stack().Stats()
- want := stats.TCP.ResetsReceived.Value() + 1
- iss := seqnum.Value(789)
- rcvWnd := seqnum.Size(30000)
- c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
-
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- SeqNum: iss.Add(1),
- AckNum: c.IRS.Add(1),
- RcvWnd: rcvWnd,
- Flags: header.TCPFlagRst,
- })
-
- if got := stats.TCP.ResetsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want)
- }
- c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond)
-}
-
-func TestActiveHandshake(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-}
-
-func TestNonBlockingClose(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- ep := c.EP
- c.EP = nil
-
- // Close the endpoint and measure how long it takes.
- t0 := time.Now()
- ep.Close()
- if diff := time.Now().Sub(t0); diff > 3*time.Second {
- t.Fatalf("Took too long to close: %v", diff)
- }
-}
-
-func TestConnectResetAfterClose(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set TCPLinger to 3 seconds so that sockets are marked closed
- // after 3 second in FIN_WAIT2 state.
- tcpLingerTimeout := 3 * time.Second
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil {
- t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpLingerTimeout, err)
- }
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- ep := c.EP
- c.EP = nil
-
- // Close the endpoint, make sure we get a FIN segment, then acknowledge
- // to complete closure of sender, but don't send our own FIN.
- ep.Close()
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- // Wait for the ep to give up waiting for a FIN.
- time.Sleep(tcpLingerTimeout + 1*time.Second)
-
- // Now send an ACK and it should trigger a RST as the endpoint should
- // not exist anymore.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- for {
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- if tcpHdr.Flags() == header.TCPFlagAck|header.TCPFlagFin {
- // This is a retransmit of the FIN, ignore it.
- continue
- }
-
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- // RST is always generated with sndNxt which if the FIN
- // has been sent will be 1 higher than the sequence number
- // of the FIN itself.
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(0),
- checker.TCPFlags(header.TCPFlagRst),
- ),
- )
- break
- }
-}
-
-// TestCurrentConnectedIncrement tests increment of the current
-// established and connected counters.
-func TestCurrentConnectedIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed
- // after 1 second in TIME_WAIT state.
- tcpTimeWaitTimeout := 1 * time.Second
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
- t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err)
- }
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- ep := c.EP
- c.EP = nil
-
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 1 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 1", got)
- }
- gotConnected := c.Stack().Stats().TCP.CurrentConnected.Value()
- if gotConnected != 1 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = 1", gotConnected)
- }
-
- ep.Close()
-
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
- }
- if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != gotConnected {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = %v", got, gotConnected)
- }
-
- // Ack and send FIN as well.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 790,
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- // Check that the stack acks the FIN.
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(791),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-
- // Wait for the TIME-WAIT state to transition to CLOSED.
- time.Sleep(1 * time.Second)
-
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
- }
- if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %v, want = 0", got)
- }
-}
-
-// TestClosingWithEnqueuedSegments tests handling of still enqueued segments
-// when the endpoint transitions to StateClose. The in-flight segments would be
-// re-enqueued to a any listening endpoint.
-func TestClosingWithEnqueuedSegments(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- ep := c.EP
- c.EP = nil
-
- if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- // Send a FIN for ESTABLISHED --> CLOSED-WAIT
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagFin | header.TCPFlagAck,
- SeqNum: 790,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Get the ACK for the FIN we sent.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(791),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-
- if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- // Close the application endpoint for CLOSE_WAIT --> LAST_ACK
- ep.Close()
-
- // Get the FIN
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(791),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
-
- if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- // Pause the endpoint`s protocolMainLoop.
- ep.(interface{ StopWork() }).StopWork()
-
- // Enqueue last ACK followed by an ACK matching the endpoint
- //
- // Send Last ACK for LAST_ACK --> CLOSED
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 791,
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- // Send a packet with ACK set, this would generate RST when
- // not using SYN cookies as in this test.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 792,
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- // Unpause endpoint`s protocolMainLoop.
- ep.(interface{ ResumeWork() }).ResumeWork()
-
- // Wait for the protocolMainLoop to resume and update state.
- time.Sleep(10 * time.Millisecond)
-
- // Expect the endpoint to be closed.
- if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = 1", got)
- }
-
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
- }
-
- // Check if the endpoint was moved to CLOSED and netstack a reset in
- // response to the ACK packet that we sent after last-ACK.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(0),
- checker.TCPFlags(header.TCPFlagRst),
- ),
- )
-}
-
-func TestSimpleReceive(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.EventIn)
- defer c.WQ.EventUnregister(&we)
-
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
- }
-
- data := []byte{1, 2, 3}
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
- // Receive data.
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %v", err)
- }
-
- if !bytes.Equal(data, v) {
- t.Fatalf("got data = %v, want = %v", v, data)
- }
-
- // Check that ACK is received.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-// TestUserSuppliedMSSOnConnectV4 tests that the user supplied MSS is used when
-// creating a new active IPv4 TCP socket. It should be present in the sent TCP
-// SYN segment.
-func TestUserSuppliedMSSOnConnectV4(t *testing.T) {
- const mtu = 5000
- const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
- tests := []struct {
- name string
- setMSS uint16
- expMSS uint16
- }{
- {
- "EqualToMaxMSS",
- maxMSS,
- maxMSS,
- },
- {
- "LessThanMTU",
- maxMSS - 1,
- maxMSS - 1,
- },
- {
- "GreaterThanMTU",
- maxMSS + 1,
- maxMSS,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := context.New(t, mtu)
- defer c.Cleanup()
-
- c.Create(-1)
-
- // Set the MSS socket option.
- opt := tcpip.MaxSegOption(test.setMSS)
- if err := c.EP.SetSockOpt(opt); err != nil {
- t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err)
- }
-
- // Get expected window size.
- rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
- if err != nil {
- t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err)
- }
- ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
-
- // Start connection attempt to IPv4 address.
- if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %v", err)
- }
-
- // Receive SYN packet with our user supplied MSS.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws})))
- })
- }
-}
-
-// TestUserSuppliedMSSOnConnectV6 tests that the user supplied MSS is used when
-// creating a new active IPv6 TCP socket. It should be present in the sent TCP
-// SYN segment.
-func TestUserSuppliedMSSOnConnectV6(t *testing.T) {
- const mtu = 5000
- const maxMSS = mtu - header.IPv6MinimumSize - header.TCPMinimumSize
- tests := []struct {
- name string
- setMSS uint16
- expMSS uint16
- }{
- {
- "EqualToMaxMSS",
- maxMSS,
- maxMSS,
- },
- {
- "LessThanMTU",
- maxMSS - 1,
- maxMSS - 1,
- },
- {
- "GreaterThanMTU",
- maxMSS + 1,
- maxMSS,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := context.New(t, mtu)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(true)
-
- // Set the MSS socket option.
- opt := tcpip.MaxSegOption(test.setMSS)
- if err := c.EP.SetSockOpt(opt); err != nil {
- t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err)
- }
-
- // Get expected window size.
- rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
- if err != nil {
- t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err)
- }
- ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
-
- // Start connection attempt to IPv6 address.
- if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %v", err)
- }
-
- // Receive SYN packet with our user supplied MSS.
- checker.IPv6(t, c.GetV6Packet(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws})))
- })
- }
-}
-
-func TestSendRstOnListenerRxSynAckV4(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.Create(-1)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(10); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- SeqNum: 100,
- AckNum: 200,
- })
-
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst),
- checker.SeqNum(200)))
-}
-
-func TestSendRstOnListenerRxSynAckV6(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(true)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(10); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- c.SendV6Packet(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- SeqNum: 100,
- AckNum: 200,
- })
-
- checker.IPv6(t, c.GetV6Packet(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst),
- checker.SeqNum(200)))
-}
-
-// TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete,
-// peers can send data and expect a response within a reasonable ammount of time
-// without calling Accept on the listening endpoint first.
-//
-// This test uses IPv4.
-func TestTCPAckBeforeAcceptV4(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.Create(-1)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(10); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
-
- // Send data before accepting the connection.
- c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- })
-
- // Receive ACK for the data we sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
-}
-
-// TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete,
-// peers can send data and expect a response within a reasonable ammount of time
-// without calling Accept on the listening endpoint first.
-//
-// This test uses IPv6.
-func TestTCPAckBeforeAcceptV6(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(true)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(10); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- irs, iss := executeV6Handshake(t, c, context.TestPort, false /* synCookiesInUse */)
-
- // Send data before accepting the connection.
- c.SendV6Packet([]byte{1, 2, 3, 4}, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- })
-
- // Receive ACK for the data we sent.
- checker.IPv6(t, c.GetV6Packet(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
-}
-
-func TestSendRstOnListenerRxAckV4(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.Create(-1 /* epRcvBuf */)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(10 /* backlog */); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagFin | header.TCPFlagAck,
- SeqNum: 100,
- AckNum: 200,
- })
-
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst),
- checker.SeqNum(200)))
-}
-
-func TestSendRstOnListenerRxAckV6(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(true /* v6Only */)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(10 /* backlog */); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- c.SendV6Packet(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagFin | header.TCPFlagAck,
- SeqNum: 100,
- AckNum: 200,
- })
-
- checker.IPv6(t, c.GetV6Packet(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst),
- checker.SeqNum(200)))
-}
-
-// TestListenShutdown tests for the listening endpoint not processing
-// any receive when it is on read shutdown.
-func TestListenShutdown(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.Create(-1 /* epRcvBuf */)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(10 /* backlog */); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
- t.Fatal("Shutdown failed:", err)
- }
-
- // Wait for the endpoint state to be propagated.
- time.Sleep(10 * time.Millisecond)
-
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: 100,
- AckNum: 200,
- })
-
- c.CheckNoPacket("Packet received when listening socket was shutdown")
-}
-
-func TestTOSV4(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- c.EP = ep
-
- const tos = 0xC0
- if err := c.EP.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil {
- t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err)
- }
-
- var v tcpip.IPv4TOSOption
- if err := c.EP.GetSockOpt(&v); err != nil {
- t.Errorf("GetSockopt failed: %s", err)
- }
-
- if want := tcpip.IPv4TOSOption(tos); v != want {
- t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
- }
-
- testV4Connect(t, c, checker.TOS(tos, 0))
-
- data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Check that data is received.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790), // Acknum is initial sequence number + 1
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- checker.TOS(tos, 0),
- )
-
- if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
- t.Errorf("got data = %x, want = %x", p, data)
- }
-}
-
-func TestTrafficClassV6(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(false)
-
- const tos = 0xC0
- if err := c.EP.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil {
- t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv6TrafficClassOption(tos), err)
- }
-
- var v tcpip.IPv6TrafficClassOption
- if err := c.EP.GetSockOpt(&v); err != nil {
- t.Fatalf("GetSockopt failed: %s", err)
- }
-
- if want := tcpip.IPv6TrafficClassOption(tos); v != want {
- t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
- }
-
- // Test the connection request.
- testV6Connect(t, c, checker.TOS(tos, 0))
-
- data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Check that data is received.
- b := c.GetV6Packet()
- checker.IPv6(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- checker.TOS(tos, 0),
- )
-
- if p := b[header.IPv6MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
- t.Errorf("got data = %x, want = %x", p, data)
- }
-}
-
-func TestConnectBindToDevice(t *testing.T) {
- for _, test := range []struct {
- name string
- device tcpip.NICID
- want tcp.EndpointState
- }{
- {"RightDevice", 1, tcp.StateEstablished},
- {"WrongDevice", 2, tcp.StateSynSent},
- {"AnyDevice", 0, tcp.StateEstablished},
- } {
- t.Run(test.name, func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.Create(-1)
- bindToDevice := tcpip.BindToDeviceOption(test.device)
- c.EP.SetSockOpt(bindToDevice)
- // Start connection attempt.
- waitEntry, _ := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&waitEntry, waiter.EventOut)
- defer c.WQ.EventUnregister(&waitEntry)
-
- if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %v", err)
- }
-
- // Receive SYN packet.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ),
- )
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
- t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
- }
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- iss := seqnum.Value(789)
- rcvWnd := seqnum.Size(30000)
- c.SendPacket(nil, &context.Headers{
- SrcPort: tcpHdr.DestinationPort(),
- DstPort: tcpHdr.SourcePort(),
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: rcvWnd,
- TCPOpts: nil,
- })
-
- c.GetPacket()
- if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want {
- t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
- }
- })
- }
-}
-
-func TestRstOnSynSent(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create an endpoint, don't handshake because we want to interfere with the
- // handshake process.
- c.Create(-1)
-
- // Start connection attempt.
- waitEntry, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&waitEntry, waiter.EventOut)
- defer c.WQ.EventUnregister(&waitEntry)
-
- addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
- if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted {
- t.Fatalf("got Connect(%+v) = %v, want %s", addr, err, tcpip.ErrConnectStarted)
- }
-
- // Receive SYN packet.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ),
- )
-
- // Ensure that we've reached SynSent state
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
- t.Fatalf("got State() = %s, want %s", got, want)
- }
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- // Send a packet with a proper ACK and a RST flag to cause the socket
- // to Error and close out
- iss := seqnum.Value(789)
- rcvWnd := seqnum.Size(30000)
- c.SendPacket(nil, &context.Headers{
- SrcPort: tcpHdr.DestinationPort(),
- DstPort: tcpHdr.SourcePort(),
- Flags: header.TCPFlagRst | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: rcvWnd,
- TCPOpts: nil,
- })
-
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(3 * time.Second):
- t.Fatal("timed out waiting for packet to arrive")
- }
-
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrConnectionRefused)
- }
-
- // Due to the RST the endpoint should be in an error state.
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
- t.Fatalf("got State() = %s, want %s", got, want)
- }
-}
-
-func TestOutOfOrderReceive(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.EventIn)
- defer c.WQ.EventUnregister(&we)
-
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
- }
-
- // Send second half of data first, with seqnum 3 ahead of expected.
- data := []byte{1, 2, 3, 4, 5, 6}
- c.SendPacket(data[3:], &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 793,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Check that we get an ACK specifying which seqnum is expected.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-
- // Wait 200ms and check that no data has been received.
- time.Sleep(200 * time.Millisecond)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
- }
-
- // Send the first 3 bytes now.
- c.SendPacket(data[:3], &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Receive data.
- read := make([]byte, 0, 6)
- for len(read) < len(data) {
- v, _, err := c.EP.Read(nil)
- if err != nil {
- if err == tcpip.ErrWouldBlock {
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(5 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
- continue
- }
- t.Fatalf("Read failed: %v", err)
- }
-
- read = append(read, v...)
- }
-
- // Check that we received the data in proper order.
- if !bytes.Equal(data, read) {
- t.Fatalf("got data = %v, want = %v", read, data)
- }
-
- // Check that the whole data is acknowledged.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestOutOfOrderFlood(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create a new connection with initial window size of 10.
- c.CreateConnected(789, 30000, 10)
-
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
- }
-
- // Send 100 packets before the actual one that is expected.
- data := []byte{1, 2, 3, 4, 5, 6}
- for i := 0; i < 100; i++ {
- c.SendPacket(data[3:], &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 796,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- }
-
- // Send packet with seqnum 793. It must be discarded because the
- // out-of-order buffer was filled by the previous packets.
- c.SendPacket(data[3:], &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 793,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-
- // Now send the expected packet, seqnum 790.
- c.SendPacket(data[:3], &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Check that only packet 790 is acknowledged.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(793),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestRstOnCloseWithUnreadData(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.EventIn)
- defer c.WQ.EventUnregister(&we)
-
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
- }
-
- data := []byte{1, 2, 3}
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(3 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
- // Check that ACK is received, this happens regardless of the read.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-
- // Now that we know we have unread data, let's just close the connection
- // and verify that netstack sends an RST rather than a FIN.
- c.EP.Close()
-
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
- // We shouldn't consume a sequence number on RST.
- checker.SeqNum(uint32(c.IRS)+1),
- ))
- // The RST puts the endpoint into an error state.
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- // This final ACK should be ignored because an ACK on a reset doesn't mean
- // anything.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790 + len(data)),
- AckNum: c.IRS.Add(seqnum.Size(2)),
- RcvWnd: 30000,
- })
-}
-
-func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.EventIn)
- defer c.WQ.EventUnregister(&we)
-
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
- }
-
- data := []byte{1, 2, 3}
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(3 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
- // Check that ACK is received, this happens regardless of the read.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-
- // Cause a FIN to be generated.
- c.EP.Shutdown(tcpip.ShutdownWrite)
-
- // Make sure we get the FIN but DON't ACK IT.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- checker.SeqNum(uint32(c.IRS)+1),
- ))
-
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- // Cause a RST to be generated by closing the read end now since we have
- // unread data.
- c.EP.Shutdown(tcpip.ShutdownRead)
-
- // Make sure we get the RST
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
- // RST is always generated with sndNxt which if the FIN
- // has been sent will be 1 higher than the sequence
- // number of the FIN itself.
- checker.SeqNum(uint32(c.IRS)+2),
- ))
- // The RST puts the endpoint into an error state.
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- // The ACK to the FIN should now be rejected since the connection has been
- // closed by a RST.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790 + len(data)),
- AckNum: c.IRS.Add(seqnum.Size(2)),
- RcvWnd: 30000,
- })
-}
-
-func TestShutdownRead(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
- }
-
- if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
- }
-
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive)
- }
- var want uint64 = 1
- if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want {
- t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %v want %v", got, want)
- }
-}
-
-func TestFullWindowReceive(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, 10)
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.EventIn)
- defer c.WQ.EventUnregister(&we)
-
- _, _, err := c.EP.Read(nil)
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("Read failed: %v", err)
- }
-
- // Fill up the window.
- data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(5 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
- // Check that data is acknowledged, and window goes to zero.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
- checker.TCPFlags(header.TCPFlagAck),
- checker.Window(0),
- ),
- )
-
- // Receive data and check it.
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %v", err)
- }
-
- if !bytes.Equal(data, v) {
- t.Fatalf("got data = %v, want = %v", v, data)
- }
-
- var want uint64 = 1
- if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want {
- t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %v want %v", got, want)
- }
-
- // Check that we get an ACK for the newly non-zero window.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
- checker.TCPFlags(header.TCPFlagAck),
- checker.Window(10),
- ),
- )
-}
-
-func TestNoWindowShrinking(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Start off with a window size of 10, then shrink it to 5.
- c.CreateConnected(789, 30000, 10)
-
- if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil {
- t.Fatalf("SetSockOpt failed: %v", err)
- }
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.EventIn)
- defer c.WQ.EventUnregister(&we)
-
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
- }
-
- // Send 3 bytes, check that the peer acknowledges them.
- data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
- c.SendPacket(data[:3], &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(5 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
- // Check that data is acknowledged, and that window doesn't go to zero
- // just yet because it was previously set to 10. It must go to 7 now.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(793),
- checker.TCPFlags(header.TCPFlagAck),
- checker.Window(7),
- ),
- )
-
- // Send 7 more bytes, check that the window fills up.
- c.SendPacket(data[3:], &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 793,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- select {
- case <-ch:
- case <-time.After(5 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
- checker.TCPFlags(header.TCPFlagAck),
- checker.Window(0),
- ),
- )
-
- // Receive data and check it.
- read := make([]byte, 0, 10)
- for len(read) < len(data) {
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %v", err)
- }
-
- read = append(read, v...)
- }
-
- if !bytes.Equal(data, read) {
- t.Fatalf("got data = %v, want = %v", read, data)
- }
-
- // Check that we get an ACK for the newly non-zero window, which is the
- // new size.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+len(data))),
- checker.TCPFlags(header.TCPFlagAck),
- checker.Window(5),
- ),
- )
-}
-
-func TestSimpleSend(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- // Check that data is received.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
-
- if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
- t.Fatalf("got data = %v, want = %v", p, data)
- }
-
- // Acknowledge the data.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: c.IRS.Add(1 + seqnum.Size(len(data))),
- RcvWnd: 30000,
- })
-}
-
-func TestZeroWindowSend(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(789, 0, -1 /* epRcvBuf */)
-
- data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{})
- if err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- // Since the window is currently zero, check that no packet is received.
- c.CheckNoPacket("Packet received when window is zero")
-
- // Open up the window. Data should be received now.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Check that data is received.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
-
- if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
- t.Fatalf("got data = %v, want = %v", p, data)
- }
-
- // Acknowledge the data.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: c.IRS.Add(1 + seqnum.Size(len(data))),
- RcvWnd: 30000,
- })
-}
-
-func TestScaledWindowConnect(t *testing.T) {
- // This test ensures that window scaling is used when the peer
- // does advertise it and connection is established with Connect().
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set the window size greater than the maximum non-scaled window.
- c.CreateConnectedWithRawOptions(789, 30000, 65535*3, []byte{
- header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
- })
-
- data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- // Check that data is received, and that advertised window is 0xbfff,
- // that is, that it is scaled.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(0xbfff),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
-}
-
-func TestNonScaledWindowConnect(t *testing.T) {
- // This test ensures that window scaling is not used when the peer
- // doesn't advertise it and connection is established with Connect().
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set the window size greater than the maximum non-scaled window.
- c.CreateConnected(789, 30000, 65535*3)
-
- data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- // Check that data is received, and that advertised window is 0xffff,
- // that is, that it's not scaled.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(0xffff),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
-}
-
-func TestScaledWindowAccept(t *testing.T) {
- // This test ensures that window scaling is used when the peer
- // does advertise it and connection is established with Accept().
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create EP and start listening.
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
- defer ep.Close()
-
- // Set the window size greater than the maximum non-scaled window.
- if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
- t.Fatalf("SetSockOpt failed failed: %v", err)
- }
-
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-
- // Do 3-way handshake.
- c.PassiveConnectWithOptions(100, 2, header.TCPSynOptions{MSS: defaultIPv4MSS})
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept()
- if err != nil {
- t.Fatalf("Accept failed: %v", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- // Check that data is received, and that advertised window is 0xbfff,
- // that is, that it is scaled.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(0xbfff),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
-}
-
-func TestNonScaledWindowAccept(t *testing.T) {
- // This test ensures that window scaling is not used when the peer
- // doesn't advertise it and connection is established with Accept().
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create EP and start listening.
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
- defer ep.Close()
-
- // Set the window size greater than the maximum non-scaled window.
- if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
- t.Fatalf("SetSockOpt failed failed: %v", err)
- }
-
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-
- // Do 3-way handshake w/ window scaling disabled. The SYN-ACK to the SYN
- // should not carry the window scaling option.
- c.PassiveConnect(100, -1, header.TCPSynOptions{MSS: defaultIPv4MSS})
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept()
- if err != nil {
- t.Fatalf("Accept failed: %v", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- // Check that data is received, and that advertised window is 0xffff,
- // that is, that it's not scaled.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(0xffff),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
-}
-
-func TestZeroScaledWindowReceive(t *testing.T) {
- // This test ensures that the endpoint sends a non-zero window size
- // advertisement when the scaled window transitions from 0 to non-zero,
- // but the actual window (not scaled) hasn't gotten to zero.
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set the window size such that a window scale of 4 will be used.
- const wnd = 65535 * 10
- const ws = uint32(4)
- c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{
- header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
- })
-
- // Write chunks of 50000 bytes.
- remain := wnd
- sent := 0
- data := make([]byte, 50000)
- for remain > len(data) {
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790 + sent),
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
- sent += len(data)
- remain -= len(data)
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(uint16(remain>>ws)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- }
-
- // Make the window non-zero, but the scaled window zero.
- if remain >= 16 {
- data = data[:remain-15]
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790 + sent),
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
- sent += len(data)
- remain -= len(data)
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(0),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- }
-
- // Read at least 1MSS of data. An ack should be sent in response to that.
- sz := 0
- for sz < defaultMTU {
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %v", err)
- }
- sz += len(v)
- }
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(uint16(sz>>ws)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestSegmentMerging(t *testing.T) {
- tests := []struct {
- name string
- stop func(tcpip.Endpoint)
- resume func(tcpip.Endpoint)
- }{
- {
- "stop work",
- func(ep tcpip.Endpoint) {
- ep.(interface{ StopWork() }).StopWork()
- },
- func(ep tcpip.Endpoint) {
- ep.(interface{ ResumeWork() }).ResumeWork()
- },
- },
- {
- "cork",
- func(ep tcpip.Endpoint) {
- ep.SetSockOpt(tcpip.CorkOption(1))
- },
- func(ep tcpip.Endpoint) {
- ep.SetSockOpt(tcpip.CorkOption(0))
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- // Prevent the endpoint from processing packets.
- test.stop(c.EP)
-
- var allData []byte
- for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
- allData = append(allData, data...)
- view := buffer.NewViewFromBytes(data)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %v", i+1, err)
- }
- }
-
- // Let the endpoint process the segments that we just sent.
- test.resume(c.EP)
-
- // Check that data is received.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.PayloadLen(len(allData)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
-
- if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, allData) {
- t.Fatalf("got data = %v, want = %v", got, allData)
- }
-
- // Acknowledge the data.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: c.IRS.Add(1 + seqnum.Size(len(allData))),
- RcvWnd: 30000,
- })
- })
- }
-}
-
-func TestDelay(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- c.EP.SetSockOptInt(tcpip.DelayOption, 1)
-
- var allData []byte
- for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
- allData = append(allData, data...)
- view := buffer.NewViewFromBytes(data)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %v", i+1, err)
- }
- }
-
- seq := c.IRS.Add(1)
- for _, want := range [][]byte{allData[:1], allData[1:]} {
- // Check that data is received.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.PayloadLen(len(want)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(seq)),
- checker.AckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
-
- if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, want) {
- t.Fatalf("got data = %v, want = %v", got, want)
- }
-
- seq = seq.Add(seqnum.Size(len(want)))
- // Acknowledge the data.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: seq,
- RcvWnd: 30000,
- })
- }
-}
-
-func TestUndelay(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- c.EP.SetSockOptInt(tcpip.DelayOption, 1)
-
- allData := [][]byte{{0}, {1, 2, 3}}
- for i, data := range allData {
- view := buffer.NewViewFromBytes(data)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %v", i+1, err)
- }
- }
-
- seq := c.IRS.Add(1)
-
- // Check that data is received.
- first := c.GetPacket()
- checker.IPv4(t, first,
- checker.PayloadLen(len(allData[0])+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(seq)),
- checker.AckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
-
- if got, want := first[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[0]; !bytes.Equal(got, want) {
- t.Fatalf("got first packet's data = %v, want = %v", got, want)
- }
-
- seq = seq.Add(seqnum.Size(len(allData[0])))
-
- // Check that we don't get the second packet yet.
- c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond)
-
- c.EP.SetSockOptInt(tcpip.DelayOption, 0)
-
- // Check that data is received.
- second := c.GetPacket()
- checker.IPv4(t, second,
- checker.PayloadLen(len(allData[1])+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(seq)),
- checker.AckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
-
- if got, want := second[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[1]; !bytes.Equal(got, want) {
- t.Fatalf("got second packet's data = %v, want = %v", got, want)
- }
-
- seq = seq.Add(seqnum.Size(len(allData[1])))
-
- // Acknowledge the data.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: seq,
- RcvWnd: 30000,
- })
-}
-
-func TestMSSNotDelayed(t *testing.T) {
- tests := []struct {
- name string
- fn func(tcpip.Endpoint)
- }{
- {"no-op", func(tcpip.Endpoint) {}},
- {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptInt(tcpip.DelayOption, 1) }},
- {"cork", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.CorkOption(1)) }},
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- const maxPayload = 100
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
- header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
- })
-
- test.fn(c.EP)
-
- allData := [][]byte{{0}, make([]byte, maxPayload), make([]byte, maxPayload)}
- for i, data := range allData {
- view := buffer.NewViewFromBytes(data)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %v", i+1, err)
- }
- }
-
- seq := c.IRS.Add(1)
-
- for i, data := range allData {
- // Check that data is received.
- packet := c.GetPacket()
- checker.IPv4(t, packet,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(seq)),
- checker.AckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
-
- if got, want := packet[header.IPv4MinimumSize+header.TCPMinimumSize:], data; !bytes.Equal(got, want) {
- t.Fatalf("got packet #%d's data = %v, want = %v", i+1, got, want)
- }
-
- seq = seq.Add(seqnum.Size(len(data)))
- }
-
- // Acknowledge the data.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: seq,
- RcvWnd: 30000,
- })
- })
- }
-}
-
-func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
- payloadMultiplier := 10
- dataLen := payloadMultiplier * maxPayload
- data := make([]byte, dataLen)
- for i := range data {
- data[i] = byte(i)
- }
-
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- // Check that data is received in chunks.
- bytesReceived := 0
- numPackets := 0
- for bytesReceived != dataLen {
- b := c.GetPacket()
- numPackets++
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- payloadLen := len(tcpHdr.Payload())
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1+uint32(bytesReceived)),
- checker.AckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
-
- pdata := data[bytesReceived : bytesReceived+payloadLen]
- if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) {
- t.Fatalf("got data = %v, want = %v", p, pdata)
- }
- bytesReceived += payloadLen
- var options []byte
- if c.TimeStampEnabled {
- // If timestamp option is enabled, echo back the timestamp and increment
- // the TSEcr value included in the packet and send that back as the TSVal.
- parsedOpts := tcpHdr.ParsedOptions()
- tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
- header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:])
- options = tsOpt[:]
- }
- // Acknowledge the data.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)),
- RcvWnd: 30000,
- TCPOpts: options,
- })
- }
- if numPackets == 1 {
- t.Fatalf("expected write to be broken up into multiple packets, but got 1 packet")
- }
-}
-
-func TestSendGreaterThanMTU(t *testing.T) {
- const maxPayload = 100
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- testBrokenUpWrite(t, c, maxPayload)
-}
-
-func TestSetTTL(t *testing.T) {
- for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} {
- t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) {
- c := context.New(t, 65535)
- defer c.Cleanup()
-
- var err *tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if err := c.EP.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil {
- t.Fatalf("SetSockOpt failed: %v", err)
- }
-
- if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %v", err)
- }
-
- // Receive SYN packet.
- b := c.GetPacket()
-
- checker.IPv4(t, b, checker.TTL(wantTTL))
- })
- }
-}
-
-func TestActiveSendMSSLessThanMTU(t *testing.T) {
- const maxPayload = 100
- c := context.New(t, 65535)
- defer c.Cleanup()
-
- c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
- header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
- })
- testBrokenUpWrite(t, c, maxPayload)
-}
-
-func TestPassiveSendMSSLessThanMTU(t *testing.T) {
- const maxPayload = 100
- const mtu = 1200
- c := context.New(t, mtu)
- defer c.Cleanup()
-
- // Create EP and start listening.
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
- defer ep.Close()
-
- // Set the buffer size to a deterministic size so that we can check the
- // window scaling option.
- const rcvBufferSize = 0x20000
- if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
- t.Fatalf("SetSockOpt failed failed: %v", err)
- }
-
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-
- // Do 3-way handshake.
- c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept()
- if err != nil {
- t.Fatalf("Accept failed: %v", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Check that data gets properly segmented.
- testBrokenUpWrite(t, c, maxPayload)
-}
-
-func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) {
- const maxPayload = 536
- const mtu = 2000
- c := context.New(t, mtu)
- defer c.Cleanup()
-
- // Set the SynRcvd threshold to zero to force a syn cookie based accept
- // to happen.
- saved := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = saved
- }()
- tcp.SynRcvdCountThreshold = 0
-
- // Create EP and start listening.
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
- defer ep.Close()
-
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-
- // Do 3-way handshake.
- c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept()
- if err != nil {
- t.Fatalf("Accept failed: %v", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Check that data gets properly segmented.
- testBrokenUpWrite(t, c, maxPayload)
-}
-
-func TestForwarderSendMSSLessThanMTU(t *testing.T) {
- const maxPayload = 100
- const mtu = 1200
- c := context.New(t, mtu)
- defer c.Cleanup()
-
- s := c.Stack()
- ch := make(chan *tcpip.Error, 1)
- f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) {
- var err *tcpip.Error
- c.EP, err = r.CreateEndpoint(&c.WQ)
- ch <- err
- })
- s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket)
-
- // Do 3-way handshake.
- c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
-
- // Wait for connection to be available.
- select {
- case err := <-ch:
- if err != nil {
- t.Fatalf("Error creating endpoint: %v", err)
- }
- case <-time.After(2 * time.Second):
- t.Fatalf("Timed out waiting for connection")
- }
-
- // Check that data gets properly segmented.
- testBrokenUpWrite(t, c, maxPayload)
-}
-
-func TestSynOptionsOnActiveConnect(t *testing.T) {
- const mtu = 1400
- c := context.New(t, mtu)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err *tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- // Set the buffer size to a deterministic size so that we can check the
- // window scaling option.
- const rcvBufferSize = 0x20000
- const wndScale = 2
- if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
- t.Fatalf("SetSockOpt failed failed: %v", err)
- }
-
- // Start connection attempt.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.EventOut)
- defer c.WQ.EventUnregister(&we)
-
- if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
- }
-
- // Receive SYN packet.
- b := c.GetPacket()
- mss := uint16(mtu - header.IPv4MinimumSize - header.TCPMinimumSize)
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}),
- ),
- )
-
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- // Wait for retransmit.
- time.Sleep(1 * time.Second)
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- checker.SrcPort(tcpHdr.SourcePort()),
- checker.SeqNum(tcpHdr.SequenceNumber()),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}),
- ),
- )
-
- // Send SYN-ACK.
- iss := seqnum.Value(789)
- c.SendPacket(nil, &context.Headers{
- SrcPort: tcpHdr.DestinationPort(),
- DstPort: tcpHdr.SourcePort(),
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- // Receive ACK packet.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+1),
- ),
- )
-
- // Wait for connection to be established.
- select {
- case <-ch:
- if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
- }
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for connection")
- }
-}
-
-func TestCloseListener(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create listener.
- var wq waiter.Queue
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if err := ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-
- // Close the listener and measure how long it takes.
- t0 := time.Now()
- ep.Close()
- if diff := time.Now().Sub(t0); diff > 3*time.Second {
- t.Fatalf("Took too long to close: %v", diff)
- }
-}
-
-func TestReceiveOnResetConnection(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- // Send RST segment.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagRst,
- SeqNum: 790,
- RcvWnd: 30000,
- })
-
- // Try to read.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.EventIn)
- defer c.WQ.EventUnregister(&we)
-
-loop:
- for {
- switch _, _, err := c.EP.Read(nil); err {
- case tcpip.ErrWouldBlock:
- select {
- case <-ch:
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for reset to arrive")
- }
- case tcpip.ErrConnectionReset:
- break loop
- default:
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset)
- }
- }
- // Expect the state to be StateError and subsequent Reads to fail with HardError.
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset)
- }
- if tcp.EndpointState(c.EP.State()) != tcp.StateError {
- t.Fatalf("got EP state is not StateError")
- }
-
- if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
- t.Errorf("got stats.TCP.EstablishedResets.Value() = %v, want = 1", got)
- }
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
- }
-}
-
-func TestSendOnResetConnection(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- // Send RST segment.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagRst,
- SeqNum: 790,
- RcvWnd: 30000,
- })
-
- // Wait for the RST to be received.
- time.Sleep(1 * time.Second)
-
- // Try to write.
- view := buffer.NewView(10)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Write(...) = %v, want = %v", err, tcpip.ErrConnectionReset)
- }
-}
-
-func TestFinImmediately(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- // Shutdown immediately, check that we get a FIN.
- if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
- }
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
-
- // Ack and send FIN as well.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 790,
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- // Check that the stack acks the FIN.
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(791),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestFinRetransmit(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- // Shutdown immediately, check that we get a FIN.
- if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
- }
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
-
- // Don't acknowledge yet. We should get a retransmit of the FIN.
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
-
- // Ack and send FIN as well.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 790,
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- // Check that the stack acks the FIN.
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(791),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestFinWithNoPendingData(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- // Write something out, and have it acknowledged.
- view := buffer.NewView(10)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- next := uint32(c.IRS) + 1
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
- next += uint32(len(view))
-
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- // Shutdown, check that we get a FIN.
- if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
- }
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
- next++
-
- // Ack and send FIN as well.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 790,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- // Check that the stack acks the FIN.
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestFinWithPendingDataCwndFull(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- // Write enough segments to fill the congestion window before ACK'ing
- // any of them.
- view := buffer.NewView(10)
- for i := tcp.InitialCwnd; i > 0; i-- {
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
- }
-
- next := uint32(c.IRS) + 1
- for i := tcp.InitialCwnd; i > 0; i-- {
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
- next += uint32(len(view))
- }
-
- // Shutdown the connection, check that the FIN segment isn't sent
- // because the congestion window doesn't allow it. Wait until a
- // retransmit is received.
- if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
- }
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
-
- // Send the ACK that will allow the FIN to be sent as well.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
- next++
-
- // Send a FIN that acknowledges everything. Get an ACK back.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 790,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestFinWithPendingData(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- // Write something out, and acknowledge it to get cwnd to 2.
- view := buffer.NewView(10)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- next := uint32(c.IRS) + 1
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
- next += uint32(len(view))
-
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- // Write new data, but don't acknowledge it.
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
- next += uint32(len(view))
-
- // Shutdown the connection, check that we do get a FIN.
- if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
- }
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
- next++
-
- // Send a FIN that acknowledges everything. Get an ACK back.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 790,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestFinWithPartialAck(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- // Write something out, and acknowledge it to get cwnd to 2. Also send
- // FIN from the test side.
- view := buffer.NewView(10)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- next := uint32(c.IRS) + 1
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
- next += uint32(len(view))
-
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 790,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- // Check that we get an ACK for the fin.
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
-
- // Write new data, but don't acknowledge it.
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
- next += uint32(len(view))
-
- // Shutdown the connection, check that we do get a FIN.
- if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
- }
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(791),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
- next++
-
- // Send an ACK for the data, but not for the FIN yet.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 791,
- AckNum: seqnum.Value(next - 1),
- RcvWnd: 30000,
- })
-
- // Check that we don't get a retransmit of the FIN.
- c.CheckNoPacketTimeout("FIN retransmitted when data was ack'd", 100*time.Millisecond)
-
- // Ack the FIN.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 791,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-}
-
-func TestUpdateListenBacklog(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create listener.
- var wq waiter.Queue
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if err := ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-
- // Update the backlog with another Listen() on the same endpoint.
- if err := ep.Listen(20); err != nil {
- t.Fatalf("Listen failed to update backlog: %v", err)
- }
-
- ep.Close()
-}
-
-func scaledSendWindow(t *testing.T, scale uint8) {
- // This test ensures that the endpoint is using the right scaling by
- // sending a buffer that is larger than the window size, and ensuring
- // that the endpoint doesn't send more than allowed.
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize
- c.CreateConnectedWithRawOptions(789, 0, -1 /* epRcvBuf */, []byte{
- header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
- header.TCPOptionWS, 3, scale, header.TCPOptionNOP,
- })
-
- // Open up the window with a scaled value.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: c.IRS.Add(1),
- RcvWnd: 1,
- })
-
- // Send some data. Check that it's capped by the window size.
- view := buffer.NewView(65535)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- // Check that only data that fits in the scaled window is sent.
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen((1<<scale)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
-
- // Reset the connection to free resources.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagRst,
- SeqNum: 790,
- })
-}
-
-func TestScaledSendWindow(t *testing.T) {
- for scale := uint8(0); scale <= 14; scale++ {
- scaledSendWindow(t, scale)
- }
-}
-
-func TestReceivedValidSegmentCountIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- stats := c.Stack().Stats()
- want := stats.TCP.ValidSegmentsReceived.Value() + 1
-
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790),
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- if got := stats.TCP.ValidSegmentsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %v, want = %v", got, want)
- }
- if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want {
- t.Errorf("got EP stats Stats.SegmentsReceived = %v, want = %v", got, want)
- }
- // Ensure there were no errors during handshake. If these stats have
- // incremented, then the connection should not have been established.
- if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 {
- t.Errorf("got EP stats Stats.SendErrors.NoRoute = %v, want = %v", got, 0)
- }
- if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoLinkAddr.Value(); got != 0 {
- t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %v, want = %v", got, 0)
- }
-}
-
-func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- stats := c.Stack().Stats()
- want := stats.TCP.InvalidSegmentsReceived.Value() + 1
- vv := c.BuildSegment(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790),
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
- tcpbuf := vv.First()[header.IPv4MinimumSize:]
- tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4
-
- c.SendSegment(vv)
-
- if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %v, want = %v", got, want)
- }
- if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want)
- }
-}
-
-func TestReceivedIncorrectChecksumIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- stats := c.Stack().Stats()
- want := stats.TCP.ChecksumErrors.Value() + 1
- vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790),
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
- tcpbuf := vv.First()[header.IPv4MinimumSize:]
- // Overwrite a byte in the payload which should cause checksum
- // verification to fail.
- tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4
-
- c.SendSegment(vv)
-
- if got := stats.TCP.ChecksumErrors.Value(); got != want {
- t.Errorf("got stats.TCP.ChecksumErrors.Value() = %d, want = %d", got, want)
- }
- if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ChecksumErrors.Value(); got != want {
- t.Errorf("got EP stats Stats.ReceiveErrors.ChecksumErrors = %d, want = %d", got, want)
- }
-}
-
-func TestReceivedSegmentQueuing(t *testing.T) {
- // This test sends 200 segments containing a few bytes each to an
- // endpoint and checks that they're all received and acknowledged by
- // the endpoint, that is, that none of the segments are dropped by
- // internal queues.
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- // Send 200 segments.
- data := []byte{1, 2, 3}
- for i := 0; i < 200; i++ {
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790 + i*len(data)),
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
- }
-
- // Receive ACKs for all segments.
- last := seqnum.Value(790 + 200*len(data))
- for {
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- ack := seqnum.Value(tcpHdr.AckNumber())
- if ack == last {
- break
- }
-
- if last.LessThan(ack) {
- t.Fatalf("Acknowledge (%v) beyond the expected (%v)", ack, last)
- }
- }
-}
-
-func TestReadAfterClosedState(t *testing.T) {
- // This test ensures that calling Read() or Peek() after the endpoint
- // has transitioned to closedState still works if there is pending
- // data. To transition to stateClosed without calling Close(), we must
- // shutdown the send path and the peer must send its own FIN.
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed
- // after 1 second in TIME_WAIT state.
- tcpTimeWaitTimeout := 1 * time.Second
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
- t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err)
- }
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.EventIn)
- defer c.WQ.EventUnregister(&we)
-
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrWouldBlock)
- }
-
- // Shutdown immediately for write, check that we get a FIN.
- if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %s", err)
- }
-
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
- ),
- )
-
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- // Send some data and acknowledge the FIN.
- data := []byte{1, 2, 3}
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: 790,
- AckNum: c.IRS.Add(2),
- RcvWnd: 30000,
- })
-
- // Check that ACK is received.
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+2),
- checker.AckNum(uint32(791+len(data))),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-
- // Give the stack the chance to transition to closed state from
- // TIME_WAIT.
- time.Sleep(tcpTimeWaitTimeout * 2)
-
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
- // Check that peek works.
- peekBuf := make([]byte, 10)
- n, _, err := c.EP.Peek([][]byte{peekBuf})
- if err != nil {
- t.Fatalf("Peek failed: %s", err)
- }
-
- peekBuf = peekBuf[:n]
- if !bytes.Equal(data, peekBuf) {
- t.Fatalf("got data = %v, want = %v", peekBuf, data)
- }
-
- // Receive data.
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
- }
-
- if !bytes.Equal(data, v) {
- t.Fatalf("got data = %v, want = %v", v, data)
- }
-
- // Now that we drained the queue, check that functions fail with the
- // right error code.
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrClosedForReceive)
- }
-
- if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Peek(...) = %v, want = %s", err, tcpip.ErrClosedForReceive)
- }
-}
-
-func TestReusePort(t *testing.T) {
- // This test ensures that ports are immediately available for reuse
- // after Close on the endpoints using them returns.
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // First case, just an endpoint that was bound.
- var err *tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
- }
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- c.EP.Close()
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
- }
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
- c.EP.Close()
-
- // Second case, an endpoint that was bound and is connecting..
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
- }
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
- if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
- }
- c.EP.Close()
-
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
- }
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
- c.EP.Close()
-
- // Third case, an endpoint that was bound and is listening.
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
- }
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
- if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
- c.EP.Close()
-
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
- }
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
- if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-}
-
-func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
- t.Helper()
-
- s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
- if err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
- }
-
- if int(s) != v {
- t.Fatalf("got receive buffer size = %v, want = %v", s, v)
- }
-}
-
-func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
- t.Helper()
-
- s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption)
- if err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
- }
-
- if int(s) != v {
- t.Fatalf("got send buffer size = %v, want = %v", s, v)
- }
-}
-
-func TestDefaultBufferSizes(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
- })
-
- // Check the default values.
- ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
- }
- defer func() {
- if ep != nil {
- ep.Close()
- }
- }()
-
- checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize)
- checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
-
- // Change the default send buffer size.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize * 2, tcp.DefaultSendBufferSize * 20}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
- }
-
- ep.Close()
- ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
- }
-
- checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2)
- checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
-
- // Change the default receive buffer size.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize * 3, tcp.DefaultReceiveBufferSize * 30}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
- }
-
- ep.Close()
- ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
- }
-
- checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2)
- checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*3)
-}
-
-func TestMinMaxBufferSizes(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
- })
-
- // Check the default values.
- ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
- }
- defer ep.Close()
-
- // Change the min/max values for send/receive
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{200, tcp.DefaultReceiveBufferSize * 2, tcp.DefaultReceiveBufferSize * 20}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
- }
-
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{300, tcp.DefaultSendBufferSize * 3, tcp.DefaultSendBufferSize * 30}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
- }
-
- // Set values below the min.
- if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
- }
-
- checkRecvBufferSize(t, ep, 200)
-
- if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
- }
-
- checkSendBufferSize(t, ep, 300)
-
- // Set values above the max.
- if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
- }
-
- checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20)
-
- if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
- }
-
- checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30)
-}
-
-func TestBindToDeviceOption(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}})
-
- ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
- }
- defer ep.Close()
-
- if err := s.CreateNIC(321, loopback.New()); err != nil {
- t.Errorf("CreateNIC failed: %v", err)
- }
-
- // nicIDPtr is used instead of taking the address of NICID literals, which is
- // a compiler error.
- nicIDPtr := func(s tcpip.NICID) *tcpip.NICID {
- return &s
- }
-
- testActions := []struct {
- name string
- setBindToDevice *tcpip.NICID
- setBindToDeviceError *tcpip.Error
- getBindToDevice tcpip.BindToDeviceOption
- }{
- {"GetDefaultValue", nil, nil, 0},
- {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
- {"BindToExistent", nicIDPtr(321), nil, 321},
- {"UnbindToDevice", nicIDPtr(0), nil, 0},
- }
- for _, testAction := range testActions {
- t.Run(testAction.name, func(t *testing.T) {
- if testAction.setBindToDevice != nil {
- bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
- if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
- t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr)
- }
- }
- bindToDevice := tcpip.BindToDeviceOption(88888)
- if err := ep.GetSockOpt(&bindToDevice); err != nil {
- t.Errorf("GetSockOpt got %v, want %v", err, nil)
- }
- if got, want := bindToDevice, testAction.getBindToDevice; got != want {
- t.Errorf("bindToDevice got %d, want %d", got, want)
- }
- })
- }
-}
-
-func makeStack() (*stack.Stack, *tcpip.Error) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{
- ipv4.NewProtocol(),
- ipv6.NewProtocol(),
- },
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
- })
-
- id := loopback.New()
- if testing.Verbose() {
- id = sniffer.New(id)
- }
-
- if err := s.CreateNIC(1, id); err != nil {
- return nil, err
- }
-
- for _, ct := range []struct {
- number tcpip.NetworkProtocolNumber
- address tcpip.Address
- }{
- {ipv4.ProtocolNumber, context.StackAddr},
- {ipv6.ProtocolNumber, context.StackV6Addr},
- } {
- if err := s.AddAddress(1, ct.number, ct.address); err != nil {
- return nil, err
- }
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: 1,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: 1,
- },
- })
-
- return s, nil
-}
-
-func TestSelfConnect(t *testing.T) {
- // This test ensures that intentional self-connects work. In particular,
- // it checks that if an endpoint binds to say 127.0.0.1:1000 then
- // connects to 127.0.0.1:1000, then it will be connected to itself, and
- // is able to send and receive data through the same endpoint.
- s, err := makeStack()
- if err != nil {
- t.Fatal(err)
- }
-
- var wq waiter.Queue
- ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
- defer ep.Close()
-
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Register for notification, then start connection attempt.
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- wq.EventRegister(&waitEntry, waiter.EventOut)
- defer wq.EventUnregister(&waitEntry)
-
- if err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
- }
-
- <-notifyCh
- if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != nil {
- t.Fatalf("Connect failed: %v", err)
- }
-
- // Write something.
- data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
- if _, _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- // Read back what was written.
- wq.EventUnregister(&waitEntry)
- wq.EventRegister(&waitEntry, waiter.EventIn)
- rd, _, err := ep.Read(nil)
- if err != nil {
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("Read failed: %v", err)
- }
- <-notifyCh
- rd, _, err = ep.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %v", err)
- }
- }
-
- if !bytes.Equal(data, rd) {
- t.Fatalf("got data = %v, want = %v", rd, data)
- }
-}
-
-func TestConnectAvoidsBoundPorts(t *testing.T) {
- addressTypes := func(t *testing.T, network string) []string {
- switch network {
- case "ipv4":
- return []string{"v4"}
- case "ipv6":
- return []string{"v6"}
- case "dual":
- return []string{"v6", "mapped"}
- default:
- t.Fatalf("unknown network: '%s'", network)
- }
-
- panic("unreachable")
- }
-
- address := func(t *testing.T, addressType string, isAny bool) tcpip.Address {
- switch addressType {
- case "v4":
- if isAny {
- return ""
- }
- return context.StackAddr
- case "v6":
- if isAny {
- return ""
- }
- return context.StackV6Addr
- case "mapped":
- if isAny {
- return context.V4MappedWildcardAddr
- }
- return context.StackV4MappedAddr
- default:
- t.Fatalf("unknown address type: '%s'", addressType)
- }
-
- panic("unreachable")
- }
- // This test ensures that Endpoint.Connect doesn't select already-bound ports.
- networks := []string{"ipv4", "ipv6", "dual"}
- for _, exhaustedNetwork := range networks {
- t.Run(fmt.Sprintf("exhaustedNetwork=%s", exhaustedNetwork), func(t *testing.T) {
- for _, exhaustedAddressType := range addressTypes(t, exhaustedNetwork) {
- t.Run(fmt.Sprintf("exhaustedAddressType=%s", exhaustedAddressType), func(t *testing.T) {
- for _, isAny := range []bool{false, true} {
- t.Run(fmt.Sprintf("isAny=%t", isAny), func(t *testing.T) {
- for _, candidateNetwork := range networks {
- t.Run(fmt.Sprintf("candidateNetwork=%s", candidateNetwork), func(t *testing.T) {
- for _, candidateAddressType := range addressTypes(t, candidateNetwork) {
- t.Run(fmt.Sprintf("candidateAddressType=%s", candidateAddressType), func(t *testing.T) {
- s, err := makeStack()
- if err != nil {
- t.Fatal(err)
- }
-
- var wq waiter.Queue
- var eps []tcpip.Endpoint
- defer func() {
- for _, ep := range eps {
- ep.Close()
- }
- }()
- makeEP := func(network string) tcpip.Endpoint {
- var networkProtocolNumber tcpip.NetworkProtocolNumber
- switch network {
- case "ipv4":
- networkProtocolNumber = ipv4.ProtocolNumber
- case "ipv6", "dual":
- networkProtocolNumber = ipv6.ProtocolNumber
- default:
- t.Fatalf("unknown network: '%s'", network)
- }
- ep, err := s.NewEndpoint(tcp.ProtocolNumber, networkProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
- eps = append(eps, ep)
- switch network {
- case "ipv4":
- case "ipv6":
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(V6OnlyOption(true)) failed: %v", err)
- }
- case "dual":
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil {
- t.Fatalf("SetSockOpt(V6OnlyOption(false)) failed: %v", err)
- }
- default:
- t.Fatalf("unknown network: '%s'", network)
- }
- return ep
- }
-
- var v4reserved, v6reserved bool
- switch exhaustedAddressType {
- case "v4", "mapped":
- v4reserved = true
- case "v6":
- v6reserved = true
- // Dual stack sockets bound to v6 any reserve on v4 as
- // well.
- if isAny {
- switch exhaustedNetwork {
- case "ipv6":
- case "dual":
- v4reserved = true
- default:
- t.Fatalf("unknown address type: '%s'", exhaustedNetwork)
- }
- }
- default:
- t.Fatalf("unknown address type: '%s'", exhaustedAddressType)
- }
- var collides bool
- switch candidateAddressType {
- case "v4", "mapped":
- collides = v4reserved
- case "v6":
- collides = v6reserved
- default:
- t.Fatalf("unknown address type: '%s'", candidateAddressType)
- }
-
- for i := ports.FirstEphemeral; i <= math.MaxUint16; i++ {
- if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil {
- t.Fatalf("Bind(%d) failed: %v", i, err)
- }
- }
- want := tcpip.ErrConnectStarted
- if collides {
- want = tcpip.ErrNoPortAvailable
- }
- if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want {
- t.Fatalf("got ep.Connect(..) = %v, want = %v", err, want)
- }
- })
- }
- })
- }
- })
- }
- })
- }
- })
- }
-}
-
-func TestPathMTUDiscovery(t *testing.T) {
- // This test verifies the stack retransmits packets after it receives an
- // ICMP packet indicating that the path MTU has been exceeded.
- c := context.New(t, 1500)
- defer c.Cleanup()
-
- // Create new connection with MSS of 1460.
- const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize
- c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
- header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
- })
-
- // Send 3200 bytes of data.
- const writeSize = 3200
- data := buffer.NewView(writeSize)
- for i := range data {
- data[i] = byte(i)
- }
-
- if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte {
- var ret []byte
- for i, size := range sizes {
- p := c.GetPacket()
- if i == which {
- ret = p
- }
- checker.IPv4(t, p,
- checker.PayloadLen(size+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(seqNum),
- checker.AckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
- seqNum += uint32(size)
- }
- return ret
- }
-
- // Receive three packets.
- sizes := []int{maxPayload, maxPayload, writeSize - 2*maxPayload}
- first := receivePackets(c, sizes, 0, uint32(c.IRS)+1)
-
- // Send "packet too big" messages back to netstack.
- const newMTU = 1200
- const newMaxPayload = newMTU - header.IPv4MinimumSize - header.TCPMinimumSize
- mtu := []byte{0, 0, newMTU / 256, newMTU % 256}
- c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, mtu, first, newMTU)
-
- // See retransmitted packets. None exceeding the new max.
- sizes = []int{newMaxPayload, maxPayload - newMaxPayload, newMaxPayload, maxPayload - newMaxPayload, writeSize - 2*maxPayload}
- receivePackets(c, sizes, -1, uint32(c.IRS)+1)
-}
-
-func TestTCPEndpointProbe(t *testing.T) {
- c := context.New(t, 1500)
- defer c.Cleanup()
-
- invoked := make(chan struct{})
- c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
- // Validate that the endpoint ID is what we expect.
- //
- // We don't do an extensive validation of every field but a
- // basic sanity test.
- if got, want := state.ID.LocalAddress, tcpip.Address(context.StackAddr); got != want {
- t.Fatalf("got LocalAddress: %q, want: %q", got, want)
- }
- if got, want := state.ID.LocalPort, c.Port; got != want {
- t.Fatalf("got LocalPort: %d, want: %d", got, want)
- }
- if got, want := state.ID.RemoteAddress, tcpip.Address(context.TestAddr); got != want {
- t.Fatalf("got RemoteAddress: %q, want: %q", got, want)
- }
- if got, want := state.ID.RemotePort, uint16(context.TestPort); got != want {
- t.Fatalf("got RemotePort: %d, want: %d", got, want)
- }
-
- invoked <- struct{}{}
- })
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- data := []byte{1, 2, 3}
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
-
- select {
- case <-invoked:
- case <-time.After(100 * time.Millisecond):
- t.Fatalf("TCP Probe function was not called")
- }
-}
-
-func TestStackSetCongestionControl(t *testing.T) {
- testCases := []struct {
- cc tcpip.CongestionControlOption
- err *tcpip.Error
- }{
- {"reno", nil},
- {"cubic", nil},
- {"blahblah", tcpip.ErrNoSuchFile},
- }
-
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("SetTransportProtocolOption(.., %v)", tc.cc), func(t *testing.T) {
- c := context.New(t, 1500)
- defer c.Cleanup()
-
- s := c.Stack()
-
- var oldCC tcpip.CongestionControlOption
- if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &oldCC, err)
- }
-
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err {
- t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.cc, err, tc.err)
- }
-
- var cc tcpip.CongestionControlOption
- if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err)
- }
-
- got, want := cc, oldCC
- // If SetTransportProtocolOption is expected to succeed
- // then the returned value for congestion control should
- // match the one specified in the
- // SetTransportProtocolOption call above, else it should
- // be what it was before the call to
- // SetTransportProtocolOption.
- if tc.err == nil {
- want = tc.cc
- }
- if got != want {
- t.Fatalf("got congestion control: %v, want: %v", got, want)
- }
- })
- }
-}
-
-func TestStackAvailableCongestionControl(t *testing.T) {
- c := context.New(t, 1500)
- defer c.Cleanup()
-
- s := c.Stack()
-
- // Query permitted congestion control algorithms.
- var aCC tcpip.AvailableCongestionControlOption
- if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err)
- }
- if got, want := aCC, tcpip.AvailableCongestionControlOption("reno cubic"); got != want {
- t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want)
- }
-}
-
-func TestStackSetAvailableCongestionControl(t *testing.T) {
- c := context.New(t, 1500)
- defer c.Cleanup()
-
- s := c.Stack()
-
- // Setting AvailableCongestionControlOption should fail.
- aCC := tcpip.AvailableCongestionControlOption("xyz")
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = nil, want non-nil", tcp.ProtocolNumber, &aCC)
- }
-
- // Verify that we still get the expected list of congestion control options.
- var cc tcpip.AvailableCongestionControlOption
- if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err)
- }
- if got, want := cc, tcpip.AvailableCongestionControlOption("reno cubic"); got != want {
- t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want)
- }
-}
-
-func TestEndpointSetCongestionControl(t *testing.T) {
- testCases := []struct {
- cc tcpip.CongestionControlOption
- err *tcpip.Error
- }{
- {"reno", nil},
- {"cubic", nil},
- {"blahblah", tcpip.ErrNoSuchFile},
- }
-
- for _, connected := range []bool{false, true} {
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("SetSockOpt(.., %v) w/ connected = %v", tc.cc, connected), func(t *testing.T) {
- c := context.New(t, 1500)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err *tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- var oldCC tcpip.CongestionControlOption
- if err := c.EP.GetSockOpt(&oldCC); err != nil {
- t.Fatalf("c.EP.SockOpt(%v) = %v", &oldCC, err)
- }
-
- if connected {
- c.Connect(789 /* iss */, 32768 /* rcvWnd */, nil)
- }
-
- if err := c.EP.SetSockOpt(tc.cc); err != tc.err {
- t.Fatalf("c.EP.SetSockOpt(%v) = %v, want %v", tc.cc, err, tc.err)
- }
-
- var cc tcpip.CongestionControlOption
- if err := c.EP.GetSockOpt(&cc); err != nil {
- t.Fatalf("c.EP.SockOpt(%v) = %v", &cc, err)
- }
-
- got, want := cc, oldCC
- // If SetSockOpt is expected to succeed then the
- // returned value for congestion control should match
- // the one specified in the SetSockOpt above, else it
- // should be what it was before the call to SetSockOpt.
- if tc.err == nil {
- want = tc.cc
- }
- if got != want {
- t.Fatalf("got congestion control: %v, want: %v", got, want)
- }
- })
- }
- }
-}
-
-func enableCUBIC(t *testing.T, c *context.Context) {
- t.Helper()
- opt := tcpip.CongestionControlOption("cubic")
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil {
- t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %v = %v", opt, err)
- }
-}
-
-func TestKeepalive(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- const keepAliveInterval = 10 * time.Millisecond
- c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond))
- c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval))
- c.EP.SetSockOpt(tcpip.KeepaliveCountOption(5))
- c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1))
-
- // 5 unacked keepalives are sent. ACK each one, and check that the
- // connection stays alive after 5.
- for i := 0; i < 10; i++ {
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)),
- checker.AckNum(uint32(790)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-
- // Acknowledge the keepalive.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: c.IRS,
- RcvWnd: 30000,
- })
- }
-
- // Check that the connection is still alive.
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
- }
-
- // Send some data and wait before ACKing it. Keepalives should be disabled
- // during this period.
- view := buffer.NewView(3)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- next := uint32(c.IRS) + 1
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
-
- // Wait for the packet to be retransmitted. Verify that no keepalives
- // were sent.
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh),
- ),
- )
- c.CheckNoPacket("Keepalive packet received while unACKed data is pending")
-
- next += uint32(len(view))
-
- // Send ACK. Keepalives should start sending again.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- // Now receive 5 keepalives, but don't ACK them. The connection
- // should be reset after 5.
- for i := 0; i < 5; i++ {
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(next-1)),
- checker.AckNum(uint32(790)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- }
-
- // Sleep for a litte over the KeepAlive interval to make sure
- // the timer has time to fire after the last ACK and close the
- // close the socket.
- time.Sleep(keepAliveInterval + 5*time.Millisecond)
-
- // The connection should be terminated after 5 unacked keepalives.
- // Send an ACK to trigger a RST from the stack as the endpoint should
- // be dead.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(next)),
- checker.AckNum(uint32(0)),
- checker.TCPFlags(header.TCPFlagRst),
- ),
- )
-
- if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %v, want = 1", got)
- }
-
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
- }
-
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
- }
-}
-
-func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
- // Send a SYN request.
- irs = seqnum.Value(789)
- c.SendPacket(nil, &context.Headers{
- SrcPort: srcPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcp := header.TCP(header.IPv4(b).Payload())
- iss = seqnum.Value(tcp.SequenceNumber())
- tcpCheckers := []checker.TransportChecker{
- checker.SrcPort(context.StackPort),
- checker.DstPort(srcPort),
- checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.AckNum(uint32(irs) + 1),
- }
-
- if synCookieInUse {
- // When cookies are in use window scaling is disabled.
- tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{
- WS: -1,
- MSS: c.MSSWithoutOptions(),
- }))
- }
-
- checker.IPv4(t, b, checker.TCP(tcpCheckers...))
-
- // Send ACK.
- c.SendPacket(nil, &context.Headers{
- SrcPort: srcPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
- return irs, iss
-}
-
-func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
- // Send a SYN request.
- irs = seqnum.Value(789)
- c.SendV6Packet(nil, &context.Headers{
- SrcPort: srcPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetV6Packet()
- tcp := header.TCP(header.IPv6(b).Payload())
- iss = seqnum.Value(tcp.SequenceNumber())
- tcpCheckers := []checker.TransportChecker{
- checker.SrcPort(context.StackPort),
- checker.DstPort(srcPort),
- checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.AckNum(uint32(irs) + 1),
- }
-
- if synCookieInUse {
- // When cookies are in use window scaling is disabled.
- tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{
- WS: -1,
- MSS: c.MSSWithoutOptionsV6(),
- }))
- }
-
- checker.IPv6(t, b, checker.TCP(tcpCheckers...))
-
- // Send ACK.
- c.SendV6Packet(nil, &context.Headers{
- SrcPort: srcPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
- return irs, iss
-}
-
-// TestListenBacklogFull tests that netstack does not complete handshakes if the
-// listen backlog for the endpoint is full.
-func TestListenBacklogFull(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err *tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test acceptance.
- // Start listening.
- listenBacklog := 2
- if err := c.EP.Listen(listenBacklog); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-
- for i := 0; i < listenBacklog; i++ {
- executeHandshake(t, c, context.TestPort+uint16(i), false /*synCookieInUse */)
- }
-
- time.Sleep(50 * time.Millisecond)
-
- // Now execute send one more SYN. The stack should not respond as the backlog
- // is full at this point.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort + 2,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: seqnum.Value(789),
- RcvWnd: 30000,
- })
- c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
-
- // Try to accept the connections in the backlog.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.EventIn)
- defer c.WQ.EventUnregister(&we)
-
- for i := 0; i < listenBacklog; i++ {
- _, _, err = c.EP.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
- select {
- case <-ch:
- _, _, err = c.EP.Accept()
- if err != nil {
- t.Fatalf("Accept failed: %v", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
- }
-
- // Now verify that there are no more connections that can be accepted.
- _, _, err = c.EP.Accept()
- if err != tcpip.ErrWouldBlock {
- select {
- case <-ch:
- t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP)
- case <-time.After(1 * time.Second):
- }
- }
-
- // Now a new handshake must succeed.
- executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */)
-
- newEP, _, err := c.EP.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
- select {
- case <-ch:
- newEP, _, err = c.EP.Accept()
- if err != nil {
- t.Fatalf("Accept failed: %v", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Now verify that the TCP socket is usable and in a connected state.
- data := "Don't panic"
- newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
- b := c.GetPacket()
- tcp := header.TCP(header.IPv4(b).Payload())
- if string(tcp.Payload()) != data {
- t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data)
- }
-}
-
-// TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a
-// non unicast IPv4 address are not accepted.
-func TestListenNoAcceptNonUnicastV4(t *testing.T) {
- multicastAddr := tcpip.Address("\xe0\x00\x01\x02")
- otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03")
-
- tests := []struct {
- name string
- srcAddr tcpip.Address
- dstAddr tcpip.Address
- }{
- {
- "SourceUnspecified",
- header.IPv4Any,
- context.StackAddr,
- },
- {
- "SourceBroadcast",
- header.IPv4Broadcast,
- context.StackAddr,
- },
- {
- "SourceOurMulticast",
- multicastAddr,
- context.StackAddr,
- },
- {
- "SourceOtherMulticast",
- otherMulticastAddr,
- context.StackAddr,
- },
- {
- "DestUnspecified",
- context.TestAddr,
- header.IPv4Any,
- },
- {
- "DestBroadcast",
- context.TestAddr,
- header.IPv4Broadcast,
- },
- {
- "DestOurMulticast",
- context.TestAddr,
- multicastAddr,
- },
- {
- "DestOtherMulticast",
- context.TestAddr,
- otherMulticastAddr,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.Create(-1)
-
- if err := c.Stack().JoinGroup(header.IPv4ProtocolNumber, 1, multicastAddr); err != nil {
- t.Fatalf("JoinGroup failed: %s", err)
- }
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := c.EP.Listen(1); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- irs := seqnum.Value(789)
- c.SendPacketWithAddrs(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- }, test.srcAddr, test.dstAddr)
- c.CheckNoPacket("Should not have received a response")
-
- // Handle normal packet.
- c.SendPacketWithAddrs(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- }, context.TestAddr, context.StackAddr)
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.AckNum(uint32(irs)+1)))
- })
- }
-}
-
-// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a
-// non unicast IPv6 address are not accepted.
-func TestListenNoAcceptNonUnicastV6(t *testing.T) {
- multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01")
- otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02")
-
- tests := []struct {
- name string
- srcAddr tcpip.Address
- dstAddr tcpip.Address
- }{
- {
- "SourceUnspecified",
- header.IPv6Any,
- context.StackV6Addr,
- },
- {
- "SourceAllNodes",
- header.IPv6AllNodesMulticastAddress,
- context.StackV6Addr,
- },
- {
- "SourceOurMulticast",
- multicastAddr,
- context.StackV6Addr,
- },
- {
- "SourceOtherMulticast",
- otherMulticastAddr,
- context.StackV6Addr,
- },
- {
- "DestUnspecified",
- context.TestV6Addr,
- header.IPv6Any,
- },
- {
- "DestAllNodes",
- context.TestV6Addr,
- header.IPv6AllNodesMulticastAddress,
- },
- {
- "DestOurMulticast",
- context.TestV6Addr,
- multicastAddr,
- },
- {
- "DestOtherMulticast",
- context.TestV6Addr,
- otherMulticastAddr,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateV6Endpoint(true)
-
- if err := c.Stack().JoinGroup(header.IPv6ProtocolNumber, 1, multicastAddr); err != nil {
- t.Fatalf("JoinGroup failed: %s", err)
- }
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := c.EP.Listen(1); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- irs := seqnum.Value(789)
- c.SendV6PacketWithAddrs(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- }, test.srcAddr, test.dstAddr)
- c.CheckNoPacket("Should not have received a response")
-
- // Handle normal packet.
- c.SendV6PacketWithAddrs(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- }, context.TestV6Addr, context.StackV6Addr)
- checker.IPv6(t, c.GetV6Packet(),
- checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.AckNum(uint32(irs)+1)))
- })
- }
-}
-
-func TestListenSynRcvdQueueFull(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err *tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test acceptance.
- // Start listening.
- listenBacklog := 1
- if err := c.EP.Listen(listenBacklog); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-
- // Send two SYN's the first one should get a SYN-ACK, the
- // second one should not get any response and is dropped as
- // the synRcvd count will be equal to backlog.
- irs := seqnum.Value(789)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcp := header.TCP(header.IPv4(b).Payload())
- iss := seqnum.Value(tcp.SequenceNumber())
- tcpCheckers := []checker.TransportChecker{
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.AckNum(uint32(irs) + 1),
- }
- checker.IPv4(t, b, checker.TCP(tcpCheckers...))
-
- // Now execute send one more SYN. The stack should not respond as the backlog
- // is full at this point.
- //
- // NOTE: we did not complete the handshake for the previous one so the
- // accept backlog should be empty and there should be one connection in
- // synRcvd state.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort + 1,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: seqnum.Value(889),
- RcvWnd: 30000,
- })
- c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
-
- // Now complete the previous connection and verify that there is a connection
- // to accept.
- // Send ACK.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
-
- // Try to accept the connections in the backlog.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.EventIn)
- defer c.WQ.EventUnregister(&we)
-
- newEP, _, err := c.EP.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
- select {
- case <-ch:
- newEP, _, err = c.EP.Accept()
- if err != nil {
- t.Fatalf("Accept failed: %v", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Now verify that the TCP socket is usable and in a connected state.
- data := "Don't panic"
- newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
- pkt := c.GetPacket()
- tcp = header.TCP(header.IPv4(pkt).Payload())
- if string(tcp.Payload()) != data {
- t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data)
- }
-}
-
-func TestListenBacklogFullSynCookieInUse(t *testing.T) {
- saved := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = saved
- }()
- tcp.SynRcvdCountThreshold = 1
-
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err *tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
-
- // Test acceptance.
- // Start listening.
- listenBacklog := 1
- portOffset := uint16(0)
- if err := c.EP.Listen(listenBacklog); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-
- executeHandshake(t, c, context.TestPort+portOffset, false)
- portOffset++
- // Wait for this to be delivered to the accept queue.
- time.Sleep(50 * time.Millisecond)
-
- // Send a SYN request.
- irs := seqnum.Value(789)
- c.SendPacket(nil, &context.Headers{
- // pick a different src port for new SYN.
- SrcPort: context.TestPort + 1,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
- // The Syn should be dropped as the endpoint's backlog is full.
- c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
-
- // Verify that there is only one acceptable connection at this point.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.EventIn)
- defer c.WQ.EventUnregister(&we)
-
- _, _, err = c.EP.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
- select {
- case <-ch:
- _, _, err = c.EP.Accept()
- if err != nil {
- t.Fatalf("Accept failed: %v", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Now verify that there are no more connections that can be accepted.
- _, _, err = c.EP.Accept()
- if err != tcpip.ErrWouldBlock {
- select {
- case <-ch:
- t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP)
- case <-time.After(1 * time.Second):
- }
- }
-}
-
-func TestSynRcvdBadSeqNumber(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Create TCP endpoint.
- var err *tcpip.Error
- c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
-
- // Bind to wildcard.
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- // Start listening.
- if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send a SYN to get a SYN-ACK. This should put the ep into SYN-RCVD state
- irs := seqnum.Value(789)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: irs,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- iss := seqnum.Value(tcpHdr.SequenceNumber())
- tcpCheckers := []checker.TransportChecker{
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.AckNum(uint32(irs) + 1),
- }
- checker.IPv4(t, b, checker.TCP(tcpCheckers...))
-
- // Now send a packet with an out-of-window sequence number
- largeSeqnum := irs + seqnum.Value(tcpHdr.WindowSize()) + 1
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: largeSeqnum,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
-
- // Should receive an ACK with the expected SEQ number
- b = c.GetPacket()
- tcpCheckers = []checker.TransportChecker{
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.AckNum(uint32(irs) + 1),
- checker.SeqNum(uint32(iss + 1)),
- }
- checker.IPv4(t, b, checker.TCP(tcpCheckers...))
-
- // Now that the socket replied appropriately with the ACK,
- // complete the connection to test that the large SEQ num
- // did not change the state from SYN-RCVD.
-
- // Send ACK to move to ESTABLISHED state.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
-
- newEP, _, err := c.EP.Accept()
-
- if err != nil && err != tcpip.ErrWouldBlock {
- t.Fatalf("Accept failed: %s", err)
- }
-
- if err == tcpip.ErrWouldBlock {
- // Try to accept the connections in the backlog.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.EventIn)
- defer c.WQ.EventUnregister(&we)
-
- // Wait for connection to be established.
- select {
- case <-ch:
- newEP, _, err = c.EP.Accept()
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Now verify that the TCP socket is usable and in a connected state.
- data := "Don't panic"
- _, _, err = newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
-
- if err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- pkt := c.GetPacket()
- tcpHdr = header.TCP(header.IPv4(pkt).Payload())
- if string(tcpHdr.Payload()) != data {
- t.Fatalf("Unexpected data: got %s, want %s", string(tcpHdr.Payload()), data)
- }
-}
-
-func TestPassiveConnectionAttemptIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
- c.EP = ep
- if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
- if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
- if err := c.EP.Listen(1); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- stats := c.Stack().Stats()
- want := stats.TCP.PassiveConnectionOpenings.Value() + 1
-
- srcPort := uint16(context.TestPort)
- executeHandshake(t, c, srcPort+1, false)
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.EventIn)
- defer c.WQ.EventUnregister(&we)
-
- // Verify that there is only one acceptable connection at this point.
- _, _, err = c.EP.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
- select {
- case <-ch:
- _, _, err = c.EP.Accept()
- if err != nil {
- t.Fatalf("Accept failed: %v", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want {
- t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %v, want = %v", got, want)
- }
-}
-
-func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- stats := c.Stack().Stats()
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
- c.EP = ep
- if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
- if err := c.EP.Listen(1); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
-
- srcPort := uint16(context.TestPort)
- // Now attempt a handshakes it will fill up the accept backlog.
- executeHandshake(t, c, srcPort, false)
-
- // Give time for the final ACK to be processed as otherwise the next handshake could
- // get accepted before the previous one based on goroutine scheduling.
- time.Sleep(50 * time.Millisecond)
-
- want := stats.TCP.ListenOverflowSynDrop.Value() + 1
-
- // Now we will send one more SYN and this one should get dropped
- // Send a SYN request.
- c.SendPacket(nil, &context.Headers{
- SrcPort: srcPort + 2,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: seqnum.Value(789),
- RcvWnd: 30000,
- })
-
- time.Sleep(50 * time.Millisecond)
- if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want {
- t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %v, want = %v", got, want)
- }
- if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want {
- t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %v, want = %v", got, want)
- }
-
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.EventIn)
- defer c.WQ.EventUnregister(&we)
-
- // Now check that there is one acceptable connections.
- _, _, err = c.EP.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
- select {
- case <-ch:
- _, _, err = c.EP.Accept()
- if err != nil {
- t.Fatalf("Accept failed: %v", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-}
-
-func TestEndpointBindListenAcceptState(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
- }
- if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- if _, _, err := ep.Read(nil); err != tcpip.ErrNotConnected {
- t.Errorf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrNotConnected)
- }
- if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 {
- t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %v want %v", got, 1)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
- }
- if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS})
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
-
- aep, _, err := ep.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
- select {
- case <-ch:
- aep, _, err = ep.Accept()
- if err != nil {
- t.Fatalf("Accept failed: %v", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
- if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
- if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected {
- t.Errorf("Unexpected error attempting to call connect on an established endpoint, got: %v, want: %v", err, tcpip.ErrAlreadyConnected)
- }
- // Listening endpoint remains in listen state.
- if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- ep.Close()
- // Give worker goroutines time to receive the close notification.
- time.Sleep(1 * time.Second)
- if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
- // Accepted endpoint remains open when the listen endpoint is closed.
- if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
-}
-
-// This test verifies that the auto tuning does not grow the receive buffer if
-// the application is not reading the data actively.
-func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
- const mtu = 1500
- const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
-
- c := context.New(t, mtu)
- defer c.Cleanup()
-
- stk := c.Stack()
- // Set lower limits for auto-tuning tests. This is required because the
- // test stops the worker which can cause packets to be dropped because
- // the segment queue holding unprocessed packets is limited to 500.
- const receiveBufferSize = 80 << 10 // 80KB.
- const maxReceiveBufferSize = receiveBufferSize * 10
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
- }
-
- // Enable auto-tuning.
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
- }
- // Change the expected window scale to match the value needed for the
- // maximum buffer size defined above.
- c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
-
- rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
-
- // NOTE: The timestamp values in the sent packets are meaningless to the
- // peer so we just increment the timestamp value by 1 every batch as we
- // are not really using them for anything. Send a single byte to verify
- // the advertised window.
- tsVal := rawEP.TSVal + 1
-
- // Introduce a 25ms latency by delaying the first byte.
- latency := 25 * time.Millisecond
- time.Sleep(latency)
- rawEP.SendPacketWithTS([]byte{1}, tsVal)
-
- // Verify that the ACK has the expected window.
- wantRcvWnd := receiveBufferSize
- wantRcvWnd = (wantRcvWnd >> uint32(c.WindowScale))
- rawEP.VerifyACKRcvWnd(uint16(wantRcvWnd - 1))
- time.Sleep(25 * time.Millisecond)
-
- // Allocate a large enough payload for the test.
- b := make([]byte, int(receiveBufferSize)*2)
- offset := 0
- payloadSize := receiveBufferSize - 1
- worker := (c.EP).(interface {
- StopWork()
- ResumeWork()
- })
- tsVal++
-
- // Stop the worker goroutine.
- worker.StopWork()
- start := offset
- end := offset + payloadSize
- packetsSent := 0
- for ; start < end; start += mss {
- rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
- packetsSent++
- }
-
- // Resume the worker so that it only sees the packets once all of them
- // are waiting to be read.
- worker.ResumeWork()
-
- // Since we read no bytes the window should goto zero till the
- // application reads some of the data.
- // Discard all intermediate acks except the last one.
- if packetsSent > 100 {
- for i := 0; i < (packetsSent / 100); i++ {
- _ = c.GetPacket()
- }
- }
- rawEP.VerifyACKRcvWnd(0)
-
- time.Sleep(25 * time.Millisecond)
- // Verify that sending more data when window is closed is dropped and
- // not acked.
- rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
-
- // Verify that the stack sends us back an ACK with the sequence number
- // of the last packet sent indicating it was dropped.
- p := c.GetPacket()
- checker.IPv4(t, p, checker.TCP(
- checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)),
- checker.Window(0),
- ))
-
- // Now read all the data from the endpoint and verify that advertised
- // window increases to the full available buffer size.
- for {
- _, _, err := c.EP.Read(nil)
- if err == tcpip.ErrWouldBlock {
- break
- }
- }
-
- // Verify that we receive a non-zero window update ACK. When running
- // under thread santizer this test can end up sending more than 1
- // ack, 1 for the non-zero window
- p = c.GetPacket()
- checker.IPv4(t, p, checker.TCP(
- checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)),
- func(t *testing.T, h header.Transport) {
- tcp, ok := h.(header.TCP)
- if !ok {
- return
- }
- if w := tcp.WindowSize(); w == 0 || w > uint16(wantRcvWnd) {
- t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w, wantRcvWnd)
- }
- },
- ))
-}
-
-// This test verifies that the auto tuning does not grow the receive buffer if
-// the application is not reading the data actively.
-func TestReceiveBufferAutoTuning(t *testing.T) {
- const mtu = 1500
- const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
-
- c := context.New(t, mtu)
- defer c.Cleanup()
-
- // Enable Auto-tuning.
- stk := c.Stack()
- // Set lower limits for auto-tuning tests. This is required because the
- // test stops the worker which can cause packets to be dropped because
- // the segment queue holding unprocessed packets is limited to 300.
- const receiveBufferSize = 80 << 10 // 80KB.
- const maxReceiveBufferSize = receiveBufferSize * 10
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
- }
-
- // Enable auto-tuning.
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
- }
- // Change the expected window scale to match the value needed for the
- // maximum buffer size used by stack.
- c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
-
- rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
-
- wantRcvWnd := receiveBufferSize
- scaleRcvWnd := func(rcvWnd int) uint16 {
- return uint16(rcvWnd >> uint16(c.WindowScale))
- }
- // Allocate a large array to send to the endpoint.
- b := make([]byte, receiveBufferSize*48)
-
- // In every iteration we will send double the number of bytes sent in
- // the previous iteration and read the same from the app. The received
- // window should grow by at least 2x of bytes read by the app in every
- // RTT.
- offset := 0
- payloadSize := receiveBufferSize / 8
- worker := (c.EP).(interface {
- StopWork()
- ResumeWork()
- })
- tsVal := rawEP.TSVal
- // We are going to do our own computation of what the moderated receive
- // buffer should be based on sent/copied data per RTT and verify that
- // the advertised window by the stack matches our calculations.
- prevCopied := 0
- done := false
- latency := 1 * time.Millisecond
- for i := 0; !done; i++ {
- tsVal++
-
- // Stop the worker goroutine.
- worker.StopWork()
- start := offset
- end := offset + payloadSize
- totalSent := 0
- packetsSent := 0
- for ; start < end; start += mss {
- rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
- totalSent += mss
- packetsSent++
- }
-
- // Resume it so that it only sees the packets once all of them
- // are waiting to be read.
- worker.ResumeWork()
-
- // Give 1ms for the worker to process the packets.
- time.Sleep(1 * time.Millisecond)
-
- // Verify that the advertised window on the ACK is reduced by
- // the total bytes sent.
- expectedWnd := wantRcvWnd - totalSent
- if packetsSent > 100 {
- for i := 0; i < (packetsSent / 100); i++ {
- _ = c.GetPacket()
- }
- }
- rawEP.VerifyACKRcvWnd(scaleRcvWnd(expectedWnd))
-
- // Now read all the data from the endpoint and invoke the
- // moderation API to allow for receive buffer auto-tuning
- // to happen before we measure the new window.
- totalCopied := 0
- for {
- b, _, err := c.EP.Read(nil)
- if err == tcpip.ErrWouldBlock {
- break
- }
- totalCopied += len(b)
- }
-
- // Invoke the moderation API. This is required for auto-tuning
- // to happen. This method is normally expected to be invoked
- // from a higher layer than tcpip.Endpoint. So we simulate
- // copying to user-space by invoking it explicitly here.
- c.EP.ModerateRecvBuf(totalCopied)
-
- // Now send a keep-alive packet to trigger an ACK so that we can
- // measure the new window.
- rawEP.NextSeqNum--
- rawEP.SendPacketWithTS(nil, tsVal)
- rawEP.NextSeqNum++
-
- if i == 0 {
- // In the first iteration the receiver based RTT is not
- // yet known as a result the moderation code should not
- // increase the advertised window.
- rawEP.VerifyACKRcvWnd(scaleRcvWnd(wantRcvWnd))
- prevCopied = totalCopied
- } else {
- rttCopied := totalCopied
- if i == 1 {
- // The moderation code accumulates copied bytes till
- // RTT is established. So add in the bytes sent in
- // the first iteration to the total bytes for this
- // RTT.
- rttCopied += prevCopied
- // Now reset it to the initial value used by the
- // auto tuning logic.
- prevCopied = tcp.InitialCwnd * mss * 2
- }
- newWnd := rttCopied<<1 + 16*mss
- grow := (newWnd * (rttCopied - prevCopied)) / prevCopied
- newWnd += (grow << 1)
- if newWnd > maxReceiveBufferSize {
- newWnd = maxReceiveBufferSize
- done = true
- }
- rawEP.VerifyACKRcvWnd(scaleRcvWnd(newWnd))
- wantRcvWnd = newWnd
- prevCopied = rttCopied
- // Increase the latency after first two iterations to
- // establish a low RTT value in the receiver since it
- // only tracks the lowest value. This ensures that when
- // ModerateRcvBuf is called the elapsed time is always >
- // rtt. Without this the test is flaky due to delays due
- // to scheduling/wakeup etc.
- latency += 50 * time.Millisecond
- }
- time.Sleep(latency)
- offset += payloadSize
- payloadSize *= 2
- }
-}
-
-func TestDelayEnabled(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- checkDelayOption(t, c, false, 0) // Delay is disabled by default.
-
- for _, v := range []struct {
- delayEnabled tcp.DelayEnabled
- wantDelayOption int
- }{
- {delayEnabled: false, wantDelayOption: 0},
- {delayEnabled: true, wantDelayOption: 1},
- } {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil {
- t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %v", v.delayEnabled, err)
- }
- checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption)
- }
-}
-
-func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption int) {
- t.Helper()
-
- var gotDelayEnabled tcp.DelayEnabled
- if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil {
- t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %v", err)
- }
- if gotDelayEnabled != wantDelayEnabled {
- t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled)
- }
-
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue))
- if err != nil {
- t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %v", err)
- }
- gotDelayOption, err := ep.GetSockOptInt(tcpip.DelayOption)
- if err != nil {
- t.Fatalf("ep.GetSockOptInt(tcpip.DelayOption) failed: %v", err)
- }
- if gotDelayOption != wantDelayOption {
- t.Errorf("ep.GetSockOptInt(tcpip.DelayOption) got: %d, want: %d", gotDelayOption, wantDelayOption)
- }
-}
-
-func TestTCPLingerTimeout(t *testing.T) {
- c := context.New(t, 1500 /* mtu */)
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- testCases := []struct {
- name string
- tcpLingerTimeout time.Duration
- want time.Duration
- }{
- {"NegativeLingerTimeout", -123123, 0},
- {"ZeroLingerTimeout", 0, 0},
- {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second},
- // Values > stack's TCPLingerTimeout are capped to the stack's
- // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds)
- {"AboveMaxLingerTimeout", 65 * time.Second, 60 * time.Second},
- }
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- if err := c.EP.SetSockOpt(tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)); err != nil {
- t.Fatalf("SetSockOpt(%s) = %s", tc.tcpLingerTimeout, err)
- }
- var v tcpip.TCPLingerTimeoutOption
- if err := c.EP.GetSockOpt(&v); err != nil {
- t.Fatalf("GetSockOpt(tcpip.TCPLingerTimeoutOption) = %s", err)
- }
- if got, want := time.Duration(v), tc.want; got != want {
- t.Fatalf("unexpected linger timeout got: %s, want: %s", got, want)
- }
- })
- }
-}
-
-func TestTCPTimeWaitRSTIgnored(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send a SYN request.
- iss := seqnum.Value(789)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- ackHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 1,
- }
-
- // Send ACK.
- c.SendPacket(nil, ackHeaders)
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept()
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- c.EP.Close()
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
-
- finHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss + 1,
- AckNum: c.IRS + 2,
- }
-
- c.SendPacket(nil, finHeaders)
-
- // Get the ACK to the FIN we just sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
- checker.TCPFlags(header.TCPFlagAck)))
-
- // Now send a RST and this should be ignored and not
- // generate an ACK.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagRst,
- SeqNum: iss + 1,
- AckNum: c.IRS + 2,
- })
-
- c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second)
-
- // Out of order ACK should generate an immediate ACK in
- // TIME_WAIT.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 3,
- })
-
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
- checker.TCPFlags(header.TCPFlagAck)))
-}
-
-func TestTCPTimeWaitOutOfOrder(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send a SYN request.
- iss := seqnum.Value(789)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- ackHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 1,
- }
-
- // Send ACK.
- c.SendPacket(nil, ackHeaders)
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept()
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- c.EP.Close()
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
-
- finHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss + 1,
- AckNum: c.IRS + 2,
- }
-
- c.SendPacket(nil, finHeaders)
-
- // Get the ACK to the FIN we just sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
- checker.TCPFlags(header.TCPFlagAck)))
-
- // Out of order ACK should generate an immediate ACK in
- // TIME_WAIT.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 3,
- })
-
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
- checker.TCPFlags(header.TCPFlagAck)))
-}
-
-func TestTCPTimeWaitNewSyn(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send a SYN request.
- iss := seqnum.Value(789)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- ackHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 1,
- }
-
- // Send ACK.
- c.SendPacket(nil, ackHeaders)
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept()
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- c.EP.Close()
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
-
- finHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss + 1,
- AckNum: c.IRS + 2,
- }
-
- c.SendPacket(nil, finHeaders)
-
- // Get the ACK to the FIN we just sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
- checker.TCPFlags(header.TCPFlagAck)))
-
- // Send a SYN request w/ sequence number lower than
- // the highest sequence number sent. We just reuse
- // the same number.
- iss = seqnum.Value(789)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- RcvWnd: 30000,
- })
-
- c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second)
-
- // Send a SYN request w/ sequence number higher than
- // the highest sequence number sent.
- iss = seqnum.Value(792)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b = c.GetPacket()
- tcpHdr = header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- ackHeaders = &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 1,
- }
-
- // Send ACK.
- c.SendPacket(nil, ackHeaders)
-
- // Try to accept the connection.
- c.EP, _, err = ep.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept()
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-}
-
-func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
- // after 5 seconds in TIME_WAIT state.
- tcpTimeWaitTimeout := 5 * time.Second
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
- t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err)
- }
-
- want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1
-
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send a SYN request.
- iss := seqnum.Value(789)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- ackHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 1,
- }
-
- // Send ACK.
- c.SendPacket(nil, ackHeaders)
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept()
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- c.EP.Close()
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
-
- finHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss + 1,
- AckNum: c.IRS + 2,
- }
-
- c.SendPacket(nil, finHeaders)
-
- // Get the ACK to the FIN we just sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
- checker.TCPFlags(header.TCPFlagAck)))
-
- time.Sleep(2 * time.Second)
-
- // Now send a duplicate FIN. This should cause the TIME_WAIT to extend
- // by another 5 seconds and also send us a duplicate ACK as it should
- // indicate that the final ACK was potentially lost.
- c.SendPacket(nil, finHeaders)
-
- // Get the ACK to the FIN we just sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+2)),
- checker.AckNum(uint32(iss)+2),
- checker.TCPFlags(header.TCPFlagAck)))
-
- // Sleep for 4 seconds so at this point we are 1 second past the
- // original tcpLingerTimeout of 5 seconds.
- time.Sleep(4 * time.Second)
-
- // Send an ACK and it should not generate any packet as the socket
- // should still be in TIME_WAIT for another another 5 seconds due
- // to the duplicate FIN we sent earlier.
- *ackHeaders = *finHeaders
- ackHeaders.SeqNum = ackHeaders.SeqNum + 1
- ackHeaders.Flags = header.TCPFlagAck
- c.SendPacket(nil, ackHeaders)
-
- c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second)
- // Now sleep for another 2 seconds so that we are past the
- // extended TIME_WAIT of 7 seconds (2 + 5).
- time.Sleep(2 * time.Second)
-
- // Resend the same ACK.
- c.SendPacket(nil, ackHeaders)
-
- // Receive the RST that should be generated as there is no valid
- // endpoint.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(ackHeaders.AckNum)),
- checker.AckNum(0),
- checker.TCPFlags(header.TCPFlagRst)))
-
- if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = %v", got, want)
- }
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
- }
-}
-
-func TestTCPCloseWithData(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
- // after 5 seconds in TIME_WAIT state.
- tcpTimeWaitTimeout := 5 * time.Second
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
- t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err)
- }
-
- wq := &waiter.Queue{}
- ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %s", err)
- }
- if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %s", err)
- }
-
- // Send a SYN request.
- iss := seqnum.Value(789)
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- RcvWnd: 30000,
- })
-
- // Receive the SYN-ACK reply.
- b := c.GetPacket()
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- ackHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 1,
- RcvWnd: 30000,
- }
-
- // Send ACK.
- c.SendPacket(nil, ackHeaders)
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept()
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Now trigger a passive close by sending a FIN.
- finHeaders := &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- SeqNum: iss + 1,
- AckNum: c.IRS + 2,
- RcvWnd: 30000,
- }
-
- c.SendPacket(nil, finHeaders)
-
- // Get the ACK to the FIN we just sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+2),
- checker.TCPFlags(header.TCPFlagAck)))
-
- // Now write a few bytes and then close the endpoint.
- data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %s", err)
- }
-
- // Check that data is received.
- b = c.GetPacket()
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
-
- if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
- t.Errorf("got data = %x, want = %x", p, data)
- }
-
- c.EP.Close()
- // Check the FIN.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))),
- checker.AckNum(uint32(iss+2)),
- checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
-
- // First send a partial ACK.
- ackHeaders = &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 2,
- AckNum: c.IRS + 1 + seqnum.Value(len(data)-1),
- RcvWnd: 30000,
- }
- c.SendPacket(nil, ackHeaders)
-
- // Now send a full ACK.
- ackHeaders = &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 2,
- AckNum: c.IRS + 1 + seqnum.Value(len(data)),
- RcvWnd: 30000,
- }
- c.SendPacket(nil, ackHeaders)
-
- // Now ACK the FIN.
- ackHeaders.AckNum++
- c.SendPacket(nil, ackHeaders)
-
- // Now send an ACK and we should get a RST back as the endpoint should
- // be in CLOSED state.
- ackHeaders = &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 2,
- AckNum: c.IRS + 1 + seqnum.Value(len(data)),
- RcvWnd: 30000,
- }
- c.SendPacket(nil, ackHeaders)
-
- // Check the RST.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(ackHeaders.AckNum)),
- checker.AckNum(0),
- checker.TCPFlags(header.TCPFlagRst)))
-}
-
-func TestTCPUserTimeout(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value()
-
- userTimeout := 50 * time.Millisecond
- c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout))
-
- // Send some data and wait before ACKing it.
- view := buffer.NewView(3)
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
- }
-
- next := uint32(c.IRS) + 1
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(len(view)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(next),
- checker.AckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
-
- // Wait for a little over the minimum retransmit timeout of 200ms for
- // the retransmitTimer to fire and close the connection.
- time.Sleep(tcp.MinRTO + 10*time.Millisecond)
-
- // No packet should be received as the connection should be silently
- // closed due to timeout.
- c.CheckNoPacket("unexpected packet received after userTimeout has expired")
-
- next += uint32(len(view))
-
- // The connection should be terminated after userTimeout has expired.
- // Send an ACK to trigger a RST from the stack as the endpoint should
- // be dead.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: seqnum.Value(next),
- RcvWnd: 30000,
- })
-
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(next)),
- checker.AckNum(uint32(0)),
- checker.TCPFlags(header.TCPFlagRst),
- ),
- )
-
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
- }
-
- if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want)
- }
-}
-
-func TestKeepaliveWithUserTimeout(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
-
- origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value()
-
- const keepAliveInterval = 10 * time.Millisecond
- c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond))
- c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval))
- c.EP.SetSockOpt(tcpip.KeepaliveCountOption(10))
- c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1))
-
- // Set userTimeout to be the duration for 3 keepalive probes.
- userTimeout := 30 * time.Millisecond
- c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout))
-
- // Check that the connection is still alive.
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
- }
-
- // Now receive 2 keepalives, but don't ACK them. The connection should
- // be reset when the 3rd one should be sent due to userTimeout being
- // 30ms and each keepalive probe should be sent 10ms apart as set above after
- // the connection has been idle for 10ms.
- for i := 0; i < 2; i++ {
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)),
- checker.AckNum(uint32(790)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- }
-
- // Sleep for a litte over the KeepAlive interval to make sure
- // the timer has time to fire after the last ACK and close the
- // close the socket.
- time.Sleep(keepAliveInterval + 5*time.Millisecond)
-
- // The connection should be terminated after 30ms.
- // Send an ACK to trigger a RST from the stack as the endpoint should
- // be dead.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: 790,
- AckNum: seqnum.Value(c.IRS + 1),
- RcvWnd: 30000,
- })
-
- checker.IPv4(t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(0)),
- checker.TCPFlags(header.TCPFlagRst),
- ),
- )
-
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
- }
- if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
- t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want)
- }
-}
-
-func TestIncreaseWindowOnReceive(t *testing.T) {
- // This test ensures that the endpoint sends an ack,
- // after recv() when the window grows to more than 1 MSS.
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- const rcvBuf = 65535 * 10
- c.CreateConnected(789, 30000, rcvBuf)
-
- // Write chunks of ~30000 bytes. It's important that two
- // payloads make it equal or longer than MSS.
- remain := rcvBuf
- sent := 0
- data := make([]byte, defaultMTU/2)
- lastWnd := uint16(0)
-
- for remain > len(data) {
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790 + sent),
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
- sent += len(data)
- remain -= len(data)
-
- lastWnd = uint16(remain)
- if remain > 0xffff {
- lastWnd = 0xffff
- }
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(lastWnd),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- }
-
- if lastWnd == 0xffff || lastWnd == 0 {
- t.Fatalf("expected small, non-zero window: %d", lastWnd)
- }
-
- // We now have < 1 MSS in the buffer space. Read the data! An
- // ack should be sent in response to that. The window was not
- // zero, but it grew to larger than MSS.
- if _, _, err := c.EP.Read(nil); err != nil {
- t.Fatalf("Read failed: %v", err)
- }
-
- if _, _, err := c.EP.Read(nil); err != nil {
- t.Fatalf("Read failed: %v", err)
- }
-
- // After reading two packets, we surely crossed MSS. See the ack:
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(uint16(0xffff)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestIncreaseWindowOnBufferResize(t *testing.T) {
- // This test ensures that the endpoint sends an ack,
- // after available recv buffer grows to more than 1 MSS.
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- const rcvBuf = 65535 * 10
- c.CreateConnected(789, 30000, rcvBuf)
-
- // Write chunks of ~30000 bytes. It's important that two
- // payloads make it equal or longer than MSS.
- remain := rcvBuf
- sent := 0
- data := make([]byte, defaultMTU/2)
- lastWnd := uint16(0)
-
- for remain > len(data) {
- c.SendPacket(data, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: seqnum.Value(790 + sent),
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- })
- sent += len(data)
- remain -= len(data)
-
- lastWnd = uint16(remain)
- if remain > 0xffff {
- lastWnd = 0xffff
- }
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(lastWnd),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
- }
-
- if lastWnd == 0xffff || lastWnd == 0 {
- t.Fatalf("expected small, non-zero window: %d", lastWnd)
- }
-
- // Increasing the buffer from should generate an ACK,
- // since window grew from small value to larger equal MSS
- c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2)
-
- // After reading two packets, we surely crossed MSS. See the ack:
- checker.IPv4(t, c.GetPacket(),
- checker.PayloadLen(header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(790+sent)),
- checker.Window(uint16(0xffff)),
- checker.TCPFlags(header.TCPFlagAck),
- ),
- )
-}
-
-func TestTCPDeferAccept(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.Create(-1)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(10); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- const tcpDeferAccept = 1 * time.Second
- if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil {
- t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err)
- }
-
- irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
-
- if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock)
- }
-
- // Send data. This should result in an acceptable endpoint.
- c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- })
-
- // Receive ACK for the data we sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
-
- // Give a bit of time for the socket to be delivered to the accept queue.
- time.Sleep(50 * time.Millisecond)
- aep, _, err := c.EP.Accept()
- if err != nil {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err)
- }
-
- aep.Close()
- // Closing aep without reading the data should trigger a RST.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
-}
-
-func TestTCPDeferAcceptTimeout(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.Create(-1)
-
- if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatal("Bind failed:", err)
- }
-
- if err := c.EP.Listen(10); err != nil {
- t.Fatal("Listen failed:", err)
- }
-
- const tcpDeferAccept = 1 * time.Second
- if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil {
- t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err)
- }
-
- irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
-
- if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock)
- }
-
- // Sleep for a little of the tcpDeferAccept timeout.
- time.Sleep(tcpDeferAccept + 100*time.Millisecond)
-
- // On timeout expiry we should get a SYN-ACK retransmission.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
- checker.AckNum(uint32(irs)+1)))
-
- // Send data. This should result in an acceptable endpoint.
- c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- })
-
- // Receive ACK for the data we sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
-
- // Give sometime for the endpoint to be delivered to the accept queue.
- time.Sleep(50 * time.Millisecond)
- aep, _, err := c.EP.Accept()
- if err != nil {
- t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err)
- }
-
- aep.Close()
- // Closing aep without reading the data should trigger a RST.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.SrcPort(context.StackPort),
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
- checker.SeqNum(uint32(iss+1)),
- checker.AckNum(uint32(irs+5))))
-}
-
-func TestResetDuringClose(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- iss := seqnum.Value(789)
- c.CreateConnected(iss, 30000, -1 /* epRecvBuf */)
- // Send some data to make sure there is some unread
- // data to trigger a reset on c.Close.
- irs := c.IRS
- c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: iss.Add(1),
- AckNum: irs.Add(1),
- RcvWnd: 30000,
- })
-
- // Receive ACK for the data we sent.
- checker.IPv4(t, c.GetPacket(), checker.TCP(
- checker.DstPort(context.TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(irs.Add(1))),
- checker.AckNum(uint32(iss.Add(5)))))
-
- // Close in a separate goroutine so that we can trigger
- // a race with the RST we send below. This should not
- // panic due to the route being released depeding on
- // whether Close() sends an active RST or the RST sent
- // below is processed by the worker first.
- var wg sync.WaitGroup
-
- wg.Add(1)
- go func() {
- defer wg.Done()
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
- DstPort: c.Port,
- SeqNum: iss.Add(5),
- AckNum: c.IRS.Add(5),
- RcvWnd: 30000,
- Flags: header.TCPFlagRst,
- })
- }()
-
- wg.Add(1)
- go func() {
- defer wg.Done()
- c.EP.Close()
- }()
-
- wg.Wait()
-}
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
deleted file mode 100644
index a641e953d..000000000
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ /dev/null
@@ -1,295 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcp_test
-
-import (
- "bytes"
- "math/rand"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-// createConnectedWithTimestampOption creates and connects c.ep with the
-// timestamp option enabled.
-func createConnectedWithTimestampOption(c *context.Context) *context.RawEndpoint {
- return c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, TSVal: 1})
-}
-
-// TestTimeStampEnabledConnect tests that netstack sends the timestamp option on
-// an active connect and sets the TS Echo Reply fields correctly when the
-// SYN-ACK also indicates support for the TS option and provides a TSVal.
-func TestTimeStampEnabledConnect(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- rep := createConnectedWithTimestampOption(c)
-
- // Register for read and validate that we have data to read.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.EventIn)
- defer c.WQ.EventUnregister(&we)
-
- // The following tests ensure that TS option once enabled behaves
- // correctly as described in
- // https://tools.ietf.org/html/rfc7323#section-4.3.
- //
- // We are not testing delayed ACKs here, but we do test out of order
- // packet delivery and filling the sequence number hole created due to
- // the out of order packet.
- //
- // The test also verifies that the sequence numbers and timestamps are
- // as expected.
- data := []byte{1, 2, 3}
-
- // First we increment tsVal by a small amount.
- tsVal := rep.TSVal + 100
- rep.SendPacketWithTS(data, tsVal)
- rep.VerifyACKWithTS(tsVal)
-
- // Next we send an out of order packet.
- rep.NextSeqNum += 3
- tsVal += 200
- rep.SendPacketWithTS(data, tsVal)
-
- // The ACK should contain the original sequenceNumber and an older TS.
- rep.NextSeqNum -= 6
- rep.VerifyACKWithTS(tsVal - 200)
-
- // Next we fill the hole and the returned ACK should contain the
- // cumulative sequence number acking all data sent till now and have the
- // latest timestamp sent below in its TSEcr field.
- tsVal -= 100
- rep.SendPacketWithTS(data, tsVal)
- rep.NextSeqNum += 3
- rep.VerifyACKWithTS(tsVal)
-
- // Increment tsVal by a large value that doesn't result in a wrap around.
- tsVal += 0x7fffffff
- rep.SendPacketWithTS(data, tsVal)
- rep.VerifyACKWithTS(tsVal)
-
- // Increment tsVal again by a large value which should cause the
- // timestamp value to wrap around. The returned ACK should contain the
- // wrapped around timestamp in its tsEcr field and not the tsVal from
- // the previous packet sent above.
- tsVal += 0x7fffffff
- rep.SendPacketWithTS(data, tsVal)
- rep.VerifyACKWithTS(tsVal)
-
- select {
- case <-ch:
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
- // There should be 5 views to read and each of them should
- // contain the same data.
- for i := 0; i < 5; i++ {
- got, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Unexpected error from Read: %v", err)
- }
- if want := data; bytes.Compare(got, want) != 0 {
- t.Fatalf("Data is different: got: %v, want: %v", got, want)
- }
- }
-}
-
-// TestTimeStampDisabledConnect tests that netstack sends timestamp option on an
-// active connect but if the SYN-ACK doesn't specify the TS option then
-// timestamp option is not enabled and future packets do not contain a
-// timestamp.
-func TestTimeStampDisabledConnect(t *testing.T) {
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- c.CreateConnectedWithOptions(header.TCPSynOptions{})
-}
-
-func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) {
- savedSynCountThreshold := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = savedSynCountThreshold
- }()
-
- if cookieEnabled {
- tcp.SynRcvdCountThreshold = 0
- }
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
- tsVal := rand.Uint32()
- c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal})
-
- // Now send some data and validate that timestamp is echoed correctly in the ACK.
- data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
- }
-
- // Check that data is received and that the timestamp option TSEcr field
- // matches the expected value.
- b := c.GetPacket()
- checker.IPv4(t, b,
- // Add 12 bytes for the timestamp option + 2 NOPs to align at 4
- // byte boundary.
- checker.PayloadLen(len(data)+header.TCPMinimumSize+12),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(wndSize),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- checker.TCPTimestampChecker(true, 0, tsVal+1),
- ),
- )
-}
-
-// TestTimeStampEnabledAccept tests that if the SYN on a passive connect
-// specifies the Timestamp option then the Timestamp option is sent on a SYN-ACK
-// and echoes the tsVal field of the original SYN in the tcEcr field of the
-// SYN-ACK. We cover the cases where SYN cookies are enabled/disabled and verify
-// that Timestamp option is enabled in both cases if requested in the original
-// SYN.
-func TestTimeStampEnabledAccept(t *testing.T) {
- testCases := []struct {
- cookieEnabled bool
- wndScale int
- wndSize uint16
- }{
- {true, -1, 0xffff}, // When cookie is used window scaling is disabled.
- {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5.
- }
- for _, tc := range testCases {
- timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize)
- }
-}
-
-func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) {
- savedSynCountThreshold := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = savedSynCountThreshold
- }()
- if cookieEnabled {
- tcp.SynRcvdCountThreshold = 0
- }
-
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
-
- t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
- c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
-
- // Now send some data with the accepted connection endpoint and validate
- // that no timestamp option is sent in the TCP segment.
- data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
- }
-
- // Check that data is received and that the timestamp option is disabled
- // when SYN cookies are enabled/disabled.
- b := c.GetPacket()
- checker.IPv4(t, b,
- checker.PayloadLen(len(data)+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.Window(wndSize),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- checker.TCPTimestampChecker(false, 0, 0),
- ),
- )
-}
-
-// TestTimeStampDisabledAccept tests that Timestamp option is not used when the
-// peer doesn't advertise it and connection is established with Accept().
-func TestTimeStampDisabledAccept(t *testing.T) {
- testCases := []struct {
- cookieEnabled bool
- wndScale int
- wndSize uint16
- }{
- {true, -1, 0xffff}, // When cookie is used window scaling is disabled.
- {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5.
- }
- for _, tc := range testCases {
- timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize)
- }
-}
-
-func TestSendGreaterThanMTUWithOptions(t *testing.T) {
- const maxPayload = 100
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- createConnectedWithTimestampOption(c)
- testBrokenUpWrite(t, c, maxPayload)
-}
-
-func TestSegmentNotDroppedWhenTimestampMissing(t *testing.T) {
- const maxPayload = 100
- c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
- defer c.Cleanup()
-
- rep := createConnectedWithTimestampOption(c)
-
- // Register for read.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.EventIn)
- defer c.WQ.EventUnregister(&we)
-
- droppedPacketsStat := c.Stack().Stats().DroppedPackets
- droppedPackets := droppedPacketsStat.Value()
- data := []byte{1, 2, 3}
- // Send a packet with no TCP options/timestamp.
- rep.SendPacket(data, nil)
-
- select {
- case <-ch:
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
- // Assert that DroppedPackets was not incremented.
- if got, want := droppedPacketsStat.Value(), droppedPackets; got != want {
- t.Fatalf("incorrect number of dropped packets, got: %v, want: %v", got, want)
- }
-
- // Issue a read and we should data.
- got, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Unexpected error from Read: %v", err)
- }
- if want := data; bytes.Compare(got, want) != 0 {
- t.Fatalf("Data is different: got: %v, want: %v", got, want)
- }
-}
diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD
deleted file mode 100644
index ce6a2c31d..000000000
--- a/pkg/tcpip/transport/tcp/testing/context/BUILD
+++ /dev/null
@@ -1,26 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "context",
- testonly = 1,
- srcs = ["context.go"],
- visibility = [
- "//visibility:public",
- ],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/sniffer",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/seqnum",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/tcp",
- "//pkg/waiter",
- ],
-)
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
deleted file mode 100644
index 1e9a0dea3..000000000
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ /dev/null
@@ -1,1102 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package context provides a test context for use in tcp tests. It also
-// provides helper methods to assert/check certain behaviours.
-package context
-
-import (
- "bytes"
- "context"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- // StackAddr is the IPv4 address assigned to the stack.
- StackAddr = "\x0a\x00\x00\x01"
-
- // StackPort is used as the listening port in tests for passive
- // connects.
- StackPort = 1234
-
- // TestAddr is the source address for packets sent to the stack via the
- // link layer endpoint.
- TestAddr = "\x0a\x00\x00\x02"
-
- // TestPort is the TCP port used for packets sent to the stack
- // via the link layer endpoint.
- TestPort = 4096
-
- // StackV6Addr is the IPv6 address assigned to the stack.
- StackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
-
- // TestV6Addr is the source address for packets sent to the stack via
- // the link layer endpoint.
- TestV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
-
- // StackV4MappedAddr is StackAddr as a mapped v6 address.
- StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr
-
- // TestV4MappedAddr is TestAddr as a mapped v6 address.
- TestV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + TestAddr
-
- // V4MappedWildcardAddr is the mapped v6 representation of 0.0.0.0.
- V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
-
- // testInitialSequenceNumber is the initial sequence number sent in packets that
- // are sent in response to a SYN or in the initial SYN sent to the stack.
- testInitialSequenceNumber = 789
-)
-
-// Headers is used to represent the TCP header fields when building a
-// new packet.
-type Headers struct {
- // SrcPort holds the src port value to be used in the packet.
- SrcPort uint16
-
- // DstPort holds the destination port value to be used in the packet.
- DstPort uint16
-
- // SeqNum is the value of the sequence number field in the TCP header.
- SeqNum seqnum.Value
-
- // AckNum represents the acknowledgement number field in the TCP header.
- AckNum seqnum.Value
-
- // Flags are the TCP flags in the TCP header.
- Flags int
-
- // RcvWnd is the window to be advertised in the ReceiveWindow field of
- // the TCP header.
- RcvWnd seqnum.Size
-
- // TCPOpts holds the options to be sent in the option field of the TCP
- // header.
- TCPOpts []byte
-}
-
-// Context provides an initialized Network stack and a link layer endpoint
-// for use in TCP tests.
-type Context struct {
- t *testing.T
- linkEP *channel.Endpoint
- s *stack.Stack
-
- // IRS holds the initial sequence number in the SYN sent by endpoint in
- // case of an active connect or the sequence number sent by the endpoint
- // in the SYN-ACK sent in response to a SYN when listening in passive
- // mode.
- IRS seqnum.Value
-
- // Port holds the port bound by EP below in case of an active connect or
- // the listening port number in case of a passive connect.
- Port uint16
-
- // EP is the test endpoint in the stack owned by this context. This endpoint
- // is used in various tests to either initiate an active connect or is used
- // as a passive listening endpoint to accept inbound connections.
- EP tcpip.Endpoint
-
- // Wq is the wait queue associated with EP and is used to block for events
- // on EP.
- WQ waiter.Queue
-
- // TimeStampEnabled is true if ep is connected with the timestamp option
- // enabled.
- TimeStampEnabled bool
-
- // WindowScale is the expected window scale in SYN packets sent by
- // the stack.
- WindowScale uint8
-}
-
-// New allocates and initializes a test context containing a new
-// stack and a link-layer endpoint.
-func New(t *testing.T, mtu uint32) *Context {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
- })
-
- // Allow minimum send/receive buffer sizes to be 1 during tests.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize, 10 * tcp.DefaultSendBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
- }
-
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize, 10 * tcp.DefaultReceiveBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
- }
-
- // Some of the congestion control tests send up to 640 packets, we so
- // set the channel size to 1000.
- ep := channel.New(1000, mtu, "")
- wep := stack.LinkEndpoint(ep)
- if testing.Verbose() {
- wep = sniffer.New(ep)
- }
- opts := stack.NICOptions{Name: "nic1"}
- if err := s.CreateNICWithOptions(1, wep, opts); err != nil {
- t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
- }
- wep2 := stack.LinkEndpoint(channel.New(1000, mtu, ""))
- if testing.Verbose() {
- wep2 = sniffer.New(channel.New(1000, mtu, ""))
- }
- opts2 := stack.NICOptions{Name: "nic2"}
- if err := s.CreateNICWithOptions(2, wep2, opts2); err != nil {
- t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err)
- }
-
- if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
- if err := s.AddAddress(1, ipv6.ProtocolNumber, StackV6Addr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: 1,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: 1,
- },
- })
-
- return &Context{
- t: t,
- s: s,
- linkEP: ep,
- WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)),
- }
-}
-
-// Cleanup closes the context endpoint if required.
-func (c *Context) Cleanup() {
- if c.EP != nil {
- c.EP.Close()
- }
-}
-
-// Stack returns a reference to the stack in the Context.
-func (c *Context) Stack() *stack.Stack {
- return c.s
-}
-
-// CheckNoPacketTimeout verifies that no packet is received during the time
-// specified by wait.
-func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) {
- c.t.Helper()
-
- ctx, _ := context.WithTimeout(context.Background(), wait)
- if _, ok := c.linkEP.ReadContext(ctx); ok {
- c.t.Fatal(errMsg)
- }
-}
-
-// CheckNoPacket verifies that no packet is received for 1 second.
-func (c *Context) CheckNoPacket(errMsg string) {
- c.CheckNoPacketTimeout(errMsg, 1*time.Second)
-}
-
-// GetPacket reads a packet from the link layer endpoint and verifies
-// that it is an IPv4 packet with the expected source and destination
-// addresses. It will fail with an error if no packet is received for
-// 2 seconds.
-func (c *Context) GetPacket() []byte {
- c.t.Helper()
-
- ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
- p, ok := c.linkEP.ReadContext(ctx)
- if !ok {
- c.t.Fatalf("Packet wasn't written out")
- return nil
- }
-
- if p.Proto != ipv4.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
- }
-
- hdr := p.Pkt.Header.View()
- b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
-
- if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize {
- c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize)
- }
-
- checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
- return b
-}
-
-// GetPacketNonBlocking reads a packet from the link layer endpoint
-// and verifies that it is an IPv4 packet with the expected source
-// and destination address. If no packet is available it will return
-// nil immediately.
-func (c *Context) GetPacketNonBlocking() []byte {
- c.t.Helper()
-
- p, ok := c.linkEP.Read()
- if !ok {
- return nil
- }
-
- if p.Proto != ipv4.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
- }
-
- hdr := p.Pkt.Header.View()
- b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
-
- checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
- return b
-}
-
-// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint.
-func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) {
- // Allocate a buffer data and headers.
- buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p2))
- if len(buf) > maxTotalSize {
- buf = buf[:maxTotalSize]
- }
-
- ip := header.IPv4(buf)
- ip.Encode(&header.IPv4Fields{
- IHL: header.IPv4MinimumSize,
- TotalLength: uint16(len(buf)),
- TTL: 65,
- Protocol: uint8(header.ICMPv4ProtocolNumber),
- SrcAddr: TestAddr,
- DstAddr: StackAddr,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- icmp := header.ICMPv4(buf[header.IPv4MinimumSize:])
- icmp.SetType(typ)
- icmp.SetCode(code)
- const icmpv4VariableHeaderOffset = 4
- copy(icmp[icmpv4VariableHeaderOffset:], p1)
- copy(icmp[header.ICMPv4PayloadOffset:], p2)
-
- // Inject packet.
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
-}
-
-// BuildSegment builds a TCP segment based on the given Headers and payload.
-func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView {
- return c.BuildSegmentWithAddrs(payload, h, TestAddr, StackAddr)
-}
-
-// BuildSegmentWithAddrs builds a TCP segment based on the given Headers,
-// payload and source and destination IPv4 addresses.
-func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) buffer.VectorisedView {
- // Allocate a buffer for data and headers.
- buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload))
- copy(buf[len(buf)-len(payload):], payload)
- copy(buf[len(buf)-len(payload)-len(h.TCPOpts):], h.TCPOpts)
-
- // Initialize the IP header.
- ip := header.IPv4(buf)
- ip.Encode(&header.IPv4Fields{
- IHL: header.IPv4MinimumSize,
- TotalLength: uint16(len(buf)),
- TTL: 65,
- Protocol: uint8(tcp.ProtocolNumber),
- SrcAddr: src,
- DstAddr: dst,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- // Initialize the TCP header.
- t := header.TCP(buf[header.IPv4MinimumSize:])
- t.Encode(&header.TCPFields{
- SrcPort: h.SrcPort,
- DstPort: h.DstPort,
- SeqNum: uint32(h.SeqNum),
- AckNum: uint32(h.AckNum),
- DataOffset: uint8(header.TCPMinimumSize + len(h.TCPOpts)),
- Flags: uint8(h.Flags),
- WindowSize: uint16(h.RcvWnd),
- })
-
- // Calculate the TCP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t)))
-
- // Calculate the TCP checksum and set it.
- xsum = header.Checksum(payload, xsum)
- t.SetChecksum(^t.CalculateChecksum(xsum))
-
- // Inject packet.
- return buf.ToVectorisedView()
-}
-
-// SendSegment sends a TCP segment that has already been built and written to a
-// buffer.VectorisedView.
-func (c *Context) SendSegment(s buffer.VectorisedView) {
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
- Data: s,
- })
-}
-
-// SendPacket builds and sends a TCP segment(with the provided payload & TCP
-// headers) in an IPv4 packet via the link layer endpoint.
-func (c *Context) SendPacket(payload []byte, h *Headers) {
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
- Data: c.BuildSegment(payload, h),
- })
-}
-
-// SendPacketWithAddrs builds and sends a TCP segment(with the provided payload
-// & TCPheaders) in an IPv4 packet via the link layer endpoint using the
-// provided source and destination IPv4 addresses.
-func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
- Data: c.BuildSegmentWithAddrs(payload, h, src, dst),
- })
-}
-
-// SendAck sends an ACK packet.
-func (c *Context) SendAck(seq seqnum.Value, bytesReceived int) {
- c.SendAckWithSACK(seq, bytesReceived, nil)
-}
-
-// SendAckWithSACK sends an ACK packet which includes the sackBlocks specified.
-func (c *Context) SendAckWithSACK(seq seqnum.Value, bytesReceived int, sackBlocks []header.SACKBlock) {
- options := make([]byte, 40)
- offset := 0
- if len(sackBlocks) > 0 {
- offset += header.EncodeNOP(options[offset:])
- offset += header.EncodeNOP(options[offset:])
- offset += header.EncodeSACKBlocks(sackBlocks, options[offset:])
- }
-
- c.SendPacket(nil, &Headers{
- SrcPort: TestPort,
- DstPort: c.Port,
- Flags: header.TCPFlagAck,
- SeqNum: seq,
- AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)),
- RcvWnd: 30000,
- TCPOpts: options[:offset],
- })
-}
-
-// ReceiveAndCheckPacket reads a packet from the link layer endpoint and
-// verifies that the packet packet payload of packet matches the slice
-// of data indicated by offset & size.
-func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) {
- c.ReceiveAndCheckPacketWithOptions(data, offset, size, 0)
-}
-
-// ReceiveAndCheckPacketWithOptions reads a packet from the link layer endpoint
-// and verifies that the packet packet payload of packet matches the slice of
-// data indicated by offset & size and skips optlen bytes in addition to the IP
-// TCP headers when comparing the data.
-func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, optlen int) {
- b := c.GetPacket()
- checker.IPv4(c.t, b,
- checker.PayloadLen(size+header.TCPMinimumSize+optlen),
- checker.TCP(
- checker.DstPort(TestPort),
- checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
- checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
-
- pdata := data[offset:][:size]
- if p := b[header.IPv4MinimumSize+header.TCPMinimumSize+optlen:]; bytes.Compare(pdata, p) != 0 {
- c.t.Fatalf("Data is different: expected %v, got %v", pdata, p)
- }
-}
-
-// ReceiveNonBlockingAndCheckPacket reads a packet from the link layer endpoint
-// and verifies that the packet packet payload of packet matches the slice of
-// data indicated by offset & size. It returns true if a packet was received and
-// processed.
-func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int) bool {
- b := c.GetPacketNonBlocking()
- if b == nil {
- return false
- }
- checker.IPv4(c.t, b,
- checker.PayloadLen(size+header.TCPMinimumSize),
- checker.TCP(
- checker.DstPort(TestPort),
- checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
- checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
- ),
- )
-
- pdata := data[offset:][:size]
- if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(pdata, p) != 0 {
- c.t.Fatalf("Data is different: expected %v, got %v", pdata, p)
- }
- return true
-}
-
-// CreateV6Endpoint creates and initializes c.ep as a IPv6 Endpoint. If v6Only
-// is true then it sets the IP_V6ONLY option on the socket to make it a IPv6
-// only endpoint instead of a default dual stack socket.
-func (c *Context) CreateV6Endpoint(v6only bool) {
- var err *tcpip.Error
- c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if err := c.EP.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil {
- c.t.Fatalf("SetSockOpt failed failed: %v", err)
- }
-}
-
-// GetV6Packet reads a single packet from the link layer endpoint of the context
-// and asserts that it is an IPv6 Packet with the expected src/dest addresses.
-func (c *Context) GetV6Packet() []byte {
- c.t.Helper()
-
- ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
- p, ok := c.linkEP.ReadContext(ctx)
- if !ok {
- c.t.Fatalf("Packet wasn't written out")
- return nil
- }
-
- if p.Proto != ipv6.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
- }
- b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size())
- copy(b, p.Pkt.Header.View())
- copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView())
-
- checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr))
- return b
-}
-
-// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of
-// the context.
-func (c *Context) SendV6Packet(payload []byte, h *Headers) {
- c.SendV6PacketWithAddrs(payload, h, TestV6Addr, StackV6Addr)
-}
-
-// SendV6PacketWithAddrs builds and sends an IPv6 Packet via the link layer
-// endpoint of the context using the provided source and destination IPv6
-// addresses.
-func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
- // Allocate a buffer for data and headers.
- buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload))
- copy(buf[len(buf)-len(payload):], payload)
-
- // Initialize the IP header.
- ip := header.IPv6(buf)
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(header.TCPMinimumSize + len(payload)),
- NextHeader: uint8(tcp.ProtocolNumber),
- HopLimit: 65,
- SrcAddr: src,
- DstAddr: dst,
- })
-
- // Initialize the TCP header.
- t := header.TCP(buf[header.IPv6MinimumSize:])
- t.Encode(&header.TCPFields{
- SrcPort: h.SrcPort,
- DstPort: h.DstPort,
- SeqNum: uint32(h.SeqNum),
- AckNum: uint32(h.AckNum),
- DataOffset: header.TCPMinimumSize,
- Flags: uint8(h.Flags),
- WindowSize: uint16(h.RcvWnd),
- })
-
- // Calculate the TCP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t)))
-
- // Calculate the TCP checksum and set it.
- xsum = header.Checksum(payload, xsum)
- t.SetChecksum(^t.CalculateChecksum(xsum))
-
- // Inject packet.
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
-}
-
-// CreateConnected creates a connected TCP endpoint.
-func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int) {
- c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil)
-}
-
-// Connect performs the 3-way handshake for c.EP with the provided Initial
-// Sequence Number (iss) and receive window(rcvWnd) and any options if
-// specified.
-//
-// It also sets the receive buffer for the endpoint to the specified
-// value in epRcvBuf.
-//
-// PreCondition: c.EP must already be created.
-func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) {
- // Start connection attempt.
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&waitEntry, waiter.EventOut)
- defer c.WQ.EventUnregister(&waitEntry)
-
- if err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}); err != tcpip.ErrConnectStarted {
- c.t.Fatalf("Unexpected return value from Connect: %v", err)
- }
-
- // Receive SYN packet.
- b := c.GetPacket()
- checker.IPv4(c.t, b,
- checker.TCP(
- checker.DstPort(TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- ),
- )
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
- c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- tcpHdr := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
-
- c.SendPacket(nil, &Headers{
- SrcPort: tcpHdr.DestinationPort(),
- DstPort: tcpHdr.SourcePort(),
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: rcvWnd,
- TCPOpts: options,
- })
-
- // Receive ACK packet.
- checker.IPv4(c.t, c.GetPacket(),
- checker.TCP(
- checker.DstPort(TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(uint32(iss)+1),
- ),
- )
-
- // Wait for connection to be established.
- select {
- case <-notifyCh:
- if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil {
- c.t.Fatalf("Unexpected error when connecting: %v", err)
- }
- case <-time.After(1 * time.Second):
- c.t.Fatalf("Timed out waiting for connection")
- }
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want {
- c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- c.Port = tcpHdr.SourcePort()
-}
-
-// Create creates a TCP endpoint.
-func (c *Context) Create(epRcvBuf int) {
- // Create TCP endpoint.
- var err *tcpip.Error
- c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if epRcvBuf != -1 {
- if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, epRcvBuf); err != nil {
- c.t.Fatalf("SetSockOpt failed failed: %v", err)
- }
- }
-}
-
-// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
-// the specified option bytes as the Option field in the initial SYN packet.
-//
-// It also sets the receive buffer for the endpoint to the specified
-// value in epRcvBuf.
-func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) {
- c.Create(epRcvBuf)
- c.Connect(iss, rcvWnd, options)
-}
-
-// RawEndpoint is just a small wrapper around a TCP endpoint's state to make
-// sending data and ACK packets easy while being able to manipulate the sequence
-// numbers and timestamp values as needed.
-type RawEndpoint struct {
- C *Context
- SrcPort uint16
- DstPort uint16
- Flags int
- NextSeqNum seqnum.Value
- AckNum seqnum.Value
- WndSize seqnum.Size
- RecentTS uint32 // Stores the latest timestamp to echo back.
- TSVal uint32 // TSVal stores the last timestamp sent by this endpoint.
-
- // SackPermitted is true if SACKPermitted option was negotiated for this endpoint.
- SACKPermitted bool
-}
-
-// SendPacketWithTS embeds the provided tsVal in the Timestamp option
-// for the packet to be sent out.
-func (r *RawEndpoint) SendPacketWithTS(payload []byte, tsVal uint32) {
- r.TSVal = tsVal
- tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
- header.EncodeTSOption(r.TSVal, r.RecentTS, tsOpt[2:])
- r.SendPacket(payload, tsOpt[:])
-}
-
-// SendPacket is a small wrapper function to build and send packets.
-func (r *RawEndpoint) SendPacket(payload []byte, opts []byte) {
- packetHeaders := &Headers{
- SrcPort: r.SrcPort,
- DstPort: r.DstPort,
- Flags: r.Flags,
- SeqNum: r.NextSeqNum,
- AckNum: r.AckNum,
- RcvWnd: r.WndSize,
- TCPOpts: opts,
- }
- r.C.SendPacket(payload, packetHeaders)
- r.NextSeqNum = r.NextSeqNum.Add(seqnum.Size(len(payload)))
-}
-
-// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided
-// tsVal.
-func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) {
- // Read ACK and verify that tsEcr of ACK packet is [1,2,3,4]
- ackPacket := r.C.GetPacket()
- checker.IPv4(r.C.t, ackPacket,
- checker.TCP(
- checker.DstPort(r.SrcPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(r.AckNum)),
- checker.AckNum(uint32(r.NextSeqNum)),
- checker.TCPTimestampChecker(true, 0, tsVal),
- ),
- )
- // Store the parsed TSVal from the ack as recentTS.
- tcpSeg := header.TCP(header.IPv4(ackPacket).Payload())
- opts := tcpSeg.ParsedOptions()
- r.RecentTS = opts.TSVal
-}
-
-// VerifyACKRcvWnd verifies that the window advertised by the incoming ACK
-// matches the provided rcvWnd.
-func (r *RawEndpoint) VerifyACKRcvWnd(rcvWnd uint16) {
- ackPacket := r.C.GetPacket()
- checker.IPv4(r.C.t, ackPacket,
- checker.TCP(
- checker.DstPort(r.SrcPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(r.AckNum)),
- checker.AckNum(uint32(r.NextSeqNum)),
- checker.Window(rcvWnd),
- ),
- )
-}
-
-// VerifyACKNoSACK verifies that the ACK does not contain a SACK block.
-func (r *RawEndpoint) VerifyACKNoSACK() {
- r.VerifyACKHasSACK(nil)
-}
-
-// VerifyACKHasSACK verifies that the ACK contains the specified SACKBlocks.
-func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) {
- // Read ACK and verify that the TCP options in the segment do
- // not contain a SACK block.
- ackPacket := r.C.GetPacket()
- checker.IPv4(r.C.t, ackPacket,
- checker.TCP(
- checker.DstPort(r.SrcPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(r.AckNum)),
- checker.AckNum(uint32(r.NextSeqNum)),
- checker.TCPSACKBlockChecker(sackBlocks),
- ),
- )
-}
-
-// CreateConnectedWithOptions creates and connects c.ep with the specified TCP
-// options enabled and returns a RawEndpoint which represents the other end of
-// the connection.
-//
-// It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK
-// does not carry an option that was not requested.
-func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint {
- var err *tcpip.Error
- c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err)
- }
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want {
- c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- // Start connection attempt.
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&waitEntry, waiter.EventOut)
- defer c.WQ.EventUnregister(&waitEntry)
-
- testFullAddr := tcpip.FullAddress{Addr: TestAddr, Port: TestPort}
- err = c.EP.Connect(testFullAddr)
- if err != tcpip.ErrConnectStarted {
- c.t.Fatalf("c.ep.Connect(%v) = %v", testFullAddr, err)
- }
- // Receive SYN packet.
- b := c.GetPacket()
- // Validate that the syn has the timestamp option and a valid
- // TS value.
- mss := uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize)
-
- checker.IPv4(c.t, b,
- checker.TCP(
- checker.DstPort(TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- checker.TCPSynOptions(header.TCPSynOptions{
- MSS: mss,
- TS: true,
- WS: int(c.WindowScale),
- SACKPermitted: c.SACKEnabled(),
- }),
- ),
- )
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
- c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- tcpSeg := header.TCP(header.IPv4(b).Payload())
- synOptions := header.ParseSynOptions(tcpSeg.Options(), false)
-
- // Build options w/ tsVal to be sent in the SYN-ACK.
- synAckOptions := make([]byte, header.TCPOptionsMaximumSize)
- offset := 0
- if wantOptions.WS != -1 {
- offset += header.EncodeWSOption(wantOptions.WS, synAckOptions[offset:])
- }
- if wantOptions.TS {
- offset += header.EncodeTSOption(wantOptions.TSVal, synOptions.TSVal, synAckOptions[offset:])
- }
- if wantOptions.SACKPermitted {
- offset += header.EncodeSACKPermittedOption(synAckOptions[offset:])
- }
-
- offset += header.AddTCPOptionPadding(synAckOptions, offset)
-
- // Build SYN-ACK.
- c.IRS = seqnum.Value(tcpSeg.SequenceNumber())
- iss := seqnum.Value(testInitialSequenceNumber)
- c.SendPacket(nil, &Headers{
- SrcPort: tcpSeg.DestinationPort(),
- DstPort: tcpSeg.SourcePort(),
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- SeqNum: iss,
- AckNum: c.IRS.Add(1),
- RcvWnd: 30000,
- TCPOpts: synAckOptions[:offset],
- })
-
- // Read ACK.
- ackPacket := c.GetPacket()
-
- // Verify TCP header fields.
- tcpCheckers := []checker.TransportChecker{
- checker.DstPort(TestPort),
- checker.TCPFlags(header.TCPFlagAck),
- checker.SeqNum(uint32(c.IRS) + 1),
- checker.AckNum(uint32(iss) + 1),
- }
-
- // Verify that tsEcr of ACK packet is wantOptions.TSVal if the
- // timestamp option was enabled, if not then we verify that
- // there is no timestamp in the ACK packet.
- if wantOptions.TS {
- tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(true, 0, wantOptions.TSVal))
- } else {
- tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0))
- }
-
- checker.IPv4(c.t, ackPacket, checker.TCP(tcpCheckers...))
-
- ackSeg := header.TCP(header.IPv4(ackPacket).Payload())
- ackOptions := ackSeg.ParsedOptions()
-
- // Wait for connection to be established.
- select {
- case <-notifyCh:
- err = c.EP.GetSockOpt(tcpip.ErrorOption{})
- if err != nil {
- c.t.Fatalf("Unexpected error when connecting: %v", err)
- }
- case <-time.After(1 * time.Second):
- c.t.Fatalf("Timed out waiting for connection")
- }
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want {
- c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- // Store the source port in use by the endpoint.
- c.Port = tcpSeg.SourcePort()
-
- // Mark in context that timestamp option is enabled for this endpoint.
- c.TimeStampEnabled = true
-
- return &RawEndpoint{
- C: c,
- SrcPort: tcpSeg.DestinationPort(),
- DstPort: tcpSeg.SourcePort(),
- Flags: header.TCPFlagAck | header.TCPFlagPsh,
- NextSeqNum: iss + 1,
- AckNum: c.IRS.Add(1),
- WndSize: 30000,
- RecentTS: ackOptions.TSVal,
- TSVal: wantOptions.TSVal,
- SACKPermitted: wantOptions.SACKPermitted,
- }
-}
-
-// AcceptWithOptions initializes a listening endpoint and connects to it with the
-// provided options enabled. It also verifies that the SYN-ACK has the expected
-// values for the provided options.
-//
-// The function returns a RawEndpoint representing the other end of the accepted
-// endpoint.
-func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
- // Create EP and start listening.
- wq := &waiter.Queue{}
- ep, err := c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
- defer ep.Close()
-
- if err := ep.Bind(tcpip.FullAddress{Port: StackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
- }
- if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
- c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- if err := ep.Listen(10); err != nil {
- c.t.Fatalf("Listen failed: %v", err)
- }
- if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
- c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- rep := c.PassiveConnectWithOptions(100, wndScale, synOptions)
-
- // Try to accept the connection.
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
-
- c.EP, _, err = ep.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
- select {
- case <-ch:
- c.EP, _, err = ep.Accept()
- if err != nil {
- c.t.Fatalf("Accept failed: %v", err)
- }
-
- case <-time.After(1 * time.Second):
- c.t.Fatalf("Timed out waiting for accept")
- }
- }
- if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want {
- c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
- }
-
- return rep
-}
-
-// PassiveConnect just disables WindowScaling and delegates the call to
-// PassiveConnectWithOptions.
-func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCPSynOptions) {
- synOptions.WS = -1
- c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions)
-}
-
-// PassiveConnectWithOptions initiates a new connection (with the specified TCP
-// options enabled) to the port on which the Context.ep is listening for new
-// connections. It also validates that the SYN-ACK has the expected values for
-// the enabled options.
-//
-// NOTE: MSS is not a negotiated option and it can be asymmetric
-// in each direction. This function uses the maxPayload to set the MSS to be
-// sent to the peer on a connect and validates that the MSS in the SYN-ACK
-// response is equal to the MTU - (tcphdr len + iphdr len).
-//
-// wndScale is the expected window scale in the SYN-ACK and synOptions.WS is the
-// value of the window scaling option to be sent in the SYN. If synOptions.WS >
-// 0 then we send the WindowScale option.
-func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
- opts := make([]byte, header.TCPOptionsMaximumSize)
- offset := 0
- offset += header.EncodeMSSOption(uint32(maxPayload), opts)
-
- if synOptions.WS >= 0 {
- offset += header.EncodeWSOption(3, opts[offset:])
- }
- if synOptions.TS {
- offset += header.EncodeTSOption(synOptions.TSVal, synOptions.TSEcr, opts[offset:])
- }
-
- if synOptions.SACKPermitted {
- offset += header.EncodeSACKPermittedOption(opts[offset:])
- }
-
- paddingToAdd := 4 - offset%4
- // Now add any padding bytes that might be required to quad align the
- // options.
- for i := offset; i < offset+paddingToAdd; i++ {
- opts[i] = header.TCPOptionNOP
- }
- offset += paddingToAdd
-
- // Send a SYN request.
- iss := seqnum.Value(testInitialSequenceNumber)
- c.SendPacket(nil, &Headers{
- SrcPort: TestPort,
- DstPort: StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: iss,
- RcvWnd: 30000,
- TCPOpts: opts[:offset],
- })
-
- // Receive the SYN-ACK reply. Make sure MSS and other expected options
- // are present.
- b := c.GetPacket()
- tcp := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcp.SequenceNumber())
-
- tcpCheckers := []checker.TransportChecker{
- checker.SrcPort(StackPort),
- checker.DstPort(TestPort),
- checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
- checker.AckNum(uint32(iss) + 1),
- checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}),
- }
-
- // If TS option was enabled in the original SYN then add a checker to
- // validate the Timestamp option in the SYN-ACK.
- if synOptions.TS {
- tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(synOptions.TS, 0, synOptions.TSVal))
- } else {
- tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0))
- }
-
- checker.IPv4(c.t, b, checker.TCP(tcpCheckers...))
- rcvWnd := seqnum.Size(30000)
- ackHeaders := &Headers{
- SrcPort: TestPort,
- DstPort: StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: iss + 1,
- AckNum: c.IRS + 1,
- RcvWnd: rcvWnd,
- }
-
- // If WS was expected to be in effect then scale the advertised window
- // correspondingly.
- if synOptions.WS > 0 {
- ackHeaders.RcvWnd = rcvWnd >> byte(synOptions.WS)
- }
-
- parsedOpts := tcp.ParsedOptions()
- if synOptions.TS {
- // Echo the tsVal back to the peer in the tsEcr field of the
- // timestamp option.
- // Increment TSVal by 1 from the value sent in the SYN and echo
- // the TSVal in the SYN-ACK in the TSEcr field.
- opts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
- header.EncodeTSOption(synOptions.TSVal+1, parsedOpts.TSVal, opts[2:])
- ackHeaders.TCPOpts = opts[:]
- }
-
- // Send ACK.
- c.SendPacket(nil, ackHeaders)
-
- c.Port = StackPort
-
- return &RawEndpoint{
- C: c,
- SrcPort: TestPort,
- DstPort: StackPort,
- Flags: header.TCPFlagPsh | header.TCPFlagAck,
- NextSeqNum: iss + 1,
- AckNum: c.IRS + 1,
- WndSize: rcvWnd,
- SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled(),
- RecentTS: parsedOpts.TSVal,
- TSVal: synOptions.TSVal + 1,
- }
-}
-
-// SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true
-// for the Stack in the context.
-func (c *Context) SACKEnabled() bool {
- var v tcp.SACKEnabled
- if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil {
- // Stack doesn't support SACK. So just return.
- return false
- }
- return bool(v)
-}
-
-// SetGSOEnabled enables or disables generic segmentation offload.
-func (c *Context) SetGSOEnabled(enable bool) {
- if enable {
- c.linkEP.LinkEPCapabilities |= stack.CapabilityHardwareGSO
- } else {
- c.linkEP.LinkEPCapabilities &^= stack.CapabilityHardwareGSO
- }
-}
-
-// MSSWithoutOptions returns the value for the MSS used by the stack when no
-// options are in use.
-func (c *Context) MSSWithoutOptions() uint16 {
- return uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize)
-}
-
-// MSSWithoutOptionsV6 returns the value for the MSS used by the stack when no
-// options are in use for IPv6 packets.
-func (c *Context) MSSWithoutOptionsV6() uint16 {
- return uint16(c.linkEP.MTU() - header.IPv6MinimumSize - header.TCPMinimumSize)
-}
diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD
deleted file mode 100644
index 3ad6994a7..000000000
--- a/pkg/tcpip/transport/tcpconntrack/BUILD
+++ /dev/null
@@ -1,23 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "tcpconntrack",
- srcs = ["tcp_conntrack.go"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip/header",
- "//pkg/tcpip/seqnum",
- ],
-)
-
-go_test(
- name = "tcpconntrack_test",
- size = "small",
- srcs = ["tcp_conntrack_test.go"],
- deps = [
- ":tcpconntrack",
- "//pkg/tcpip/header",
- ],
-)
diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
index 93712cd45..93712cd45 100644..100755
--- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
+++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go
deleted file mode 100644
index 5e271b7ca..000000000
--- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go
+++ /dev/null
@@ -1,511 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tcpconntrack_test
-
-import (
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack"
-)
-
-// connected creates a connection tracker TCB and sets it to a connected state
-// by performing a 3-way handshake.
-func connected(t *testing.T, iss, irs uint32, isw, irw uint16) *tcpconntrack.TCB {
- // Send SYN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: iss,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: irw,
- })
-
- tcb := tcpconntrack.TCB{}
- tcb.Init(tcp)
-
- // Receive SYN-ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: irs,
- AckNum: iss + 1,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- WindowSize: isw,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Send ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: iss + 1,
- AckNum: irs + 1,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: irw,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- return &tcb
-}
-
-func TestConnectionRefused(t *testing.T) {
- // Send SYN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: 1234,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 30000,
- })
-
- tcb := tcpconntrack.TCB{}
- tcb.Init(tcp)
-
- // Receive RST.
- tcp.Encode(&header.TCPFields{
- SeqNum: 789,
- AckNum: 1235,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagRst | header.TCPFlagAck,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset)
- }
-}
-
-func TestConnectionRefusedInSynRcvd(t *testing.T) {
- // Send SYN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: 1234,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 30000,
- })
-
- tcb := tcpconntrack.TCB{}
- tcb.Init(tcp)
-
- // Receive SYN.
- tcp.Encode(&header.TCPFields{
- SeqNum: 789,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Receive RST with no ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 790,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagRst,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset)
- }
-}
-
-func TestConnectionResetInSynRcvd(t *testing.T) {
- // Send SYN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: 1234,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 30000,
- })
-
- tcb := tcpconntrack.TCB{}
- tcb.Init(tcp)
-
- // Receive SYN.
- tcp.Encode(&header.TCPFields{
- SeqNum: 789,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Send RST with no ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 1235,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagRst,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultReset {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset)
- }
-}
-
-func TestRetransmitOnSynSent(t *testing.T) {
- // Send initial SYN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: 1234,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 30000,
- })
-
- tcb := tcpconntrack.TCB{}
- tcb.Init(tcp)
-
- // Retransmit the same SYN.
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultConnecting {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultConnecting)
- }
-}
-
-func TestRetransmitOnSynRcvd(t *testing.T) {
- // Send initial SYN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: 1234,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 30000,
- })
-
- tcb := tcpconntrack.TCB{}
- tcb.Init(tcp)
-
- // Receive SYN. This will cause the state to go to SYN-RCVD.
- tcp.Encode(&header.TCPFields{
- SeqNum: 789,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Retransmit the original SYN.
- tcp.Encode(&header.TCPFields{
- SeqNum: 1234,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 30000,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Transmit a SYN-ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 1234,
- AckNum: 790,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- WindowSize: 30000,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-}
-
-func TestClosedBySelf(t *testing.T) {
- tcb := connected(t, 1234, 789, 30000, 50000)
-
- // Send FIN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: 1235,
- AckNum: 790,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- WindowSize: 30000,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Receive FIN/ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 790,
- AckNum: 1236,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Send ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 1236,
- AckNum: 791,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: 30000,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf)
- }
-}
-
-func TestClosedByPeer(t *testing.T) {
- tcb := connected(t, 1234, 789, 30000, 50000)
-
- // Receive FIN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: 790,
- AckNum: 1235,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Send FIN/ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 1235,
- AckNum: 791,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- WindowSize: 30000,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Receive ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 791,
- AckNum: 1236,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultClosedByPeer {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedByPeer)
- }
-}
-
-func TestSendAndReceiveDataClosedBySelf(t *testing.T) {
- sseq := uint32(1234)
- rseq := uint32(789)
- tcb := connected(t, sseq, rseq, 30000, 50000)
- sseq++
- rseq++
-
- // Send some data.
- tcp := make(header.TCP, header.TCPMinimumSize+1024)
-
- for i := uint32(0); i < 10; i++ {
- // Send some data.
- tcp.Encode(&header.TCPFields{
- SeqNum: sseq,
- AckNum: rseq,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: 30000,
- })
- sseq += uint32(len(tcp)) - header.TCPMinimumSize
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Receive ack for data.
- tcp.Encode(&header.TCPFields{
- SeqNum: rseq,
- AckNum: sseq,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
- }
-
- for i := uint32(0); i < 10; i++ {
- // Receive some data.
- tcp.Encode(&header.TCPFields{
- SeqNum: rseq,
- AckNum: sseq,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: 50000,
- })
- rseq += uint32(len(tcp)) - header.TCPMinimumSize
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Send ack for data.
- tcp.Encode(&header.TCPFields{
- SeqNum: sseq,
- AckNum: rseq,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: 30000,
- })
-
- if r := tcb.UpdateStateOutbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
- }
-
- // Send FIN.
- tcp = tcp[:header.TCPMinimumSize]
- tcp.Encode(&header.TCPFields{
- SeqNum: sseq,
- AckNum: rseq,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- WindowSize: 30000,
- })
- sseq++
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Receive FIN/ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: rseq,
- AckNum: sseq,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck | header.TCPFlagFin,
- WindowSize: 50000,
- })
- rseq++
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Send ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: sseq,
- AckNum: rseq,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: 30000,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf)
- }
-}
-
-func TestIgnoreBadResetOnSynSent(t *testing.T) {
- // Send SYN.
- tcp := make(header.TCP, header.TCPMinimumSize)
- tcp.Encode(&header.TCPFields{
- SeqNum: 1234,
- AckNum: 0,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn,
- WindowSize: 30000,
- })
-
- tcb := tcpconntrack.TCB{}
- tcb.Init(tcp)
-
- // Receive a RST with a bad ACK, it should not cause the connection to
- // be reset.
- acks := []uint32{1234, 1236, 1000, 5000}
- flags := []uint8{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck}
- for _, a := range acks {
- for _, f := range flags {
- tcp.Encode(&header.TCPFields{
- SeqNum: 789,
- AckNum: a,
- DataOffset: header.TCPMinimumSize,
- Flags: f,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultConnecting {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
- }
- }
-
- // Complete the handshake.
- // Receive SYN-ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 789,
- AckNum: 1235,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagSyn | header.TCPFlagAck,
- WindowSize: 50000,
- })
-
- if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-
- // Send ACK.
- tcp.Encode(&header.TCPFields{
- SeqNum: 1235,
- AckNum: 790,
- DataOffset: header.TCPMinimumSize,
- Flags: header.TCPFlagAck,
- WindowSize: 30000,
- })
-
- if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
- t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
- }
-}
diff --git a/pkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go b/pkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go
new file mode 100755
index 000000000..ff53204da
--- /dev/null
+++ b/pkg/tcpip/transport/tcpconntrack/tcpconntrack_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package tcpconntrack
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
deleted file mode 100644
index adc908e24..000000000
--- a/pkg/tcpip/transport/udp/BUILD
+++ /dev/null
@@ -1,61 +0,0 @@
-load("//tools:defs.bzl", "go_library", "go_test")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-package(licenses = ["notice"])
-
-go_template_instance(
- name = "udp_packet_list",
- out = "udp_packet_list.go",
- package = "udp",
- prefix = "udpPacket",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*udpPacket",
- "Linker": "*udpPacket",
- },
-)
-
-go_library(
- name = "udp",
- srcs = [
- "endpoint.go",
- "endpoint_state.go",
- "forwarder.go",
- "protocol.go",
- "udp_packet_list.go",
- ],
- imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/sleep",
- "//pkg/sync",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
- "//pkg/tcpip/ports",
- "//pkg/tcpip/stack",
- "//pkg/tcpip/transport/raw",
- "//pkg/waiter",
- ],
-)
-
-go_test(
- name = "udp_x_test",
- size = "small",
- srcs = ["udp_test.go"],
- deps = [
- ":udp",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/checker",
- "//pkg/tcpip/header",
- "//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/loopback",
- "//pkg/tcpip/link/sniffer",
- "//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/stack",
- "//pkg/waiter",
- ],
-)
diff --git a/pkg/tcpip/transport/udp/udp_packet_list.go b/pkg/tcpip/transport/udp/udp_packet_list.go
new file mode 100755
index 000000000..673a9373b
--- /dev/null
+++ b/pkg/tcpip/transport/udp/udp_packet_list.go
@@ -0,0 +1,173 @@
+package udp
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type udpPacketElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (udpPacketElementMapper) linkerFor(elem *udpPacket) *udpPacket { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type udpPacketList struct {
+ head *udpPacket
+ tail *udpPacket
+}
+
+// Reset resets list l to the empty state.
+func (l *udpPacketList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *udpPacketList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *udpPacketList) Front() *udpPacket {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *udpPacketList) Back() *udpPacket {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *udpPacketList) PushFront(e *udpPacket) {
+ udpPacketElementMapper{}.linkerFor(e).SetNext(l.head)
+ udpPacketElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ udpPacketElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *udpPacketList) PushBack(e *udpPacket) {
+ udpPacketElementMapper{}.linkerFor(e).SetNext(nil)
+ udpPacketElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ udpPacketElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *udpPacketList) PushBackList(m *udpPacketList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ udpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ udpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *udpPacketList) InsertAfter(b, e *udpPacket) {
+ a := udpPacketElementMapper{}.linkerFor(b).Next()
+ udpPacketElementMapper{}.linkerFor(e).SetNext(a)
+ udpPacketElementMapper{}.linkerFor(e).SetPrev(b)
+ udpPacketElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ udpPacketElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *udpPacketList) InsertBefore(a, e *udpPacket) {
+ b := udpPacketElementMapper{}.linkerFor(a).Prev()
+ udpPacketElementMapper{}.linkerFor(e).SetNext(a)
+ udpPacketElementMapper{}.linkerFor(e).SetPrev(b)
+ udpPacketElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ udpPacketElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *udpPacketList) Remove(e *udpPacket) {
+ prev := udpPacketElementMapper{}.linkerFor(e).Prev()
+ next := udpPacketElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ udpPacketElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ udpPacketElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type udpPacketEntry struct {
+ next *udpPacket
+ prev *udpPacket
+}
+
+// Next returns the entry that follows e in the list.
+func (e *udpPacketEntry) Next() *udpPacket {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *udpPacketEntry) Prev() *udpPacket {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *udpPacketEntry) SetNext(elem *udpPacket) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *udpPacketEntry) SetPrev(elem *udpPacket) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go
new file mode 100755
index 000000000..5cf9848cc
--- /dev/null
+++ b/pkg/tcpip/transport/udp/udp_state_autogen.go
@@ -0,0 +1,144 @@
+// automatically generated by stateify.
+
+package udp
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+func (x *udpPacket) beforeSave() {}
+func (x *udpPacket) save(m state.Map) {
+ x.beforeSave()
+ var data buffer.VectorisedView = x.saveData()
+ m.SaveValue("data", data)
+ m.Save("udpPacketEntry", &x.udpPacketEntry)
+ m.Save("senderAddress", &x.senderAddress)
+ m.Save("packetInfo", &x.packetInfo)
+ m.Save("timestamp", &x.timestamp)
+ m.Save("tos", &x.tos)
+}
+
+func (x *udpPacket) afterLoad() {}
+func (x *udpPacket) load(m state.Map) {
+ m.Load("udpPacketEntry", &x.udpPacketEntry)
+ m.Load("senderAddress", &x.senderAddress)
+ m.Load("packetInfo", &x.packetInfo)
+ m.Load("timestamp", &x.timestamp)
+ m.Load("tos", &x.tos)
+ m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) })
+}
+
+func (x *endpoint) save(m state.Map) {
+ x.beforeSave()
+ var rcvBufSizeMax int = x.saveRcvBufSizeMax()
+ m.SaveValue("rcvBufSizeMax", rcvBufSizeMax)
+ m.Save("TransportEndpointInfo", &x.TransportEndpointInfo)
+ m.Save("waiterQueue", &x.waiterQueue)
+ m.Save("uniqueID", &x.uniqueID)
+ m.Save("rcvReady", &x.rcvReady)
+ m.Save("rcvList", &x.rcvList)
+ m.Save("rcvBufSize", &x.rcvBufSize)
+ m.Save("rcvClosed", &x.rcvClosed)
+ m.Save("sndBufSize", &x.sndBufSize)
+ m.Save("state", &x.state)
+ m.Save("dstPort", &x.dstPort)
+ m.Save("v6only", &x.v6only)
+ m.Save("ttl", &x.ttl)
+ m.Save("multicastTTL", &x.multicastTTL)
+ m.Save("multicastAddr", &x.multicastAddr)
+ m.Save("multicastNICID", &x.multicastNICID)
+ m.Save("multicastLoop", &x.multicastLoop)
+ m.Save("reusePort", &x.reusePort)
+ m.Save("bindToDevice", &x.bindToDevice)
+ m.Save("broadcast", &x.broadcast)
+ m.Save("boundBindToDevice", &x.boundBindToDevice)
+ m.Save("boundPortFlags", &x.boundPortFlags)
+ m.Save("sendTOS", &x.sendTOS)
+ m.Save("receiveTOS", &x.receiveTOS)
+ m.Save("receiveTClass", &x.receiveTClass)
+ m.Save("receiveIPPacketInfo", &x.receiveIPPacketInfo)
+ m.Save("shutdownFlags", &x.shutdownFlags)
+ m.Save("multicastMemberships", &x.multicastMemberships)
+ m.Save("effectiveNetProtos", &x.effectiveNetProtos)
+}
+
+func (x *endpoint) load(m state.Map) {
+ m.Load("TransportEndpointInfo", &x.TransportEndpointInfo)
+ m.Load("waiterQueue", &x.waiterQueue)
+ m.Load("uniqueID", &x.uniqueID)
+ m.Load("rcvReady", &x.rcvReady)
+ m.Load("rcvList", &x.rcvList)
+ m.Load("rcvBufSize", &x.rcvBufSize)
+ m.Load("rcvClosed", &x.rcvClosed)
+ m.Load("sndBufSize", &x.sndBufSize)
+ m.Load("state", &x.state)
+ m.Load("dstPort", &x.dstPort)
+ m.Load("v6only", &x.v6only)
+ m.Load("ttl", &x.ttl)
+ m.Load("multicastTTL", &x.multicastTTL)
+ m.Load("multicastAddr", &x.multicastAddr)
+ m.Load("multicastNICID", &x.multicastNICID)
+ m.Load("multicastLoop", &x.multicastLoop)
+ m.Load("reusePort", &x.reusePort)
+ m.Load("bindToDevice", &x.bindToDevice)
+ m.Load("broadcast", &x.broadcast)
+ m.Load("boundBindToDevice", &x.boundBindToDevice)
+ m.Load("boundPortFlags", &x.boundPortFlags)
+ m.Load("sendTOS", &x.sendTOS)
+ m.Load("receiveTOS", &x.receiveTOS)
+ m.Load("receiveTClass", &x.receiveTClass)
+ m.Load("receiveIPPacketInfo", &x.receiveIPPacketInfo)
+ m.Load("shutdownFlags", &x.shutdownFlags)
+ m.Load("multicastMemberships", &x.multicastMemberships)
+ m.Load("effectiveNetProtos", &x.effectiveNetProtos)
+ m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) })
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *multicastMembership) beforeSave() {}
+func (x *multicastMembership) save(m state.Map) {
+ x.beforeSave()
+ m.Save("nicID", &x.nicID)
+ m.Save("multicastAddr", &x.multicastAddr)
+}
+
+func (x *multicastMembership) afterLoad() {}
+func (x *multicastMembership) load(m state.Map) {
+ m.Load("nicID", &x.nicID)
+ m.Load("multicastAddr", &x.multicastAddr)
+}
+
+func (x *udpPacketList) beforeSave() {}
+func (x *udpPacketList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *udpPacketList) afterLoad() {}
+func (x *udpPacketList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *udpPacketEntry) beforeSave() {}
+func (x *udpPacketEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *udpPacketEntry) afterLoad() {}
+func (x *udpPacketEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func init() {
+ state.Register("pkg/tcpip/transport/udp.udpPacket", (*udpPacket)(nil), state.Fns{Save: (*udpPacket).save, Load: (*udpPacket).load})
+ state.Register("pkg/tcpip/transport/udp.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load})
+ state.Register("pkg/tcpip/transport/udp.multicastMembership", (*multicastMembership)(nil), state.Fns{Save: (*multicastMembership).save, Load: (*multicastMembership).load})
+ state.Register("pkg/tcpip/transport/udp.udpPacketList", (*udpPacketList)(nil), state.Fns{Save: (*udpPacketList).save, Load: (*udpPacketList).load})
+ state.Register("pkg/tcpip/transport/udp.udpPacketEntry", (*udpPacketEntry)(nil), state.Fns{Save: (*udpPacketEntry).save, Load: (*udpPacketEntry).load})
+}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
deleted file mode 100644
index 34b7c2360..000000000
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ /dev/null
@@ -1,1802 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package udp_test
-
-import (
- "bytes"
- "context"
- "fmt"
- "math/rand"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/checker"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-// Addresses and ports used for testing. It is recommended that tests stick to
-// using these addresses as it allows using the testFlow helper.
-// Naming rules: 'stack*'' denotes local addresses and ports, while 'test*'
-// represents the remote endpoint.
-const (
- v4MappedAddrPrefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff"
- stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- stackV4MappedAddr = v4MappedAddrPrefix + stackAddr
- testV4MappedAddr = v4MappedAddrPrefix + testAddr
- multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr
- broadcastV4MappedAddr = v4MappedAddrPrefix + broadcastAddr
- v4MappedWildcardAddr = v4MappedAddrPrefix + "\x00\x00\x00\x00"
-
- stackAddr = "\x0a\x00\x00\x01"
- stackPort = 1234
- testAddr = "\x0a\x00\x00\x02"
- testPort = 4096
- multicastAddr = "\xe8\x2b\xd3\xea"
- multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
- broadcastAddr = header.IPv4Broadcast
- testTOS = 0x80
-
- // defaultMTU is the MTU, in bytes, used throughout the tests, except
- // where another value is explicitly used. It is chosen to match the MTU
- // of loopback interfaces on linux systems.
- defaultMTU = 65536
-)
-
-// header4Tuple stores the 4-tuple {src-IP, src-port, dst-IP, dst-port} used in
-// a packet header. These values are used to populate a header or verify one.
-// Note that because they are used in packet headers, the addresses are never in
-// a V4-mapped format.
-type header4Tuple struct {
- srcAddr tcpip.FullAddress
- dstAddr tcpip.FullAddress
-}
-
-// testFlow implements a helper type used for sending and receiving test
-// packets. A given test flow value defines 1) the socket endpoint used for the
-// test and 2) the type of packet send or received on the endpoint. E.g., a
-// multicastV6Only flow is a V6 multicast packet passing through a V6-only
-// endpoint. The type provides helper methods to characterize the flow (e.g.,
-// isV4) as well as return a proper header4Tuple for it.
-type testFlow int
-
-const (
- unicastV4 testFlow = iota // V4 unicast on a V4 socket
- unicastV4in6 // V4-mapped unicast on a V6-dual socket
- unicastV6 // V6 unicast on a V6 socket
- unicastV6Only // V6 unicast on a V6-only socket
- multicastV4 // V4 multicast on a V4 socket
- multicastV4in6 // V4-mapped multicast on a V6-dual socket
- multicastV6 // V6 multicast on a V6 socket
- multicastV6Only // V6 multicast on a V6-only socket
- broadcast // V4 broadcast on a V4 socket
- broadcastIn6 // V4-mapped broadcast on a V6-dual socket
-)
-
-func (flow testFlow) String() string {
- switch flow {
- case unicastV4:
- return "unicastV4"
- case unicastV6:
- return "unicastV6"
- case unicastV6Only:
- return "unicastV6Only"
- case unicastV4in6:
- return "unicastV4in6"
- case multicastV4:
- return "multicastV4"
- case multicastV6:
- return "multicastV6"
- case multicastV6Only:
- return "multicastV6Only"
- case multicastV4in6:
- return "multicastV4in6"
- case broadcast:
- return "broadcast"
- case broadcastIn6:
- return "broadcastIn6"
- default:
- return "unknown"
- }
-}
-
-// packetDirection explains if a flow is incoming (read) or outgoing (write).
-type packetDirection int
-
-const (
- incoming packetDirection = iota
- outgoing
-)
-
-// header4Tuple returns the header4Tuple for the given flow and direction. Note
-// that the tuple contains no mapped addresses as those only exist at the socket
-// level but not at the packet header level.
-func (flow testFlow) header4Tuple(d packetDirection) header4Tuple {
- var h header4Tuple
- if flow.isV4() {
- if d == outgoing {
- h = header4Tuple{
- srcAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort},
- dstAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort},
- }
- } else {
- h = header4Tuple{
- srcAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort},
- dstAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort},
- }
- }
- if flow.isMulticast() {
- h.dstAddr.Addr = multicastAddr
- } else if flow.isBroadcast() {
- h.dstAddr.Addr = broadcastAddr
- }
- } else { // IPv6
- if d == outgoing {
- h = header4Tuple{
- srcAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
- dstAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
- }
- } else {
- h = header4Tuple{
- srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
- dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
- }
- }
- if flow.isMulticast() {
- h.dstAddr.Addr = multicastV6Addr
- }
- }
- return h
-}
-
-func (flow testFlow) getMcastAddr() tcpip.Address {
- if flow.isV4() {
- return multicastAddr
- }
- return multicastV6Addr
-}
-
-// mapAddrIfApplicable converts the given V4 address into its V4-mapped version
-// if it is applicable to the flow.
-func (flow testFlow) mapAddrIfApplicable(v4Addr tcpip.Address) tcpip.Address {
- if flow.isMapped() {
- return v4MappedAddrPrefix + v4Addr
- }
- return v4Addr
-}
-
-// netProto returns the protocol number used for the network packet.
-func (flow testFlow) netProto() tcpip.NetworkProtocolNumber {
- if flow.isV4() {
- return ipv4.ProtocolNumber
- }
- return ipv6.ProtocolNumber
-}
-
-// sockProto returns the protocol number used when creating the socket
-// endpoint for this flow.
-func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber {
- switch flow {
- case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6:
- return ipv6.ProtocolNumber
- case unicastV4, multicastV4, broadcast:
- return ipv4.ProtocolNumber
- default:
- panic(fmt.Sprintf("invalid testFlow given: %d", flow))
- }
-}
-
-func (flow testFlow) checkerFn() func(*testing.T, []byte, ...checker.NetworkChecker) {
- if flow.isV4() {
- return checker.IPv4
- }
- return checker.IPv6
-}
-
-func (flow testFlow) isV6() bool { return !flow.isV4() }
-func (flow testFlow) isV4() bool {
- return flow.sockProto() == ipv4.ProtocolNumber || flow.isMapped()
-}
-
-func (flow testFlow) isV6Only() bool {
- switch flow {
- case unicastV6Only, multicastV6Only:
- return true
- case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6:
- return false
- default:
- panic(fmt.Sprintf("invalid testFlow given: %d", flow))
- }
-}
-
-func (flow testFlow) isMulticast() bool {
- switch flow {
- case multicastV4, multicastV4in6, multicastV6, multicastV6Only:
- return true
- case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6:
- return false
- default:
- panic(fmt.Sprintf("invalid testFlow given: %d", flow))
- }
-}
-
-func (flow testFlow) isBroadcast() bool {
- switch flow {
- case broadcast, broadcastIn6:
- return true
- case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only:
- return false
- default:
- panic(fmt.Sprintf("invalid testFlow given: %d", flow))
- }
-}
-
-func (flow testFlow) isMapped() bool {
- switch flow {
- case unicastV4in6, multicastV4in6, broadcastIn6:
- return true
- case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast:
- return false
- default:
- panic(fmt.Sprintf("invalid testFlow given: %d", flow))
- }
-}
-
-type testContext struct {
- t *testing.T
- linkEP *channel.Endpoint
- s *stack.Stack
-
- ep tcpip.Endpoint
- wq waiter.Queue
-}
-
-func newDualTestContext(t *testing.T, mtu uint32) *testContext {
- t.Helper()
- return newDualTestContextWithOptions(t, mtu, stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
- })
-}
-
-func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Options) *testContext {
- t.Helper()
-
- s := stack.New(options)
- ep := channel.New(256, mtu, "")
- wep := stack.LinkEndpoint(ep)
-
- if testing.Verbose() {
- wep = sniffer.New(ep)
- }
- if err := s.CreateNIC(1, wep); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
- }
-
- if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
- if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: 1,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: 1,
- },
- })
-
- return &testContext{
- t: t,
- s: s,
- linkEP: ep,
- }
-}
-
-func (c *testContext) cleanup() {
- if c.ep != nil {
- c.ep.Close()
- }
-}
-
-func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) {
- c.t.Helper()
-
- var err *tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq)
- if err != nil {
- c.t.Fatal("NewEndpoint failed: ", err)
- }
-}
-
-func (c *testContext) createEndpointForFlow(flow testFlow) {
- c.t.Helper()
-
- c.createEndpoint(flow.sockProto())
- if flow.isV6Only() {
- if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
- }
- } else if flow.isBroadcast() {
- if err := c.ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil {
- c.t.Fatal("SetSockOpt failed:", err)
- }
- }
-}
-
-// getPacketAndVerify reads a packet from the link endpoint and verifies the
-// header against expected values from the given test flow. In addition, it
-// calls any extra checker functions provided.
-func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte {
- c.t.Helper()
-
- ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
- p, ok := c.linkEP.ReadContext(ctx)
- if !ok {
- c.t.Fatalf("Packet wasn't written out")
- return nil
- }
-
- if p.Proto != flow.netProto() {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
- }
-
- hdr := p.Pkt.Header.View()
- b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
-
- h := flow.header4Tuple(outgoing)
- checkers = append(
- checkers,
- checker.SrcAddr(h.srcAddr.Addr),
- checker.DstAddr(h.dstAddr.Addr),
- checker.UDP(checker.DstPort(h.dstAddr.Port)),
- )
- flow.checkerFn()(c.t, b, checkers...)
- return b
-}
-
-// injectPacket creates a packet of the given flow and with the given payload,
-// and injects it into the link endpoint.
-func (c *testContext) injectPacket(flow testFlow, payload []byte) {
- c.t.Helper()
-
- h := flow.header4Tuple(incoming)
- if flow.isV4() {
- c.injectV4Packet(payload, &h, true /* valid */)
- } else {
- c.injectV6Packet(payload, &h, true /* valid */)
- }
-}
-
-// injectV6Packet creates a V6 test packet with the given payload and header
-// values, and injects it into the link endpoint. valid indicates if the
-// caller intends to inject a packet with a valid or an invalid UDP header.
-// We can invalidate the header by corrupting the UDP payload length.
-func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) {
- // Allocate a buffer for data and headers.
- buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
- payloadStart := len(buf) - len(payload)
- copy(buf[payloadStart:], payload)
-
- // Initialize the IP header.
- ip := header.IPv6(buf)
- ip.Encode(&header.IPv6Fields{
- TrafficClass: testTOS,
- PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
- NextHeader: uint8(udp.ProtocolNumber),
- HopLimit: 65,
- SrcAddr: h.srcAddr.Addr,
- DstAddr: h.dstAddr.Addr,
- })
-
- // Initialize the UDP header.
- u := header.UDP(buf[header.IPv6MinimumSize:])
- l := uint16(header.UDPMinimumSize + len(payload))
- if !valid {
- // Change the UDP payload length to corrupt the header
- // as requested by the caller.
- l++
- }
- u.Encode(&header.UDPFields{
- SrcPort: h.srcAddr.Port,
- DstPort: h.dstAddr.Port,
- Length: l,
- })
-
- // Calculate the UDP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u)))
-
- // Calculate the UDP checksum and set it.
- xsum = header.Checksum(payload, xsum)
- u.SetChecksum(^u.CalculateChecksum(xsum))
-
- // Inject packet.
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
- Data: buf.ToVectorisedView(),
- NetworkHeader: buffer.View(ip),
- TransportHeader: buffer.View(u),
- })
-}
-
-// injectV4Packet creates a V4 test packet with the given payload and header
-// values, and injects it into the link endpoint. valid indicates if the
-// caller intends to inject a packet with a valid or an invalid UDP header.
-// We can invalidate the header by corrupting the UDP payload length.
-func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) {
- // Allocate a buffer for data and headers.
- buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
- payloadStart := len(buf) - len(payload)
- copy(buf[payloadStart:], payload)
-
- // Initialize the IP header.
- ip := header.IPv4(buf)
- ip.Encode(&header.IPv4Fields{
- IHL: header.IPv4MinimumSize,
- TOS: testTOS,
- TotalLength: uint16(len(buf)),
- TTL: 65,
- Protocol: uint8(udp.ProtocolNumber),
- SrcAddr: h.srcAddr.Addr,
- DstAddr: h.dstAddr.Addr,
- })
- ip.SetChecksum(^ip.CalculateChecksum())
-
- // Initialize the UDP header.
- u := header.UDP(buf[header.IPv4MinimumSize:])
- u.Encode(&header.UDPFields{
- SrcPort: h.srcAddr.Port,
- DstPort: h.dstAddr.Port,
- Length: uint16(header.UDPMinimumSize + len(payload)),
- })
-
- // Calculate the UDP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u)))
-
- // Calculate the UDP checksum and set it.
- xsum = header.Checksum(payload, xsum)
- u.SetChecksum(^u.CalculateChecksum(xsum))
-
- // Inject packet.
-
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
- Data: buf.ToVectorisedView(),
- NetworkHeader: buffer.View(ip),
- TransportHeader: buffer.View(u),
- })
-}
-
-func newPayload() []byte {
- return newMinPayload(30)
-}
-
-func newMinPayload(minSize int) []byte {
- b := make([]byte, minSize+rand.Intn(100))
- for i := range b {
- b[i] = byte(rand.Intn(256))
- }
- return b
-}
-
-func TestBindToDeviceOption(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
-
- ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
- }
- defer ep.Close()
-
- opts := stack.NICOptions{Name: "my_device"}
- if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil {
- t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
- }
-
- // nicIDPtr is used instead of taking the address of NICID literals, which is
- // a compiler error.
- nicIDPtr := func(s tcpip.NICID) *tcpip.NICID {
- return &s
- }
-
- testActions := []struct {
- name string
- setBindToDevice *tcpip.NICID
- setBindToDeviceError *tcpip.Error
- getBindToDevice tcpip.BindToDeviceOption
- }{
- {"GetDefaultValue", nil, nil, 0},
- {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
- {"BindToExistent", nicIDPtr(321), nil, 321},
- {"UnbindToDevice", nicIDPtr(0), nil, 0},
- }
- for _, testAction := range testActions {
- t.Run(testAction.name, func(t *testing.T) {
- if testAction.setBindToDevice != nil {
- bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
- if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
- t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr)
- }
- }
- bindToDevice := tcpip.BindToDeviceOption(88888)
- if err := ep.GetSockOpt(&bindToDevice); err != nil {
- t.Errorf("GetSockOpt got %v, want %v", err, nil)
- }
- if got, want := bindToDevice, testAction.getBindToDevice; got != want {
- t.Errorf("bindToDevice got %d, want %d", got, want)
- }
- })
- }
-}
-
-// testReadInternal sends a packet of the given test flow into the stack by
-// injecting it into the link endpoint. It then attempts to read it from the
-// UDP endpoint and depending on if this was expected to succeed verifies its
-// correctness including any additional checker functions provided.
-func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) {
- c.t.Helper()
-
- payload := newPayload()
- c.injectPacket(flow, payload)
-
- // Try to receive the data.
- we, ch := waiter.NewChannelEntry(nil)
- c.wq.EventRegister(&we, waiter.EventIn)
- defer c.wq.EventUnregister(&we)
-
- // Take a snapshot of the stats to validate them at the end of the test.
- epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
-
- var addr tcpip.FullAddress
- v, cm, err := c.ep.Read(&addr)
- if err == tcpip.ErrWouldBlock {
- // Wait for data to become available.
- select {
- case <-ch:
- v, cm, err = c.ep.Read(&addr)
-
- case <-time.After(300 * time.Millisecond):
- if packetShouldBeDropped {
- return // expected to time out
- }
- c.t.Fatal("timed out waiting for data")
- }
- }
-
- if expectReadError && err != nil {
- c.checkEndpointReadStats(1, epstats, err)
- return
- }
-
- if err != nil {
- c.t.Fatal("Read failed:", err)
- }
-
- if packetShouldBeDropped {
- c.t.Fatalf("Read unexpectedly received data from %s", addr.Addr)
- }
-
- // Check the peer address.
- h := flow.header4Tuple(incoming)
- if addr.Addr != h.srcAddr.Addr {
- c.t.Fatalf("unexpected remote address: got %s, want %s", addr.Addr, h.srcAddr)
- }
-
- // Check the payload.
- if !bytes.Equal(payload, v) {
- c.t.Fatalf("bad payload: got %x, want %x", v, payload)
- }
-
- // Run any checkers against the ControlMessages.
- for _, f := range checkers {
- f(c.t, cm)
- }
-
- c.checkEndpointReadStats(1, epstats, err)
-}
-
-// testRead sends a packet of the given test flow into the stack by injecting it
-// into the link endpoint. It then reads it from the UDP endpoint and verifies
-// its correctness including any additional checker functions provided.
-func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) {
- c.t.Helper()
- testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...)
-}
-
-// testFailingRead sends a packet of the given test flow into the stack by
-// injecting it into the link endpoint. It then tries to read it from the UDP
-// endpoint and expects this to fail.
-func testFailingRead(c *testContext, flow testFlow, expectReadError bool) {
- c.t.Helper()
- testReadInternal(c, flow, true /* packetShouldBeDropped */, expectReadError)
-}
-
-func TestBindEphemeralPort(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- if err := c.ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
- }
-}
-
-func TestBindReservedPort(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
- }
-
- addr, err := c.ep.GetLocalAddress()
- if err != nil {
- t.Fatalf("GetLocalAddress failed: %v", err)
- }
-
- // We can't bind the address reserved by the connected endpoint above.
- {
- ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
- defer ep.Close()
- if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want {
- t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want)
- }
- }
-
- func() {
- ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
- defer ep.Close()
- // We can't bind ipv4-any on the port reserved by the connected endpoint
- // above, since the endpoint is dual-stack.
- if got, want := ep.Bind(tcpip.FullAddress{Port: addr.Port}), tcpip.ErrPortInUse; got != want {
- t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want)
- }
- // We can bind an ipv4 address on this port, though.
- if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
- }
- }()
-
- // Once the connected endpoint releases its port reservation, we are able to
- // bind ipv4-any once again.
- c.ep.Close()
- func() {
- ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
- defer ep.Close()
- if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
- }
- }()
-}
-
-func TestV4ReadOnV6(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV4in6)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
- }
-
- // Test acceptance.
- testRead(c, unicastV4in6)
-}
-
-func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV4in6)
-
- // Bind to v4 mapped wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
- }
-
- // Test acceptance.
- testRead(c, unicastV4in6)
-}
-
-func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV4in6)
-
- // Bind to local address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
- }
-
- // Test acceptance.
- testRead(c, unicastV4in6)
-}
-
-func TestV6ReadOnV6(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV6)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
- }
-
- // Test acceptance.
- testRead(c, unicastV6)
-}
-
-// TestV4ReadSelfSource checks that packets coming from a local IP address are
-// correctly dropped when handleLocal is true and not otherwise.
-func TestV4ReadSelfSource(t *testing.T) {
- for _, tt := range []struct {
- name string
- handleLocal bool
- wantErr *tcpip.Error
- wantInvalidSource uint64
- }{
- {"HandleLocal", false, nil, 0},
- {"NoHandleLocal", true, tcpip.ErrWouldBlock, 1},
- } {
- t.Run(tt.name, func(t *testing.T) {
- c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
- HandleLocal: tt.handleLocal,
- })
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV4)
-
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- t.Fatalf("Bind failed: %s", err)
- }
-
- payload := newPayload()
- h := unicastV4.header4Tuple(incoming)
- h.srcAddr = h.dstAddr
-
- c.injectV4Packet(payload, &h, true /* valid */)
-
- if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource {
- t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource)
- }
-
- if _, _, err := c.ep.Read(nil); err != tt.wantErr {
- t.Errorf("c.ep.Read() got error %v, want %v", err, tt.wantErr)
- }
- })
- }
-}
-
-func TestV4ReadOnV4(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV4)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
- }
-
- // Test acceptance.
- testRead(c, unicastV4)
-}
-
-// TestReadOnBoundToMulticast checks that an endpoint can bind to a multicast
-// address and receive data sent to that address.
-func TestReadOnBoundToMulticast(t *testing.T) {
- // FIXME(b/128189410): multicastV4in6 currently doesn't work as
- // AddMembershipOption doesn't handle V4in6 addresses.
- for _, flow := range []testFlow{multicastV4, multicastV6, multicastV6Only} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to multicast address.
- mcastAddr := flow.mapAddrIfApplicable(flow.getMcastAddr())
- if err := c.ep.Bind(tcpip.FullAddress{Addr: mcastAddr, Port: stackPort}); err != nil {
- c.t.Fatal("Bind failed:", err)
- }
-
- // Join multicast group.
- ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr}
- if err := c.ep.SetSockOpt(ifoptSet); err != nil {
- c.t.Fatal("SetSockOpt failed:", err)
- }
-
- // Check that we receive multicast packets but not unicast or broadcast
- // ones.
- testRead(c, flow)
- testFailingRead(c, broadcast, false /* expectReadError */)
- testFailingRead(c, unicastV4, false /* expectReadError */)
- })
- }
-}
-
-// TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast
-// address and can receive only broadcast data.
-func TestV4ReadOnBoundToBroadcast(t *testing.T) {
- for _, flow := range []testFlow{broadcast, broadcastIn6} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to broadcast address.
- bcastAddr := flow.mapAddrIfApplicable(broadcastAddr)
- if err := c.ep.Bind(tcpip.FullAddress{Addr: bcastAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- // Check that we receive broadcast packets but not unicast ones.
- testRead(c, flow)
- testFailingRead(c, unicastV4, false /* expectReadError */)
- })
- }
-}
-
-// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY
-// and receive broadcast and unicast data.
-func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) {
- for _, flow := range []testFlow{broadcast, broadcastIn6} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s (", err)
- }
-
- // Check that we receive both broadcast and unicast packets.
- testRead(c, flow)
- testRead(c, unicastV4)
- })
- }
-}
-
-// testFailingWrite sends a packet of the given test flow into the UDP endpoint
-// and verifies it fails with the provided error code.
-func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) {
- c.t.Helper()
- // Take a snapshot of the stats to validate them at the end of the test.
- epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
- h := flow.header4Tuple(outgoing)
- writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
-
- payload := buffer.View(newPayload())
- _, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
- })
- c.checkEndpointWriteStats(1, epstats, gotErr)
- if gotErr != wantErr {
- c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr)
- }
-}
-
-// testWrite sends a packet of the given test flow from the UDP endpoint to the
-// flow's destination address:port. It then receives it from the link endpoint
-// and verifies its correctness including any additional checker functions
-// provided.
-func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
- c.t.Helper()
- return testWriteInternal(c, flow, true, checkers...)
-}
-
-// testWriteWithoutDestination sends a packet of the given test flow from the
-// UDP endpoint without giving a destination address:port. It then receives it
-// from the link endpoint and verifies its correctness including any additional
-// checker functions provided.
-func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
- c.t.Helper()
- return testWriteInternal(c, flow, false, checkers...)
-}
-
-func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 {
- c.t.Helper()
- // Take a snapshot of the stats to validate them at the end of the test.
- epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
-
- writeOpts := tcpip.WriteOptions{}
- if setDest {
- h := flow.header4Tuple(outgoing)
- writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
- writeOpts = tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
- }
- }
- payload := buffer.View(newPayload())
- n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
- if err != nil {
- c.t.Fatalf("Write failed: %v", err)
- }
- if n != int64(len(payload)) {
- c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
- }
- c.checkEndpointWriteStats(1, epstats, err)
- // Received the packet and check the payload.
- b := c.getPacketAndVerify(flow, checkers...)
- var udp header.UDP
- if flow.isV4() {
- udp = header.UDP(header.IPv4(b).Payload())
- } else {
- udp = header.UDP(header.IPv6(b).Payload())
- }
- if !bytes.Equal(payload, udp.Payload()) {
- c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
- }
-
- return udp.SourcePort()
-}
-
-func testDualWrite(c *testContext) uint16 {
- c.t.Helper()
-
- v4Port := testWrite(c, unicastV4in6)
- v6Port := testWrite(c, unicastV6)
- if v4Port != v6Port {
- c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port)
- }
-
- return v4Port
-}
-
-func TestDualWriteUnbound(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- testDualWrite(c)
-}
-
-func TestDualWriteBoundToWildcard(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
- }
-
- p := testDualWrite(c)
- if p != stackPort {
- c.t.Fatalf("Bad port: got %v, want %v", p, stackPort)
- }
-}
-
-func TestDualWriteConnectedToV6(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Connect to v6 address.
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
- }
-
- testWrite(c, unicastV6)
-
- // Write to V4 mapped address.
- testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable)
- const want = 1
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want {
- c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want)
- }
-}
-
-func TestDualWriteConnectedToV4Mapped(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Connect to v4 mapped address.
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
- }
-
- testWrite(c, unicastV4in6)
-
- // Write to v6 address.
- testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState)
-}
-
-func TestV4WriteOnV6Only(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(unicastV6Only)
-
- // Write to V4 mapped address.
- testFailingWrite(c, unicastV4in6, tcpip.ErrNoRoute)
-}
-
-func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Bind to v4 mapped address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
- }
-
- // Write to v6 address.
- testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState)
-}
-
-func TestV6WriteOnConnected(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Connect to v6 address.
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
- }
-
- testWriteWithoutDestination(c, unicastV6)
-}
-
-func TestV4WriteOnConnected(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Connect to v4 mapped address.
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
- }
-
- testWriteWithoutDestination(c, unicastV4)
-}
-
-// TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket
-// that is bound to a V4 multicast address.
-func TestWriteOnBoundToV4Multicast(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
- t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to V4 mcast address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastAddr, Port: stackPort}); err != nil {
- c.t.Fatal("Bind failed:", err)
- }
-
- testWrite(c, flow)
- })
- }
-}
-
-// TestWriteOnBoundToV4MappedMulticast checks that we can send packets out of a
-// socket that is bound to a V4-mapped multicast address.
-func TestWriteOnBoundToV4MappedMulticast(t *testing.T) {
- for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} {
- t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to V4Mapped mcast address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- testWrite(c, flow)
- })
- }
-}
-
-// TestWriteOnBoundToV6Multicast checks that we can send packets out of a
-// socket that is bound to a V6 multicast address.
-func TestWriteOnBoundToV6Multicast(t *testing.T) {
- for _, flow := range []testFlow{unicastV6, multicastV6} {
- t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to V6 mcast address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- testWrite(c, flow)
- })
- }
-}
-
-// TestWriteOnBoundToV6Multicast checks that we can send packets out of a
-// V6-only socket that is bound to a V6 multicast address.
-func TestWriteOnBoundToV6OnlyMulticast(t *testing.T) {
- for _, flow := range []testFlow{unicastV6Only, multicastV6Only} {
- t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to V6 mcast address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- testWrite(c, flow)
- })
- }
-}
-
-// TestWriteOnBoundToBroadcast checks that we can send packets out of a
-// socket that is bound to the broadcast address.
-func TestWriteOnBoundToBroadcast(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
- t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to V4 broadcast address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastAddr, Port: stackPort}); err != nil {
- c.t.Fatal("Bind failed:", err)
- }
-
- testWrite(c, flow)
- })
- }
-}
-
-// TestWriteOnBoundToV4MappedBroadcast checks that we can send packets out of a
-// socket that is bound to the V4-mapped broadcast address.
-func TestWriteOnBoundToV4MappedBroadcast(t *testing.T) {
- for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} {
- t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- // Bind to V4Mapped mcast address.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
-
- testWrite(c, flow)
- })
- }
-}
-
-func TestReadIncrementsPacketsReceived(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- // Create IPv4 UDP endpoint
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
- }
-
- testRead(c, unicastV4)
-
- var want uint64 = 1
- if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want {
- c.t.Fatalf("Read did not increment PacketsReceived: got %v, want %v", got, want)
- }
-}
-
-func TestWriteIncrementsPacketsSent(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- testDualWrite(c)
-
- var want uint64 = 2
- if got := c.s.Stats().UDP.PacketsSent.Value(); got != want {
- c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want)
- }
-}
-
-func TestTTL(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- const multicastTTL = 42
- if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
- }
-
- var wantTTL uint8
- if flow.isMulticast() {
- wantTTL = multicastTTL
- } else {
- var p stack.NetworkProtocol
- if flow.isV4() {
- p = ipv4.NewProtocol()
- } else {
- p = ipv6.NewProtocol()
- }
- ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil, stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
- }))
- if err != nil {
- t.Fatal(err)
- }
- wantTTL = ep.DefaultTTL()
- ep.Close()
- }
-
- testWrite(c, flow, checker.TTL(wantTTL))
- })
- }
-}
-
-func TestSetTTL(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} {
- t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- if err := c.ep.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
- }
-
- var p stack.NetworkProtocol
- if flow.isV4() {
- p = ipv4.NewProtocol()
- } else {
- p = ipv6.NewProtocol()
- }
- ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil, stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
- }))
- if err != nil {
- t.Fatal(err)
- }
- ep.Close()
-
- testWrite(c, flow, checker.TTL(wantTTL))
- })
- }
- })
- }
-}
-
-func TestSetTOS(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- const tos = testTOS
- var v tcpip.IPv4TOSOption
- if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
- }
- // Test for expected default value.
- if v != 0 {
- c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0)
- }
-
- if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil {
- c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv4TOSOption(tos), err)
- }
-
- if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
- }
-
- if want := tcpip.IPv4TOSOption(tos); v != want {
- c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want)
- }
-
- testWrite(c, flow, checker.TOS(tos, 0))
- })
- }
-}
-
-func TestSetTClass(t *testing.T) {
- for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
-
- const tClass = testTOS
- var v tcpip.IPv6TrafficClassOption
- if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
- }
- // Test for expected default value.
- if v != 0 {
- c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0)
- }
-
- if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tClass)); err != nil {
- c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv6TrafficClassOption(tClass), err)
- }
-
- if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
- }
-
- if want := tcpip.IPv6TrafficClassOption(tClass); v != want {
- c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want)
- }
-
- // The header getter for TClass is called TOS, so use that checker.
- testWrite(c, flow, checker.TOS(tClass, 0))
- })
- }
-}
-
-func TestReceiveTosTClass(t *testing.T) {
- testCases := []struct {
- name string
- getReceiveOption tcpip.SockOptBool
- tests []testFlow
- }{
- {"ReceiveTosOption", tcpip.ReceiveTOSOption, []testFlow{unicastV4, broadcast}},
- {"ReceiveTClassOption", tcpip.ReceiveTClassOption, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}},
- }
- for _, testCase := range testCases {
- for _, flow := range testCase.tests {
- t.Run(fmt.Sprintf("%s:flow:%s", testCase.name, flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpointForFlow(flow)
- option := testCase.getReceiveOption
- name := testCase.name
-
- // Verify that setting and reading the option works.
- v, err := c.ep.GetSockOptBool(option)
- if err != nil {
- c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err)
- }
- // Test for expected default value.
- if v != false {
- c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false)
- }
-
- want := true
- if err := c.ep.SetSockOptBool(option, want); err != nil {
- c.t.Fatalf("SetSockOptBool(%s, %t) failed: %s", name, want, err)
- }
-
- got, err := c.ep.GetSockOptBool(option)
- if err != nil {
- c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err)
- }
-
- if got != want {
- c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want)
- }
-
- // Verify that the correct received TOS or TClass is handed through as
- // ancillary data to the ControlMessages struct.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %s", err)
- }
- switch option {
- case tcpip.ReceiveTClassOption:
- testRead(c, flow, checker.ReceiveTClass(testTOS))
- case tcpip.ReceiveTOSOption:
- testRead(c, flow, checker.ReceiveTOS(testTOS))
- default:
- t.Fatalf("unknown test variant: %s", name)
- }
- })
- }
- }
-}
-
-func TestMulticastInterfaceOption(t *testing.T) {
- for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- for _, bindTyp := range []string{"bound", "unbound"} {
- t.Run(bindTyp, func(t *testing.T) {
- for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} {
- t.Run(optTyp, func(t *testing.T) {
- h := flow.header4Tuple(outgoing)
- mcastAddr := h.dstAddr.Addr
- localIfAddr := h.srcAddr.Addr
-
- var ifoptSet tcpip.MulticastInterfaceOption
- switch optTyp {
- case "use local-addr":
- ifoptSet.InterfaceAddr = localIfAddr
- case "use NICID":
- ifoptSet.NIC = 1
- case "use local-addr and NIC":
- ifoptSet.InterfaceAddr = localIfAddr
- ifoptSet.NIC = 1
- default:
- t.Fatal("unknown test variant")
- }
-
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(flow.sockProto())
-
- if bindTyp == "bound" {
- // Bind the socket by connecting to the multicast address.
- // This may have an influence on how the multicast interface
- // is set.
- addr := tcpip.FullAddress{
- Addr: flow.mapAddrIfApplicable(mcastAddr),
- Port: stackPort,
- }
- if err := c.ep.Connect(addr); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
- }
- }
-
- if err := c.ep.SetSockOpt(ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
- }
-
- // Verify multicast interface addr and NIC were set correctly.
- // Note that NIC must be 1 since this is our outgoing interface.
- ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}
- var ifoptGot tcpip.MulticastInterfaceOption
- if err := c.ep.GetSockOpt(&ifoptGot); err != nil {
- c.t.Fatalf("GetSockOpt failed: %v", err)
- }
- if ifoptGot != ifoptWant {
- c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant)
- }
- })
- }
- })
- }
- })
- }
-}
-
-// TestV4UnknownDestination verifies that we generate an ICMPv4 Destination
-// Unreachable message when a udp datagram is received on ports for which there
-// is no bound udp socket.
-func TestV4UnknownDestination(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- testCases := []struct {
- flow testFlow
- icmpRequired bool
- // largePayload if true, will result in a payload large enough
- // so that the final generated IPv4 packet is larger than
- // header.IPv4MinimumProcessableDatagramSize.
- largePayload bool
- }{
- {unicastV4, true, false},
- {unicastV4, true, true},
- {multicastV4, false, false},
- {multicastV4, false, true},
- {broadcast, false, false},
- {broadcast, false, true},
- }
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) {
- payload := newPayload()
- if tc.largePayload {
- payload = newMinPayload(576)
- }
- c.injectPacket(tc.flow, payload)
- if !tc.icmpRequired {
- ctx, _ := context.WithTimeout(context.Background(), time.Second)
- if p, ok := c.linkEP.ReadContext(ctx); ok {
- t.Fatalf("unexpected packet received: %+v", p)
- }
- return
- }
-
- // ICMP required.
- ctx, _ := context.WithTimeout(context.Background(), time.Second)
- p, ok := c.linkEP.ReadContext(ctx)
- if !ok {
- t.Fatalf("packet wasn't written out")
- return
- }
-
- var pkt []byte
- pkt = append(pkt, p.Pkt.Header.View()...)
- pkt = append(pkt, p.Pkt.Data.ToView()...)
- if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
- t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
- }
-
- hdr := header.IPv4(pkt)
- checker.IPv4(t, hdr, checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4DstUnreachable),
- checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
-
- icmpPkt := header.ICMPv4(hdr.Payload())
- payloadIPHeader := header.IPv4(icmpPkt.Payload())
- wantLen := len(payload)
- if tc.largePayload {
- wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize
- }
-
- // In case of large payloads the IP packet may be truncated. Update
- // the length field before retrieving the udp datagram payload.
- payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize))
-
- origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram.Payload()), wantLen; got != want {
- t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
- }
- if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
- t.Fatalf("unexpected payload got: %d, want: %d", got, want)
- }
- })
- }
-}
-
-// TestV6UnknownDestination verifies that we generate an ICMPv6 Destination
-// Unreachable message when a udp datagram is received on ports for which there
-// is no bound udp socket.
-func TestV6UnknownDestination(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- testCases := []struct {
- flow testFlow
- icmpRequired bool
- // largePayload if true will result in a payload large enough to
- // create an IPv6 packet > header.IPv6MinimumMTU bytes.
- largePayload bool
- }{
- {unicastV6, true, false},
- {unicastV6, true, true},
- {multicastV6, false, false},
- {multicastV6, false, true},
- }
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) {
- payload := newPayload()
- if tc.largePayload {
- payload = newMinPayload(1280)
- }
- c.injectPacket(tc.flow, payload)
- if !tc.icmpRequired {
- ctx, _ := context.WithTimeout(context.Background(), time.Second)
- if p, ok := c.linkEP.ReadContext(ctx); ok {
- t.Fatalf("unexpected packet received: %+v", p)
- }
- return
- }
-
- // ICMP required.
- ctx, _ := context.WithTimeout(context.Background(), time.Second)
- p, ok := c.linkEP.ReadContext(ctx)
- if !ok {
- t.Fatalf("packet wasn't written out")
- return
- }
-
- var pkt []byte
- pkt = append(pkt, p.Pkt.Header.View()...)
- pkt = append(pkt, p.Pkt.Data.ToView()...)
- if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
- t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
- }
-
- hdr := header.IPv6(pkt)
- checker.IPv6(t, hdr, checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6DstUnreachable),
- checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
-
- icmpPkt := header.ICMPv6(hdr.Payload())
- payloadIPHeader := header.IPv6(icmpPkt.Payload())
- wantLen := len(payload)
- if tc.largePayload {
- wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
- }
- // In case of large payloads the IP packet may be truncated. Update
- // the length field before retrieving the udp datagram payload.
- payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
-
- origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram.Payload()), wantLen; got != want {
- t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
- }
- if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
- t.Fatalf("unexpected payload got: %v, want: %v", got, want)
- }
- })
- }
-}
-
-// TestIncrementMalformedPacketsReceived verifies if the malformed received
-// global and endpoint stats get incremented.
-func TestIncrementMalformedPacketsReceived(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
- }
-
- payload := newPayload()
- c.t.Helper()
- h := unicastV6.header4Tuple(incoming)
- c.injectV6Packet(payload, &h, false /* !valid */)
-
- var want uint64 = 1
- if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want)
- }
-}
-
-// TestShutdownRead verifies endpoint read shutdown and error
-// stats increment on packet receive.
-func TestShutdownRead(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- // Bind to wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
- }
-
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
- }
-
- if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
- }
-
- testFailingRead(c, unicastV6, true /* expectReadError */)
-
- var want uint64 = 1
- if got := c.s.Stats().UDP.ReceiveBufferErrors.Value(); got != want {
- t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want)
- }
- if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want)
- }
-}
-
-// TestShutdownWrite verifies endpoint write shutdown and error
-// stats increment on packet write.
-func TestShutdownWrite(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
- }
-
- if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
- }
-
- testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend)
-}
-
-func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) {
- got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
- switch err {
- case nil:
- want.PacketsSent.IncrementBy(incr)
- case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue:
- want.WriteErrors.InvalidArgs.IncrementBy(incr)
- case tcpip.ErrClosedForSend:
- want.WriteErrors.WriteClosed.IncrementBy(incr)
- case tcpip.ErrInvalidEndpointState:
- want.WriteErrors.InvalidEndpointState.IncrementBy(incr)
- case tcpip.ErrNoLinkAddress:
- want.SendErrors.NoLinkAddr.IncrementBy(incr)
- case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
- want.SendErrors.NoRoute.IncrementBy(incr)
- default:
- want.SendErrors.SendToNetworkFailed.IncrementBy(incr)
- }
- if got != want {
- c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
- }
-}
-
-func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) {
- got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
- switch err {
- case nil, tcpip.ErrWouldBlock:
- case tcpip.ErrClosedForReceive:
- want.ReadErrors.ReadClosed.IncrementBy(incr)
- default:
- c.t.Errorf("Endpoint error missing stats update err %v", err)
- }
- if got != want {
- c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
- }
-}