summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.bazelignore1
-rw-r--r--.devcontainer.json9
-rw-r--r--.vscode/tasks.json31
-rw-r--r--WORKSPACE15
-rw-r--r--g3doc/user_guide/FAQ.md4
-rw-r--r--go.mod23
-rw-r--r--go.sum10
-rw-r--r--images/basic/integrationtest/Dockerfile.x86_646
-rw-r--r--images/basic/integrationtest/test_sticky.c96
-rw-r--r--images/syzkaller/Dockerfile2
-rw-r--r--images/syzkaller/README.md55
-rw-r--r--nogo.yaml5
-rw-r--r--pkg/abi/linux/BUILD1
-rw-r--r--pkg/abi/linux/netfilter.go73
-rw-r--r--pkg/abi/linux/netfilter_ipv6.go67
-rw-r--r--pkg/sentry/fs/host/file.go13
-rw-r--r--pkg/sentry/fs/host/inode.go7
-rw-r--r--pkg/sentry/fs/host/socket.go16
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go8
-rw-r--r--pkg/sentry/fsimpl/host/socket.go18
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go6
-rw-r--r--pkg/sentry/fsimpl/kernfs/synthetic_directory.go11
-rw-r--r--pkg/sentry/fsimpl/overlay/filesystem.go6
-rw-r--r--pkg/sentry/fsimpl/overlay/overlay.go10
-rw-r--r--pkg/sentry/fsimpl/tmpfs/directory.go8
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go3
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go8
-rw-r--r--pkg/sentry/kernel/task_exit.go6
-rw-r--r--pkg/sentry/loader/elf.go6
-rw-r--r--pkg/sentry/platform/ptrace/filters.go2
-rw-r--r--pkg/sentry/socket/hostinet/socket_vfs2.go3
-rw-r--r--pkg/sentry/socket/netstack/netstack.go12
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go3
-rw-r--r--pkg/sentry/socket/unix/BUILD1
-rw-r--r--pkg/sentry/socket/unix/transport/BUILD2
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go30
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned_state.go5
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go13
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless_state.go20
-rw-r--r--pkg/sentry/socket/unix/transport/queue.go9
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go77
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go3
-rw-r--r--pkg/sentry/syscalls/linux/BUILD1
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go3
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/BUILD1
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/socket.go3
-rw-r--r--pkg/sentry/vfs/file_description_impl_util.go2
-rw-r--r--pkg/sentry/vfs/permissions.go6
-rw-r--r--pkg/tcpip/link/pipe/pipe.go44
-rw-r--r--pkg/tcpip/network/BUILD1
-rw-r--r--pkg/tcpip/network/arp/BUILD4
-rw-r--r--pkg/tcpip/network/arp/arp.go86
-rw-r--r--pkg/tcpip/network/arp/arp_test.go151
-rw-r--r--pkg/tcpip/network/arp/stats_test.go2
-rw-r--r--pkg/tcpip/network/internal/fragmentation/BUILD (renamed from pkg/tcpip/network/fragmentation/BUILD)7
-rw-r--r--pkg/tcpip/network/internal/fragmentation/fragmentation.go (renamed from pkg/tcpip/network/fragmentation/fragmentation.go)0
-rw-r--r--pkg/tcpip/network/internal/fragmentation/fragmentation_test.go (renamed from pkg/tcpip/network/fragmentation/fragmentation_test.go)2
-rw-r--r--pkg/tcpip/network/internal/fragmentation/reassembler.go (renamed from pkg/tcpip/network/fragmentation/reassembler.go)0
-rw-r--r--pkg/tcpip/network/internal/fragmentation/reassembler_test.go (renamed from pkg/tcpip/network/fragmentation/reassembler_test.go)0
-rw-r--r--pkg/tcpip/network/internal/ip/BUILD (renamed from pkg/tcpip/network/ip/BUILD)16
-rw-r--r--pkg/tcpip/network/internal/ip/duplicate_address_detection.go172
-rw-r--r--pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go279
-rw-r--r--pkg/tcpip/network/internal/ip/generic_multicast_protocol.go (renamed from pkg/tcpip/network/ip/generic_multicast_protocol.go)0
-rw-r--r--pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go (renamed from pkg/tcpip/network/ip/generic_multicast_protocol_test.go)11
-rw-r--r--pkg/tcpip/network/internal/ip/stats.go (renamed from pkg/tcpip/network/ip/stats.go)0
-rw-r--r--pkg/tcpip/network/internal/testutil/BUILD (renamed from pkg/tcpip/network/testutil/BUILD)2
-rw-r--r--pkg/tcpip/network/internal/testutil/testutil.go (renamed from pkg/tcpip/network/testutil/testutil.go)0
-rw-r--r--pkg/tcpip/network/internal/testutil/testutil_unsafe.go (renamed from pkg/tcpip/network/testutil/testutil_unsafe.go)0
-rw-r--r--pkg/tcpip/network/ip_test.go22
-rw-r--r--pkg/tcpip/network/ipv4/BUILD8
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go67
-rw-r--r--pkg/tcpip/network/ipv4/igmp.go2
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go285
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go129
-rw-r--r--pkg/tcpip/network/ipv4/stats.go2
-rw-r--r--pkg/tcpip/network/ipv4/stats_test.go2
-rw-r--r--pkg/tcpip/network/ipv6/BUILD6
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go55
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go462
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go231
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go2
-rw-r--r--pkg/tcpip/network/ipv6/mld.go2
-rw-r--r--pkg/tcpip/network/ipv6/mld_test.go2
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go417
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go1087
-rw-r--r--pkg/tcpip/network/ipv6/stats.go2
-rw-r--r--pkg/tcpip/socketops.go73
-rw-r--r--pkg/tcpip/stack/BUILD15
-rw-r--r--pkg/tcpip/stack/conntrack.go6
-rw-r--r--pkg/tcpip/stack/forwarding_test.go644
-rw-r--r--pkg/tcpip/stack/iptables.go2
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go359
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go291
-rw-r--r--pkg/tcpip/stack/ndp_test.go881
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go156
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go26
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go398
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go575
-rw-r--r--pkg/tcpip/stack/neighborstate_string.go7
-rw-r--r--pkg/tcpip/stack/nic.go258
-rw-r--r--pkg/tcpip/stack/nic_test.go2
-rw-r--r--pkg/tcpip/stack/nud_test.go9
-rw-r--r--pkg/tcpip/stack/pending_packets.go4
-rw-r--r--pkg/tcpip/stack/registration.go104
-rw-r--r--pkg/tcpip/stack/route.go20
-rw-r--r--pkg/tcpip/stack/stack.go105
-rw-r--r--pkg/tcpip/stack/stack_test.go15
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go4
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go310
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go290
-rw-r--r--pkg/tcpip/transport/tcp/connect.go9
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go6
-rw-r--r--pkg/tcpip/transport/tcp/rack.go115
-rw-r--r--pkg/tcpip/transport/tcp/rack_state.go5
-rw-r--r--pkg/tcpip/transport/tcp/segment.go3
-rw-r--r--pkg/tcpip/transport/tcp/snd.go68
-rw-r--r--pkg/tcpip/transport/tcp/snd_state.go2
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go58
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go7
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go7
-rw-r--r--runsc/boot/filter/config.go9
-rw-r--r--runsc/container/container_test.go22
-rw-r--r--runsc/fsgofer/filter/config.go9
-rw-r--r--test/e2e/integration_test.go53
-rw-r--r--test/packetimpact/tests/tcp_rack_test.go37
-rw-r--r--test/packetimpact/tests/udp_send_recv_dgram_test.go1
-rw-r--r--test/syscalls/BUILD4
-rw-r--r--test/syscalls/linux/BUILD15
-rw-r--r--test/syscalls/linux/exec_binary.cc28
-rw-r--r--test/syscalls/linux/open_create.cc4
-rw-r--r--test/syscalls/linux/processes.cc90
-rw-r--r--test/syscalls/linux/rename.cc33
-rw-r--r--test/syscalls/linux/sendfile.cc23
-rw-r--r--test/syscalls/linux/socket_generic.cc2
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc6
-rw-r--r--test/syscalls/linux/socket_unix_dgram.cc35
-rw-r--r--test/syscalls/linux/socket_unix_seqpacket.cc35
-rw-r--r--test/syscalls/linux/socket_unix_stream.cc80
-rw-r--r--test/syscalls/linux/udp_socket.cc81
-rw-r--r--tools/bazel_gazelle_noise.patch24
-rw-r--r--tools/go_marshal/README.md12
-rw-r--r--tools/go_marshal/gomarshal/generator.go31
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_struct.go339
-rw-r--r--tools/go_marshal/gomarshal/generator_tests.go12
-rw-r--r--tools/go_marshal/test/BUILD6
-rw-r--r--tools/go_marshal/test/marshal_test.go19
-rw-r--r--tools/go_marshal/test/test.go35
147 files changed, 5644 insertions, 4515 deletions
diff --git a/.bazelignore b/.bazelignore
new file mode 100644
index 000000000..511b10433
--- /dev/null
+++ b/.bazelignore
@@ -0,0 +1 @@
+bazel-gvisor
diff --git a/.devcontainer.json b/.devcontainer.json
new file mode 100644
index 000000000..6f7fe4bf8
--- /dev/null
+++ b/.devcontainer.json
@@ -0,0 +1,9 @@
+{
+ "dockerFile": "images/default/Dockerfile",
+ "overrideCommand": true,
+ "mounts": ["source=/var/run/docker.sock,target=/var/run/docker-host.sock,type=bind"],
+ "runArgs": ["--cap-add=SYS_PTRACE", "--security-opt", "seccomp=unconfined"],
+ "extensions": [
+ "bazelbuild.vscode-bazel"
+ ]
+}
diff --git a/.vscode/tasks.json b/.vscode/tasks.json
new file mode 100644
index 000000000..42a018434
--- /dev/null
+++ b/.vscode/tasks.json
@@ -0,0 +1,31 @@
+{
+ "version": "2.0.0",
+ "tasks": [
+ {
+ "label": "Build",
+ "type": "shell",
+ "command": "bazel build //...",
+ "group": {
+ "kind": "build",
+ "isDefault": true
+ },
+ "presentation": {
+ "reveal": "always",
+ "panel": "new"
+ }
+ },
+ {
+ "label": "Test",
+ "type": "shell",
+ "command": "bazel test //...",
+ "group": {
+ "kind": "test",
+ "isDefault": true
+ },
+ "presentation": {
+ "reveal": "always",
+ "panel": "new"
+ }
+ }
+ ]
+}
diff --git a/WORKSPACE b/WORKSPACE
index a6b3144e4..4ee93a670 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -57,9 +57,6 @@ http_archive(
# This is actually a no-op with the hacky patch above, but should
# slightly future proof this mechanism.
"//tools:bazel_gazelle_generate.patch",
- # False positive output complaining about Go logrus versions spam the
- # logs. Strip this message in this case. Does not affect control flow.
- "//tools:bazel_gazelle_noise.patch",
],
sha256 = "222e49f034ca7a1d1231422cdb67066b885819885c356673cb1f72f748a3c9d4",
urls = [
@@ -377,8 +374,8 @@ go_repository(
name = "org_golang_google_grpc",
build_file_proto_mode = "disable",
importpath = "google.golang.org/grpc",
- sum = "h1:cb+I9RwgcErlwAuOVnGhJ2d3YrcdwGXw+RPArsTWot4=",
- version = "v1.36.0-dev.0.20210122012134-2c42474aca0c",
+ sum = "h1:iGG0ZwQMaxJT/qsL0nnzZCg+4aiWHuQy3MytzLieAjo=",
+ version = "v1.36.0-dev.0.20210208035533-9280052d3665",
)
go_repository(
@@ -517,8 +514,8 @@ go_repository(
go_repository(
name = "com_github_konsorten_go_windows_terminal_sequences",
importpath = "github.com/konsorten/go-windows-terminal-sequences",
- sum = "h1:CE8S1cTafDpPvMhIxNJKvHsGVBgn1xWYf1NbHQhywc8=",
- version = "v1.0.3",
+ sum = "h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=",
+ version = "v1.0.2",
)
go_repository(
@@ -637,8 +634,8 @@ go_repository(
go_repository(
name = "com_github_containerd_continuity",
importpath = "github.com/containerd/continuity",
- sum = "h1:6ejg6Lkk8dskcM7wQ28gONkukbQkM4qpj4RnYbpFzrI=",
- version = "v0.0.0-20201208142359-180525291bb7",
+ sum = "h1:6JKvHHt396/qabvMhnhUZvWaHZzfVfldxE60TK8YLhg=",
+ version = "v0.0.0-20210208174643-50096c924a4e",
)
go_repository(
diff --git a/g3doc/user_guide/FAQ.md b/g3doc/user_guide/FAQ.md
index 8e5721ad1..26c836ddf 100644
--- a/g3doc/user_guide/FAQ.md
+++ b/g3doc/user_guide/FAQ.md
@@ -107,11 +107,11 @@ kubeadm to create your cluster please check if Docker is also installed on that
system. Kubeadm prefers using Docker if both Docker and containerd are
installed.
-Please recreate your cluster and set the `--cni-socket` option on kubeadm
+Please recreate your cluster and set the `--cri-socket` option on kubeadm
commands. For example:
```bash
-kubeadm init --cni-socket=/var/run/containerd/containerd.sock ...
+kubeadm init --cri-socket=/var/run/containerd/containerd.sock ...
```
To fix an existing cluster edit the `/var/lib/kubelet/kubeadm-flags.env` file
diff --git a/go.mod b/go.mod
index 942e67b44..0774d2930 100644
--- a/go.mod
+++ b/go.mod
@@ -1,18 +1,17 @@
module gvisor.dev/gvisor
-go 1.15
-
-replace github.com/Sirupsen/logrus => github.com/sirupsen/logrus v1.6.0
+go 1.16
require (
cloud.google.com/go v0.75.0 // indirect
+ github.com/BurntSushi/toml v0.3.1 // indirect
github.com/Microsoft/go-winio v0.4.16 // indirect
github.com/Microsoft/hcsshim v0.8.14 // indirect
github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422 // indirect
github.com/containerd/cgroups v0.0.0-20201119153540-4cbc285b3327 // indirect
github.com/containerd/console v1.0.1 // indirect
github.com/containerd/containerd v1.3.9 // indirect
- github.com/containerd/continuity v0.0.0-20201208142359-180525291bb7 // indirect
+ github.com/containerd/continuity v0.0.0-20210208174643-50096c924a4e // indirect
github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00 // indirect
github.com/containerd/go-runc v0.0.0-20200220073739-7016d3ce2328 // indirect
github.com/containerd/ttrpc v1.0.2 // indirect
@@ -23,6 +22,10 @@ require (
github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c // indirect
github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 // indirect
github.com/gogo/googleapis v1.4.0 // indirect
+ github.com/gogo/protobuf v1.3.1 // indirect
+ github.com/golang/mock v1.4.4 // indirect
+ github.com/google/btree v1.0.0 // indirect
+ github.com/google/go-cmp v0.5.4 // indirect
github.com/google/go-github/v32 v32.1.0 // indirect
github.com/google/pprof v0.0.0-20210115211752-39141e76b647 // indirect
github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8 // indirect
@@ -32,15 +35,25 @@ require (
github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9 // indirect
github.com/opencontainers/image-spec v1.0.1 // indirect
github.com/opencontainers/runc v0.1.1 // indirect
+ github.com/opencontainers/runtime-spec v1.0.2 // indirect
github.com/pborman/uuid v1.2.0 // indirect
+ github.com/sirupsen/logrus v1.7.0 // indirect
github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 // indirect
github.com/vishvananda/netlink v1.0.1-0.20190930145447-2ec5bdc52b86 // indirect
github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f // indirect
github.com/xeipuuv/gojsonschema v1.2.0 // indirect
go.uber.org/multierr v1.6.0 // indirect
- google.golang.org/grpc v1.36.0-dev.0.20210122012134-2c42474aca0c // indirect
+ golang.org/x/net v0.0.0-20201224014010-6772e930b67b // indirect
+ golang.org/x/oauth2 v0.0.0-20201208152858-08078c50e5b5 // indirect
+ golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4 // indirect
+ golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect
+ golang.org/x/tools v0.1.0 // indirect
+ google.golang.org/api v0.36.0 // indirect
+ google.golang.org/grpc v1.36.0-dev.0.20210208035533-9280052d3665 // indirect
google.golang.org/protobuf v1.25.1-0.20201020201750-d3470999428b // indirect
+ gopkg.in/yaml.v2 v2.2.8 // indirect
honnef.co/go/tools v0.1.1 // indirect
+ k8s.io/api v0.16.13 // indirect
k8s.io/apimachinery v0.16.14-rc.0 // indirect
k8s.io/client-go v0.16.13 // indirect
)
diff --git a/go.sum b/go.sum
index 4df67c5dd..9d7ef2243 100644
--- a/go.sum
+++ b/go.sum
@@ -77,8 +77,8 @@ github.com/containerd/containerd v1.3.2/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMX
github.com/containerd/containerd v1.3.9 h1:K2U/F4jGAMBqeUssfgJRbFuomLcS2Fxo1vR3UM/Mbh8=
github.com/containerd/containerd v1.3.9/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMXFTttgp+kVtyUA=
github.com/containerd/continuity v0.0.0-20190426062206-aaeac12a7ffc/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y=
-github.com/containerd/continuity v0.0.0-20201208142359-180525291bb7 h1:6ejg6Lkk8dskcM7wQ28gONkukbQkM4qpj4RnYbpFzrI=
-github.com/containerd/continuity v0.0.0-20201208142359-180525291bb7/go.mod h1:kR3BEg7bDFaEddKm54WSmrol1fKWDU1nKYkgrcgZT7Y=
+github.com/containerd/continuity v0.0.0-20210208174643-50096c924a4e h1:6JKvHHt396/qabvMhnhUZvWaHZzfVfldxE60TK8YLhg=
+github.com/containerd/continuity v0.0.0-20210208174643-50096c924a4e/go.mod h1:EXlVlkqNba9rJe3j7w3Xa924itAMLgZH4UD/Q4PExuQ=
github.com/containerd/fifo v0.0.0-20190226154929-a9fb20d87448/go.mod h1:ODA38xgv3Kuk8dQz2ZQXpnv/UZZUHUCL7pnLehbXgQI=
github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00 h1:lsjC5ENBl+Zgf38+B0ymougXFp0BaubeIVETltYZTQw=
github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00/go.mod h1:jPQ2IAeZRCYxpS/Cm1495vGFww6ecHmMk1YJH2Q5ln0=
@@ -236,7 +236,6 @@ github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQL
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
-github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1 h1:zc0R6cOw98cMengLA0fvU55mqbnN7sd/tBMLzSejp+M=
@@ -286,7 +285,6 @@ github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
-github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
github.com/sirupsen/logrus v1.7.0 h1:ShrD1U9pZB12TX0cVy0DtePoCH97K8EtX+mg7ZARUtM=
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/spf13/afero v1.2.2/go.mod h1:9ZxEEn6pIJ8Rxe320qSDBk6AsU0r9pR7Q4OcevTdifk=
@@ -612,8 +610,8 @@ google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM
google.golang.org/grpc v1.31.1/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak=
google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc=
google.golang.org/grpc v1.34.0/go.mod h1:WotjhfgOW/POjDeRt8vscBtXq+2VjORFy659qA51WJ8=
-google.golang.org/grpc v1.36.0-dev.0.20210122012134-2c42474aca0c h1:cb+I9RwgcErlwAuOVnGhJ2d3YrcdwGXw+RPArsTWot4=
-google.golang.org/grpc v1.36.0-dev.0.20210122012134-2c42474aca0c/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU=
+google.golang.org/grpc v1.36.0-dev.0.20210208035533-9280052d3665 h1:iGG0ZwQMaxJT/qsL0nnzZCg+4aiWHuQy3MytzLieAjo=
+google.golang.org/grpc v1.36.0-dev.0.20210208035533-9280052d3665/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
diff --git a/images/basic/integrationtest/Dockerfile.x86_64 b/images/basic/integrationtest/Dockerfile.x86_64
index e80e17527..b9fed05cb 100644
--- a/images/basic/integrationtest/Dockerfile.x86_64
+++ b/images/basic/integrationtest/Dockerfile.x86_64
@@ -5,3 +5,9 @@ COPY . .
RUN chmod +x *.sh
RUN apt-get update && apt-get install -y gcc iputils-ping iproute2
+
+# Compilation Steps.
+RUN gcc -O2 -o test_copy_up test_copy_up.c
+RUN gcc -O2 -o test_rewinddir test_rewinddir.c
+RUN gcc -O2 -o link_test link_test.c
+RUN gcc -O2 -o test_sticky test_sticky.c
diff --git a/images/basic/integrationtest/test_sticky.c b/images/basic/integrationtest/test_sticky.c
new file mode 100644
index 000000000..58dcf91d3
--- /dev/null
+++ b/images/basic/integrationtest/test_sticky.c
@@ -0,0 +1,96 @@
+#include <err.h>
+#include <errno.h>
+#include <fcntl.h>
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <sys/wait.h>
+#include <unistd.h>
+
+void createFile(const char* path) {
+ int fd = open(path, O_WRONLY | O_CREAT, 0777);
+ if (fd < 0) {
+ err(1, "open(%s)", path);
+ exit(1);
+ } else {
+ close(fd);
+ }
+}
+
+void waitAndCheckStatus(pid_t child) {
+ int status;
+ if (waitpid(child, &status, 0) == -1) {
+ err(1, "waitpid() failed");
+ exit(1);
+ }
+
+ if (WIFEXITED(status)) {
+ int es = WEXITSTATUS(status);
+ if (es) {
+ err(1, "child exit status %d", es);
+ exit(1);
+ }
+ } else {
+ err(1, "child did not exit normally");
+ exit(1);
+ }
+}
+
+void deleteFile(uid_t user, const char* path) {
+ pid_t child = fork();
+ if (child == 0) {
+ if (setuid(user)) {
+ err(1, "setuid(%d)", user);
+ exit(1);
+ }
+
+ if (unlink(path)) {
+ err(1, "unlink(%s)", path);
+ exit(1);
+ }
+ exit(0);
+ }
+ waitAndCheckStatus(child);
+}
+
+int main(int argc, char** argv) {
+ const char kUser1Dir[] = "/user1dir";
+ const char kUser2File[] = "/user1dir/user2file";
+ const char kUser2File2[] = "/user1dir/user2file2";
+
+ const uid_t user1 = 6666;
+ const uid_t user2 = 6667;
+
+ if (mkdir(kUser1Dir, 0755) != 0) {
+ err(1, "mkdir(%s)", kUser1Dir);
+ exit(1);
+ }
+ // Enable sticky bit for user1dir.
+ if (chmod(kUser1Dir, 01777) != 0) {
+ err(1, "chmod(%s)", kUser1Dir);
+ exit(1);
+ }
+ createFile(kUser2File);
+ createFile(kUser2File2);
+
+ if (chown(kUser1Dir, user1, getegid())) {
+ err(1, "chown(%s)", kUser1Dir);
+ exit(1);
+ }
+ if (chown(kUser2File, user2, getegid())) {
+ err(1, "chown(%s)", kUser2File);
+ exit(1);
+ }
+ if (chown(kUser2File2, user2, getegid())) {
+ err(1, "chown(%s)", kUser2File2);
+ exit(1);
+ }
+
+ // User1 should be able to delete any file inside user1dir, even files of
+ // other users due to the sticky bit.
+ deleteFile(user1, kUser2File);
+
+ // User2 should naturally be able to delete its own file even if the file is
+ // inside a sticky dir owned by someone else.
+ deleteFile(user2, kUser2File2);
+}
diff --git a/images/syzkaller/Dockerfile b/images/syzkaller/Dockerfile
index df6680f40..9a85ae345 100644
--- a/images/syzkaller/Dockerfile
+++ b/images/syzkaller/Dockerfile
@@ -1,5 +1,7 @@
FROM gcr.io/syzkaller/env
+# This image is mostly for investigating syzkaller crashes, so let's install
+# developer tools.
RUN apt update && apt install -y git vim strace gdb procps
WORKDIR /syzkaller/gopath/src/github.com/google/syzkaller
diff --git a/images/syzkaller/README.md b/images/syzkaller/README.md
index 1eac474f3..47e309422 100644
--- a/images/syzkaller/README.md
+++ b/images/syzkaller/README.md
@@ -5,21 +5,54 @@ syzkaller is an unsupervised coverage-guided kernel fuzzer.
# How to run syzkaller.
-* Build the syzkaller docker image `make load-syzkaller`
-* Build runsc and place it in /tmp/syzkaller. `make RUNTIME_DIR=/tmp/syzkaller
- refresh`
-* Copy the syzkaller config in /tmp/syzkaller `cp
- images/syzkaller/default-gvisor-config.cfg /tmp/syzkaller/syzkaller.cfg`
-* Run syzkaller `docker run --privileged -it --rm -v
- /tmp/syzkaller:/tmp/syzkaller gvisor.dev/images/syzkaller:latest`
+First, we need to load a syzkaller docker image:
+
+```bash
+make load-syzkaller
+```
+
+or we can rebuild it to use an up-to-date version of the master branch:
+
+```bash
+make rebuild-syzkaller
+```
+
+Then we need to create a directory with all artifacts that we will need to run a
+syzkaller. Then we will bind-mount this directory to a docker container.
+
+We need to build runsc and place it on the artifact directory:
+
+```bash
+make RUNTIME_DIR=/tmp/syzkaller refresh
+```
+
+The next step is to create a syzkaller config. We can copy the default one and
+customize it:
+
+```bash
+cp images/syzkaller/default-gvisor-config.cfg /tmp/syzkaller/syzkaller.cfg
+```
+
+Now we can start syzkaller in a docker container:
+
+```bash
+docker run --privileged -it --rm \
+ -v /tmp/syzkaller:/tmp/syzkaller \
+ gvisor.dev/images/syzkaller:latest
+```
+
+All logs will be in /tmp/syzkaller/workdir.
# How to run a syz repro.
-* Repeate all steps except the last one from the previous section.
+We need to repeat all preparation steps from the previous section and save a
+syzkaller repro in /tmp/syzkaller/repro.
-* Save a syzkaller repro in /tmp/syzkaller/repro
+Now we can run syz-repro to reproduce a crash:
-* Run syz-repro `docker run --privileged -it --rm -v
+```bash
+docker run --privileged -it --rm -v
/tmp/syzkaller:/tmp/syzkaller --entrypoint=""
gvisor.dev/images/syzkaller:latest ./bin/syz-repro -config
- /tmp/syzkaller/syzkaller.cfg /tmp/syzkaller/repro`
+ /tmp/syzkaller/syzkaller.cfg /tmp/syzkaller/repro
+```
diff --git a/nogo.yaml b/nogo.yaml
index d9cbd900d..e43148a6d 100644
--- a/nogo.yaml
+++ b/nogo.yaml
@@ -147,6 +147,11 @@ analyzers:
- pkg/sentry/fs/fs.go # Intentional.
- pkg/sentry/fs/gofer/inode.go # Intentional.
- pkg/refs/refcounter_test.go # Intentional.
+ ST1019:
+ generated:
+ exclude:
+ # package ".../kubeapi/core/v1/v1" is being imported more than once
+ - generated.gen.pb.go
ST1021:
internal:
suppress:
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index 8fa61d6f7..ecaeb11ac 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -80,7 +80,6 @@ go_library(
"//pkg/bits",
"//pkg/marshal",
"//pkg/marshal/primitive",
- "//pkg/usermem",
],
)
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
index b521144d9..378f1baf3 100644
--- a/pkg/abi/linux/netfilter.go
+++ b/pkg/abi/linux/netfilter.go
@@ -15,11 +15,8 @@
package linux
import (
- "io"
-
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
- "gvisor.dev/gvisor/pkg/usermem"
)
// This file contains structures required to support netfilter, specifically
@@ -129,8 +126,8 @@ type IPTEntry struct {
const SizeOfIPTEntry = 112
// KernelIPTEntry is identical to IPTEntry, but includes the Elems field.
-// KernelIPTEntry itself is not Marshallable but it implements some methods of
-// marshal.Marshallable that help in other implementations of Marshallable.
+//
+// +marshal dynamic
type KernelIPTEntry struct {
Entry IPTEntry
@@ -158,6 +155,8 @@ func (ke *KernelIPTEntry) UnmarshalBytes(src []byte) {
ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():])
}
+var _ marshal.Marshallable = (*KernelIPTEntry)(nil)
+
// IPTIP contains information for matching a packet's IP header.
// It corresponds to struct ipt_ip in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
@@ -411,8 +410,9 @@ type IPTGetEntries struct {
const SizeOfIPTGetEntries = 40
// KernelIPTGetEntries is identical to IPTGetEntries, but includes the
-// Entrytable field. This has been manually made marshal.Marshallable since it
-// is dynamically sized.
+// Entrytable field.
+//
+// +marshal dynamic
type KernelIPTGetEntries struct {
IPTGetEntries
Entrytable []KernelIPTEntry
@@ -447,65 +447,6 @@ func (ke *KernelIPTGetEntries) UnmarshalBytes(src []byte) {
}
}
-// Packed implements marshal.Marshallable.Packed.
-func (ke *KernelIPTGetEntries) Packed() bool {
- // KernelIPTGetEntries isn't packed because the ke.Entrytable contains an
- // indirection to the actual data we want to marshal (the slice data
- // pointer), and the memory for KernelIPTGetEntries contains the slice
- // header which we don't want to marshal.
- return false
-}
-
-// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
-func (ke *KernelIPTGetEntries) MarshalUnsafe(dst []byte) {
- // Fall back to safe Marshal because the type in not packed.
- ke.MarshalBytes(dst)
-}
-
-// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
-func (ke *KernelIPTGetEntries) UnmarshalUnsafe(src []byte) {
- // Fall back to safe Unmarshal because the type in not packed.
- ke.UnmarshalBytes(src)
-}
-
-// CopyIn implements marshal.Marshallable.CopyIn.
-func (ke *KernelIPTGetEntries) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) {
- buf := cc.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay.
- length, err := cc.CopyInBytes(addr, buf) // escapes: okay.
- // Unmarshal unconditionally. If we had a short copy-in, this results in a
- // partially unmarshalled struct.
- ke.UnmarshalBytes(buf) // escapes: fallback.
- return length, err
-}
-
-// CopyOut implements marshal.Marshallable.CopyOut.
-func (ke *KernelIPTGetEntries) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) {
- // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall
- // back to MarshalBytes.
- return cc.CopyOutBytes(addr, ke.marshalAll(cc))
-}
-
-// CopyOutN implements marshal.Marshallable.CopyOutN.
-func (ke *KernelIPTGetEntries) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {
- // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall
- // back to MarshalBytes.
- return cc.CopyOutBytes(addr, ke.marshalAll(cc)[:limit])
-}
-
-func (ke *KernelIPTGetEntries) marshalAll(cc marshal.CopyContext) []byte {
- buf := cc.CopyScratchBuffer(ke.SizeBytes())
- ke.MarshalBytes(buf)
- return buf
-}
-
-// WriteTo implements io.WriterTo.WriteTo.
-func (ke *KernelIPTGetEntries) WriteTo(w io.Writer) (int64, error) {
- buf := make([]byte, ke.SizeBytes())
- ke.MarshalBytes(buf)
- length, err := w.Write(buf)
- return int64(length), err
-}
-
var _ marshal.Marshallable = (*KernelIPTGetEntries)(nil)
// IPTReplace is the argument for the IPT_SO_SET_REPLACE sockopt. It
diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go
index bcb57642e..b953e62dc 100644
--- a/pkg/abi/linux/netfilter_ipv6.go
+++ b/pkg/abi/linux/netfilter_ipv6.go
@@ -15,11 +15,8 @@
package linux
import (
- "io"
-
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
- "gvisor.dev/gvisor/pkg/usermem"
)
// This file contains structures required to support IPv6 netfilter and
@@ -70,8 +67,9 @@ type IP6TReplace struct {
const SizeOfIP6TReplace = 96
// KernelIP6TGetEntries is identical to IP6TGetEntries, but includes the
-// Entrytable field. This has been manually made marshal.Marshallable since it
-// is dynamically sized.
+// Entrytable field.
+//
+// +marshal dynamic
type KernelIP6TGetEntries struct {
IPTGetEntries
Entrytable []KernelIP6TEntry
@@ -106,65 +104,6 @@ func (ke *KernelIP6TGetEntries) UnmarshalBytes(src []byte) {
}
}
-// Packed implements marshal.Marshallable.Packed.
-func (ke *KernelIP6TGetEntries) Packed() bool {
- // KernelIP6TGetEntries isn't packed because the ke.Entrytable contains
- // an indirection to the actual data we want to marshal (the slice data
- // pointer), and the memory for KernelIP6TGetEntries contains the slice
- // header which we don't want to marshal.
- return false
-}
-
-// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
-func (ke *KernelIP6TGetEntries) MarshalUnsafe(dst []byte) {
- // Fall back to safe Marshal because the type in not packed.
- ke.MarshalBytes(dst)
-}
-
-// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
-func (ke *KernelIP6TGetEntries) UnmarshalUnsafe(src []byte) {
- // Fall back to safe Unmarshal because the type in not packed.
- ke.UnmarshalBytes(src)
-}
-
-// CopyIn implements marshal.Marshallable.CopyIn.
-func (ke *KernelIP6TGetEntries) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) {
- buf := cc.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay.
- length, err := cc.CopyInBytes(addr, buf) // escapes: okay.
- // Unmarshal unconditionally. If we had a short copy-in, this results
- // in a partially unmarshalled struct.
- ke.UnmarshalBytes(buf) // escapes: fallback.
- return length, err
-}
-
-// CopyOut implements marshal.Marshallable.CopyOut.
-func (ke *KernelIP6TGetEntries) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) {
- // Type KernelIP6TGetEntries doesn't have a packed layout in memory,
- // fall back to MarshalBytes.
- return cc.CopyOutBytes(addr, ke.marshalAll(cc))
-}
-
-// CopyOutN implements marshal.Marshallable.CopyOutN.
-func (ke *KernelIP6TGetEntries) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {
- // Type KernelIP6TGetEntries doesn't have a packed layout in memory, fall
- // back to MarshalBytes.
- return cc.CopyOutBytes(addr, ke.marshalAll(cc)[:limit])
-}
-
-func (ke *KernelIP6TGetEntries) marshalAll(cc marshal.CopyContext) []byte {
- buf := cc.CopyScratchBuffer(ke.SizeBytes())
- ke.MarshalBytes(buf)
- return buf
-}
-
-// WriteTo implements io.WriterTo.WriteTo.
-func (ke *KernelIP6TGetEntries) WriteTo(w io.Writer) (int64, error) {
- buf := make([]byte, ke.SizeBytes())
- ke.MarshalBytes(buf)
- length, err := w.Write(buf)
- return int64(length), err
-}
-
var _ marshal.Marshallable = (*KernelIP6TGetEntries)(nil)
// IP6TEntry is an iptables rule. It corresponds to struct ip6t_entry in
diff --git a/pkg/sentry/fs/host/file.go b/pkg/sentry/fs/host/file.go
index 86d1a87f0..fd4e057d8 100644
--- a/pkg/sentry/fs/host/file.go
+++ b/pkg/sentry/fs/host/file.go
@@ -180,16 +180,9 @@ func (f *fileOperations) Readdir(ctx context.Context, file *fs.File, serializer
// IterateDir implements fs.DirIterator.IterateDir.
func (f *fileOperations) IterateDir(ctx context.Context, d *fs.Dirent, dirCtx *fs.DirCtx, offset int) (int, error) {
- if f.dirinfo == nil {
- f.dirinfo = new(dirInfo)
- f.dirinfo.buf = make([]byte, usermem.PageSize)
- }
- entries, err := f.iops.readdirAll(f.dirinfo)
- if err != nil {
- return offset, err
- }
- count, err := fs.GenericReaddir(dirCtx, fs.NewSortedDentryMap(entries))
- return offset + count, err
+ // We only support non-directory file descriptors that have been
+ // imported, so just claim that this isn't a directory, even if it is.
+ return offset, syscall.ENOTDIR
}
// Write implements fs.FileOperations.Write.
diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go
index 2c14aa6d9..df4b265fa 100644
--- a/pkg/sentry/fs/host/inode.go
+++ b/pkg/sentry/fs/host/inode.go
@@ -411,10 +411,3 @@ func (i *inodeOperations) DropLink() {}
// NotifyStatusChange implements fs.InodeOperations.NotifyStatusChange.
func (i *inodeOperations) NotifyStatusChange(ctx context.Context) {}
-
-// readdirAll returns all of the directory entries in i.
-func (i *inodeOperations) readdirAll(d *dirInfo) (map[string]fs.DentAttr, error) {
- // We only support non-directory file descriptors that have been
- // imported, so just claim that this isn't a directory, even if it is.
- return nil, syscall.ENOTDIR
-}
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
index 07b4fb70f..2b58fc52c 100644
--- a/pkg/sentry/fs/host/socket.go
+++ b/pkg/sentry/fs/host/socket.go
@@ -16,6 +16,7 @@ package host
import (
"fmt"
+ "sync/atomic"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -206,7 +207,7 @@ func (c *ConnectedEndpoint) Send(ctx context.Context, data [][]byte, controlMess
// only as much of the message as fits in the send buffer.
truncate := c.stype == linux.SOCK_STREAM
- n, totalLen, err := fdWriteVec(c.file.FD(), data, c.sndbuf, truncate)
+ n, totalLen, err := fdWriteVec(c.file.FD(), data, c.SendMaxQueueSize(), truncate)
if n < totalLen && err == nil {
// The host only returns a short write if it would otherwise
// block (and only for stream sockets).
@@ -282,7 +283,7 @@ func (c *ConnectedEndpoint) Recv(ctx context.Context, data [][]byte, creds bool,
// N.B. Unix sockets don't have a receive buffer, the send buffer
// serves both purposes.
- rl, ml, cl, cTrunc, err := fdReadVec(c.file.FD(), data, []byte(cm), peek, c.sndbuf)
+ rl, ml, cl, cTrunc, err := fdReadVec(c.file.FD(), data, []byte(cm), peek, c.RecvMaxQueueSize())
if rl > 0 && err != nil {
// We got some data, so all we need to do on error is return
// the data that we got. Short reads are fine, no need to
@@ -363,14 +364,14 @@ func (c *ConnectedEndpoint) RecvQueuedSize() int64 {
// SendMaxQueueSize implements transport.Receiver.SendMaxQueueSize.
func (c *ConnectedEndpoint) SendMaxQueueSize() int64 {
- return int64(c.sndbuf)
+ return atomic.LoadInt64(&c.sndbuf)
}
// RecvMaxQueueSize implements transport.Receiver.RecvMaxQueueSize.
func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 {
// N.B. Unix sockets don't use the receive buffer. We'll claim it is
// the same size as the send buffer.
- return int64(c.sndbuf)
+ return atomic.LoadInt64(&c.sndbuf)
}
// Release implements transport.ConnectedEndpoint.Release and transport.Receiver.Release.
@@ -381,4 +382,11 @@ func (c *ConnectedEndpoint) Release(ctx context.Context) {
// CloseUnread implements transport.ConnectedEndpoint.CloseUnread.
func (c *ConnectedEndpoint) CloseUnread() {}
+// SetSendBufferSize implements transport.ConnectedEndpoint.SetSendBufferSize.
+func (c *ConnectedEndpoint) SetSendBufferSize(v int64) (newSz int64) {
+ // gVisor does not permit setting of SO_SNDBUF for host backed unix domain
+ // sockets.
+ return atomic.LoadInt64(&c.sndbuf)
+}
+
// LINT.ThenChange(../../fsimpl/host/socket.go)
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 98f7bc52f..094d993a8 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -1216,7 +1216,13 @@ func (d *dentry) checkXattrPermissions(creds *auth.Credentials, name string, ats
}
func (d *dentry) mayDelete(creds *auth.Credentials, child *dentry) error {
- return vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&child.uid)))
+ return vfs.CheckDeleteSticky(
+ creds,
+ linux.FileMode(atomic.LoadUint32(&d.mode)),
+ auth.KUID(atomic.LoadUint32(&d.uid)),
+ auth.KUID(atomic.LoadUint32(&child.uid)),
+ auth.KGID(atomic.LoadUint32(&child.gid)),
+ )
}
func dentryUIDFromP9UID(uid p9.UID) uint32 {
diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go
index 72aa535f8..6763f5b0c 100644
--- a/pkg/sentry/fsimpl/host/socket.go
+++ b/pkg/sentry/fsimpl/host/socket.go
@@ -16,6 +16,7 @@ package host
import (
"fmt"
+ "sync/atomic"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -111,7 +112,7 @@ func (c *ConnectedEndpoint) init() *syserr.Error {
}
c.stype = linux.SockType(stype)
- c.sndbuf = int64(sndbuf)
+ atomic.StoreInt64(&c.sndbuf, int64(sndbuf))
return nil
}
@@ -150,7 +151,7 @@ func (c *ConnectedEndpoint) Send(ctx context.Context, data [][]byte, controlMess
// only as much of the message as fits in the send buffer.
truncate := c.stype == linux.SOCK_STREAM
- n, totalLen, err := fdWriteVec(c.fd, data, c.sndbuf, truncate)
+ n, totalLen, err := fdWriteVec(c.fd, data, c.SendMaxQueueSize(), truncate)
if n < totalLen && err == nil {
// The host only returns a short write if it would otherwise
// block (and only for stream sockets).
@@ -226,7 +227,7 @@ func (c *ConnectedEndpoint) Recv(ctx context.Context, data [][]byte, creds bool,
// N.B. Unix sockets don't have a receive buffer, the send buffer
// serves both purposes.
- rl, ml, cl, cTrunc, err := fdReadVec(c.fd, data, []byte(cm), peek, c.sndbuf)
+ rl, ml, cl, cTrunc, err := fdReadVec(c.fd, data, []byte(cm), peek, c.RecvMaxQueueSize())
if rl > 0 && err != nil {
// We got some data, so all we need to do on error is return
// the data that we got. Short reads are fine, no need to
@@ -300,14 +301,14 @@ func (c *ConnectedEndpoint) RecvQueuedSize() int64 {
// SendMaxQueueSize implements transport.Receiver.SendMaxQueueSize.
func (c *ConnectedEndpoint) SendMaxQueueSize() int64 {
- return int64(c.sndbuf)
+ return atomic.LoadInt64(&c.sndbuf)
}
// RecvMaxQueueSize implements transport.Receiver.RecvMaxQueueSize.
func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 {
// N.B. Unix sockets don't use the receive buffer. We'll claim it is
// the same size as the send buffer.
- return int64(c.sndbuf)
+ return atomic.LoadInt64(&c.sndbuf)
}
func (c *ConnectedEndpoint) destroyLocked() {
@@ -327,6 +328,13 @@ func (c *ConnectedEndpoint) Release(ctx context.Context) {
// CloseUnread implements transport.ConnectedEndpoint.CloseUnread.
func (c *ConnectedEndpoint) CloseUnread() {}
+// SetSendBufferSize implements transport.ConnectedEndpoint.SetSendBufferSize.
+func (c *ConnectedEndpoint) SetSendBufferSize(v int64) (newSz int64) {
+ // gVisor does not permit setting of SO_SNDBUF for host backed unix domain
+ // sockets.
+ return atomic.LoadInt64(&c.sndbuf)
+}
+
// SCMConnectedEndpoint represents an endpoint backed by a host fd that was
// passed through a gofer Unix socket. It resembles ConnectedEndpoint, with the
// following differences:
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index a7a553619..d6dd6bc41 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -668,6 +668,12 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
// Can we create the dst dentry?
var dst *Dentry
pc := rp.Component()
+ if pc == "." || pc == ".." {
+ if noReplace {
+ return syserror.EEXIST
+ }
+ return syserror.EBUSY
+ }
switch err := checkCreateLocked(ctx, rp.Credentials(), pc, dstDir); err {
case nil:
// Ok, continue with rename as replacement.
diff --git a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go
index 463d77d79..11694c392 100644
--- a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go
+++ b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go
@@ -42,19 +42,16 @@ type syntheticDirectory struct {
var _ Inode = (*syntheticDirectory)(nil)
func newSyntheticDirectory(ctx context.Context, creds *auth.Credentials, perm linux.FileMode) Inode {
- inode := &syntheticDirectory{}
- inode.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, 0 /* ino */, perm)
- return inode
-}
-
-func (dir *syntheticDirectory) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) {
if perm&^linux.PermissionsMask != 0 {
panic(fmt.Sprintf("perm contains non-permission bits: %#o", perm))
}
- dir.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.S_IFDIR|perm)
+ dir := &syntheticDirectory{}
+ dir.InitRefs()
+ dir.InodeAttrs.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, 0 /* ino */, linux.S_IFDIR|perm)
dir.OrderedChildren.Init(OrderedChildrenOptions{
Writable: true,
})
+ return dir
}
// Open implements Inode.Open.
diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go
index e46f593c7..b36031291 100644
--- a/pkg/sentry/fsimpl/overlay/filesystem.go
+++ b/pkg/sentry/fsimpl/overlay/filesystem.go
@@ -1068,7 +1068,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if err != nil {
return err
}
- if err := vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&oldParent.mode)), auth.KUID(atomic.LoadUint32(&renamed.uid))); err != nil {
+ if err := oldParent.mayDelete(creds, renamed); err != nil {
return err
}
if renamed.isDir() {
@@ -1317,7 +1317,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
if !child.isDir() {
return syserror.ENOTDIR
}
- if err := vfs.CheckDeleteSticky(rp.Credentials(), linux.FileMode(atomic.LoadUint32(&parent.mode)), auth.KUID(atomic.LoadUint32(&child.uid))); err != nil {
+ if err := parent.mayDelete(rp.Credentials(), child); err != nil {
return err
}
child.dirMu.Lock()
@@ -1584,7 +1584,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
if child.isDir() {
return syserror.EISDIR
}
- if err := vfs.CheckDeleteSticky(rp.Credentials(), linux.FileMode(parentMode), auth.KUID(atomic.LoadUint32(&child.uid))); err != nil {
+ if err := parent.mayDelete(rp.Credentials(), child); err != nil {
return err
}
if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil {
diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go
index 082fa6504..acd3684c6 100644
--- a/pkg/sentry/fsimpl/overlay/overlay.go
+++ b/pkg/sentry/fsimpl/overlay/overlay.go
@@ -760,6 +760,16 @@ func (d *dentry) updateAfterSetStatLocked(opts *vfs.SetStatOptions) {
}
}
+func (d *dentry) mayDelete(creds *auth.Credentials, child *dentry) error {
+ return vfs.CheckDeleteSticky(
+ creds,
+ linux.FileMode(atomic.LoadUint32(&d.mode)),
+ auth.KUID(atomic.LoadUint32(&d.uid)),
+ auth.KUID(atomic.LoadUint32(&child.uid)),
+ auth.KGID(atomic.LoadUint32(&child.gid)),
+ )
+}
+
// fileDescription is embedded by overlay implementations of
// vfs.FileDescriptionImpl.
//
diff --git a/pkg/sentry/fsimpl/tmpfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go
index e90669cf0..417ac2eff 100644
--- a/pkg/sentry/fsimpl/tmpfs/directory.go
+++ b/pkg/sentry/fsimpl/tmpfs/directory.go
@@ -84,7 +84,13 @@ func (dir *directory) removeChildLocked(child *dentry) {
}
func (dir *directory) mayDelete(creds *auth.Credentials, child *dentry) error {
- return vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&dir.inode.mode)), auth.KUID(atomic.LoadUint32(&child.inode.uid)))
+ return vfs.CheckDeleteSticky(
+ creds,
+ linux.FileMode(atomic.LoadUint32(&dir.inode.mode)),
+ auth.KUID(atomic.LoadUint32(&dir.inode.uid)),
+ auth.KUID(atomic.LoadUint32(&child.inode.uid)),
+ auth.KGID(atomic.LoadUint32(&child.inode.gid)),
+ )
}
// +stateify savable
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index 6255a7c84..82a743ff3 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -656,6 +656,9 @@ func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64,
// Write to that memory as usual.
seg, gap = rw.file.data.Insert(gap, gapMR, fr.Start), fsutil.FileRangeGapIterator{}
+
+ default:
+ panic("unreachable")
}
}
exitLoop:
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index c551acd99..2c8668fc4 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -247,11 +247,15 @@ func (p *Pipe) writeLocked(count int64, f func(safemem.BlockSeq) (uint64, error)
return 0, syscall.EPIPE
}
- // POSIX requires that a write smaller than atomicIOBytes (PIPE_BUF) be
- // atomic, but requires no atomicity for writes larger than this.
avail := p.max - p.size
+ if avail == 0 {
+ return 0, syserror.ErrWouldBlock
+ }
short := false
if count > avail {
+ // POSIX requires that a write smaller than atomicIOBytes
+ // (PIPE_BUF) be atomic, but requires no atomicity for writes
+ // larger than this.
if count <= atomicIOBytes {
return 0, syserror.ErrWouldBlock
}
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
index 16986244c..f7765fa3a 100644
--- a/pkg/sentry/kernel/task_exit.go
+++ b/pkg/sentry/kernel/task_exit.go
@@ -415,6 +415,12 @@ func (tg *ThreadGroup) anyNonExitingTaskLocked() *Task {
func (t *Task) reparentLocked(parent *Task) {
oldParent := t.parent
t.parent = parent
+ if oldParent != nil {
+ delete(oldParent.children, t)
+ }
+ if parent != nil {
+ parent.children[t] = struct{}{}
+ }
// If a thread group leader's parent changes, reset the thread group's
// termination signal to SIGCHLD and re-check exit notification. (Compare
// kernel/exit.c:reparent_leader().)
diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go
index 98af2cc38..cd9fa4031 100644
--- a/pkg/sentry/loader/elf.go
+++ b/pkg/sentry/loader/elf.go
@@ -517,12 +517,14 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, in
start, ok = start.AddLength(uint64(offset))
if !ok {
- panic(fmt.Sprintf("Start %#x + offset %#x overflows?", start, offset))
+ ctx.Infof(fmt.Sprintf("Start %#x + offset %#x overflows?", start, offset))
+ return loadedELF{}, syserror.EINVAL
}
end, ok = end.AddLength(uint64(offset))
if !ok {
- panic(fmt.Sprintf("End %#x + offset %#x overflows?", end, offset))
+ ctx.Infof(fmt.Sprintf("End %#x + offset %#x overflows?", end, offset))
+ return loadedELF{}, syserror.EINVAL
}
info.entry, ok = info.entry.AddLength(uint64(offset))
diff --git a/pkg/sentry/platform/ptrace/filters.go b/pkg/sentry/platform/ptrace/filters.go
index b0970e356..20fc62acb 100644
--- a/pkg/sentry/platform/ptrace/filters.go
+++ b/pkg/sentry/platform/ptrace/filters.go
@@ -17,14 +17,12 @@ package ptrace
import (
"syscall"
- "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/seccomp"
)
// SyscallFilters returns syscalls made exclusively by the ptrace platform.
func (*PTrace) SyscallFilters() seccomp.SyscallRules {
return seccomp.SyscallRules{
- unix.SYS_GETCPU: {},
syscall.SYS_PTRACE: {},
syscall.SYS_TGKILL: {},
syscall.SYS_WAIT4: {},
diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go
index f82c7c224..dc03ccb47 100644
--- a/pkg/sentry/socket/hostinet/socket_vfs2.go
+++ b/pkg/sentry/socket/hostinet/socket_vfs2.go
@@ -80,8 +80,7 @@ func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol in
// Release implements vfs.FileDescriptionImpl.Release.
func (s *socketVFS2) Release(ctx context.Context) {
- t := kernel.TaskFromContext(ctx)
- t.Kernel().DeleteSocketVFS2(&s.vfsfd)
+ kernel.KernelFromContext(ctx).DeleteSocketVFS2(&s.vfsfd)
s.socketOpsCommon.Release(ctx)
}
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 69693f263..cee8120ab 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -855,10 +855,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- size, err := ep.SocketOptions().GetSendBufferSize()
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
+ size := ep.SocketOptions().GetSendBufferSize()
if size > math.MaxInt32 {
size = math.MaxInt32
@@ -1647,13 +1644,6 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
return syserr.ErrInvalidArgument
}
- family, _, _ := s.Type()
- // TODO(gvisor.dev/issue/5132): We currently do not support
- // setting this option for unix sockets.
- if family == linux.AF_UNIX {
- return nil
- }
-
v := usermem.ByteOrder.Uint32(optVal)
ep.SocketOptions().SetSendBufferSize(int64(v), true)
return nil
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index 24922c400..fc29f8f13 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -79,8 +79,7 @@ func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, qu
// Release implements vfs.FileDescriptionImpl.Release.
func (s *SocketVFS2) Release(ctx context.Context) {
- t := kernel.TaskFromContext(ctx)
- t.Kernel().DeleteSocketVFS2(&s.vfsfd)
+ kernel.KernelFromContext(ctx).DeleteSocketVFS2(&s.vfsfd)
s.socketOpsCommon.Release(ctx)
}
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index cce0acc33..acf2ab8e7 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -51,6 +51,7 @@ go_library(
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/lock",
"//pkg/sentry/fsimpl/sockfs",
+ "//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/time",
"//pkg/sentry/socket",
diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD
index 3ebbd28b0..0d11bb251 100644
--- a/pkg/sentry/socket/unix/transport/BUILD
+++ b/pkg/sentry/socket/unix/transport/BUILD
@@ -32,6 +32,7 @@ go_library(
"connectioned.go",
"connectioned_state.go",
"connectionless.go",
+ "connectionless_state.go",
"queue.go",
"queue_refs.go",
"transport_message_list.go",
@@ -45,6 +46,7 @@ go_library(
"//pkg/log",
"//pkg/refs",
"//pkg/refsvfs2",
+ "//pkg/sentry/inet",
"//pkg/sync",
"//pkg/syserr",
"//pkg/tcpip",
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index fc5b823b0..809c95429 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -128,7 +128,9 @@ func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProv
idGenerator: uid,
stype: stype,
}
- ep.ops.InitHandler(ep, nil, nil)
+
+ ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */)
+ ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits)
return ep
}
@@ -137,9 +139,9 @@ func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (E
a := newConnectioned(ctx, stype, uid)
b := newConnectioned(ctx, stype, uid)
- q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit}
+ q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: defaultBufferSize}
q1.InitRefs()
- q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: initialLimit}
+ q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: defaultBufferSize}
q2.InitRefs()
if stype == linux.SOCK_STREAM {
@@ -173,7 +175,8 @@ func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider
idGenerator: uid,
stype: stype,
}
- ep.ops.InitHandler(ep, nil, nil)
+ ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits)
+ ep.ops.SetSendBufferSize(connected.SendMaxQueueSize(), false /* notify */)
return ep
}
@@ -296,16 +299,18 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn
idGenerator: e.idGenerator,
stype: e.stype,
}
- ne.ops.InitHandler(ne, nil, nil)
+ ne.ops.InitHandler(ne, &stackHandler{}, getSendBufferLimits)
+ ne.ops.SetSendBufferSize(defaultBufferSize, false /* notify */)
- readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit}
+ readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: defaultBufferSize}
readQueue.InitRefs()
ne.connected = &connectedEndpoint{
endpoint: ce,
writeQueue: readQueue,
}
- writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: initialLimit}
+ // Make sure the accepted endpoint inherits this listening socket's SO_SNDBUF.
+ writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: e.ops.GetSendBufferSize()}
writeQueue.InitRefs()
if e.stype == linux.SOCK_STREAM {
ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}}
@@ -357,6 +362,9 @@ func (e *connectionedEndpoint) Connect(ctx context.Context, server BoundEndpoint
returnConnect := func(r Receiver, ce ConnectedEndpoint) {
e.receiver = r
e.connected = ce
+ // Make sure the newly created connected endpoint's write queue is updated
+ // to reflect this endpoint's send buffer size.
+ e.connected.SetSendBufferSize(e.ops.GetSendBufferSize())
}
return server.BidirectionalConnect(ctx, e, returnConnect)
@@ -495,3 +503,11 @@ func (e *connectionedEndpoint) State() uint32 {
}
return linux.SS_UNCONNECTED
}
+
+// OnSetSendBufferSize implements tcpip.SocketOptionsHandler.OnSetSendBufferSize.
+func (e *connectionedEndpoint) OnSetSendBufferSize(v int64) (newSz int64) {
+ if e.Connected() {
+ return e.baseEndpoint.connected.SetSendBufferSize(v)
+ }
+ return v
+}
diff --git a/pkg/sentry/socket/unix/transport/connectioned_state.go b/pkg/sentry/socket/unix/transport/connectioned_state.go
index 7e02a5db8..590b0bd01 100644
--- a/pkg/sentry/socket/unix/transport/connectioned_state.go
+++ b/pkg/sentry/socket/unix/transport/connectioned_state.go
@@ -51,3 +51,8 @@ func (e *connectionedEndpoint) loadAcceptedChan(acceptedSlice []*connectionedEnd
}
}
}
+
+// afterLoad is invoked by stateify.
+func (e *connectionedEndpoint) afterLoad() {
+ e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits)
+}
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
index 20fa8b874..0be78480c 100644
--- a/pkg/sentry/socket/unix/transport/connectionless.go
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -41,10 +41,11 @@ var (
// NewConnectionless creates a new unbound dgram endpoint.
func NewConnectionless(ctx context.Context) Endpoint {
ep := &connectionlessEndpoint{baseEndpoint{Queue: &waiter.Queue{}}}
- q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit}
+ q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: defaultBufferSize}
q.InitRefs()
ep.receiver = &queueReceiver{readQueue: &q}
- ep.ops.InitHandler(ep, nil, nil)
+ ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */)
+ ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits)
return ep
}
@@ -217,3 +218,11 @@ func (e *connectionlessEndpoint) State() uint32 {
return linux.SS_DISCONNECTING
}
}
+
+// OnSetSendBufferSize implements tcpip.SocketOptionsHandler.OnSetSendBufferSize.
+func (e *connectionlessEndpoint) OnSetSendBufferSize(v int64) (newSz int64) {
+ if e.Connected() {
+ return e.baseEndpoint.connected.SetSendBufferSize(v)
+ }
+ return v
+}
diff --git a/pkg/sentry/socket/unix/transport/connectionless_state.go b/pkg/sentry/socket/unix/transport/connectionless_state.go
new file mode 100644
index 000000000..2ef337ec8
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/connectionless_state.go
@@ -0,0 +1,20 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+// afterLoad is invoked by stateify.
+func (e *connectionlessEndpoint) afterLoad() {
+ e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits)
+}
diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go
index 342def28f..698a9a82c 100644
--- a/pkg/sentry/socket/unix/transport/queue.go
+++ b/pkg/sentry/socket/unix/transport/queue.go
@@ -237,9 +237,18 @@ func (q *queue) QueuedSize() int64 {
// MaxQueueSize returns the maximum number of bytes storable in the queue.
func (q *queue) MaxQueueSize() int64 {
+ q.mu.Lock()
+ defer q.mu.Unlock()
return q.limit
}
+// SetMaxQueueSize sets the maximum number of bytes storable in the queue.
+func (q *queue) SetMaxQueueSize(v int64) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ q.limit = v
+}
+
// CloseUnread sets flag to indicate that the peer is closed (not shutdown)
// with unread data. So if read on this queue shall return ECONNRESET error.
func (q *queue) CloseUnread() {
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 70227bbd2..ceada54a8 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -26,8 +26,16 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-// initialLimit is the starting limit for the socket buffers.
-const initialLimit = 16 * 1024
+const (
+ // The minimum size of the send/receive buffers.
+ minimumBufferSize = 4 << 10 // 4 KiB (match default in linux)
+
+ // The default size of the send/receive buffers.
+ defaultBufferSize = 208 << 10 // 208 KiB (default in linux for net.core.wmem_default)
+
+ // The maximum permitted size for the send/receive buffers.
+ maxBufferSize = 4 << 20 // 4 MiB 4 MiB (default in linux for net.core.wmem_max)
+)
// A RightsControlMessage is a control message containing FDs.
//
@@ -627,6 +635,10 @@ type ConnectedEndpoint interface {
// CloseUnread sets the fact that this end is closed with unread data to
// the peer socket.
CloseUnread()
+
+ // SetSendBufferSize is called when the endpoint's send buffer size is
+ // changed.
+ SetSendBufferSize(v int64) (newSz int64)
}
// +stateify savable
@@ -722,6 +734,14 @@ func (e *connectedEndpoint) CloseUnread() {
e.writeQueue.CloseUnread()
}
+// SetSendBufferSize implements ConnectedEndpoint.SetSendBufferSize.
+// SetSendBufferSize sets the send buffer size for the write queue to the
+// specified value.
+func (e *connectedEndpoint) SetSendBufferSize(v int64) (newSz int64) {
+ e.writeQueue.SetMaxQueueSize(v)
+ return v
+}
+
// baseEndpoint is an embeddable unix endpoint base used in both the connected and connectionless
// unix domain socket Endpoint implementations.
//
@@ -849,27 +869,6 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
return nil
}
-// IsUnixSocket implements tcpip.SocketOptionsHandler.IsUnixSocket.
-func (e *baseEndpoint) IsUnixSocket() bool {
- return true
-}
-
-// GetSendBufferSize implements tcpip.SocketOptionsHandler.GetSendBufferSize.
-func (e *baseEndpoint) GetSendBufferSize() (int64, tcpip.Error) {
- e.Lock()
- defer e.Unlock()
-
- if !e.Connected() {
- return -1, &tcpip.ErrNotConnected{}
- }
-
- v := e.connected.SendMaxQueueSize()
- if v < 0 {
- return -1, &tcpip.ErrQueueSizeNotSupported{}
- }
- return v, nil
-}
-
func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
@@ -987,3 +986,35 @@ func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
func (*baseEndpoint) Release(context.Context) {
// Binding a baseEndpoint doesn't take a reference.
}
+
+// stackHandler is just a stub implementation of tcpip.StackHandler to provide
+// when initializing socketoptions.
+type stackHandler struct {
+}
+
+// Option implements tcpip.StackHandler.
+func (h *stackHandler) Option(option interface{}) tcpip.Error {
+ panic("unimplemented")
+}
+
+// TransportProtocolOption implements tcpip.StackHandler.
+func (h *stackHandler) TransportProtocolOption(proto tcpip.TransportProtocolNumber, option tcpip.GettableTransportProtocolOption) tcpip.Error {
+ panic("unimplemented")
+}
+
+// getSendBufferLimits implements tcpip.GetSendBufferLimits.
+//
+// AF_UNIX sockets buffer sizes are not tied to the networking stack/namespace
+// in linux but are bound by net.core.(wmem|rmem)_(max|default).
+//
+// In gVisor net.core sysctls today are not exposed or if exposed are currently
+// tied to the networking stack in use. This makes it complicated for AF_UNIX
+// when we are in a new namespace w/ no networking stack. As a result for now we
+// define default/max values here in the unix socket implementation itself.
+func getSendBufferLimits(tcpip.StackHandler) tcpip.SendBufferSizeOption {
+ return tcpip.SendBufferSizeOption{
+ Min: minimumBufferSize,
+ Default: defaultBufferSize,
+ Max: maxBufferSize,
+ }
+}
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index a7d4d7f1f..9c037cbae 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -95,8 +95,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3
// DecRef implements RefCounter.DecRef.
func (s *SocketVFS2) DecRef(ctx context.Context) {
s.socketVFS2Refs.DecRef(func() {
- t := kernel.TaskFromContext(ctx)
- t.Kernel().DeleteSocketVFS2(&s.vfsfd)
+ kernel.KernelFromContext(ctx).DeleteSocketVFS2(&s.vfsfd)
s.ep.Close(ctx)
if s.abstractNamespace != nil {
s.abstractNamespace.Remove(s.abstractName, s)
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index a2e441448..4188502dc 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -62,7 +62,6 @@ go_library(
deps = [
"//pkg/abi",
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/bpf",
"//pkg/context",
"//pkg/log",
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index fe45225c1..686392cc8 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -18,7 +18,6 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -457,7 +456,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, e.ToError()
}
- vLen := int32(binary.Size(v))
+ vLen := int32(v.SizeBytes())
if _, err := primitive.CopyInt32Out(t, optLenAddr, vLen); err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD
index 9ee766552..2e59bd5b1 100644
--- a/pkg/sentry/syscalls/linux/vfs2/BUILD
+++ b/pkg/sentry/syscalls/linux/vfs2/BUILD
@@ -39,7 +39,6 @@ go_library(
visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/bits",
"//pkg/context",
"//pkg/fspath",
diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go
index f5795b4a8..7636ca453 100644
--- a/pkg/sentry/syscalls/linux/vfs2/socket.go
+++ b/pkg/sentry/syscalls/linux/vfs2/socket.go
@@ -18,7 +18,6 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -460,7 +459,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, e.ToError()
}
- vLen := int32(binary.Size(v))
+ vLen := int32(v.SizeBytes())
if _, err := primitive.CopyInt32Out(t, optLenAddr, vLen); err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go
index eb7d2fd3b..d2050b3f7 100644
--- a/pkg/sentry/vfs/file_description_impl_util.go
+++ b/pkg/sentry/vfs/file_description_impl_util.go
@@ -238,6 +238,8 @@ func (s *StaticData) Generate(ctx context.Context, buf *bytes.Buffer) error {
// WritableDynamicBytesSource extends DynamicBytesSource to allow writes to the
// underlying source.
+//
+// TODO(b/179825241): Make utility for integer-based writable files.
type WritableDynamicBytesSource interface {
DynamicBytesSource
diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go
index d48520d58..db6146fd2 100644
--- a/pkg/sentry/vfs/permissions.go
+++ b/pkg/sentry/vfs/permissions.go
@@ -243,11 +243,13 @@ func CheckSetStat(ctx context.Context, creds *auth.Credentials, opts *SetStatOpt
// the given file mode, and if so, checks whether creds has permission to
// remove a file owned by childKUID from a directory with the given mode.
// CheckDeleteSticky is consistent with fs/linux.h:check_sticky().
-func CheckDeleteSticky(creds *auth.Credentials, parentMode linux.FileMode, childKUID auth.KUID) error {
+func CheckDeleteSticky(creds *auth.Credentials, parentMode linux.FileMode, parentKUID auth.KUID, childKUID auth.KUID, childKGID auth.KGID) error {
if parentMode&linux.ModeSticky == 0 {
return nil
}
- if CanActAsOwner(creds, childKUID) {
+ if creds.EffectiveKUID == childKUID ||
+ creds.EffectiveKUID == parentKUID ||
+ HasCapabilityOnFile(creds, linux.CAP_FOWNER, childKUID, childKGID) {
return nil
}
return syserror.EPERM
diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go
index bbe84f220..21fb87757 100644
--- a/pkg/tcpip/link/pipe/pipe.go
+++ b/pkg/tcpip/link/pipe/pipe.go
@@ -46,6 +46,10 @@ type Endpoint struct {
}
func (e *Endpoint) deliverPackets(r stack.RouteInfo, proto tcpip.NetworkProtocolNumber, pkts stack.PacketBufferList) {
+ if !e.linked.IsAttached() {
+ return
+ }
+
// Note that the local address from the perspective of this endpoint is the
// remote address from the perspective of the other end of the pipe
// (e.linked). Similarly, the remote address from the perspective of this
@@ -54,44 +58,26 @@ func (e *Endpoint) deliverPackets(r stack.RouteInfo, proto tcpip.NetworkProtocol
// Deliver the packet in a new goroutine to escape this goroutine's stack and
// avoid a deadlock when a packet triggers a response which leads the stack to
// try and take a lock it already holds.
- //
- // As of writing, a deadlock may occur when performing link resolution as the
- // neighbor table will send a solicitation while holding a lock and the
- // response advertisement will be sent in the same stack that sent the
- // solictation. When the response is received, the stack attempts to take the
- // same lock it already took before sending the solicitation, leading to a
- // deadlock. Basically, we attempt to lock the same lock twice in the same
- // call stack.
- //
- // TODO(gvisor.dev/issue/5289): don't use a new goroutine once we support send
- // and receive queues.
- go func() {
- for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
- }))
- }
- }()
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
+ }))
+ }
}
// WritePacket implements stack.LinkEndpoint.
func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- if e.linked.IsAttached() {
- var pkts stack.PacketBufferList
- pkts.PushBack(pkt)
- e.deliverPackets(r, proto, pkts)
- }
-
+ var pkts stack.PacketBufferList
+ pkts.PushBack(pkt)
+ e.deliverPackets(r, proto, pkts)
return nil
}
// WritePackets implements stack.LinkEndpoint.
func (e *Endpoint) WritePackets(r stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- if e.linked.IsAttached() {
- e.deliverPackets(r, proto, pkts)
- }
-
- return pkts.Len(), nil
+ n := pkts.Len()
+ e.deliverPackets(r, proto, pkts)
+ return n, nil
}
// Attach implements stack.LinkEndpoint.
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index 0caa65251..fa8814bac 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -16,7 +16,6 @@ go_test(
"//pkg/tcpip/checker",
"//pkg/tcpip/faketime",
"//pkg/tcpip/header",
- "//pkg/tcpip/header/parse",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index 933845269..d59d678b2 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -10,10 +10,12 @@ go_library(
],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/header/parse",
+ "//pkg/tcpip/network/internal/ip",
"//pkg/tcpip/stack",
],
)
@@ -44,7 +46,7 @@ go_test(
library = ":arp",
deps = [
"//pkg/tcpip",
- "//pkg/tcpip/network/testutil",
+ "//pkg/tcpip/network/internal/testutil",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 0d7fadc31..3fcdea119 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -22,10 +22,12 @@ import (
"reflect"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/header/parse"
+ "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -34,6 +36,7 @@ const (
ProtocolNumber = header.ARPProtocolNumber
)
+var _ stack.DuplicateAddressDetector = (*endpoint)(nil)
var _ stack.LinkAddressResolver = (*endpoint)(nil)
// ARP endpoints need to implement stack.NetworkEndpoint because the stack
@@ -52,6 +55,35 @@ type endpoint struct {
nic stack.NetworkInterface
stats sharedStats
+
+ mu struct {
+ sync.Mutex
+
+ dad ip.DAD
+ }
+}
+
+// CheckDuplicateAddress implements stack.DuplicateAddressDetector.
+func (e *endpoint) CheckDuplicateAddress(addr tcpip.Address, h stack.DADCompletionHandler) stack.DADCheckAddressDisposition {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.dad.CheckDuplicateAddressLocked(addr, h)
+}
+
+// SetDADConfigurations implements stack.DuplicateAddressDetector.
+func (e *endpoint) SetDADConfigurations(c stack.DADConfigurations) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.mu.dad.SetConfigsLocked(c)
+}
+
+// DuplicateAddressProtocol implements stack.DuplicateAddressDetector.
+func (*endpoint) DuplicateAddressProtocol() tcpip.NetworkProtocolNumber {
+ return header.IPv4ProtocolNumber
+}
+
+func (e *endpoint) SendDADMessage(addr tcpip.Address) tcpip.Error {
+ return e.sendARPRequest(header.IPv4Any, addr, header.EthernetBroadcastAddress)
}
func (e *endpoint) Enable() tcpip.Error {
@@ -129,6 +161,11 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
+ if _, _, ok := e.protocol.Parse(pkt); !ok {
+ stats.malformedPacketsReceived.Increment()
+ return
+ }
+
h := header.ARP(pkt.NetworkHeader().View())
if !h.IsValid() {
stats.malformedPacketsReceived.Increment()
@@ -140,7 +177,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
stats.requestsReceived.Increment()
localAddr := tcpip.Address(h.ProtocolAddressTarget())
- if e.protocol.stack.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
+ if !e.nic.CheckLocalAddress(header.IPv4ProtocolNumber, localAddr) {
stats.requestsReceivedUnknownTargetAddress.Increment()
return // we have no useful answer, ignore the request
}
@@ -194,6 +231,10 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
addr := tcpip.Address(h.ProtocolAddressSender())
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
+ e.mu.Lock()
+ e.mu.dad.StopLocked(addr, false /* aborted */)
+ e.mu.Unlock()
+
// The solicited, override, and isRouter flags are not available for ARP;
// they are only available for IPv6 Neighbor Advertisements.
switch err := e.nic.HandleNeighborConfirmation(header.IPv4ProtocolNumber, addr, linkAddr, stack.ReachabilityConfirmationFlags{
@@ -222,9 +263,9 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats {
var _ stack.NetworkProtocol = (*protocol)(nil)
-// protocol implements stack.NetworkProtocol and stack.LinkAddressResolver.
type protocol struct {
- stack *stack.Stack
+ stack *stack.Stack
+ options Options
}
func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber }
@@ -241,6 +282,14 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran
nic: nic,
}
+ e.mu.Lock()
+ e.mu.dad.Init(&e.mu, p.options.DADConfigs, ip.DADOptions{
+ Clock: p.stack.Clock(),
+ Protocol: e,
+ NICID: nic.ID(),
+ })
+ e.mu.Unlock()
+
tcpip.InitStatCounters(reflect.ValueOf(&e.stats.localStats).Elem())
stackStats := p.stack.Stats()
@@ -276,13 +325,17 @@ func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
}
localAddr = addr.Address
- } else if e.protocol.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
+ } else if !e.nic.CheckLocalAddress(header.IPv4ProtocolNumber, localAddr) {
stats.outgoingRequestBadLocalAddressErrors.Increment()
return &tcpip.ErrBadLocalAddress{}
}
+ return e.sendARPRequest(localAddr, targetAddr, remoteLinkAddr)
+}
+
+func (e *endpoint) sendARPRequest(localAddr, targetAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize,
+ ReserveHeaderBytes: int(e.MaxHeaderLength()),
})
h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize))
pkt.NetworkProtocolNumber = ProtocolNumber
@@ -297,6 +350,8 @@ func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
if n := copy(h.ProtocolAddressTarget(), targetAddr); n != header.IPv4AddressSize {
panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
}
+
+ stats := e.stats.arp
if err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil {
stats.outgoingRequestsDropped.Increment()
return err
@@ -337,9 +392,24 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return 0, false, parse.ARP(pkt)
}
+// Options holds options to configure a protocol.
+type Options struct {
+ // DADConfigs is the default DAD configurations used by ARP endpoints.
+ DADConfigs stack.DADConfigurations
+}
+
+// NewProtocolWithOptions returns an ARP network protocol factory that
+// will return an ARP network protocol with the provided options.
+func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
+ return func(s *stack.Stack) stack.NetworkProtocol {
+ return &protocol{
+ stack: s,
+ options: opts,
+ }
+ }
+}
+
// NewProtocol returns an ARP network protocol.
func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
- return &protocol{
- stack: s,
- }
+ return NewProtocolWithOptions(Options{})(s)
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 24357e15d..018d6a578 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -17,7 +17,6 @@ package arp_test
import (
"context"
"fmt"
- "strconv"
"testing"
"time"
@@ -155,7 +154,7 @@ type testContext struct {
nudDisp *arpDispatcher
}
-func newTestContext(t *testing.T, useNeighborCache bool) *testContext {
+func newTestContext(t *testing.T) *testContext {
c := stack.DefaultNUDConfigurations()
// Transition from Reachable to Stale almost immediately to test if receiving
// probes refreshes positive reachability.
@@ -173,7 +172,6 @@ func newTestContext(t *testing.T, useNeighborCache bool) *testContext {
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
NUDConfigs: c,
NUDDisp: &d,
- UseNeighborCache: useNeighborCache,
})
ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr)
@@ -191,15 +189,6 @@ func newTestContext(t *testing.T, useNeighborCache bool) *testContext {
if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
t.Fatalf("AddAddress for ipv4 failed: %v", err)
}
- if !useNeighborCache {
- // The remote address needs to be assigned to the NIC so we can receive and
- // verify outgoing ARP packets. The neighbor cache isn't concerned with
- // this; the tests that use linkAddrCache expect the ARP responses to be
- // received by the same NIC.
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, remoteAddr); err != nil {
- t.Fatalf("AddAddress for ipv4 failed: %v", err)
- }
- }
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
@@ -217,86 +206,8 @@ func (c *testContext) cleanup() {
c.linkEP.Close()
}
-func TestDirectRequest(t *testing.T) {
- c := newTestContext(t, false /* useNeighborCache */)
- defer c.cleanup()
-
- const senderMAC = "\x01\x02\x03\x04\x05\x06"
- const senderIPv4 = "\x0a\x00\x00\x02"
-
- v := make(buffer.View, header.ARPSize)
- h := header.ARP(v)
- h.SetIPv4OverEthernet()
- h.SetOp(header.ARPRequest)
- copy(h.HardwareAddressSender(), senderMAC)
- copy(h.ProtocolAddressSender(), senderIPv4)
-
- inject := func(addr tcpip.Address) {
- copy(h.ProtocolAddressTarget(), addr)
- c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: v.ToVectorisedView(),
- }))
- }
-
- for i, address := range []tcpip.Address{stackAddr, remoteAddr} {
- t.Run(strconv.Itoa(i), func(t *testing.T) {
- expectedPacketsReceived := c.s.Stats().ARP.PacketsReceived.Value() + 1
- expectedRequestsReceived := c.s.Stats().ARP.RequestsReceived.Value() + 1
- expectedRepliesSent := c.s.Stats().ARP.OutgoingRepliesSent.Value() + 1
-
- inject(address)
- pi, _ := c.linkEP.ReadContext(context.Background())
- if pi.Proto != arp.ProtocolNumber {
- t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto)
- }
- rep := header.ARP(pi.Pkt.NetworkHeader().View())
- if !rep.IsValid() {
- t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep)
- }
- if got := rep.Op(); got != header.ARPReply {
- t.Fatalf("got Op = %d, want = %d", got, header.ARPReply)
- }
- if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want {
- t.Errorf("got HardwareAddressSender = %s, want = %s", got, want)
- }
- if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want {
- t.Errorf("got ProtocolAddressSender = %s, want = %s", got, want)
- }
- if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress(h.HardwareAddressSender()); got != want {
- t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want)
- }
- if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want {
- t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, want)
- }
-
- if got := c.s.Stats().ARP.PacketsReceived.Value(); got != expectedPacketsReceived {
- t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, expectedPacketsReceived)
- }
- if got := c.s.Stats().ARP.RequestsReceived.Value(); got != expectedRequestsReceived {
- t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, expectedRequestsReceived)
- }
- if got := c.s.Stats().ARP.OutgoingRepliesSent.Value(); got != expectedRepliesSent {
- t.Errorf("got c.s.Stats().ARP.OutgoingRepliesSent.Value() = %d, want = %d", got, expectedRepliesSent)
- }
- })
- }
-
- inject(unknownAddr)
- // Sleep tests are gross, but this will only potentially flake
- // if there's a bug. If there is no bug this will reliably
- // succeed.
- ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
- defer cancel()
- if pkt, ok := c.linkEP.ReadContext(ctx); ok {
- t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto)
- }
- if got := c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value(); got != 1 {
- t.Errorf("got c.s.Stats().ARP.RequestsReceivedUnKnownTargetAddress.Value() = %d, want = 1", got)
- }
-}
-
func TestMalformedPacket(t *testing.T) {
- c := newTestContext(t, false)
+ c := newTestContext(t)
defer c.cleanup()
v := make(buffer.View, header.ARPSize)
@@ -315,7 +226,7 @@ func TestMalformedPacket(t *testing.T) {
}
func TestDisabledEndpoint(t *testing.T) {
- c := newTestContext(t, false)
+ c := newTestContext(t)
defer c.cleanup()
ep, err := c.s.GetNetworkEndpoint(nicID, header.ARPProtocolNumber)
@@ -340,7 +251,7 @@ func TestDisabledEndpoint(t *testing.T) {
}
func TestDirectReply(t *testing.T) {
- c := newTestContext(t, false)
+ c := newTestContext(t)
defer c.cleanup()
const senderMAC = "\x01\x02\x03\x04\x05\x06"
@@ -370,8 +281,8 @@ func TestDirectReply(t *testing.T) {
}
}
-func TestDirectRequestWithNeighborCache(t *testing.T) {
- c := newTestContext(t, true /* useNeighborCache */)
+func TestDirectRequest(t *testing.T) {
+ c := newTestContext(t)
defer c.cleanup()
tests := []struct {
@@ -748,3 +659,53 @@ func TestLinkAddressRequest(t *testing.T) {
})
}
}
+
+func TestDADARPRequestPacket(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocolWithOptions(arp.Options{
+ DADConfigs: stack.DADConfigurations{
+ DupAddrDetectTransmits: 1,
+ RetransmitTimer: time.Second,
+ },
+ }), ipv4.NewProtocol},
+ })
+ e := channel.New(1, defaultMTU, stackLinkAddr)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ if res, err := s.CheckDuplicateAddress(nicID, header.IPv4ProtocolNumber, remoteAddr, func(stack.DADResult) {}); err != nil {
+ t.Fatalf("s.CheckDuplicateAddress(%d, %d, %s, _): %s", nicID, header.IPv4ProtocolNumber, remoteAddr, err)
+ } else if res != stack.DADStarting {
+ t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, header.IPv4ProtocolNumber, remoteAddr, res, stack.DADStarting)
+ }
+
+ pkt, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatal("expected to send an ARP request")
+ }
+
+ if pkt.Route.RemoteLinkAddress != header.EthernetBroadcastAddress {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, header.EthernetBroadcastAddress)
+ }
+
+ req := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
+ if !req.IsValid() {
+ t.Errorf("got req.IsValid() = false, want = true")
+ }
+ if got := req.Op(); got != header.ARPRequest {
+ t.Errorf("got req.Op() = %d, want = %d", got, header.ARPRequest)
+ }
+ if got := tcpip.LinkAddress(req.HardwareAddressSender()); got != stackLinkAddr {
+ t.Errorf("got req.HardwareAddressSender() = %s, want = %s", got, stackLinkAddr)
+ }
+ if got := tcpip.Address(req.ProtocolAddressSender()); got != header.IPv4Any {
+ t.Errorf("got req.ProtocolAddressSender() = %s, want = %s", got, header.IPv4Any)
+ }
+ if got, want := tcpip.LinkAddress(req.HardwareAddressTarget()), tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"); got != want {
+ t.Errorf("got req.HardwareAddressTarget() = %s, want = %s", got, want)
+ }
+ if got := tcpip.Address(req.ProtocolAddressTarget()); got != remoteAddr {
+ t.Errorf("got req.ProtocolAddressTarget() = %s, want = %s", got, remoteAddr)
+ }
+}
diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go
index 65c708ac4..e867b3c3f 100644
--- a/pkg/tcpip/network/arp/stats_test.go
+++ b/pkg/tcpip/network/arp/stats_test.go
@@ -19,7 +19,7 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/internal/fragmentation/BUILD
index 429af69ee..274f09092 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/internal/fragmentation/BUILD
@@ -22,7 +22,10 @@ go_library(
"reassembler.go",
"reassembler_list.go",
],
- visibility = ["//visibility:public"],
+ visibility = [
+ "//pkg/tcpip/network/ipv4:__pkg__",
+ "//pkg/tcpip/network/ipv6:__pkg__",
+ ],
deps = [
"//pkg/log",
"//pkg/sync",
@@ -44,7 +47,7 @@ go_test(
deps = [
"//pkg/tcpip/buffer",
"//pkg/tcpip/faketime",
- "//pkg/tcpip/network/testutil",
+ "//pkg/tcpip/network/internal/testutil",
"//pkg/tcpip/stack",
"@com_github_google_go_cmp//cmp:go_default_library",
],
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/internal/fragmentation/fragmentation.go
index 243738951..243738951 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/internal/fragmentation/fragmentation.go
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go
index 905bbc19b..47ea3173e 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation_test.go
+++ b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go
@@ -22,7 +22,7 @@ import (
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/internal/fragmentation/reassembler.go
index 933d63d32..933d63d32 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/internal/fragmentation/reassembler.go
diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/internal/fragmentation/reassembler_test.go
index 214a93709..214a93709 100644
--- a/pkg/tcpip/network/fragmentation/reassembler_test.go
+++ b/pkg/tcpip/network/internal/fragmentation/reassembler_test.go
diff --git a/pkg/tcpip/network/ip/BUILD b/pkg/tcpip/network/internal/ip/BUILD
index 411bca25d..d21b4c7ef 100644
--- a/pkg/tcpip/network/ip/BUILD
+++ b/pkg/tcpip/network/internal/ip/BUILD
@@ -5,25 +5,35 @@ package(licenses = ["notice"])
go_library(
name = "ip",
srcs = [
+ "duplicate_address_detection.go",
"generic_multicast_protocol.go",
"stats.go",
],
- visibility = ["//visibility:public"],
+ visibility = [
+ "//pkg/tcpip/network/arp:__pkg__",
+ "//pkg/tcpip/network/ipv4:__pkg__",
+ "//pkg/tcpip/network/ipv6:__pkg__",
+ ],
deps = [
"//pkg/sync",
"//pkg/tcpip",
+ "//pkg/tcpip/stack",
],
)
go_test(
- name = "ip_test",
+ name = "ip_x_test",
size = "small",
- srcs = ["generic_multicast_protocol_test.go"],
+ srcs = [
+ "duplicate_address_detection_test.go",
+ "generic_multicast_protocol_test.go",
+ ],
deps = [
":ip",
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/faketime",
+ "//pkg/tcpip/stack",
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go
new file mode 100644
index 000000000..6f89a6a16
--- /dev/null
+++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go
@@ -0,0 +1,172 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ip holds IPv4/IPv6 common utilities.
+package ip
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type dadState struct {
+ done *bool
+ timer tcpip.Timer
+
+ completionHandlers []stack.DADCompletionHandler
+}
+
+// DADProtocol is a protocol whose core state machine can be represented by DAD.
+type DADProtocol interface {
+ // SendDADMessage attempts to send a DAD probe message.
+ SendDADMessage(tcpip.Address) tcpip.Error
+}
+
+// DADOptions holds options for DAD.
+type DADOptions struct {
+ Clock tcpip.Clock
+ Protocol DADProtocol
+ NICID tcpip.NICID
+}
+
+// DAD performs duplicate address detection for addresses.
+type DAD struct {
+ opts DADOptions
+ configs stack.DADConfigurations
+
+ protocolMU sync.Locker
+ addresses map[tcpip.Address]dadState
+}
+
+// Init initializes the DAD state.
+//
+// Must only be called once for the lifetime of d; Init will panic if it is
+// called twice.
+//
+// The lock will only be taken when timers fire.
+func (d *DAD) Init(protocolMU sync.Locker, configs stack.DADConfigurations, opts DADOptions) {
+ if d.addresses != nil {
+ panic("attempted to initialize DAD state twice")
+ }
+
+ *d = DAD{
+ opts: opts,
+ configs: configs,
+ protocolMU: protocolMU,
+ addresses: make(map[tcpip.Address]dadState),
+ }
+}
+
+// CheckDuplicateAddressLocked performs DAD for an address, calling the
+// completion handler once DAD resolves.
+//
+// If DAD is already performing for the provided address, h will be called when
+// the currently running process completes.
+//
+// Precondition: d.protocolMU must be locked.
+func (d *DAD) CheckDuplicateAddressLocked(addr tcpip.Address, h stack.DADCompletionHandler) stack.DADCheckAddressDisposition {
+ if d.configs.DupAddrDetectTransmits == 0 {
+ return stack.DADDisabled
+ }
+
+ ret := stack.DADAlreadyRunning
+ s, ok := d.addresses[addr]
+ if !ok {
+ ret = stack.DADStarting
+
+ remaining := d.configs.DupAddrDetectTransmits
+
+ // Protected by d.protocolMU.
+ done := false
+
+ s = dadState{
+ done: &done,
+ timer: d.opts.Clock.AfterFunc(0, func() {
+ var err tcpip.Error
+ dadDone := remaining == 0
+ if !dadDone {
+ err = d.opts.Protocol.SendDADMessage(addr)
+ }
+
+ d.protocolMU.Lock()
+ defer d.protocolMU.Unlock()
+
+ if done {
+ return
+ }
+
+ s, ok := d.addresses[addr]
+ if !ok {
+ panic(fmt.Sprintf("dad: timer fired but missing state for %s on NIC(%d)", addr, d.opts.NICID))
+ }
+
+ if !dadDone && err == nil {
+ remaining--
+ s.timer.Reset(d.configs.RetransmitTimer)
+ return
+ }
+
+ // At this point we know that either DAD has resolved or we hit an error
+ // sending the last DAD message. Either way, clear the DAD state.
+ done = false
+ s.timer.Stop()
+ delete(d.addresses, addr)
+
+ r := stack.DADResult{Resolved: dadDone, Err: err}
+ for _, h := range s.completionHandlers {
+ h(r)
+ }
+ }),
+ }
+ }
+
+ s.completionHandlers = append(s.completionHandlers, h)
+ d.addresses[addr] = s
+ return ret
+}
+
+// StopLocked stops a currently running DAD process.
+//
+// Precondition: d.protocolMU must be locked.
+func (d *DAD) StopLocked(addr tcpip.Address, aborted bool) {
+ s, ok := d.addresses[addr]
+ if !ok {
+ return
+ }
+
+ *s.done = true
+ s.timer.Stop()
+ delete(d.addresses, addr)
+
+ var err tcpip.Error
+ if aborted {
+ err = &tcpip.ErrAborted{}
+ }
+
+ r := stack.DADResult{Resolved: false, Err: err}
+ for _, h := range s.completionHandlers {
+ h(r)
+ }
+}
+
+// SetConfigsLocked sets the DAD configurations.
+//
+// Precondition: d.protocolMU must be locked.
+func (d *DAD) SetConfigsLocked(c stack.DADConfigurations) {
+ c.Validate()
+ d.configs = c
+}
diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go
new file mode 100644
index 000000000..18c357b56
--- /dev/null
+++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go
@@ -0,0 +1,279 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ip_test
+
+import (
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type mockDADProtocol struct {
+ t *testing.T
+
+ mu struct {
+ sync.Mutex
+
+ dad ip.DAD
+ sendCount map[tcpip.Address]int
+ }
+}
+
+func (m *mockDADProtocol) init(t *testing.T, c stack.DADConfigurations, opts ip.DADOptions) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ m.t = t
+ opts.Protocol = m
+ m.mu.dad.Init(&m.mu, c, opts)
+ m.initLocked()
+}
+
+func (m *mockDADProtocol) initLocked() {
+ m.mu.sendCount = make(map[tcpip.Address]int)
+}
+
+func (m *mockDADProtocol) SendDADMessage(addr tcpip.Address) tcpip.Error {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.sendCount[addr]++
+ return nil
+}
+
+func (m *mockDADProtocol) check(addrs []tcpip.Address) string {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ sendCount := make(map[tcpip.Address]int)
+ for _, a := range addrs {
+ sendCount[a]++
+ }
+
+ diff := cmp.Diff(sendCount, m.mu.sendCount)
+ m.initLocked()
+ return diff
+}
+
+func (m *mockDADProtocol) checkDuplicateAddress(addr tcpip.Address, h stack.DADCompletionHandler) stack.DADCheckAddressDisposition {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ return m.mu.dad.CheckDuplicateAddressLocked(addr, h)
+}
+
+func (m *mockDADProtocol) stop(addr tcpip.Address, aborted bool) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.dad.StopLocked(addr, aborted)
+}
+
+func (m *mockDADProtocol) setConfigs(c stack.DADConfigurations) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.mu.dad.SetConfigsLocked(c)
+}
+
+const (
+ addr1 = tcpip.Address("\x01")
+ addr2 = tcpip.Address("\x02")
+ addr3 = tcpip.Address("\x03")
+ addr4 = tcpip.Address("\x04")
+)
+
+type dadResult struct {
+ Addr tcpip.Address
+ R stack.DADResult
+}
+
+func handler(ch chan<- dadResult, a tcpip.Address) func(stack.DADResult) {
+ return func(r stack.DADResult) {
+ ch <- dadResult{Addr: a, R: r}
+ }
+}
+
+func TestDADCheckDuplicateAddress(t *testing.T) {
+ var dad mockDADProtocol
+ clock := faketime.NewManualClock()
+ dad.init(t, stack.DADConfigurations{}, ip.DADOptions{
+ Clock: clock,
+ })
+
+ ch := make(chan dadResult, 2)
+
+ // DAD should initially be disabled.
+ if res := dad.checkDuplicateAddress(addr1, handler(nil, "")); res != stack.DADDisabled {
+ t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADDisabled)
+ }
+ // Wait for any initially fired timers to complete.
+ clock.Advance(0)
+ if diff := dad.check(nil); diff != "" {
+ t.Errorf("dad check mismatch (-want +got):\n%s", diff)
+ }
+
+ // Enable and request DAD.
+ dadConfigs1 := stack.DADConfigurations{
+ DupAddrDetectTransmits: 1,
+ RetransmitTimer: time.Second,
+ }
+ dad.setConfigs(dadConfigs1)
+ if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting {
+ t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting)
+ }
+ clock.Advance(0)
+ if diff := dad.check([]tcpip.Address{addr1}); diff != "" {
+ t.Errorf("dad check mismatch (-want +got):\n%s", diff)
+ }
+ // The second request for DAD on the same address should use the original
+ // request since it has not completed yet.
+ if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADAlreadyRunning {
+ t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADAlreadyRunning)
+ }
+ clock.Advance(0)
+ if diff := dad.check(nil); diff != "" {
+ t.Errorf("dad check mismatch (-want +got):\n%s", diff)
+ }
+
+ dadConfigs2 := stack.DADConfigurations{
+ DupAddrDetectTransmits: 2,
+ RetransmitTimer: time.Second,
+ }
+ dad.setConfigs(dadConfigs2)
+ // A new address should start a new DAD process.
+ if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting {
+ t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting)
+ }
+ clock.Advance(0)
+ if diff := dad.check([]tcpip.Address{addr2}); diff != "" {
+ t.Errorf("dad check mismatch (-want +got):\n%s", diff)
+ }
+
+ // Make sure DAD for addr1 only resolves after the expected timeout.
+ const delta = time.Nanosecond
+ dadConfig1Duration := time.Duration(dadConfigs1.DupAddrDetectTransmits) * dadConfigs1.RetransmitTimer
+ clock.Advance(dadConfig1Duration - delta)
+ select {
+ case r := <-ch:
+ t.Fatalf("unexpectedly got a DAD result before the expected timeout of %s; r = %#v", dadConfig1Duration, r)
+ default:
+ }
+ clock.Advance(delta)
+ for i := 0; i < 2; i++ {
+ if diff := cmp.Diff(dadResult{Addr: addr1, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" {
+ t.Errorf("(i=%d) dad result mismatch (-want +got):\n%s", i, diff)
+ }
+ }
+
+ // Make sure DAD for addr2 only resolves after the expected timeout.
+ dadConfig2Duration := time.Duration(dadConfigs2.DupAddrDetectTransmits) * dadConfigs2.RetransmitTimer
+ clock.Advance(dadConfig2Duration - dadConfig1Duration - delta)
+ select {
+ case r := <-ch:
+ t.Fatalf("unexpectedly got a DAD result before the expected timeout of %s; r = %#v", dadConfig2Duration, r)
+ default:
+ }
+ clock.Advance(delta)
+ if diff := cmp.Diff(dadResult{Addr: addr2, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" {
+ t.Errorf("dad result mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should be able to restart DAD for addr2 after it resolved.
+ if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting {
+ t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting)
+ }
+ clock.Advance(0)
+ if diff := dad.check([]tcpip.Address{addr2, addr2}); diff != "" {
+ t.Errorf("dad check mismatch (-want +got):\n%s", diff)
+ }
+ clock.Advance(dadConfig2Duration)
+ if diff := cmp.Diff(dadResult{Addr: addr2, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" {
+ t.Errorf("dad result mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have anymore results.
+ select {
+ case r := <-ch:
+ t.Fatalf("unexpectedly got an extra DAD result; r = %#v", r)
+ default:
+ }
+}
+
+func TestDADStop(t *testing.T) {
+ var dad mockDADProtocol
+ clock := faketime.NewManualClock()
+ dadConfigs := stack.DADConfigurations{
+ DupAddrDetectTransmits: 1,
+ RetransmitTimer: time.Second,
+ }
+ dad.init(t, dadConfigs, ip.DADOptions{
+ Clock: clock,
+ })
+
+ ch := make(chan dadResult, 1)
+
+ if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting {
+ t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting)
+ }
+ if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting {
+ t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting)
+ }
+ if res := dad.checkDuplicateAddress(addr3, handler(ch, addr3)); res != stack.DADStarting {
+ t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting)
+ }
+ clock.Advance(0)
+ if diff := dad.check([]tcpip.Address{addr1, addr2, addr3}); diff != "" {
+ t.Errorf("dad check mismatch (-want +got):\n%s", diff)
+ }
+
+ dad.stop(addr1, true /* aborted */)
+ if diff := cmp.Diff(dadResult{Addr: addr1, R: stack.DADResult{Resolved: false, Err: &tcpip.ErrAborted{}}}, <-ch); diff != "" {
+ t.Errorf("dad result mismatch (-want +got):\n%s", diff)
+ }
+
+ dad.stop(addr2, false /* aborted */)
+ if diff := cmp.Diff(dadResult{Addr: addr2, R: stack.DADResult{Resolved: false, Err: nil}}, <-ch); diff != "" {
+ t.Errorf("dad result mismatch (-want +got):\n%s", diff)
+ }
+
+ dadResolutionDuration := time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer
+ clock.Advance(dadResolutionDuration)
+ if diff := cmp.Diff(dadResult{Addr: addr3, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" {
+ t.Errorf("dad result mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should be able to restart DAD for an address we stopped DAD on.
+ if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting {
+ t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting)
+ }
+ clock.Advance(0)
+ if diff := dad.check([]tcpip.Address{addr1}); diff != "" {
+ t.Errorf("dad check mismatch (-want +got):\n%s", diff)
+ }
+ clock.Advance(dadResolutionDuration)
+ if diff := cmp.Diff(dadResult{Addr: addr1, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" {
+ t.Errorf("dad result mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have anymore updates.
+ select {
+ case r := <-ch:
+ t.Fatalf("unexpectedly got an extra DAD result; r = %#v", r)
+ default:
+ }
+}
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
index b9f129728..b9f129728 100644
--- a/pkg/tcpip/network/ip/generic_multicast_protocol.go
+++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
index 60eaea37e..381460c82 100644
--- a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
+++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
@@ -23,17 +23,10 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
- "gvisor.dev/gvisor/pkg/tcpip/network/ip"
+ "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip"
)
-const (
- addr1 = tcpip.Address("\x01")
- addr2 = tcpip.Address("\x02")
- addr3 = tcpip.Address("\x03")
- addr4 = tcpip.Address("\x04")
-
- maxUnsolicitedReportDelay = time.Second
-)
+const maxUnsolicitedReportDelay = time.Second
var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil)
diff --git a/pkg/tcpip/network/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go
index 898f8b356..898f8b356 100644
--- a/pkg/tcpip/network/ip/stats.go
+++ b/pkg/tcpip/network/internal/ip/stats.go
diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/internal/testutil/BUILD
index bd62c4482..1c4f583c7 100644
--- a/pkg/tcpip/network/testutil/BUILD
+++ b/pkg/tcpip/network/internal/testutil/BUILD
@@ -10,7 +10,7 @@ go_library(
],
visibility = [
"//pkg/tcpip/network/arp:__pkg__",
- "//pkg/tcpip/network/fragmentation:__pkg__",
+ "//pkg/tcpip/network/internal/fragmentation:__pkg__",
"//pkg/tcpip/network/ipv4:__pkg__",
"//pkg/tcpip/network/ipv6:__pkg__",
],
diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go
index f5fa77b65..f5fa77b65 100644
--- a/pkg/tcpip/network/testutil/testutil.go
+++ b/pkg/tcpip/network/internal/testutil/testutil.go
diff --git a/pkg/tcpip/network/testutil/testutil_unsafe.go b/pkg/tcpip/network/internal/testutil/testutil_unsafe.go
index 5ff764800..5ff764800 100644
--- a/pkg/tcpip/network/testutil/testutil_unsafe.go
+++ b/pkg/tcpip/network/internal/testutil/testutil_unsafe.go
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 6a1f11a36..90236ed9e 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -24,7 +24,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@@ -315,6 +314,10 @@ func (*testInterface) Promiscuous() bool {
return false
}
+func (*testInterface) Spoofing() bool {
+ return false
+}
+
func (t *testInterface) setEnabled(v bool) {
t.mu.Lock()
defer t.mu.Unlock()
@@ -333,6 +336,10 @@ func (*testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tc
return nil
}
+func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool {
+ return false
+}
+
func TestSourceAddressValidation(t *testing.T) {
rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) {
totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
@@ -626,9 +633,6 @@ func TestReceive(t *testing.T) {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: view.ToVectorisedView(),
})
- if ok := parse.IPv4(pkt); !ok {
- t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
- }
ep.HandlePacket(pkt)
},
},
@@ -664,9 +668,6 @@ func TestReceive(t *testing.T) {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: view.ToVectorisedView(),
})
- if _, _, _, _, ok := parse.IPv6(pkt); !ok {
- t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
- }
ep.HandlePacket(pkt)
},
},
@@ -943,9 +944,6 @@ func TestIPv4FragmentationReceive(t *testing.T) {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: frag1.ToVectorisedView(),
})
- if _, _, ok := proto.Parse(pkt); !ok {
- t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
- }
addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
if !ok {
@@ -967,9 +965,6 @@ func TestIPv4FragmentationReceive(t *testing.T) {
pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: frag2.ToVectorisedView(),
})
- if _, _, ok := proto.Parse(pkt); !ok {
- t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
- }
ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
@@ -1234,7 +1229,6 @@ func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: v.ToVectorisedView(),
})
- _, _ = pkt.NetworkHeader().Consume(netHdrLen)
return pkt
}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 9713c4448..4b21ee79c 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -17,9 +17,9 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/header/parse",
- "//pkg/tcpip/network/fragmentation",
"//pkg/tcpip/network/hash",
- "//pkg/tcpip/network/ip",
+ "//pkg/tcpip/network/internal/fragmentation",
+ "//pkg/tcpip/network/internal/ip",
"//pkg/tcpip/stack",
],
)
@@ -40,8 +40,8 @@ go_test(
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/arp",
+ "//pkg/tcpip/network/internal/testutil",
"//pkg/tcpip/network/ipv4",
- "//pkg/tcpip/network/testutil",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/raw",
@@ -59,7 +59,7 @@ go_test(
library = ":ipv4",
deps = [
"//pkg/tcpip",
- "//pkg/tcpip/network/testutil",
+ "//pkg/tcpip/network/internal/testutil",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 74e70e283..b44304cee 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -120,6 +120,18 @@ func (*icmpv4FragmentationNeededSockError) Kind() stack.TransportErrorKind {
return stack.PacketTooBigTransportError
}
+func (e *endpoint) checkLocalAddress(addr tcpip.Address) bool {
+ if e.nic.Spoofing() {
+ return true
+ }
+
+ if addressEndpoint := e.AcquireAssignedAddress(addr, false, stack.NeverPrimaryEndpoint); addressEndpoint != nil {
+ addressEndpoint.DecRef()
+ return true
+ }
+ return false
+}
+
// handleControl handles the case when an ICMP error packet contains the headers
// of the original packet that caused the ICMP one to be sent. This information
// is used to find out which transport endpoint must be notified about the ICMP
@@ -139,7 +151,7 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet
// Drop packet if it doesn't have the basic IPv4 header or if the
// original source address doesn't match an address we own.
srcAddr := hdr.SourceAddress()
- if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, srcAddr) == 0 {
+ if !e.checkLocalAddress(srcAddr) {
return
}
@@ -201,7 +213,8 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
} else {
op = &optionUsageReceive{}
}
- tmp, optProblem := e.processIPOptions(pkt, opts, op)
+ var optProblem *header.IPv4OptParameterProblem
+ newOptions, optProblem = e.processIPOptions(pkt, opts, op)
if optProblem != nil {
if optProblem.NeedICMP {
_ = e.protocol.returnError(&icmpReasonParamProblem{
@@ -212,7 +225,14 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
}
return
}
- newOptions = tmp
+ copied := copy(opts, newOptions)
+ if copied != len(newOptions) {
+ panic(fmt.Sprintf("copied %d bytes of new options, expected %d bytes", copied, len(newOptions)))
+ }
+ for i := copied; i < len(opts); i++ {
+ // Pad with 0 (EOL). RFC 791 page 23 says "The padding is zero".
+ opts[i] = byte(header.IPv4OptionListEndType)
+ }
}
// TODO(b/112892170): Meaningfully handle all ICMP types.
@@ -363,6 +383,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
// icmpReason is a marker interface for IPv4 specific ICMP errors.
type icmpReason interface {
isICMPReason()
+ isForwarding() bool
}
// icmpReasonPortUnreachable is an error where the transport protocol has no
@@ -370,12 +391,18 @@ type icmpReason interface {
type icmpReasonPortUnreachable struct{}
func (*icmpReasonPortUnreachable) isICMPReason() {}
+func (*icmpReasonPortUnreachable) isForwarding() bool {
+ return false
+}
// icmpReasonProtoUnreachable is an error where the transport protocol is
// not supported.
type icmpReasonProtoUnreachable struct{}
func (*icmpReasonProtoUnreachable) isICMPReason() {}
+func (*icmpReasonProtoUnreachable) isForwarding() bool {
+ return false
+}
// icmpReasonTTLExceeded is an error where a packet's time to live exceeded in
// transit to its final destination, as per RFC 792 page 6, Time Exceeded
@@ -383,6 +410,15 @@ func (*icmpReasonProtoUnreachable) isICMPReason() {}
type icmpReasonTTLExceeded struct{}
func (*icmpReasonTTLExceeded) isICMPReason() {}
+func (*icmpReasonTTLExceeded) isForwarding() bool {
+ // If we hit a TTL Exceeded error, then we know we are operating as a router.
+ // As per RFC 792 page 6, Time Exceeded Message,
+ //
+ // If the gateway processing a datagram finds the time to live field
+ // is zero it must discard the datagram. The gateway may also notify
+ // the source host via the time exceeded message.
+ return true
+}
// icmpReasonReassemblyTimeout is an error where insufficient fragments are
// received to complete reassembly of a packet within a configured time after
@@ -390,14 +426,21 @@ func (*icmpReasonTTLExceeded) isICMPReason() {}
type icmpReasonReassemblyTimeout struct{}
func (*icmpReasonReassemblyTimeout) isICMPReason() {}
+func (*icmpReasonReassemblyTimeout) isForwarding() bool {
+ return false
+}
// icmpReasonParamProblem is an error to use to request a Parameter Problem
// message to be sent.
type icmpReasonParamProblem struct {
- pointer byte
+ pointer byte
+ forwarding bool
}
func (*icmpReasonParamProblem) isICMPReason() {}
+func (r *icmpReasonParamProblem) isForwarding() bool {
+ return r.forwarding
+}
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv4 and sends it back to the remote device that sent
@@ -436,26 +479,14 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
return nil
}
- // If we hit a TTL Exceeded error, then we know we are operating as a router.
- // As per RFC 792 page 6, Time Exceeded Message,
- //
- // If the gateway processing a datagram finds the time to live field
- // is zero it must discard the datagram. The gateway may also notify
- // the source host via the time exceeded message.
- //
- // ...
- //
- // Code 0 may be received from a gateway. ...
- //
- // Note, Code 0 is the TTL exceeded error.
- //
// If we are operating as a router/gateway, don't use the packet's destination
// address as the response's source address as we should not not own the
// destination address of a packet we are forwarding.
localAddr := origIPHdrDst
- if _, ok := reason.(*icmpReasonTTLExceeded); ok {
+ if reason.isForwarding() {
localAddr = ""
}
+
// Even if we were able to receive a packet from some remote, we may not have
// a route to it - the remote may be blocked via routing rules. We must always
// consult our routing table and find a route to the remote before sending any
diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go
index acc126c3b..12632aceb 100644
--- a/pkg/tcpip/network/ipv4/igmp.go
+++ b/pkg/tcpip/network/ipv4/igmp.go
@@ -22,7 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/network/ip"
+ "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index b2d626107..250e4846a 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -27,8 +27,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/header/parse"
- "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
"gvisor.dev/gvisor/pkg/tcpip/network/hash"
+ "gvisor.dev/gvisor/pkg/tcpip/network/internal/fragmentation"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -130,6 +130,20 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran
return e
}
+func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+
+ for _, e := range p.mu.eps {
+ if addressEndpoint := e.AcquireAssignedAddress(addr, false /* allowTemp */, stack.NeverPrimaryEndpoint); addressEndpoint != nil {
+ addressEndpoint.DecRef()
+ return e
+ }
+ }
+
+ return nil
+}
+
func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
p.mu.Lock()
defer p.mu.Unlock()
@@ -347,15 +361,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv4(pkt.NetworkHeader().View())
- ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress())
- if err == nil {
- pkt := pkt.CloneToInbound()
- if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- // Since we rewrote the packet but it is being routed back to us, we can
- // safely assume the checksum is valid.
- pkt.RXTransportChecksumValidated = true
- ep.(*endpoint).handlePacket(pkt)
- }
+ if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil {
+ // Since we rewrote the packet but it is being routed back to us, we
+ // can safely assume the checksum is valid.
+ ep.handleLocalPacket(pkt, true /* canSkipRXChecksum */)
return nil
}
}
@@ -365,14 +374,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error {
if r.Loop&stack.PacketLoop != 0 {
- pkt := pkt.CloneToInbound()
- if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- // If the packet was generated by the stack (not a raw/packet endpoint
- // where a packet may be written with the header included), then we can
- // safely assume the checksum is valid.
- pkt.RXTransportChecksumValidated = !headerIncluded
- e.handlePacket(pkt)
- }
+ // If the packet was generated by the stack (not a raw/packet endpoint
+ // where a packet may be written with the header included), then we can
+ // safely assume the checksum is valid.
+ e.handleLocalPacket(pkt, !headerIncluded /* canSkipRXChecksum */)
}
if r.Loop&stack.PacketOut == 0 {
return nil
@@ -450,51 +455,37 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// iptables filtering. All packets that reach here are locally
// generated.
dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "", outNicName)
- if len(dropped) == 0 && len(natPkts) == 0 {
- // Fast path: If no packets are to be dropped then we can just invoke the
- // faster WritePackets API directly.
- n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
- stats.PacketsSent.IncrementBy(uint64(n))
- if err != nil {
- stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
- }
- return n, err
- }
stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
+ for pkt := range dropped {
+ pkts.Remove(pkt)
+ }
- // Slow path as we are dropping some packets in the batch degrade to
- // emitting one packet at a time.
- n := 0
- for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- if _, ok := dropped[pkt]; ok {
+ // The NAT-ed packets may now be destined for us.
+ locallyDelivered := 0
+ for pkt := range natPkts {
+ ep := e.protocol.findEndpointWithAddress(header.IPv4(pkt.NetworkHeader().View()).DestinationAddress())
+ if ep == nil {
+ // The NAT-ed packet is still destined for some remote node.
continue
}
- if _, ok := natPkts[pkt]; ok {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
- pkt := pkt.CloneToInbound()
- if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- // Since we rewrote the packet but it is being routed back to us, we
- // can safely assume the checksum is valid.
- pkt.RXTransportChecksumValidated = true
- ep.(*endpoint).handlePacket(pkt)
- }
- n++
- continue
- }
- }
- if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
- stats.PacketsSent.IncrementBy(uint64(n))
- stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped)))
- // Dropped packets aren't errors, so include them in
- // the return value.
- return n + len(dropped), err
- }
- n++
+
+ // Do not send the locally destined packet out the NIC.
+ pkts.Remove(pkt)
+
+ // Deliver the packet locally.
+ ep.handleLocalPacket(pkt, true /* canSkipRXChecksum */)
+ locallyDelivered++
+
}
- stats.PacketsSent.IncrementBy(uint64(n))
+
+ // The rest of the packets can be delivered to the NIC as a batch.
+ pktsLen := pkts.Len()
+ written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
+ stats.PacketsSent.IncrementBy(uint64(written))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written))
+
// Dropped packets aren't errors, so include them in the return value.
- return n + len(dropped), nil
+ return locallyDelivered + written + len(dropped), err
}
// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
@@ -570,17 +561,41 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
return e.protocol.returnError(&icmpReasonTTLExceeded{}, pkt)
}
+ if opts := h.Options(); len(opts) != 0 {
+ newOpts, optProblem := e.processIPOptions(pkt, opts, &optionUsageForward{})
+ if optProblem != nil {
+ if optProblem.NeedICMP {
+ _ = e.protocol.returnError(&icmpReasonParamProblem{
+ pointer: optProblem.Pointer,
+ forwarding: true,
+ }, pkt)
+ e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
+ e.stats.ip.MalformedPacketsReceived.Increment()
+ }
+ return nil // option problems are not reported locally.
+ }
+ copied := copy(opts, newOpts)
+ if copied != len(newOpts) {
+ panic(fmt.Sprintf("copied %d bytes of new options, expected %d bytes", copied, len(newOpts)))
+ }
+ // Since in forwarding we handle all options, including copying those we
+ // do not recognise, the options region should remain the same size which
+ // simplifies processing. As we MAY receive a packet with a lot of padded
+ // bytes after the "end of options list" byte, make sure we copy
+ // them as the legal padding value (0).
+ for i := copied; i < len(opts); i++ {
+ // Pad with 0 (EOL). RFC 791 page 23 says "The padding is zero".
+ opts[i] = byte(header.IPv4OptionListEndType)
+ }
+ }
+
dstAddr := h.DestinationAddress()
// Check if the destination is owned by the stack.
- networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr)
- if err == nil {
- networkEndpoint.(*endpoint).handlePacket(pkt)
+ if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
+ ep.handlePacket(pkt)
return nil
}
- if _, ok := err.(*tcpip.ErrBadAddress); !ok {
- return err
- }
r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
if err != nil {
@@ -619,8 +634,26 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
- // Loopback traffic skips the prerouting chain.
+ if !e.protocol.parse(pkt) {
+ stats.MalformedPacketsReceived.Increment()
+ return
+ }
+
if !e.nic.IsLoopback() {
+ if e.protocol.stack.HandleLocal() {
+ addressEndpoint := e.AcquireAssignedAddress(header.IPv4(pkt.NetworkHeader().View()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
+ if addressEndpoint != nil {
+ addressEndpoint.DecRef()
+
+ // The source address is one of our own, so we never should have gotten
+ // a packet like this unless HandleLocal is false or our NIC is the
+ // loopback interface.
+ stats.InvalidSourceAddressesReceived.Increment()
+ return
+ }
+ }
+
+ // Loopback traffic skips the prerouting chain.
inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
@@ -632,6 +665,21 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
e.handlePacket(pkt)
}
+func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum bool) {
+ stats := e.stats.ip
+
+ stats.PacketsReceived.Increment()
+
+ pkt = pkt.CloneToInbound()
+ if e.protocol.parse(pkt) {
+ pkt.RXTransportChecksumValidated = canSkipRXChecksum
+ e.handlePacket(pkt)
+ return
+ }
+
+ stats.MalformedPacketsReceived.Increment()
+}
+
// handlePacket is like HandlePacket except it does not perform the prerouting
// iptables hook.
func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
@@ -691,8 +739,9 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
}
}
- // The destination address should be an address we own or a group we joined
- // for us to receive the packet. Otherwise, attempt to forward the packet.
+ // Before we do any processing, note if the packet was received as some
+ // sort of broadcast. The destination address should be an address we own
+ // or a group we joined.
if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil {
subnet := addressEndpoint.AddressWithPrefix().Subnet()
addressEndpoint.DecRef()
@@ -702,7 +751,6 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
stats.ip.InvalidDestinationAddressesReceived.Increment()
return
}
-
_ = e.forwardPacket(pkt)
return
}
@@ -724,6 +772,21 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
stats.ip.MalformedFragmentsReceived.Increment()
return
}
+ if opts := h.Options(); len(opts) != 0 {
+ // If there are options we need to check them before we do assembly
+ // or we could be assembling errant packets. However we do not change the
+ // options as that could lead to double processing later.
+ if _, optProblem := e.processIPOptions(pkt, opts, &optionUsageVerify{}); optProblem != nil {
+ if optProblem.NeedICMP {
+ _ = e.protocol.returnError(&icmpReasonParamProblem{
+ pointer: optProblem.Pointer,
+ }, pkt)
+ e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
+ e.stats.ip.MalformedPacketsReceived.Increment()
+ }
+ return
+ }
+ }
// The packet is a fragment, let's try to reassemble it.
start := h.FragmentOffset()
// Drop the fragment if the size of the reassembled payload would exceed the
@@ -782,17 +845,10 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
e.handleICMP(pkt)
return
}
- if p == header.IGMPProtocolNumber {
- e.mu.Lock()
- e.mu.igmp.handleIGMP(pkt)
- e.mu.Unlock()
- return
- }
+ // ICMP handles options itself but do it here for all remaining destinations.
if opts := h.Options(); len(opts) != 0 {
- // TODO(gvisor.dev/issue/4586):
- // When we add forwarding support we should use the verified options
- // rather than just throwing them away.
- if _, optProblem := e.processIPOptions(pkt, opts, &optionUsageReceive{}); optProblem != nil {
+ newOpts, optProblem := e.processIPOptions(pkt, opts, &optionUsageReceive{})
+ if optProblem != nil {
if optProblem.NeedICMP {
_ = e.protocol.returnError(&icmpReasonParamProblem{
pointer: optProblem.Pointer,
@@ -802,6 +858,20 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
}
return
}
+ copied := copy(opts, newOpts)
+ if copied != len(newOpts) {
+ panic(fmt.Sprintf("copied %d bytes of new options, expected %d bytes", copied, len(newOpts)))
+ }
+ for i := copied; i < len(opts); i++ {
+ // Pad with 0 (EOL). RFC 791 page 23 says "The padding is zero".
+ opts[i] = byte(header.IPv4OptionListEndType)
+ }
+ }
+ if p == header.IGMPProtocolNumber {
+ e.mu.Lock()
+ e.mu.igmp.handleIGMP(pkt)
+ e.mu.Unlock()
+ return
}
switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res {
@@ -1043,6 +1113,29 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
+// parse is like Parse but also attempts to parse the transport layer.
+//
+// Returns true if the network header was successfully parsed.
+func (p *protocol) parse(pkt *stack.PacketBuffer) bool {
+ transProtoNum, hasTransportHdr, ok := p.Parse(pkt)
+ if !ok {
+ return false
+ }
+
+ if hasTransportHdr {
+ switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err {
+ case stack.ParsedOK:
+ case stack.UnknownTransportProtocol, stack.TransportLayerParseError:
+ // The transport layer will handle unknown protocols and transport layer
+ // parsing errors.
+ default:
+ panic(fmt.Sprintf("unexpected error parsing transport header = %d", err))
+ }
+ }
+
+ return true
+}
+
// Parse implements stack.NetworkProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
if ok := parse.IPv4(pkt); !ok {
@@ -1211,23 +1304,49 @@ type optionsUsage interface {
actions() optionActions
}
-// optionUsageReceive implements optionsUsage for received packets.
-type optionUsageReceive struct{}
+// optionUsageVerify implements optionsUsage for when we just want to check
+// fragments. Don't change anything, just check and reject if bad. No
+// replacement options are generated.
+type optionUsageVerify struct{}
// actions implements optionsUsage.
-func (*optionUsageReceive) actions() optionActions {
+func (*optionUsageVerify) actions() optionActions {
return optionActions{
timestamp: optionVerify,
recordRoute: optionVerify,
+ unknown: optionRemove,
+ }
+}
+
+// optionUsageReceive implements optionsUsage for packets we will pass
+// to the transport layer (with the exception of Echo requests).
+type optionUsageReceive struct{}
+
+// actions implements optionsUsage.
+func (*optionUsageReceive) actions() optionActions {
+ return optionActions{
+ timestamp: optionProcess,
+ recordRoute: optionProcess,
unknown: optionPass,
}
}
-// TODO(gvisor.dev/issue/4586): Add an entry here for forwarding when it
-// is enabled (Process, Process, Pass) and for fragmenting (Process, Process,
-// Pass for frag1, but Remove,Remove,Remove for all other frags).
+// optionUsageForward implements optionsUsage for packets about to be forwarded.
+// All options are passed on regardless of whether we recognise them, however
+// we do process the Timestamp and Record Route options.
+type optionUsageForward struct{}
+
+// actions implements optionsUsage.
+func (*optionUsageForward) actions() optionActions {
+ return optionActions{
+ timestamp: optionProcess,
+ recordRoute: optionProcess,
+ unknown: optionPass,
+ }
+}
// optionUsageEcho implements optionsUsage for echo packet processing.
+// Only Timestamp and RecordRoute are processed and sent back.
type optionUsageEcho struct{}
// actions implements optionsUsage.
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index a296bed79..dc4db6e5f 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -34,8 +34,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
+ "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
- "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/raw"
@@ -111,10 +111,11 @@ func TestExcludeBroadcast(t *testing.T) {
func TestForwarding(t *testing.T) {
const (
- nicID1 = 1
- nicID2 = 2
- randomSequence = 123
- randomIdent = 42
+ nicID1 = 1
+ nicID2 = 2
+ randomSequence = 123
+ randomIdent = 42
+ randomTimeOffset = 0x10203040
)
ipv4Addr1 := tcpip.AddressWithPrefix{
@@ -129,14 +130,20 @@ func TestForwarding(t *testing.T) {
remoteIPv4Addr2 := tcpip.Address(net.ParseIP("11.0.0.2").To4())
tests := []struct {
- name string
- TTL uint8
- expectErrorICMP bool
+ name string
+ TTL uint8
+ expectErrorICMP bool
+ options header.IPv4Options
+ forwardedOptions header.IPv4Options
+ icmpType header.ICMPv4Type
+ icmpCode header.ICMPv4Code
}{
{
name: "TTL of zero",
TTL: 0,
expectErrorICMP: true,
+ icmpType: header.ICMPv4TimeExceeded,
+ icmpCode: header.ICMPv4TTLExceeded,
},
{
name: "TTL of one",
@@ -153,14 +160,78 @@ func TestForwarding(t *testing.T) {
TTL: math.MaxUint8,
expectErrorICMP: false,
},
+ {
+ name: "four EOL options",
+ TTL: 2,
+ expectErrorICMP: false,
+ options: header.IPv4Options{0, 0, 0, 0},
+ forwardedOptions: header.IPv4Options{0, 0, 0, 0},
+ },
+ {
+ name: "TS type 1 full",
+ TTL: 2,
+ options: header.IPv4Options{
+ 68, 12, 13, 0xF1,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ },
+ expectErrorICMP: true,
+ icmpType: header.ICMPv4ParamProblem,
+ icmpCode: header.ICMPv4UnusedCode,
+ },
+ {
+ name: "TS type 0",
+ TTL: 2,
+ options: header.IPv4Options{
+ 68, 24, 21, 0x00,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 0, 0, 0, 0,
+ },
+ forwardedOptions: header.IPv4Options{
+ 68, 24, 25, 0x00,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock
+ },
+ },
+ {
+ name: "end of options list",
+ TTL: 2,
+ options: header.IPv4Options{
+ 68, 12, 13, 0x11,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 0, 10, 3, 99, // EOL followed by junk
+ 1, 2, 3, 4,
+ },
+ forwardedOptions: header.IPv4Options{
+ 68, 12, 13, 0x21,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 0, // End of Options hides following bytes.
+ 0, 0, 0, // 7 bytes unknown option removed.
+ 0, 0, 0, 0,
+ },
+ },
}
-
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
+ Clock: clock,
})
+
+ // Advance the clock by some unimportant amount to make
+ // it give a more recognisable signature than 00,00,00,00.
+ clock.Advance(time.Millisecond * randomTimeOffset)
+
// We expect at most a single packet in response to our ICMP Echo Request.
e1 := channel.New(1, ipv4.MaxTotalSize, "")
if err := s.CreateNIC(nicID1, e1); err != nil {
@@ -195,7 +266,11 @@ func TestForwarding(t *testing.T) {
t.Fatalf("SetForwarding(%d, true): %s", header.IPv4ProtocolNumber, err)
}
- totalLen := uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize)
+ ipHeaderLength := header.IPv4MinimumSize + len(test.options)
+ if ipHeaderLength > header.IPv4MaximumHeaderSize {
+ t.Fatalf("got ipHeaderLength = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
+ }
+ totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize)
hdr := buffer.NewPrependable(int(totalLen))
icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
icmp.SetIdent(randomIdent)
@@ -204,7 +279,7 @@ func TestForwarding(t *testing.T) {
icmp.SetCode(header.ICMPv4UnusedCode)
icmp.SetChecksum(0)
icmp.SetChecksum(^header.Checksum(icmp, 0))
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip := header.IPv4(hdr.Prepend(ipHeaderLength))
ip.Encode(&header.IPv4Fields{
TotalLength: totalLen,
Protocol: uint8(header.ICMPv4ProtocolNumber),
@@ -212,6 +287,14 @@ func TestForwarding(t *testing.T) {
SrcAddr: remoteIPv4Addr1,
DstAddr: remoteIPv4Addr2,
})
+ if len(test.options) != 0 {
+ ip.SetHeaderLength(uint8(ipHeaderLength))
+ // Copy options manually. We do not use Encode for options so we can
+ // verify malformed options with handcrafted payloads.
+ if want, got := copy(ip.Options(), test.options), len(test.options); want != got {
+ t.Fatalf("got copy(ip.Options(), test.options) = %d, want = %d", got, want)
+ }
+ }
ip.SetChecksum(0)
ip.SetChecksum(^ip.CalculateChecksum())
requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -222,7 +305,7 @@ func TestForwarding(t *testing.T) {
if test.expectErrorICMP {
reply, ok := e1.Read()
if !ok {
- t.Fatal("expected ICMP TTL Exceeded packet through incoming NIC")
+ t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType)
}
checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())),
@@ -231,8 +314,8 @@ func TestForwarding(t *testing.T) {
checker.TTL(ipv4.DefaultTTL),
checker.ICMPv4(
checker.ICMPv4Checksum(),
- checker.ICMPv4Type(header.ICMPv4TimeExceeded),
- checker.ICMPv4Code(header.ICMPv4TTLExceeded),
+ checker.ICMPv4Type(test.icmpType),
+ checker.ICMPv4Code(test.icmpCode),
checker.ICMPv4Payload([]byte(hdr.View())),
),
)
@@ -250,6 +333,7 @@ func TestForwarding(t *testing.T) {
checker.SrcAddr(remoteIPv4Addr1),
checker.DstAddr(remoteIPv4Addr2),
checker.TTL(test.TTL-1),
+ checker.IPv4Options(test.forwardedOptions),
checker.ICMPv4(
checker.ICMPv4Checksum(),
checker.ICMPv4Type(header.ICMPv4Echo),
@@ -279,6 +363,7 @@ func TestIPv4Sanity(t *testing.T) {
// (offset 1). For compatibility we must do the same. Use this constant
// to indicate where this happens.
pointerOffsetForInvalidLength = 0
+ randomTimeOffset = 0x10203040
)
var (
ipv4Addr = tcpip.AddressWithPrefix{
@@ -893,6 +978,10 @@ func TestIPv4Sanity(t *testing.T) {
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
Clock: clock,
})
+ // Advance the clock by some unimportant amount to make
+ // it give a more recognisable signature than 00,00,00,00.
+ clock.Advance(time.Millisecond * randomTimeOffset)
+
// We expect at most a single packet in response to our ICMP Echo Request.
e := channel.New(1, ipv4.MaxTotalSize, "")
if err := s.CreateNIC(nicID, e); err != nil {
@@ -902,9 +991,6 @@ func TestIPv4Sanity(t *testing.T) {
if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
}
- // Advance the clock by some unimportant amount to make
- // sure it's all set up.
- clock.Advance(time.Millisecond * 0x10203040)
// Default routes for IPv4 so ICMP can find a route to the remote
// node when attempting to send the ICMP Echo Reply.
@@ -1739,12 +1825,19 @@ func TestInvalidFragments(t *testing.T) {
ip := header.IPv4(hdr.Prepend(pktSize))
ip.Encode(&f.ipv4fields)
+ if want, got := len(f.payload), copy(ip[header.IPv4MinimumSize:], f.payload); want != got {
+ t.Fatalf("copied %d bytes, expected %d bytes.", got, want)
+ }
// Encode sets this up correctly. If we want a different value for
// testing then we need to overwrite the good value.
if f.overrideIHL != 0 {
ip.SetHeaderLength(uint8(f.overrideIHL))
+ // If we are asked to add options (type not specified) then pad
+ // with 0 (EOL). RFC 791 page 23 says "The padding is zero".
+ for i := header.IPv4MinimumSize; i < f.overrideIHL; i++ {
+ ip[i] = byte(header.IPv4OptionListEndType)
+ }
}
- copy(ip[header.IPv4MinimumSize:], f.payload)
if f.autoChecksum {
ip.SetChecksum(0)
diff --git a/pkg/tcpip/network/ipv4/stats.go b/pkg/tcpip/network/ipv4/stats.go
index bee72c649..5ae73fbfb 100644
--- a/pkg/tcpip/network/ipv4/stats.go
+++ b/pkg/tcpip/network/ipv4/stats.go
@@ -16,7 +16,7 @@ package ipv4
import (
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/network/ip"
+ "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go
index fbbc6e69c..a637f9d50 100644
--- a/pkg/tcpip/network/ipv4/stats_test.go
+++ b/pkg/tcpip/network/ipv4/stats_test.go
@@ -19,7 +19,7 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index 0c5f8d683..bb9a02ed0 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -19,9 +19,9 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/header/parse",
- "//pkg/tcpip/network/fragmentation",
"//pkg/tcpip/network/hash",
- "//pkg/tcpip/network/ip",
+ "//pkg/tcpip/network/internal/fragmentation",
+ "//pkg/tcpip/network/internal/ip",
"//pkg/tcpip/stack",
],
)
@@ -43,7 +43,7 @@ go_test(
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
- "//pkg/tcpip/network/testutil",
+ "//pkg/tcpip/network/internal/testutil",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index dcfd93bab..2690644d6 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -148,6 +148,18 @@ func (*icmpv6PacketTooBigSockError) Kind() stack.TransportErrorKind {
return stack.PacketTooBigTransportError
}
+func (e *endpoint) checkLocalAddress(addr tcpip.Address) bool {
+ if e.nic.Spoofing() {
+ return true
+ }
+
+ if addressEndpoint := e.AcquireAssignedAddress(addr, false, stack.NeverPrimaryEndpoint); addressEndpoint != nil {
+ addressEndpoint.DecRef()
+ return true
+ }
+ return false
+}
+
// handleControl handles the case when an ICMP packet contains the headers of
// the original packet that caused the ICMP one to be sent. This information is
// used to find out which transport endpoint must be notified about the ICMP
@@ -165,8 +177,8 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe
//
// Drop packet if it doesn't have the basic IPv6 header or if the
// original source address doesn't match an address we own.
- src := hdr.SourceAddress()
- if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 {
+ srcAddr := hdr.SourceAddress()
+ if !e.checkLocalAddress(srcAddr) {
return
}
@@ -192,7 +204,7 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe
p = fragHdr.TransportProtocol()
}
- e.dispatcher.DeliverTransportError(src, hdr.DestinationAddress(), ProtocolNumber, p, transErr, pkt)
+ e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, transErr, pkt)
}
// getLinkAddrOption searches NDP options for a given link address option using
@@ -377,7 +389,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// section 5.4.3.
// Is the NS targeting us?
- if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 {
+ if !e.checkLocalAddress(targetAddr) {
return
}
@@ -525,6 +537,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// NDP datagrams are very small and ToView() will not incur allocations.
na := header.NDPNeighborAdvert(payload.ToView())
targetAddr := na.TargetAddress()
+
+ e.dad.mu.Lock()
+ e.dad.mu.dad.StopLocked(targetAddr, false /* aborted */)
+ e.dad.mu.Unlock()
+
if e.hasTentativeAddr(targetAddr) {
// We just got an NA from a node that owns an address we are performing
// DAD on, implying the address is not unique. In this case we let the
@@ -854,37 +871,9 @@ func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
return &tcpip.ErrBadLocalAddress{}
}
- optsSerializer := header.NDPOptionsSerializer{
+ return e.sendNDPNS(localAddr, remoteAddr, targetAddr, remoteLinkAddr, header.NDPOptionsSerializer{
header.NDPSourceLinkLayerAddressOption(e.nic.LinkAddress()),
- }
- neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length()
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.IPv6FixedHeaderSize + neighborSolicitSize,
})
- pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
- packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize))
- packet.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(packet.MessageBody())
- ns.SetTargetAddress(targetAddr)
- ns.Options().Serialize(optsSerializer)
- packet.SetChecksum(header.ICMPv6Checksum(packet, localAddr, remoteAddr, buffer.VectorisedView{}))
-
- if err := addIPHeader(localAddr, remoteAddr, pkt, stack.NetworkHeaderParams{
- Protocol: header.ICMPv6ProtocolNumber,
- TTL: header.NDPHopLimit,
- }, header.IPv6ExtHdrSerializer{}); err != nil {
- panic(fmt.Sprintf("failed to add IP header: %s", err))
- }
-
- stat := e.stats.icmp.packetsSent
-
- if err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil {
- stat.dropped.Increment()
- return err
- }
-
- stat.neighborSolicit.Increment()
- return nil
}
// ResolveStaticAddress implements stack.LinkAddressResolver.
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 92f9ee2c2..69c1e4bea 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -124,6 +124,10 @@ func (*testInterface) Promiscuous() bool {
return false
}
+func (*testInterface) Spoofing() bool {
+ return false
+}
+
func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt)
}
@@ -149,185 +153,31 @@ func (t *testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber,
return nil
}
-func TestICMPCounts(t *testing.T) {
- tests := []struct {
- name string
- useNeighborCache bool
- }{
- {
- name: "linkAddrCache",
- useNeighborCache: false,
- },
- {
- name: "neighborCache",
- useNeighborCache: true,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
- UseNeighborCache: test.useNeighborCache,
- })
- if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
- t.Fatalf("CreateNIC(_, _) = %s", err)
- }
- {
- subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet,
- NIC: nicID,
- }},
- )
- }
-
- netProto := s.NetworkProtocolInstance(ProtocolNumber)
- if netProto == nil {
- t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
- }
- ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{})
- defer ep.Close()
-
- if err := ep.Enable(); err != nil {
- t.Fatalf("ep.Enable(): %s", err)
- }
-
- addressableEndpoint, ok := ep.(stack.AddressableEndpoint)
- if !ok {
- t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
- }
- addr := lladdr0.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
- } else {
- ep.DecRef()
- }
-
- var tllData [header.NDPLinkLayerAddressSize]byte
- header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(linkAddr1),
- })
-
- types := []struct {
- typ header.ICMPv6Type
- size int
- extraData []byte
- }{
- {
- typ: header.ICMPv6DstUnreachable,
- size: header.ICMPv6DstUnreachableMinimumSize,
- },
- {
- typ: header.ICMPv6PacketTooBig,
- size: header.ICMPv6PacketTooBigMinimumSize,
- },
- {
- typ: header.ICMPv6TimeExceeded,
- size: header.ICMPv6MinimumSize,
- },
- {
- typ: header.ICMPv6ParamProblem,
- size: header.ICMPv6MinimumSize,
- },
- {
- typ: header.ICMPv6EchoRequest,
- size: header.ICMPv6EchoMinimumSize,
- },
- {
- typ: header.ICMPv6EchoReply,
- size: header.ICMPv6EchoMinimumSize,
- },
- {
- typ: header.ICMPv6RouterSolicit,
- size: header.ICMPv6MinimumSize,
- },
- {
- typ: header.ICMPv6RouterAdvert,
- size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
- },
- {
- typ: header.ICMPv6NeighborSolicit,
- size: header.ICMPv6NeighborSolicitMinimumSize,
- },
- {
- typ: header.ICMPv6NeighborAdvert,
- size: header.ICMPv6NeighborAdvertMinimumSize,
- extraData: tllData[:],
- },
- {
- typ: header.ICMPv6RedirectMsg,
- size: header.ICMPv6MinimumSize,
- },
- {
- typ: header.ICMPv6MulticastListenerQuery,
- size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
- },
- {
- typ: header.ICMPv6MulticastListenerReport,
- size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
- },
- {
- typ: header.ICMPv6MulticastListenerDone,
- size: header.MLDMinimumSize + header.ICMPv6HeaderSize,
- },
- {
- typ: 255, /* Unrecognized */
- size: 50,
- },
- }
-
- handleIPv6Payload := func(icmp header.ICMPv6) {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.IPv6MinimumSize,
- Data: buffer.View(icmp).ToVectorisedView(),
- })
- ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
- ep.HandlePacket(pkt)
- }
-
- for _, typ := range types {
- icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
- copy(icmp[typ.size:], typ.extraData)
- icmp.SetType(typ.typ)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView()))
- handleIPv6Payload(icmp)
- }
-
- // Construct an empty ICMP packet so that
- // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
- handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
+func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool {
+ return false
+}
- icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived
- visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
- if got, want := s.Value(), uint64(1); got != want {
- t.Errorf("got %s = %d, want = %d", name, got, want)
- }
- })
- if t.Failed() {
- t.Logf("stats:\n%+v", s.Stats())
- }
- })
- }
+func handleICMPInIPv6(ep stack.NetworkEndpoint, src, dst tcpip.Address, icmp header.ICMPv6) {
+ ip := buffer.NewView(header.IPv6MinimumSize)
+ header.IPv6(ip).Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(icmp)),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+ vv := ip.ToVectorisedView()
+ vv.AppendView(buffer.View(icmp))
+ ep.HandlePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.IPv6MinimumSize,
+ Data: vv,
+ }))
}
-func TestICMPCountsWithNeighborCache(t *testing.T) {
+func TestICMPCounts(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
- UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
t.Fatalf("CreateNIC(_, _) = %s", err)
@@ -440,33 +290,17 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
},
}
- handleIPv6Payload := func(icmp header.ICMPv6) {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.IPv6MinimumSize,
- Data: buffer.View(icmp).ToVectorisedView(),
- })
- ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
- ep.HandlePacket(pkt)
- }
-
for _, typ := range types {
icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
copy(icmp[typ.size:], typ.extraData)
icmp.SetType(typ.typ)
icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView()))
- handleIPv6Payload(icmp)
+ handleICMPInIPv6(ep, lladdr1, lladdr0, icmp)
}
// Construct an empty ICMP packet so that
// Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
- handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
+ handleICMPInIPv6(ep, lladdr1, lladdr0, header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived
visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
@@ -777,135 +611,116 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
},
}
- tests := []struct {
- name string
- useNeighborCache bool
- }{
- {
- name: "linkAddrCache",
- useNeighborCache: false,
- },
- {
- name: "neighborCache",
- useNeighborCache: true,
- },
- }
+ for _, typ := range types {
+ for _, isRouter := range []bool{false, true} {
+ name := typ.name
+ if isRouter {
+ name += " (Router)"
+ }
+ t.Run(name, func(t *testing.T) {
+ e := channel.New(0, 1280, linkAddr0)
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- for _, typ := range types {
- for _, isRouter := range []bool{false, true} {
- name := typ.name
- if isRouter {
- name += " (Router)"
+ // Indicate that resolution for link layer addresses is required to
+ // send packets over this link. This is needed so the NIC knows to
+ // allocate a neighbor table.
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ if isRouter {
+ // Enabling forwarding makes the stack act as a router.
+ s.SetForwarding(ProtocolNumber, true)
+ }
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
+ }
+
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
}
- t.Run(name, func(t *testing.T) {
- e := channel.New(0, 1280, linkAddr0)
-
- // Indicate that resolution for link layer addresses is required to
- // send packets over this link. This is needed so the NIC knows to
- // allocate a neighbor table.
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- UseNeighborCache: test.useNeighborCache,
- })
- if isRouter {
- // Enabling forwarding makes the stack act as a router.
- s.SetForwarding(ProtocolNumber, true)
- }
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(_, _) = %s", err)
- }
-
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
- }
- {
- subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: subnet,
- NIC: nicID,
- }},
- )
- }
-
- handleIPv6Payload := func(checksum bool) {
- icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
- copy(icmp[typ.size:], typ.extraData)
- icmp.SetType(typ.typ)
- if checksum {
- icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView()))
- }
- ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: header.NDPHopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}),
- })
- e.InjectInbound(ProtocolNumber, pkt)
- }
-
- stats := s.Stats().ICMP.V6.PacketsReceived
- invalid := stats.Invalid
- routerOnly := stats.RouterOnlyPacketsDroppedByHost
- typStat := typ.statCounter(stats)
-
- // Initial stat counts should be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- if got := routerOnly.Value(); got != 0 {
- t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
- }
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
- }
-
- // Without setting checksum, the incoming packet should
- // be invalid.
- handleIPv6Payload(false)
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- // Router only count should not have increased.
- if got := routerOnly.Value(); got != 0 {
- t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
- }
- // Rx count of type typ.typ should not have increased.
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
- }
-
- // When checksum is set, it should be received.
- handleIPv6Payload(true)
- if got := typStat.Value(); got != 1 {
- t.Fatalf("got %s = %d, want = 1", typ.name, got)
- }
- // Invalid count should not have increased again.
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- if !isRouter && typ.routerOnly && test.useNeighborCache {
- // Router only count should have increased.
- if got := routerOnly.Value(); got != 1 {
- t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 1", got)
- }
- }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: nicID,
+ }},
+ )
+ }
+
+ handleIPv6Payload := func(checksum bool) {
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
+ if checksum {
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView()))
+ }
+ ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(icmp)),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}),
})
+ e.InjectInbound(ProtocolNumber, pkt)
}
- }
- })
+
+ stats := s.Stats().ICMP.V6.PacketsReceived
+ invalid := stats.Invalid
+ routerOnly := stats.RouterOnlyPacketsDroppedByHost
+ typStat := typ.statCounter(stats)
+
+ // Initial stat counts should be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := routerOnly.Value(); got != 0 {
+ t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
+ }
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Without setting checksum, the incoming packet should
+ // be invalid.
+ handleIPv6Payload(false)
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ // Router only count should not have increased.
+ if got := routerOnly.Value(); got != 0 {
+ t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
+ }
+ // Rx count of type typ.typ should not have increased.
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // When checksum is set, it should be received.
+ handleIPv6Payload(true)
+ if got := typStat.Value(); got != 1 {
+ t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ }
+ // Invalid count should not have increased again.
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ if !isRouter && typ.routerOnly {
+ // Router only count should have increased.
+ if got := routerOnly.Value(); got != 1 {
+ t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 1", got)
+ }
+ }
+ })
+ }
}
}
@@ -1769,7 +1584,6 @@ func TestCallsToNeighborCache(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
- UseNeighborCache: true,
})
{
if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
@@ -1818,19 +1632,7 @@ func TestCallsToNeighborCache(t *testing.T) {
icmp := test.createPacket()
icmp.SetChecksum(header.ICMPv6Checksum(icmp, test.source, test.destination, buffer.VectorisedView{}))
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.IPv6MinimumSize,
- Data: buffer.View(icmp).ToVectorisedView(),
- })
- ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(icmp)),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: header.NDPHopLimit,
- SrcAddr: test.source,
- DstAddr: test.destination,
- })
- ep.HandlePacket(pkt)
+ handleICMPInIPv6(ep, test.source, test.destination, icmp)
// Confirm the endpoint calls the correct NUDHandler method.
if testInterface.probeCount != test.wantProbeCount {
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index c2e8c3ea7..c5c3ef882 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -30,8 +30,9 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/header/parse"
- "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
"gvisor.dev/gvisor/pkg/tcpip/network/hash"
+ "gvisor.dev/gvisor/pkg/tcpip/network/internal/fragmentation"
+ "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -164,6 +165,7 @@ func getLabel(addr tcpip.Address) uint8 {
panic(fmt.Sprintf("should have a label for address = %s", addr))
}
+var _ stack.DuplicateAddressDetector = (*endpoint)(nil)
var _ stack.LinkAddressResolver = (*endpoint)(nil)
var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil)
var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
@@ -192,6 +194,23 @@ type endpoint struct {
ndp ndpState
mld mldState
}
+
+ // dad is used to check if an arbitrary address is already assigned to some
+ // neighbor.
+ //
+ // Note: this is different from mu.ndp.dad which is used to perform DAD for
+ // addresses that are assigned to the interface. Removing an address aborts
+ // DAD; if we had used the same state, handlers for a removed address would
+ // not be called with the actual DAD result.
+ //
+ // LOCK ORDERING: mu > dad.mu.
+ dad struct {
+ mu struct {
+ sync.Mutex
+
+ dad ip.DAD
+ }
+ }
}
// NICNameFromID is a function that returns a stable name for the specified NIC,
@@ -226,6 +245,29 @@ type OpaqueInterfaceIdentifierOptions struct {
SecretKey []byte
}
+// CheckDuplicateAddress implements stack.DuplicateAddressDetector.
+func (e *endpoint) CheckDuplicateAddress(addr tcpip.Address, h stack.DADCompletionHandler) stack.DADCheckAddressDisposition {
+ e.dad.mu.Lock()
+ defer e.dad.mu.Unlock()
+ return e.dad.mu.dad.CheckDuplicateAddressLocked(addr, h)
+}
+
+// SetDADConfigurations implements stack.DuplicateAddressDetector.
+func (e *endpoint) SetDADConfigurations(c stack.DADConfigurations) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.dad.mu.Lock()
+ defer e.dad.mu.Unlock()
+
+ e.mu.ndp.dad.SetConfigsLocked(c)
+ e.dad.mu.dad.SetConfigsLocked(c)
+}
+
+// DuplicateAddressProtocol implements stack.DuplicateAddressDetector.
+func (*endpoint) DuplicateAddressProtocol() tcpip.NetworkProtocolNumber {
+ return ProtocolNumber
+}
+
// HandleLinkResolutionFailure implements stack.LinkResolvableNetworkEndpoint.
func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) {
// handleControl expects the entire offending packet to be in the packet
@@ -321,7 +363,7 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) tcpip.Error {
// If the address is a SLAAC address, do not invalidate its SLAAC prefix as an
// attempt will be made to generate a new address for it.
- if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */); err != nil {
+ if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */, true /* dadFailure */); err != nil {
return err
}
@@ -525,7 +567,7 @@ func (e *endpoint) stopDADForPermanentAddressesLocked() {
addr := addressEndpoint.AddressWithPrefix().Address
if header.IsV6UnicastAddress(addr) {
- e.mu.ndp.stopDuplicateAddressDetection(addr)
+ e.mu.ndp.stopDuplicateAddressDetection(addr, false /* failed */)
}
return true
@@ -648,14 +690,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv6(pkt.NetworkHeader().View())
- if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
- pkt := pkt.CloneToInbound()
- if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- // Since we rewrote the packet but it is being routed back to us, we can
- // safely assume the checksum is valid.
- pkt.RXTransportChecksumValidated = true
- ep.(*endpoint).handlePacket(pkt)
- }
+ if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil {
+ // Since we rewrote the packet but it is being routed back to us, we
+ // can safely assume the checksum is valid.
+ ep.handleLocalPacket(pkt, true /* canSkipRXChecksum */)
return nil
}
}
@@ -665,14 +703,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) tcpip.Error {
if r.Loop&stack.PacketLoop != 0 {
- pkt := pkt.CloneToInbound()
- if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- // If the packet was generated by the stack (not a raw/packet endpoint
- // where a packet may be written with the header included), then we can
- // safely assume the checksum is valid.
- pkt.RXTransportChecksumValidated = !headerIncluded
- e.handlePacket(pkt)
- }
+ // If the packet was generated by the stack (not a raw/packet endpoint
+ // where a packet may be written with the header included), then we can
+ // safely assume the checksum is valid.
+ e.handleLocalPacket(pkt, !headerIncluded /* canSkipRXChecksum */)
}
if r.Loop&stack.PacketOut == 0 {
return nil
@@ -750,52 +784,36 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// generated.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName)
- if len(dropped) == 0 && len(natPkts) == 0 {
- // Fast path: If no packets are to be dropped then we can just invoke the
- // faster WritePackets API directly.
- n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
- stats.PacketsSent.IncrementBy(uint64(n))
- if err != nil {
- stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
- }
- return n, err
- }
stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
+ for pkt := range dropped {
+ pkts.Remove(pkt)
+ }
- // Slow path as we are dropping some packets in the batch degrade to
- // emitting one packet at a time.
- n := 0
- for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- if _, ok := dropped[pkt]; ok {
+ // The NAT-ed packets may now be destined for us.
+ locallyDelivered := 0
+ for pkt := range natPkts {
+ ep := e.protocol.findEndpointWithAddress(header.IPv6(pkt.NetworkHeader().View()).DestinationAddress())
+ if ep == nil {
+ // The NAT-ed packet is still destined for some remote node.
continue
}
- if _, ok := natPkts[pkt]; ok {
- netHeader := header.IPv6(pkt.NetworkHeader().View())
- if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
- pkt := pkt.CloneToInbound()
- if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
- // Since we rewrote the packet but it is being routed back to us, we
- // can safely assume the checksum is valid.
- pkt.RXTransportChecksumValidated = true
- ep.(*endpoint).handlePacket(pkt)
- }
- n++
- continue
- }
- }
- if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
- stats.PacketsSent.IncrementBy(uint64(n))
- stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped)))
- // Dropped packets aren't errors, so include them in
- // the return value.
- return n + len(dropped), err
- }
- n++
+
+ // Do not send the locally destined packet out the NIC.
+ pkts.Remove(pkt)
+
+ // Deliver the packet locally.
+ ep.handleLocalPacket(pkt, true /* canSkipRXChecksum */)
+ locallyDelivered++
}
- stats.PacketsSent.IncrementBy(uint64(n))
+ // The rest of the packets can be delivered to the NIC as a batch.
+ pktsLen := pkts.Len()
+ written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
+ stats.PacketsSent.IncrementBy(uint64(written))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written))
+
// Dropped packets aren't errors, so include them in the return value.
- return n + len(dropped), nil
+ return locallyDelivered + written + len(dropped), err
}
// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
@@ -852,14 +870,11 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
dstAddr := h.DestinationAddress()
// Check if the destination is owned by the stack.
- networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr)
- if err == nil {
- networkEndpoint.(*endpoint).handlePacket(pkt)
+
+ if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
+ ep.handlePacket(pkt)
return nil
}
- if _, ok := err.(*tcpip.ErrBadAddress); !ok {
- return err
- }
r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */)
if err != nil {
@@ -896,8 +911,26 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
- // Loopback traffic skips the prerouting chain.
+ if !e.protocol.parse(pkt) {
+ stats.MalformedPacketsReceived.Increment()
+ return
+ }
+
if !e.nic.IsLoopback() {
+ if e.protocol.stack.HandleLocal() {
+ addressEndpoint := e.AcquireAssignedAddress(header.IPv6(pkt.NetworkHeader().View()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint)
+ if addressEndpoint != nil {
+ addressEndpoint.DecRef()
+
+ // The source address is one of our own, so we never should have gotten
+ // a packet like this unless HandleLocal is false or our NIC is the
+ // loopback interface.
+ stats.InvalidSourceAddressesReceived.Increment()
+ return
+ }
+ }
+
+ // Loopback traffic skips the prerouting chain.
inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
@@ -909,6 +942,21 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
e.handlePacket(pkt)
}
+func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum bool) {
+ stats := e.stats.ip
+
+ stats.PacketsReceived.Increment()
+
+ pkt = pkt.CloneToInbound()
+ if e.protocol.parse(pkt) {
+ pkt.RXTransportChecksumValidated = canSkipRXChecksum
+ e.handlePacket(pkt)
+ return
+ }
+
+ stats.MalformedPacketsReceived.Increment()
+}
+
// handlePacket is like HandlePacket except it does not perform the prerouting
// iptables hook.
func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
@@ -1384,18 +1432,18 @@ func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error {
return &tcpip.ErrBadLocalAddress{}
}
- return e.removePermanentEndpointLocked(addressEndpoint, true)
+ return e.removePermanentEndpointLocked(addressEndpoint, true /* allowSLAACInvalidation */, false /* dadFailure */)
}
// removePermanentEndpointLocked is like removePermanentAddressLocked except
// it works with a stack.AddressEndpoint.
//
// Precondition: e.mu must be write locked.
-func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool) tcpip.Error {
+func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation, dadFailure bool) tcpip.Error {
addr := addressEndpoint.AddressWithPrefix()
unicast := header.IsV6UnicastAddress(addr.Address)
if unicast {
- e.mu.ndp.stopDuplicateAddressDetection(addr.Address)
+ e.mu.ndp.stopDuplicateAddressDetection(addr.Address, dadFailure)
// If we are removing an address generated via SLAAC, cleanup
// its SLAAC resources and notify the integrator.
@@ -1741,6 +1789,13 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran
e.mu.addressableEndpointState.Init(e)
e.mu.ndp.init(e)
e.mu.mld.init(e)
+ e.dad.mu.Lock()
+ e.dad.mu.dad.Init(&e.dad.mu, p.options.DADConfigs, ip.DADOptions{
+ Clock: p.stack.Clock(),
+ Protocol: &e.mu.ndp,
+ NICID: nic.ID(),
+ })
+ e.dad.mu.Unlock()
e.mu.Unlock()
stackStats := p.stack.Stats()
@@ -1754,6 +1809,20 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran
return e
}
+func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+
+ for _, e := range p.mu.eps {
+ if addressEndpoint := e.AcquireAssignedAddress(addr, false /* allowTemp */, stack.NeverPrimaryEndpoint); addressEndpoint != nil {
+ addressEndpoint.DecRef()
+ return e
+ }
+ }
+
+ return nil
+}
+
func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
p.mu.Lock()
defer p.mu.Unlock()
@@ -1798,6 +1867,29 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
+// parse is like Parse but also attempts to parse the transport layer.
+//
+// Returns true if the network header was successfully parsed.
+func (p *protocol) parse(pkt *stack.PacketBuffer) bool {
+ transProtoNum, hasTransportHdr, ok := p.Parse(pkt)
+ if !ok {
+ return false
+ }
+
+ if hasTransportHdr {
+ switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err {
+ case stack.ParsedOK:
+ case stack.UnknownTransportProtocol, stack.TransportLayerParseError:
+ // The transport layer will handle unknown protocols and transport layer
+ // parsing errors.
+ default:
+ panic(fmt.Sprintf("unexpected error parsing transport header = %d", err))
+ }
+ }
+
+ return true
+}
+
// Parse implements stack.NetworkProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
proto, _, fragOffset, fragMore, ok := parse.IPv6(pkt)
@@ -1906,6 +1998,9 @@ type Options struct {
// MLD holds options for MLD.
MLD MLDOptions
+
+ // DADConfigs holds the default DAD configurations used by IPv6 endpoints.
+ DADConfigs stack.DADConfigurations
}
// NewProtocolWithOptions returns an IPv6 network protocol.
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 1c6c37c91..7e714b50e 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -31,7 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go
index 2cc0dfebd..205e36cdd 100644
--- a/pkg/tcpip/network/ipv6/mld.go
+++ b/pkg/tcpip/network/ipv6/mld.go
@@ -21,7 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/network/ip"
+ "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go
index f6ffa7133..fe39555e0 100644
--- a/pkg/tcpip/network/ipv6/mld_test.go
+++ b/pkg/tcpip/network/ipv6/mld_test.go
@@ -126,7 +126,7 @@ func TestSendQueuedMLDReports(t *testing.T) {
clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
+ DADConfigs: stack.DADConfigurations{
DupAddrDetectTransmits: test.dadTransmits,
RetransmitTimer: test.retransmitTimer,
},
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index d7dde1767..53c043dcd 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -16,7 +16,6 @@ package ipv6
import (
"fmt"
- "log"
"math/rand"
"time"
@@ -24,34 +23,11 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
- // defaultRetransmitTimer is the default amount of time to wait between
- // sending reachability probes.
- //
- // Default taken from RETRANS_TIMER of RFC 4861 section 10.
- defaultRetransmitTimer = time.Second
-
- // minimumRetransmitTimer is the minimum amount of time to wait between
- // sending reachability probes.
- //
- // Note, RFC 4861 does not impose a minimum Retransmit Timer, but we do here
- // to make sure the messages are not sent all at once. We also come to this
- // value because in the RetransmitTimer field of a Router Advertisement, a
- // value of 0 means unspecified, so the smallest valid value is 1. Note, the
- // unit of the RetransmitTimer field in the Router Advertisement is
- // milliseconds.
- minimumRetransmitTimer = time.Millisecond
-
- // defaultDupAddrDetectTransmits is the default number of NDP Neighbor
- // Solicitation messages to send when doing Duplicate Address Detection
- // for a tentative address.
- //
- // Default = 1 (from RFC 4862 section 5.1)
- defaultDupAddrDetectTransmits = 1
-
// defaultMaxRtrSolicitations is the default number of Router
// Solicitation messages to send when an IPv6 endpoint becomes enabled.
//
@@ -331,18 +307,6 @@ type NDPDispatcher interface {
// NDPConfigurations is the NDP configurations for the netstack.
type NDPConfigurations struct {
- // The number of Neighbor Solicitation messages to send when doing
- // Duplicate Address Detection for a tentative address.
- //
- // Note, a value of zero effectively disables DAD.
- DupAddrDetectTransmits uint8
-
- // The amount of time to wait between sending Neighbor solicitation
- // messages.
- //
- // Must be greater than or equal to 1ms.
- RetransmitTimer time.Duration
-
// The number of Router Solicitation messages to send when the IPv6 endpoint
// becomes enabled.
MaxRtrSolicitations uint8
@@ -414,8 +378,6 @@ type NDPConfigurations struct {
// default values.
func DefaultNDPConfigurations() NDPConfigurations {
return NDPConfigurations{
- DupAddrDetectTransmits: defaultDupAddrDetectTransmits,
- RetransmitTimer: defaultRetransmitTimer,
MaxRtrSolicitations: defaultMaxRtrSolicitations,
RtrSolicitationInterval: defaultRtrSolicitationInterval,
MaxRtrSolicitationDelay: defaultMaxRtrSolicitationDelay,
@@ -433,10 +395,6 @@ func DefaultNDPConfigurations() NDPConfigurations {
// validate modifies an NDPConfigurations with valid values. If invalid values
// are present in c, the corresponding default values are used instead.
func (c *NDPConfigurations) validate() {
- if c.RetransmitTimer < minimumRetransmitTimer {
- c.RetransmitTimer = defaultRetransmitTimer
- }
-
if c.RtrSolicitationInterval < minimumRtrSolicitationInterval {
c.RtrSolicitationInterval = defaultRtrSolicitationInterval
}
@@ -458,7 +416,14 @@ func (c *NDPConfigurations) validate() {
}
}
-// ndpState is the per-interface NDP state.
+type timer struct {
+ // done indicates to the timer that the timer was stopped.
+ done *bool
+
+ timer tcpip.Timer
+}
+
+// ndpState is the per-Interface NDP state.
type ndpState struct {
// Do not allow overwriting this state.
_ sync.NoCopy
@@ -469,14 +434,17 @@ type ndpState struct {
// configs is the per-interface NDP configurations.
configs NDPConfigurations
- // The DAD state to send the next NS message, or resolve the address.
- dad map[tcpip.Address]dadState
+ // The DAD timers to send the next NS message, or resolve the address.
+ dad ip.DAD
// The default routers discovered through Router Advertisements.
defaultRouters map[tcpip.Address]defaultRouterState
- // The job used to send the next router solicitation message.
- rtrSolicitJob *tcpip.Job
+ // rtrSolicitTimer is the timer used to send the next router solicitation
+ // message.
+ //
+ // rtrSolicitTimer is the zero value when NDP is not soliciting routers.
+ rtrSolicitTimer timer
// The on-link prefixes discovered through Router Advertisements' Prefix
// Information option.
@@ -498,19 +466,6 @@ type ndpState struct {
temporaryAddressDesyncFactor time.Duration
}
-// dadState holds the Duplicate Address Detection timer and channel to signal
-// to the DAD goroutine that DAD should stop.
-type dadState struct {
- // The DAD timer to send the next NS message, or resolve the address.
- job *tcpip.Job
-
- // Used to let the DAD timer know that it has been stopped.
- //
- // Must only be read from or written to while protected by the lock of
- // the IPv6 endpoint this dadState is associated with.
- done *bool
-}
-
// defaultRouterState holds data associated with a default router discovered by
// a Router Advertisement (RA).
type defaultRouterState struct {
@@ -625,125 +580,47 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
}
- // Should not attempt to perform DAD on an address that is currently in the
- // DAD process.
- if _, ok := ndp.dad[addr]; ok {
- // Should never happen because we should only ever call this function for
- // newly created addresses. If we attemped to "add" an address that already
- // existed, we would get an error since we attempted to add a duplicate
- // address, or its reference count would have been increased without doing
- // the work that would have been done for an address that was brand new.
- // See endpoint.addAddressLocked.
- panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.ep.nic.ID()))
- }
+ ret := ndp.dad.CheckDuplicateAddressLocked(addr, func(r stack.DADResult) {
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
+ // The endpoint should still be marked as tentative since we are still
+ // performing DAD on it.
+ panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
+ }
- remaining := ndp.configs.DupAddrDetectTransmits
- if remaining == 0 {
- addressEndpoint.SetKind(stack.Permanent)
+ if r.Resolved {
+ addressEndpoint.SetKind(stack.Permanent)
+ }
- // Consider DAD to have resolved even if no DAD messages were actually
- // transmitted.
if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil)
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, r.Resolved, r.Err)
}
- ndp.ep.onAddressAssignedLocked(addr)
- return nil
- }
-
- state := dadState{
- job: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
- state, ok := ndp.dad[addr]
- if !ok {
- panic(fmt.Sprintf("ndpdad: DAD timer fired but missing state for %s on NIC(%d)", addr, ndp.ep.nic.ID()))
- }
-
- if addressEndpoint.GetKind() != stack.PermanentTentative {
- // The endpoint should still be marked as tentative since we are still
- // performing DAD on it.
- panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
- }
-
- dadDone := remaining == 0
-
- var err tcpip.Error
- if !dadDone {
- err = ndp.sendDADPacket(addr, addressEndpoint)
- }
-
- if dadDone {
- // DAD has resolved.
- addressEndpoint.SetKind(stack.Permanent)
- } else if err == nil {
- // DAD is not done and we had no errors when sending the last NDP NS,
- // schedule the next DAD timer.
- remaining--
- state.job.Schedule(ndp.configs.RetransmitTimer)
- return
- }
-
- // At this point we know that either DAD is done or we hit an error
- // sending the last NDP NS. Either way, clean up addr's DAD state and let
- // the integrator know DAD has completed.
- delete(ndp.dad, addr)
-
- if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err)
+ if r.Resolved {
+ if addressEndpoint.ConfigType() == stack.AddressConfigSlaac {
+ // Reset the generation attempts counter as we are starting the
+ // generation of a new address for the SLAAC prefix.
+ ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */)
}
- if dadDone {
- if addressEndpoint.ConfigType() == stack.AddressConfigSlaac {
- // Reset the generation attempts counter as we are starting the
- // generation of a new address for the SLAAC prefix.
- ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */)
- }
-
- ndp.ep.onAddressAssignedLocked(addr)
- }
- }),
- }
-
- // We initially start a timer to fire immediately because some of the DAD work
- // cannot be done while holding the IPv6 endpoint's lock. This is effectively
- // the same as starting a goroutine but we use a timer that fires immediately
- // so we can reset it for the next DAD iteration.
- state.job.Schedule(0)
- ndp.dad[addr] = state
-
- return nil
-}
-
-// sendDADPacket sends a NS message to see if any nodes on ndp's NIC's link owns
-// addr.
-//
-// addr must be a tentative IPv6 address on ndp's IPv6 endpoint.
-func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) tcpip.Error {
- snmc := header.SolicitedNodeAddr(addr)
-
- icmp := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize))
- icmp.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(icmp.MessageBody())
- ns.SetTargetAddress(addr)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp, header.IPv6Any, snmc, buffer.VectorisedView{}))
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()),
- Data: buffer.View(icmp).ToVectorisedView(),
+ ndp.ep.onAddressAssignedLocked(addr)
+ }
})
- sent := ndp.ep.stats.icmp.packetsSent
- if err := addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{
- Protocol: header.ICMPv6ProtocolNumber,
- TTL: header.NDPHopLimit,
- }, nil /* extensionHeaders */); err != nil {
- panic(fmt.Sprintf("failed to add IP header: %s", err))
- }
+ switch ret {
+ case stack.DADStarting:
+ case stack.DADAlreadyRunning:
+ panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.ep.nic.ID()))
+ case stack.DADDisabled:
+ addressEndpoint.SetKind(stack.Permanent)
- if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* gso */, ProtocolNumber, pkt); err != nil {
- sent.dropped.Increment()
- return err
+ // Consider DAD to have resolved even if no DAD messages were actually
+ // transmitted.
+ if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil)
+ }
+
+ ndp.ep.onAddressAssignedLocked(addr)
}
- sent.neighborSolicit.Increment()
return nil
}
@@ -756,20 +633,8 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.Add
// of this function to handle such a scenario.
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
-func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
- dad, ok := ndp.dad[addr]
- if !ok {
- // Not currently performing DAD on addr, just return.
- return
- }
-
- dad.job.Cancel()
- delete(ndp.dad, addr)
-
- // Let the integrator know DAD did not resolve.
- if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, false, nil)
- }
+func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address, failed bool) {
+ ndp.dad.StopLocked(addr, !failed)
}
// handleRA handles a Router Advertisement message that arrived on the NIC
@@ -1634,7 +1499,7 @@ func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefi
if addressEndpoint := state.stableAddr.addressEndpoint; addressEndpoint != nil {
// Since we are already invalidating the prefix, do not invalidate the
// prefix when removing the address.
- if err := ndp.ep.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */); err != nil {
+ if err := ndp.ep.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */, false /* dadFailure */); err != nil {
panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", addressEndpoint.AddressWithPrefix(), err))
}
}
@@ -1693,7 +1558,7 @@ func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaa
func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) {
// Since we are already invalidating the address, do not invalidate the
// address when removing the address.
- if err := ndp.ep.removePermanentEndpointLocked(tempAddrState.addressEndpoint, false /* allowSLAACInvalidation */); err != nil {
+ if err := ndp.ep.removePermanentEndpointLocked(tempAddrState.addressEndpoint, false /* allowSLAACInvalidation */, false /* dadFailure */); err != nil {
panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.addressEndpoint.AddressWithPrefix(), err))
}
@@ -1803,7 +1668,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) {
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) startSolicitingRouters() {
- if ndp.rtrSolicitJob != nil {
+ if ndp.rtrSolicitTimer.timer != nil {
// We are already soliciting routers.
return
}
@@ -1820,65 +1685,85 @@ func (ndp *ndpState) startSolicitingRouters() {
delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay)))
}
- ndp.rtrSolicitJob = ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
- // As per RFC 4861 section 4.1, the source of the RS is an address assigned
- // to the sending interface, or the unspecified address if no address is
- // assigned to the sending interface.
- localAddr := header.IPv6Any
- if addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false); addressEndpoint != nil {
- localAddr = addressEndpoint.AddressWithPrefix().Address
- addressEndpoint.DecRef()
- }
+ // Protected by ndp.ep.mu.
+ done := false
- // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source
- // link-layer address option if the source address of the NDP RS is
- // specified. This option MUST NOT be included if the source address is
- // unspecified.
- //
- // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
- // LinkEndpoint.LinkAddress) before reaching this point.
- var optsSerializer header.NDPOptionsSerializer
- linkAddress := ndp.ep.nic.LinkAddress()
- if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(linkAddress) {
- optsSerializer = header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(linkAddress),
+ ndp.rtrSolicitTimer = timer{
+ done: &done,
+ timer: ndp.ep.protocol.stack.Clock().AfterFunc(delay, func() {
+ // As per RFC 4861 section 4.1:
+ //
+ // IP Fields:
+ // Source Address
+ // An IP address assigned to the sending interface, or
+ // the unspecified address if no address is assigned
+ // to the sending interface.
+ localAddr := header.IPv6Any
+ if addressEndpoint := ndp.ep.AcquireOutgoingPrimaryAddress(header.IPv6AllRoutersMulticastAddress, false); addressEndpoint != nil {
+ localAddr = addressEndpoint.AddressWithPrefix().Address
+ addressEndpoint.DecRef()
}
- }
- payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length())
- icmpData := header.ICMPv6(buffer.NewView(payloadSize))
- icmpData.SetType(header.ICMPv6RouterSolicit)
- rs := header.NDPRouterSolicit(icmpData.MessageBody())
- rs.Options().Serialize(optsSerializer)
- icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, localAddr, header.IPv6AllRoutersMulticastAddress, buffer.VectorisedView{}))
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()),
- Data: buffer.View(icmpData).ToVectorisedView(),
- })
-
- sent := ndp.ep.stats.icmp.packetsSent
- if err := addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{
- Protocol: header.ICMPv6ProtocolNumber,
- TTL: header.NDPHopLimit,
- }, nil /* extensionHeaders */); err != nil {
- panic(fmt.Sprintf("failed to add IP header: %s", err))
- }
- if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
- sent.dropped.Increment()
- log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err)
- // Don't send any more messages if we had an error.
- remaining = 0
- } else {
- sent.routerSolicit.Increment()
- remaining--
- }
- if remaining != 0 {
- ndp.rtrSolicitJob.Schedule(ndp.configs.RtrSolicitationInterval)
- }
- })
+ // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source
+ // link-layer address option if the source address of the NDP RS is
+ // specified. This option MUST NOT be included if the source address is
+ // unspecified.
+ //
+ // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
+ // LinkEndpoint.LinkAddress) before reaching this point.
+ var optsSerializer header.NDPOptionsSerializer
+ linkAddress := ndp.ep.nic.LinkAddress()
+ if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(linkAddress) {
+ optsSerializer = header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(linkAddress),
+ }
+ }
+ payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length())
+ icmpData := header.ICMPv6(buffer.NewView(payloadSize))
+ icmpData.SetType(header.ICMPv6RouterSolicit)
+ rs := header.NDPRouterSolicit(icmpData.MessageBody())
+ rs.Options().Serialize(optsSerializer)
+ icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, localAddr, header.IPv6AllRoutersMulticastAddress, buffer.VectorisedView{}))
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()),
+ Data: buffer.View(icmpData).ToVectorisedView(),
+ })
+
+ sent := ndp.ep.stats.icmp.packetsSent
+ if err := addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ }, nil /* extensionHeaders */); err != nil {
+ panic(fmt.Sprintf("failed to add IP header: %s", err))
+ }
- ndp.rtrSolicitJob.Schedule(delay)
+ if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
+ sent.dropped.Increment()
+ // Don't send any more messages if we had an error.
+ remaining = 0
+ } else {
+ sent.routerSolicit.Increment()
+ remaining--
+ }
+
+ ndp.ep.mu.Lock()
+ defer ndp.ep.mu.Unlock()
+
+ if done {
+ // Router solicitation was stopped.
+ return
+ }
+
+ if remaining == 0 {
+ // We are done soliciting routers.
+ ndp.stopSolicitingRouters()
+ return
+ }
+
+ ndp.rtrSolicitTimer.timer.Reset(ndp.configs.RtrSolicitationInterval)
+ }),
+ }
}
// stopSolicitingRouters stops soliciting routers. If routers are not currently
@@ -1886,23 +1771,28 @@ func (ndp *ndpState) startSolicitingRouters() {
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) stopSolicitingRouters() {
- if ndp.rtrSolicitJob == nil {
+ if ndp.rtrSolicitTimer.timer == nil {
// Nothing to do.
return
}
- ndp.rtrSolicitJob.Cancel()
- ndp.rtrSolicitJob = nil
+ ndp.rtrSolicitTimer.timer.Stop()
+ *ndp.rtrSolicitTimer.done = true
+ ndp.rtrSolicitTimer = timer{}
}
func (ndp *ndpState) init(ep *endpoint) {
- if ndp.dad != nil {
+ if ndp.defaultRouters != nil {
panic("attempted to initialize NDP state twice")
}
ndp.ep = ep
ndp.configs = ep.protocol.options.NDPConfigs
- ndp.dad = make(map[tcpip.Address]dadState)
+ ndp.dad.Init(&ndp.ep.mu, ep.protocol.options.DADConfigs, ip.DADOptions{
+ Clock: ep.protocol.stack.Clock(),
+ Protocol: ndp,
+ NICID: ep.nic.ID(),
+ })
ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState)
ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState)
ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState)
@@ -1912,3 +1802,38 @@ func (ndp *ndpState) init(ep *endpoint) {
ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor)))
}
}
+
+func (ndp *ndpState) SendDADMessage(addr tcpip.Address) tcpip.Error {
+ snmc := header.SolicitedNodeAddr(addr)
+ return ndp.ep.sendNDPNS(header.IPv6Any, snmc, addr, header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* opts */)
+}
+
+func (e *endpoint) sendNDPNS(srcAddr, dstAddr, targetAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, opts header.NDPOptionsSerializer) tcpip.Error {
+ icmp := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize + opts.Length()))
+ icmp.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
+ ns.SetTargetAddress(targetAddr)
+ ns.Options().Serialize(opts)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, srcAddr, dstAddr, buffer.VectorisedView{}))
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(e.MaxHeaderLength()),
+ Data: buffer.View(icmp).ToVectorisedView(),
+ })
+
+ if err := addIPHeader(srcAddr, dstAddr, pkt, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ }, nil /* extensionHeaders */); err != nil {
+ panic(fmt.Sprintf("failed to add IP header: %s", err))
+ }
+
+ sent := e.stats.icmp.packetsSent
+ err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt)
+ if err != nil {
+ sent.dropped.Increment()
+ } else {
+ sent.neighborSolicit.Increment()
+ }
+ return err
+}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 8edaa9508..ce20af0e3 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -33,13 +34,12 @@ import (
// setupStackAndEndpoint creates a stack with a single NIC with a link-local
// address llladdr and an IPv6 endpoint to a remote with link-local address
// rlladdr
-func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeighborCache bool) (*stack.Stack, stack.NetworkEndpoint) {
+func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) {
t.Helper()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
- UseNeighborCache: useNeighborCache,
})
if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
@@ -237,107 +237,6 @@ func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) {
Data: hdr.View().ToVectorisedView(),
}))
- ch := make(chan stack.LinkResolutionResult, 1)
- err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, func(r stack.LinkResolutionResult) {
- ch <- r
- })
-
- wantInvalid := uint64(0)
- wantSucccess := true
- if len(test.expectedLinkAddr) == 0 {
- wantInvalid = 1
- wantSucccess = false
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, &tcpip.ErrWouldBlock{})
- }
- } else {
- if err != nil {
- t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = nil", nicID, lladdr1, lladdr0, ProtocolNumber, err)
- }
- }
-
- if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: wantSucccess}, <-ch); diff != "" {
- t.Errorf("linkResolutionResult mismatch (-want +got):\n%s", diff)
- }
- if got := invalid.Value(); got != wantInvalid {
- t.Errorf("got invalid = %d, want = %d", got, wantInvalid)
- }
- })
- }
-}
-
-// TestNeighborSolicitationWithSourceLinkLayerOptionUsingNeighborCache tests
-// that receiving a valid NDP NS message with the Source Link Layer Address
-// option results in a new entry in the link address cache for the sender of
-// the message.
-func TestNeighborSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- optsBuf []byte
- expectedLinkAddr tcpip.LinkAddress
- }{
- {
- name: "Valid",
- optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7},
- expectedLinkAddr: "\x02\x03\x04\x05\x06\x07",
- },
- {
- name: "Too Small",
- optsBuf: []byte{1, 1, 2, 3, 4, 5, 6},
- },
- {
- name: "Invalid Length",
- optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7},
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- UseNeighborCache: true,
- })
- e := channel.New(0, 1280, linkAddr0)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
- }
-
- ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
- pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
- pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.MessageBody())
- ns.SetTargetAddress(lladdr0)
- opts := ns.Options()
- copy(opts, test.optsBuf)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
-
- invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
-
- // Invalid count should initially be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
-
- e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
-
neighbors, err := s.Neighbors(nicID, ProtocolNumber)
if err != nil {
t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err)
@@ -392,20 +291,6 @@ func TestNeighborSolicitationResponse(t *testing.T) {
remoteLinkAddr0 := linkAddr1
remoteLinkAddr1 := linkAddr2
- stacks := []struct {
- name string
- useNeighborCache bool
- }{
- {
- name: "linkAddrCache",
- useNeighborCache: false,
- },
- {
- name: "neighborCache",
- useNeighborCache: true,
- },
- }
-
tests := []struct {
name string
nsOpts header.NDPOptionsSerializer
@@ -564,229 +449,44 @@ func TestNeighborSolicitationResponse(t *testing.T) {
},
}
- for _, stackTyp := range stacks {
- t.Run(stackTyp.name, func(t *testing.T) {
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- UseNeighborCache: stackTyp.useNeighborCache,
- })
- e := channel.New(1, 1280, nicLinkAddr)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv6EmptySubnet,
- NIC: 1,
- },
- })
-
- ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length()
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
- pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
- pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.MessageBody())
- ns.SetTargetAddress(nicAddr)
- opts := ns.Options()
- opts.Serialize(test.nsOpts)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{}))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: 255,
- SrcAddr: test.nsSrc,
- DstAddr: test.nsDst,
- })
-
- invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
-
- // Invalid count should initially be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
-
- e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
-
- if test.nsInvalid {
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
-
- if p, got := e.Read(); got {
- t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt)
- }
-
- // If we expected the NS to be invalid, we have nothing else to check.
- return
- }
-
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
-
- if test.performsLinkResolution {
- p, got := e.ReadContext(context.Background())
- if !got {
- t.Fatal("expected an NDP NS response")
- }
-
- respNSDst := header.SolicitedNodeAddr(test.nsSrc)
- var want stack.RouteInfo
- want.NetProto = ProtocolNumber
- want.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(respNSDst)
- if diff := cmp.Diff(want, p.Route, cmp.AllowUnexported(want)); diff != "" {
- t.Errorf("route info mismatch (-want +got):\n%s", diff)
- }
-
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(nicAddr),
- checker.DstAddr(respNSDst),
- checker.TTL(header.NDPHopLimit),
- checker.NDPNS(
- checker.NDPNSTargetAddress(test.nsSrc),
- checker.NDPNSOptions([]header.NDPOption{
- header.NDPSourceLinkLayerAddressOption(nicLinkAddr),
- }),
- ))
-
- ser := header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(linkAddr1),
- }
- ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + ser.Length()
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
- pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
- pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.MessageBody())
- na.SetSolicitedFlag(true)
- na.SetOverrideFlag(true)
- na.SetTargetAddress(test.nsSrc)
- na.Options().Serialize(ser)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, nicAddr, buffer.VectorisedView{}))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: header.NDPHopLimit,
- SrcAddr: test.nsSrc,
- DstAddr: nicAddr,
- })
- e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
- }
-
- p, got := e.ReadContext(context.Background())
- if !got {
- t.Fatal("expected an NDP NA response")
- }
-
- if p.Route.LocalAddress != test.naSrc {
- t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, test.naSrc)
- }
- if p.Route.LocalLinkAddress != nicLinkAddr {
- t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr)
- }
- if p.Route.RemoteAddress != test.naDst {
- t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst)
- }
- if p.Route.RemoteLinkAddress != test.naDstLinkAddr {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr)
- }
-
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(test.naSrc),
- checker.DstAddr(test.naDst),
- checker.TTL(header.NDPHopLimit),
- checker.NDPNA(
- checker.NDPNASolicitedFlag(test.naSolicited),
- checker.NDPNATargetAddress(nicAddr),
- checker.NDPNAOptions([]header.NDPOption{
- header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]),
- }),
- ))
- })
- }
- })
- }
-}
-
-// TestNeighborAdvertisementWithTargetLinkLayerOption tests that receiving a
-// valid NDP NA message with the Target Link Layer Address option results in a
-// new entry in the link address cache for the target of the message.
-func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) {
- const nicID = 1
-
- tests := []struct {
- name string
- optsBuf []byte
- expectedLinkAddr tcpip.LinkAddress
- }{
- {
- name: "Valid",
- optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7},
- expectedLinkAddr: "\x02\x03\x04\x05\x06\x07",
- },
- {
- name: "Too Small",
- optsBuf: []byte{2, 1, 2, 3, 4, 5, 6},
- },
- {
- name: "Invalid Length",
- optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7},
- },
- {
- name: "Multiple",
- optsBuf: []byte{
- 2, 1, 2, 3, 4, 5, 6, 7,
- 2, 1, 2, 3, 4, 5, 6, 8,
- },
- },
- }
-
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- UseLinkAddrCache: true,
})
- e := channel.New(0, 1280, linkAddr0)
+ e := channel.New(1, 1280, nicLinkAddr)
e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err)
}
- ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
- pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
- pkt.SetType(header.ICMPv6NeighborAdvert)
- ns := header.NDPNeighborAdvert(pkt.MessageBody())
- ns.SetTargetAddress(lladdr1)
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: 1,
+ },
+ })
+
+ ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length()
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.MessageBody())
+ ns.SetTargetAddress(nicAddr)
opts := ns.Options()
- copy(opts, test.optsBuf)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ opts.Serialize(test.nsOpts)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{}))
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(payloadLength),
TransportProtocol: header.ICMPv6ProtocolNumber,
HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
+ SrcAddr: test.nsSrc,
+ DstAddr: test.nsDst,
})
invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid
@@ -796,44 +496,116 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) {
t.Fatalf("got invalid = %d, want = 0", got)
}
- e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
}))
- ch := make(chan stack.LinkResolutionResult, 1)
- err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, func(r stack.LinkResolutionResult) {
- ch <- r
- })
+ if test.nsInvalid {
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
- wantInvalid := uint64(0)
- wantSucccess := true
- if len(test.expectedLinkAddr) == 0 {
- wantInvalid = 1
- wantSucccess = false
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, &tcpip.ErrWouldBlock{})
+ if p, got := e.Read(); got {
+ t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt)
}
- } else {
- if err != nil {
- t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = nil", nicID, lladdr1, lladdr0, ProtocolNumber, err)
+
+ // If we expected the NS to be invalid, we have nothing else to check.
+ return
+ }
+
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ if test.performsLinkResolution {
+ p, got := e.ReadContext(context.Background())
+ if !got {
+ t.Fatal("expected an NDP NS response")
+ }
+
+ respNSDst := header.SolicitedNodeAddr(test.nsSrc)
+ var want stack.RouteInfo
+ want.NetProto = ProtocolNumber
+ want.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(respNSDst)
+ if diff := cmp.Diff(want, p.Route, cmp.AllowUnexported(want)); diff != "" {
+ t.Errorf("route info mismatch (-want +got):\n%s", diff)
}
+
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(nicAddr),
+ checker.DstAddr(respNSDst),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(test.nsSrc),
+ checker.NDPNSOptions([]header.NDPOption{
+ header.NDPSourceLinkLayerAddressOption(nicLinkAddr),
+ }),
+ ))
+
+ ser := header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ }
+ ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + ser.Length()
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(pkt.MessageBody())
+ na.SetSolicitedFlag(true)
+ na.SetOverrideFlag(true)
+ na.SetTargetAddress(test.nsSrc)
+ na.Options().Serialize(ser)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, nicAddr, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: test.nsSrc,
+ DstAddr: nicAddr,
+ })
+ e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
}
- if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: wantSucccess}, <-ch); diff != "" {
- t.Errorf("linkResolutionResult mismatch (-want +got):\n%s", diff)
+ p, got := e.ReadContext(context.Background())
+ if !got {
+ t.Fatal("expected an NDP NA response")
}
- if got := invalid.Value(); got != wantInvalid {
- t.Errorf("got invalid = %d, want = %d", got, wantInvalid)
+
+ if p.Route.LocalAddress != test.naSrc {
+ t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, test.naSrc)
+ }
+ if p.Route.LocalLinkAddress != nicLinkAddr {
+ t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr)
+ }
+ if p.Route.RemoteAddress != test.naDst {
+ t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst)
}
+ if p.Route.RemoteLinkAddress != test.naDstLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(test.naSrc),
+ checker.DstAddr(test.naDst),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNA(
+ checker.NDPNASolicitedFlag(test.naSolicited),
+ checker.NDPNATargetAddress(nicAddr),
+ checker.NDPNAOptions([]header.NDPOption{
+ header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]),
+ }),
+ ))
})
}
}
-// TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache tests
-// that receiving a valid NDP NA message with the Target Link Layer Address
-// option does not result in a new entry in the neighbor cache for the target
-// of the message.
-func TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *testing.T) {
+// TestNeighborAdvertisementWithTargetLinkLayerOption tests that receiving a
+// valid NDP NA message with the Target Link Layer Address option does not
+// result in a new entry in the neighbor cache for the target of the message.
+func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) {
const nicID = 1
tests := []struct {
@@ -867,7 +639,6 @@ func TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *tes
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- UseNeighborCache: true,
})
e := channel.New(0, 1280, linkAddr0)
e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
@@ -944,235 +715,216 @@ func TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *tes
}
func TestNDPValidation(t *testing.T) {
- stacks := []struct {
- name string
- useNeighborCache bool
+ setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) {
+ t.Helper()
+
+ // Create a stack with the assigned link-local address lladdr0
+ // and an endpoint to lladdr1.
+ s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1)
+
+ return s, ep
+ }
+
+ handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) {
+ var extHdrs header.IPv6ExtHdrSerializer
+ if atomicFragment {
+ extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{})
+ }
+ extHdrsLen := extHdrs.Length()
+
+ ip := buffer.NewView(header.IPv6MinimumSize + extHdrsLen)
+ header.IPv6(ip).Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(payload) + extHdrsLen),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: hopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ ExtensionHeaders: extHdrs,
+ })
+ vv := ip.ToVectorisedView()
+ vv.AppendView(payload)
+ ep.HandlePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ }))
+ }
+
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+
+ var sllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(sllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(linkAddr1),
+ })
+
+ types := []struct {
+ name string
+ typ header.ICMPv6Type
+ size int
+ extraData []byte
+ statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ routerOnly bool
}{
{
- name: "linkAddrCache",
- useNeighborCache: false,
+ name: "RouterSolicit",
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterSolicit
+ },
+ routerOnly: true,
+ },
+ {
+ name: "RouterAdvert",
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterAdvert
+ },
},
{
- name: "neighborCache",
- useNeighborCache: true,
+ name: "NeighborSolicit",
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize,
+ extraData: sllData[:],
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborSolicit
+ },
+ },
+ {
+ name: "NeighborAdvert",
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborAdvert
+ },
+ },
+ {
+ name: "RedirectMsg",
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RedirectMsg
+ },
},
}
- for _, stackTyp := range stacks {
- t.Run(stackTyp.name, func(t *testing.T) {
- setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) {
- t.Helper()
-
- // Create a stack with the assigned link-local address lladdr0
- // and an endpoint to lladdr1.
- s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1, stackTyp.useNeighborCache)
+ subTests := []struct {
+ name string
+ atomicFragment bool
+ hopLimit uint8
+ code header.ICMPv6Code
+ valid bool
+ }{
+ {
+ name: "Valid",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit,
+ code: 0,
+ valid: true,
+ },
+ {
+ name: "Fragmented",
+ atomicFragment: true,
+ hopLimit: header.NDPHopLimit,
+ code: 0,
+ valid: false,
+ },
+ {
+ name: "Invalid hop limit",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit - 1,
+ code: 0,
+ valid: false,
+ },
+ {
+ name: "Invalid ICMPv6 code",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit,
+ code: 1,
+ valid: false,
+ },
+ }
- return s, ep
+ for _, typ := range types {
+ for _, isRouter := range []bool{false, true} {
+ name := typ.name
+ if isRouter {
+ name += " (Router)"
}
- handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) {
- var extHdrs header.IPv6ExtHdrSerializer
- if atomicFragment {
- extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{})
- }
- extHdrsLen := extHdrs.Length()
+ t.Run(name, func(t *testing.T) {
+ for _, test := range subTests {
+ t.Run(test.name, func(t *testing.T) {
+ s, ep := setup(t)
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.IPv6MinimumSize + extHdrsLen,
- Data: payload.ToVectorisedView(),
- })
- ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(len(payload) + extHdrsLen),
- TransportProtocol: header.ICMPv6ProtocolNumber,
- HopLimit: hopLimit,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- ExtensionHeaders: extHdrs,
- })
- ep.HandlePacket(pkt)
- }
+ if isRouter {
+ // Enabling forwarding makes the stack act as a router.
+ s.SetForwarding(ProtocolNumber, true)
+ }
- var tllData [header.NDPLinkLayerAddressSize]byte
- header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(linkAddr1),
- })
+ stats := s.Stats().ICMP.V6.PacketsReceived
+ invalid := stats.Invalid
+ routerOnly := stats.RouterOnlyPacketsDroppedByHost
+ typStat := typ.statCounter(stats)
- var sllData [header.NDPLinkLayerAddressSize]byte
- header.NDPOptions(sllData[:]).Serialize(header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(linkAddr1),
- })
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
+ icmp.SetCode(test.code)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView()))
- types := []struct {
- name string
- typ header.ICMPv6Type
- size int
- extraData []byte
- statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
- routerOnly bool
- }{
- {
- name: "RouterSolicit",
- typ: header.ICMPv6RouterSolicit,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterSolicit
- },
- routerOnly: true,
- },
- {
- name: "RouterAdvert",
- typ: header.ICMPv6RouterAdvert,
- size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterAdvert
- },
- },
- {
- name: "NeighborSolicit",
- typ: header.ICMPv6NeighborSolicit,
- size: header.ICMPv6NeighborSolicitMinimumSize,
- extraData: sllData[:],
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborSolicit
- },
- },
- {
- name: "NeighborAdvert",
- typ: header.ICMPv6NeighborAdvert,
- size: header.ICMPv6NeighborAdvertMinimumSize,
- extraData: tllData[:],
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborAdvert
- },
- },
- {
- name: "RedirectMsg",
- typ: header.ICMPv6RedirectMsg,
- size: header.ICMPv6MinimumSize,
- statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RedirectMsg
- },
- },
- }
+ // Rx count of the NDP message should initially be 0.
+ if got := typStat.Value(); got != 0 {
+ t.Errorf("got %s = %d, want = 0", typ.name, got)
+ }
- subTests := []struct {
- name string
- atomicFragment bool
- hopLimit uint8
- code header.ICMPv6Code
- valid bool
- }{
- {
- name: "Valid",
- atomicFragment: false,
- hopLimit: header.NDPHopLimit,
- code: 0,
- valid: true,
- },
- {
- name: "Fragmented",
- atomicFragment: true,
- hopLimit: header.NDPHopLimit,
- code: 0,
- valid: false,
- },
- {
- name: "Invalid hop limit",
- atomicFragment: false,
- hopLimit: header.NDPHopLimit - 1,
- code: 0,
- valid: false,
- },
- {
- name: "Invalid ICMPv6 code",
- atomicFragment: false,
- hopLimit: header.NDPHopLimit,
- code: 1,
- valid: false,
- },
- }
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
- for _, typ := range types {
- for _, isRouter := range []bool{false, true} {
- name := typ.name
- if isRouter {
- name += " (Router)"
- }
+ // RouterOnlyPacketsReceivedByHost count should initially be 0.
+ if got := routerOnly.Value(); got != 0 {
+ t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep)
+
+ // Rx count of the NDP packet should have increased.
+ if got := typStat.Value(); got != 1 {
+ t.Errorf("got %s = %d, want = 1", typ.name, got)
+ }
+
+ want := uint64(0)
+ if !test.valid {
+ // Invalid count should have increased.
+ want = 1
+ }
+ if got := invalid.Value(); got != want {
+ t.Errorf("got invalid = %d, want = %d", got, want)
+ }
- t.Run(name, func(t *testing.T) {
- for _, test := range subTests {
- t.Run(test.name, func(t *testing.T) {
- s, ep := setup(t)
-
- if isRouter {
- // Enabling forwarding makes the stack act as a router.
- s.SetForwarding(ProtocolNumber, true)
- }
-
- stats := s.Stats().ICMP.V6.PacketsReceived
- invalid := stats.Invalid
- routerOnly := stats.RouterOnlyPacketsDroppedByHost
- typStat := typ.statCounter(stats)
-
- icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
- copy(icmp[typ.size:], typ.extraData)
- icmp.SetType(typ.typ)
- icmp.SetCode(test.code)
- icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView()))
-
- // Rx count of the NDP message should initially be 0.
- if got := typStat.Value(); got != 0 {
- t.Errorf("got %s = %d, want = 0", typ.name, got)
- }
-
- // Invalid count should initially be 0.
- if got := invalid.Value(); got != 0 {
- t.Errorf("got invalid = %d, want = 0", got)
- }
-
- // RouterOnlyPacketsReceivedByHost count should initially be 0.
- if got := routerOnly.Value(); got != 0 {
- t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got)
- }
-
- if t.Failed() {
- t.FailNow()
- }
-
- handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep)
-
- // Rx count of the NDP packet should have increased.
- if got := typStat.Value(); got != 1 {
- t.Errorf("got %s = %d, want = 1", typ.name, got)
- }
-
- want := uint64(0)
- if !test.valid {
- // Invalid count should have increased.
- want = 1
- }
- if got := invalid.Value(); got != want {
- t.Errorf("got invalid = %d, want = %d", got, want)
- }
-
- want = 0
- if test.valid && !isRouter && typ.routerOnly {
- // RouterOnlyPacketsReceivedByHost count should have increased.
- want = 1
- }
- if got := routerOnly.Value(); got != want {
- t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want)
- }
-
- })
+ want = 0
+ if test.valid && !isRouter && typ.routerOnly {
+ // RouterOnlyPacketsReceivedByHost count should have increased.
+ want = 1
+ }
+ if got := routerOnly.Value(); got != want {
+ t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want)
}
+
})
}
- }
- })
+ })
+ }
}
-
}
// TestNeighborAdvertisementValidation tests that the NIC validates received
@@ -1218,7 +970,6 @@ func TestNeighborAdvertisementValidation(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- UseNeighborCache: true,
})
e := channel.New(0, header.IPv6MinimumMTU, linkAddr0)
e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
@@ -1291,20 +1042,6 @@ func TestNeighborAdvertisementValidation(t *testing.T) {
// NDP Router Advertisement packets, it validates the Router Advertisement
// properly before handling them.
func TestRouterAdvertValidation(t *testing.T) {
- stacks := []struct {
- name string
- useNeighborCache bool
- }{
- {
- name: "linkAddrCache",
- useNeighborCache: false,
- },
- {
- name: "neighborCache",
- useNeighborCache: true,
- },
- }
-
tests := []struct {
name string
src tcpip.Address
@@ -1426,68 +1163,170 @@ func TestRouterAdvertValidation(t *testing.T) {
},
}
- for _, stackTyp := range stacks {
- t.Run(stackTyp.name, func(t *testing.T) {
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- e := channel.New(10, 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- UseNeighborCache: stackTyp.useNeighborCache,
- })
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(10, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
- icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
- pkt := header.ICMPv6(hdr.Prepend(icmpSize))
- pkt.SetType(header.ICMPv6RouterAdvert)
- pkt.SetCode(test.code)
- copy(pkt.MessageBody(), test.ndpPayload)
- payloadLength := hdr.UsedLength()
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- TransportProtocol: icmp.ProtocolNumber6,
- HopLimit: test.hopLimit,
- SrcAddr: test.src,
- DstAddr: header.IPv6AllNodesMulticastAddress,
- })
+ icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
+ pkt := header.ICMPv6(hdr.Prepend(icmpSize))
+ pkt.SetType(header.ICMPv6RouterAdvert)
+ pkt.SetCode(test.code)
+ copy(pkt.MessageBody(), test.ndpPayload)
+ payloadLength := hdr.UsedLength()
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: icmp.ProtocolNumber6,
+ HopLimit: test.hopLimit,
+ SrcAddr: test.src,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
+ })
- stats := s.Stats().ICMP.V6.PacketsReceived
- invalid := stats.Invalid
- rxRA := stats.RouterAdvert
+ stats := s.Stats().ICMP.V6.PacketsReceived
+ invalid := stats.Invalid
+ rxRA := stats.RouterAdvert
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- if got := rxRA.Value(); got != 0 {
- t.Fatalf("got rxRA = %d, want = 0", got)
- }
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := rxRA.Value(); got != 0 {
+ t.Fatalf("got rxRA = %d, want = 0", got)
+ }
- e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: hdr.View().ToVectorisedView(),
- }))
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
- if got := rxRA.Value(); got != 1 {
- t.Fatalf("got rxRA = %d, want = 1", got)
- }
+ if got := rxRA.Value(); got != 1 {
+ t.Fatalf("got rxRA = %d, want = 1", got)
+ }
- if test.expectedSuccess {
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
- } else {
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
- }
- }
- })
+ if test.expectedSuccess {
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ } else {
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
}
})
}
}
+
+// TestCheckDuplicateAddress checks that calls to CheckDuplicateAddress and DAD
+// performed when adding new addresses do not interfere with each other.
+func TestCheckDuplicateAddress(t *testing.T) {
+ const nicID = 1
+
+ clock := faketime.NewManualClock()
+ dadConfigs := stack.DADConfigurations{
+ DupAddrDetectTransmits: 1,
+ RetransmitTimer: time.Second,
+ }
+ s := stack.New(stack.Options{
+ Clock: clock,
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{
+ DADConfigs: dadConfigs,
+ })},
+ })
+ // This test is expected to send at max 2 DAD messages. We allow an extra
+ // packet to be stored to catch unexpected packets.
+ e := channel.New(3, header.IPv6MinimumMTU, linkAddr0)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ dadPacketsSent := 1
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ // Start DAD for the address we just added.
+ //
+ // Even though the stack will perform DAD before the added address transitions
+ // from tentative to assigned, this DAD request should be independent of that.
+ ch := make(chan stack.DADResult, 3)
+ dadRequestsMade := 1
+ dadPacketsSent++
+ if res, err := s.CheckDuplicateAddress(nicID, ProtocolNumber, lladdr0, func(r stack.DADResult) {
+ ch <- r
+ }); err != nil {
+ t.Fatalf("s.CheckDuplicateAddress(%d, %d, %s, _): %s", nicID, ProtocolNumber, lladdr0, err)
+ } else if res != stack.DADStarting {
+ t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, ProtocolNumber, lladdr0, res, stack.DADStarting)
+ }
+
+ // Remove the address and make sure our DAD request was not stopped.
+ if err := s.RemoveAddress(nicID, lladdr0); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s): %s", nicID, lladdr0, err)
+ }
+ // Should not restart DAD since we already requested DAD above - the handler
+ // should be called when the original request compeletes so we should not send
+ // an extra DAD message here.
+ dadRequestsMade++
+ if res, err := s.CheckDuplicateAddress(nicID, ProtocolNumber, lladdr0, func(r stack.DADResult) {
+ ch <- r
+ }); err != nil {
+ t.Fatalf("s.CheckDuplicateAddress(%d, %d, %s, _): %s", nicID, ProtocolNumber, lladdr0, err)
+ } else if res != stack.DADAlreadyRunning {
+ t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, ProtocolNumber, lladdr0, res, stack.DADAlreadyRunning)
+ }
+
+ // Wait for DAD to resolve.
+ clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer)
+ for i := 0; i < dadRequestsMade; i++ {
+ if diff := cmp.Diff(stack.DADResult{Resolved: true}, <-ch); diff != "" {
+ t.Errorf("(i=%d) DAD result mismatch (-want +got):\n%s", i, diff)
+ }
+ }
+ // Should have no more results.
+ select {
+ case r := <-ch:
+ t.Errorf("unexpectedly got an extra DAD result; r = %#v", r)
+ default:
+ }
+
+ snmc := header.SolicitedNodeAddr(lladdr0)
+ remoteLinkAddr := header.EthernetAddressFromMulticastIPv6Address(snmc)
+
+ for i := 0; i < dadPacketsSent; i++ {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected %d-th DAD message", i)
+ }
+
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Errorf("(i=%d) got p.Proto = %d, want = %d", i, p.Proto, header.IPv6ProtocolNumber)
+ }
+
+ if p.Route.RemoteLinkAddress != remoteLinkAddr {
+ t.Errorf("(i=%d) got p.Route.RemoteLinkAddress = %s, want = %s", i, p.Route.RemoteLinkAddress, remoteLinkAddr)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(snmc),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(lladdr0),
+ checker.NDPNSOptions(nil),
+ ))
+ }
+
+ // Should have no more packets.
+ if p, ok := e.Read(); ok {
+ t.Errorf("got unexpected packet = %#v", p)
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/stats.go b/pkg/tcpip/network/ipv6/stats.go
index 0839be3cd..c2758352f 100644
--- a/pkg/tcpip/network/ipv6/stats.go
+++ b/pkg/tcpip/network/ipv6/stats.go
@@ -16,7 +16,7 @@ package ipv6
import (
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/network/ip"
+ "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index 1e00144a5..dc37e61a4 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -54,11 +54,10 @@ type SocketOptionsHandler interface {
// HasNIC is invoked to check if the NIC is valid for SO_BINDTODEVICE.
HasNIC(v int32) bool
- // GetSendBufferSize is invoked to get the SO_SNDBUFSIZE.
- GetSendBufferSize() (int64, Error)
-
- // IsUnixSocket is invoked to check if the socket is of unix domain.
- IsUnixSocket() bool
+ // OnSetSendBufferSize is invoked when the send buffer size for an endpoint is
+ // changed. The handler is invoked with the new value for the socket send
+ // buffer size. It also returns the newly set value.
+ OnSetSendBufferSize(v int64) (newSz int64)
}
// DefaultSocketOptionsHandler is an embeddable type that implements no-op
@@ -95,14 +94,9 @@ func (*DefaultSocketOptionsHandler) HasNIC(int32) bool {
return false
}
-// GetSendBufferSize implements SocketOptionsHandler.GetSendBufferSize.
-func (*DefaultSocketOptionsHandler) GetSendBufferSize() (int64, Error) {
- return 0, nil
-}
-
-// IsUnixSocket implements SocketOptionsHandler.IsUnixSocket.
-func (*DefaultSocketOptionsHandler) IsUnixSocket() bool {
- return false
+// OnSetSendBufferSize implements SocketOptionsHandler.OnSetSendBufferSize.
+func (*DefaultSocketOptionsHandler) OnSetSendBufferSize(v int64) (newSz int64) {
+ return v
}
// StackHandler holds methods to access the stack options. These must be
@@ -600,42 +594,41 @@ func (so *SocketOptions) SetBindToDevice(bindToDevice int32) Error {
}
// GetSendBufferSize gets value for SO_SNDBUF option.
-func (so *SocketOptions) GetSendBufferSize() (int64, Error) {
- if so.handler.IsUnixSocket() {
- return so.handler.GetSendBufferSize()
- }
- return atomic.LoadInt64(&so.sendBufferSize), nil
+func (so *SocketOptions) GetSendBufferSize() int64 {
+ return atomic.LoadInt64(&so.sendBufferSize)
}
// SetSendBufferSize sets value for SO_SNDBUF option. notify indicates if the
// stack handler should be invoked to set the send buffer size.
func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) {
- if so.handler.IsUnixSocket() {
+ v := sendBufferSize
+
+ if !notify {
+ atomic.StoreInt64(&so.sendBufferSize, v)
return
}
- v := sendBufferSize
- if notify {
- // TODO(b/176170271): Notify waiters after size has grown.
- // Make sure the send buffer size is within the min and max
- // allowed.
- ss := so.getSendBufferLimits(so.stackHandler)
- min := int64(ss.Min)
- max := int64(ss.Max)
- // Validate the send buffer size with min and max values.
- // Multiply it by factor of 2.
- if v > max {
- v = max
- }
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ ss := so.getSendBufferLimits(so.stackHandler)
+ min := int64(ss.Min)
+ max := int64(ss.Max)
+ // Validate the send buffer size with min and max values.
+ // Multiply it by factor of 2.
+ if v > max {
+ v = max
+ }
- if v < math.MaxInt32/PacketOverheadFactor {
- v *= PacketOverheadFactor
- if v < min {
- v = min
- }
- } else {
- v = math.MaxInt32
+ if v < math.MaxInt32/PacketOverheadFactor {
+ v *= PacketOverheadFactor
+ if v < min {
+ v = min
}
+ } else {
+ v = math.MaxInt32
}
- atomic.StoreInt64(&so.sendBufferSize, v)
+
+ // Notify endpoint about change in buffer size.
+ newSz := so.handler.OnSetSendBufferSize(v)
+ atomic.StoreInt64(&so.sendBufferSize, newSz)
}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index ee23c9b98..49362333a 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -4,18 +4,6 @@ load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
go_template_instance(
- name = "linkaddrentry_list",
- out = "linkaddrentry_list.go",
- package = "stack",
- prefix = "linkAddrEntry",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*linkAddrEntry",
- "Linker": "*linkAddrEntry",
- },
-)
-
-go_template_instance(
name = "neighbor_entry_list",
out = "neighbor_entry_list.go",
package = "stack",
@@ -62,8 +50,6 @@ go_library(
"iptables_state.go",
"iptables_targets.go",
"iptables_types.go",
- "linkaddrcache.go",
- "linkaddrentry_list.go",
"neighbor_cache.go",
"neighbor_entry.go",
"neighbor_entry_list.go",
@@ -141,7 +127,6 @@ go_test(
size = "small",
srcs = [
"forwarding_test.go",
- "linkaddrcache_test.go",
"neighbor_cache_test.go",
"neighbor_entry_test.go",
"nic_test.go",
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 54617f2e6..cdb435644 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -231,6 +231,12 @@ func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
return &conn
}
+func (ct *ConnTrack) init() {
+ ct.mu.Lock()
+ defer ct.mu.Unlock()
+ ct.buckets = make([]bucket, numBuckets)
+}
+
// connFor gets the conn for pkt if it exists, or returns nil
// if it does not. It returns an error when pkt does not contain a valid TCP
// header.
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index c24f56ece..c987c1851 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -75,6 +75,10 @@ func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
}
func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) {
+ if _, _, ok := f.proto.Parse(pkt); !ok {
+ return
+ }
+
netHdr := pkt.NetworkHeader().View()
_, dst := f.proto.ParseAddresses(netHdr)
@@ -161,9 +165,9 @@ var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil)
type fwdTestNetworkProtocol struct {
stack *Stack
- neighborTable neighborTable
+ neigh *neighborCache
addrResolveDelay time.Duration
- onLinkAddressResolved func(neighborTable, tcpip.Address, tcpip.LinkAddress)
+ onLinkAddressResolved func(*neighborCache, tcpip.Address, tcpip.LinkAddress)
onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool)
mu struct {
@@ -221,7 +225,7 @@ func (*fwdTestNetworkProtocol) Wait() {}
func (f *fwdTestNetworkEndpoint) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error {
if fn := f.proto.onLinkAddressResolved; fn != nil {
time.AfterFunc(f.proto.addrResolveDelay, func() {
- fn(f.proto.neighborTable, addr, remoteLinkAddr)
+ fn(f.proto.neigh, addr, remoteLinkAddr)
})
}
return nil
@@ -357,14 +361,13 @@ func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protoco
panic("not implemented")
}
-func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) {
+func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) {
// Create a stack with the network protocol and two NICs.
s := New(Options{
NetworkProtocols: []NetworkProtocolFactory{func(s *Stack) NetworkProtocol {
proto.stack = s
return proto
}},
- UseNeighborCache: useNeighborCache,
})
// Enable forwarding.
@@ -402,7 +405,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC
}
if l, ok := nic.linkAddrResolvers[fwdTestNetNumber]; ok {
- proto.neighborTable = l.neighborTable
+ proto.neigh = &l.neigh
}
// Route all packets to NIC 2.
@@ -418,121 +421,85 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC
}
func TestForwardingWithStaticResolver(t *testing.T) {
- tests := []struct {
- name string
- useNeighborCache bool
- }{
- {
- name: "linkAddrCache",
- useNeighborCache: false,
- },
- {
- name: "neighborCache",
- useNeighborCache: true,
+ // Create a network protocol with a static resolver.
+ proto := &fwdTestNetworkProtocol{
+ onResolveStaticAddress:
+ // The network address 3 is resolved to the link address "c".
+ func(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if addr == "\x03" {
+ return "c", true
+ }
+ return "", false
},
}
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- // Create a network protocol with a static resolver.
- proto := &fwdTestNetworkProtocol{
- onResolveStaticAddress:
- // The network address 3 is resolved to the link address "c".
- func(addr tcpip.Address) (tcpip.LinkAddress, bool) {
- if addr == "\x03" {
- return "c", true
- }
- return "", false
- },
- }
-
- ep1, ep2 := fwdTestNetFactory(t, proto, test.useNeighborCache)
+ ep1, ep2 := fwdTestNetFactory(t, proto)
- // Inject an inbound packet to address 3 on NIC 1, and see if it is
- // forwarded to NIC 2.
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 3
- ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
+ // Inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
- var p fwdTestPacketInfo
+ var p fwdTestPacketInfo
- select {
- case p = <-ep2.C:
- default:
- t.Fatal("packet not forwarded")
- }
+ select {
+ case p = <-ep2.C:
+ default:
+ t.Fatal("packet not forwarded")
+ }
- // Test that the static address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
- }
- })
+ // Test that the static address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
}
}
func TestForwardingWithFakeResolver(t *testing.T) {
- tests := []struct {
- name string
- useNeighborCache bool
- }{
- {
- name: "linkAddrCache",
- useNeighborCache: false,
- },
- {
- name: "neighborCache",
- useNeighborCache: true,
+ proto := fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
+ t.Helper()
+ if len(linkAddr) != 0 {
+ t.Fatalf("got linkAddr=%q, want unspecified", linkAddr)
+ }
+ // Any address will be resolved to the link address "c".
+ neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
},
}
+ ep1, ep2 := fwdTestNetFactory(t, &proto)
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- proto := fwdTestNetworkProtocol{
- addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
- t.Helper()
- if len(linkAddr) != 0 {
- t.Fatalf("got linkAddr=%q, want unspecified", linkAddr)
- }
- // Any address will be resolved to the link address "c".
- neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- },
- }
- ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache)
-
- // Inject an inbound packet to address 3 on NIC 1, and see if it is
- // forwarded to NIC 2.
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 3
- ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- var p fwdTestPacketInfo
-
- select {
- case p = <-ep2.C:
- case <-time.After(time.Second):
- t.Fatal("packet not forwarded")
- }
+ // Inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
- // Test that the address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
- }
- })
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
}
}
@@ -542,7 +509,7 @@ func TestForwardingWithNoResolver(t *testing.T) {
// Whether or not we use the neighbor cache here does not matter since
// neither linkAddrCache nor neighborCache will be used.
- ep1, ep2 := fwdTestNetFactory(t, proto, false /* useNeighborCache */)
+ ep1, ep2 := fwdTestNetFactory(t, proto)
// inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
@@ -562,12 +529,12 @@ func TestForwardingWithNoResolver(t *testing.T) {
func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) {
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 50 * time.Millisecond,
- onLinkAddressResolved: func(neighborTable, tcpip.Address, tcpip.LinkAddress) {
+ onLinkAddressResolved: func(*neighborCache, tcpip.Address, tcpip.LinkAddress) {
// Don't resolve the link address.
},
}
- ep1, ep2 := fwdTestNetFactory(t, proto, true /* useNeighborCache */)
+ ep1, ep2 := fwdTestNetFactory(t, proto)
const numPackets int = 5
// These packets will all be enqueued in the packet queue to wait for link
@@ -592,300 +559,227 @@ func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) {
}
func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
- tests := []struct {
- name string
- useNeighborCache bool
- }{
- {
- name: "linkAddrCache",
- useNeighborCache: false,
- },
- {
- name: "neighborCache",
- useNeighborCache: true,
+ proto := fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
+ t.Helper()
+ if len(linkAddr) != 0 {
+ t.Fatalf("got linkAddr=%q, want unspecified", linkAddr)
+ }
+ // Only packets to address 3 will be resolved to the
+ // link address "c".
+ if addr == "\x03" {
+ neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ }
},
}
+ ep1, ep2 := fwdTestNetFactory(t, &proto)
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- proto := fwdTestNetworkProtocol{
- addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
- t.Helper()
- if len(linkAddr) != 0 {
- t.Fatalf("got linkAddr=%q, want unspecified", linkAddr)
- }
- // Only packets to address 3 will be resolved to the
- // link address "c".
- if addr == "\x03" {
- neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- }
- },
- }
- ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache)
-
- // Inject an inbound packet to address 4 on NIC 1. This packet should
- // not be forwarded.
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 4
- ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- // Inject an inbound packet to address 3 on NIC 1, and see if it is
- // forwarded to NIC 2.
- buf = buffer.NewView(30)
- buf[dstAddrOffset] = 3
- ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- var p fwdTestPacketInfo
-
- select {
- case p = <-ep2.C:
- case <-time.After(time.Second):
- t.Fatal("packet not forwarded")
- }
+ // Inject an inbound packet to address 4 on NIC 1. This packet should
+ // not be forwarded.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 4
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
- if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 {
- t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset])
- }
+ // Inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf = buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
- // Test that the address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
- }
- })
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset])
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
}
}
func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
- tests := []struct {
- name string
- useNeighborCache bool
- }{
- {
- name: "linkAddrCache",
- useNeighborCache: false,
- },
- {
- name: "neighborCache",
- useNeighborCache: true,
+ proto := fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
+ t.Helper()
+ if len(linkAddr) != 0 {
+ t.Fatalf("got linkAddr=%q, want unspecified", linkAddr)
+ }
+ // Any packets will be resolved to the link address "c".
+ neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
},
}
+ ep1, ep2 := fwdTestNetFactory(t, &proto)
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- proto := fwdTestNetworkProtocol{
- addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
- t.Helper()
- if len(linkAddr) != 0 {
- t.Fatalf("got linkAddr=%q, want unspecified", linkAddr)
- }
- // Any packets will be resolved to the link address "c".
- neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- },
- }
- ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache)
-
- // Inject two inbound packets to address 3 on NIC 1.
- for i := 0; i < 2; i++ {
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 3
- ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- }
+ // Inject two inbound packets to address 3 on NIC 1.
+ for i := 0; i < 2; i++ {
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+ }
- for i := 0; i < 2; i++ {
- var p fwdTestPacketInfo
-
- select {
- case p = <-ep2.C:
- case <-time.After(time.Second):
- t.Fatal("packet not forwarded")
- }
-
- if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 {
- t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset])
- }
-
- // Test that the address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
- }
- }
- })
+ for i := 0; i < 2; i++ {
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset])
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
}
}
func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
- tests := []struct {
- name string
- useNeighborCache bool
- }{
- {
- name: "linkAddrCache",
- useNeighborCache: false,
- },
- {
- name: "neighborCache",
- useNeighborCache: true,
+ proto := fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
+ t.Helper()
+ if len(linkAddr) != 0 {
+ t.Fatalf("got linkAddr=%q, want unspecified", linkAddr)
+ }
+ // Any packets will be resolved to the link address "c".
+ neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
},
}
+ ep1, ep2 := fwdTestNetFactory(t, &proto)
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- proto := fwdTestNetworkProtocol{
- addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
- t.Helper()
- if len(linkAddr) != 0 {
- t.Fatalf("got linkAddr=%q, want unspecified", linkAddr)
- }
- // Any packets will be resolved to the link address "c".
- neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- },
- }
- ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache)
-
- for i := 0; i < maxPendingPacketsPerResolution+5; i++ {
- // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1.
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 3
- // Set the packet sequence number.
- binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i))
- ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- }
+ for i := 0; i < maxPendingPacketsPerResolution+5; i++ {
+ // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ // Set the packet sequence number.
+ binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i))
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+ }
- for i := 0; i < maxPendingPacketsPerResolution; i++ {
- var p fwdTestPacketInfo
-
- select {
- case p = <-ep2.C:
- case <-time.After(time.Second):
- t.Fatal("packet not forwarded")
- }
-
- b := PayloadSince(p.Pkt.NetworkHeader())
- if b[dstAddrOffset] != 3 {
- t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset])
- }
- if len(b) < fwdTestNetHeaderLen+2 {
- t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b)
- }
- seqNumBuf := b[fwdTestNetHeaderLen:]
-
- // The first 5 packets should not be forwarded so the sequence number should
- // start with 5.
- want := uint16(i + 5)
- if n := binary.BigEndian.Uint16(seqNumBuf); n != want {
- t.Fatalf("got the packet #%d, want = #%d", n, want)
- }
-
- // Test that the address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
- }
- }
- })
+ for i := 0; i < maxPendingPacketsPerResolution; i++ {
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ b := PayloadSince(p.Pkt.NetworkHeader())
+ if b[dstAddrOffset] != 3 {
+ t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset])
+ }
+ if len(b) < fwdTestNetHeaderLen+2 {
+ t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b)
+ }
+ seqNumBuf := b[fwdTestNetHeaderLen:]
+
+ // The first 5 packets should not be forwarded so the sequence number should
+ // start with 5.
+ want := uint16(i + 5)
+ if n := binary.BigEndian.Uint16(seqNumBuf); n != want {
+ t.Fatalf("got the packet #%d, want = #%d", n, want)
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
}
}
func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
- tests := []struct {
- name string
- useNeighborCache bool
- proto *fwdTestNetworkProtocol
- }{
- {
- name: "linkAddrCache",
- useNeighborCache: false,
- },
- {
- name: "neighborCache",
- useNeighborCache: true,
+ proto := fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
+ t.Helper()
+ if len(linkAddr) != 0 {
+ t.Fatalf("got linkAddr=%q, want unspecified", linkAddr)
+ }
+ // Any packets will be resolved to the link address "c".
+ neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
},
}
+ ep1, ep2 := fwdTestNetFactory(t, &proto)
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- proto := fwdTestNetworkProtocol{
- addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
- t.Helper()
- if len(linkAddr) != 0 {
- t.Fatalf("got linkAddr=%q, want unspecified", linkAddr)
- }
- // Any packets will be resolved to the link address "c".
- neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- },
- }
- ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache)
-
- for i := 0; i < maxPendingResolutions+5; i++ {
- // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1.
- // Each packet has a different destination address (3 to
- // maxPendingResolutions + 7).
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = byte(3 + i)
- ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- }
+ for i := 0; i < maxPendingResolutions+5; i++ {
+ // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1.
+ // Each packet has a different destination address (3 to
+ // maxPendingResolutions + 7).
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = byte(3 + i)
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+ }
- for i := 0; i < maxPendingResolutions; i++ {
- var p fwdTestPacketInfo
-
- select {
- case p = <-ep2.C:
- case <-time.After(time.Second):
- t.Fatal("packet not forwarded")
- }
-
- // The first 5 packets (address 3 to 7) should not be forwarded
- // because their address resolutions are interrupted.
- if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 {
- t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset])
- }
-
- // Test that the address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
- }
- }
- })
+ for i := 0; i < maxPendingResolutions; i++ {
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ // The first 5 packets (address 3 to 7) should not be forwarded
+ // because their address resolutions are interrupted.
+ if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset])
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
}
}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 63832c200..52890f6eb 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -235,7 +235,7 @@ func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) tcpip.Error
// If iptables is being enabled, initialize the conntrack table and
// reaper.
if !it.modified {
- it.connections.buckets = make([]bucket, numBuckets)
+ it.connections.init()
it.startReaper(reaperDelay)
}
it.modified = true
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
deleted file mode 100644
index 5b6b58b1d..000000000
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ /dev/null
@@ -1,359 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package stack
-
-import (
- "fmt"
- "time"
-
- "gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/tcpip"
-)
-
-const linkAddrCacheSize = 512 // max cache entries
-
-// linkAddrCache is a fixed-sized cache mapping IP addresses to link addresses.
-//
-// The entries are stored in a ring buffer, oldest entry replaced first.
-//
-// This struct is safe for concurrent use.
-type linkAddrCache struct {
- nic *NIC
-
- linkRes LinkAddressResolver
-
- // ageLimit is how long a cache entry is valid for.
- ageLimit time.Duration
-
- // resolutionTimeout is the amount of time to wait for a link request to
- // resolve an address.
- resolutionTimeout time.Duration
-
- // resolutionAttempts is the number of times an address is attempted to be
- // resolved before failing.
- resolutionAttempts int
-
- mu struct {
- sync.Mutex
- table map[tcpip.Address]*linkAddrEntry
- lru linkAddrEntryList
- }
-}
-
-// entryState controls the state of a single entry in the cache.
-type entryState int
-
-const (
- // incomplete means that there is an outstanding request to resolve the
- // address. This is the initial state.
- incomplete entryState = iota
- // ready means that the address has been resolved and can be used.
- ready
-)
-
-// String implements Stringer.
-func (s entryState) String() string {
- switch s {
- case incomplete:
- return "incomplete"
- case ready:
- return "ready"
- default:
- return fmt.Sprintf("unknown(%d)", s)
- }
-}
-
-// A linkAddrEntry is an entry in the linkAddrCache.
-// This struct is thread-compatible.
-type linkAddrEntry struct {
- // linkAddrEntryEntry access is synchronized by the linkAddrCache lock.
- linkAddrEntryEntry
-
- cache *linkAddrCache
-
- mu struct {
- sync.RWMutex
-
- addr tcpip.Address
- linkAddr tcpip.LinkAddress
- expiration time.Time
- s entryState
-
- // done is closed when address resolution is complete. It is nil iff s is
- // incomplete and resolution is not yet in progress.
- done chan struct{}
-
- // onResolve is called with the result of address resolution.
- onResolve []func(LinkResolutionResult)
- }
-}
-
-func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) {
- res := LinkResolutionResult{LinkAddress: linkAddr, Success: len(linkAddr) != 0}
- for _, callback := range e.mu.onResolve {
- callback(res)
- }
- e.mu.onResolve = nil
- if ch := e.mu.done; ch != nil {
- close(ch)
- e.mu.done = nil
- // Dequeue the pending packets in a new goroutine to not hold up the current
- // goroutine as writing packets may be a costly operation.
- //
- // At the time of writing, when writing packets, a neighbor's link address
- // is resolved (which ends up obtaining the entry's lock) while holding the
- // link resolution queue's lock. Dequeuing packets in a new goroutine avoids
- // a lock ordering violation.
- go e.cache.nic.linkResQueue.dequeue(ch, linkAddr, len(linkAddr) != 0)
- }
-}
-
-// changeStateLocked sets the entry's state to ns.
-//
-// The entry's expiration is bumped up to the greater of itself and the passed
-// expiration; the zero value indicates immediate expiration, and is set
-// unconditionally - this is an implementation detail that allows for entries
-// to be reused.
-//
-// Precondition: e.mu must be locked
-func (e *linkAddrEntry) changeStateLocked(ns entryState, expiration time.Time) {
- if e.mu.s == incomplete && ns == ready {
- e.notifyCompletionLocked(e.mu.linkAddr)
- }
-
- if expiration.IsZero() || expiration.After(e.mu.expiration) {
- e.mu.expiration = expiration
- }
- e.mu.s = ns
-}
-
-// add adds a k -> v mapping to the cache.
-func (c *linkAddrCache) add(k tcpip.Address, v tcpip.LinkAddress) {
- // Calculate expiration time before acquiring the lock, since expiration is
- // relative to the time when information was learned, rather than when it
- // happened to be inserted into the cache.
- expiration := time.Now().Add(c.ageLimit)
-
- c.mu.Lock()
- entry := c.getOrCreateEntryLocked(k)
- entry.mu.Lock()
- defer entry.mu.Unlock()
- c.mu.Unlock()
-
- entry.mu.linkAddr = v
- entry.changeStateLocked(ready, expiration)
-}
-
-// getOrCreateEntryLocked retrieves a cache entry associated with k. The
-// returned entry is always refreshed in the cache (it is reachable via the
-// map, and its place is bumped in LRU).
-//
-// If a matching entry exists in the cache, it is returned. If no matching
-// entry exists and the cache is full, an existing entry is evicted via LRU,
-// reset to state incomplete, and returned. If no matching entry exists and the
-// cache is not full, a new entry with state incomplete is allocated and
-// returned.
-func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry {
- if entry, ok := c.mu.table[k]; ok {
- c.mu.lru.Remove(entry)
- c.mu.lru.PushFront(entry)
- return entry
- }
- var entry *linkAddrEntry
- if len(c.mu.table) == linkAddrCacheSize {
- entry = c.mu.lru.Back()
- entry.mu.Lock()
-
- delete(c.mu.table, entry.mu.addr)
- c.mu.lru.Remove(entry)
-
- // Wake waiters and mark the soon-to-be-reused entry as expired.
- entry.notifyCompletionLocked("" /* linkAddr */)
- entry.mu.Unlock()
- } else {
- entry = new(linkAddrEntry)
- }
-
- *entry = linkAddrEntry{
- cache: c,
- }
- entry.mu.Lock()
- entry.mu.addr = k
- entry.mu.s = incomplete
- entry.mu.Unlock()
- c.mu.table[k] = entry
- c.mu.lru.PushFront(entry)
- return entry
-}
-
-// get reports any known link address for addr.
-func (c *linkAddrCache) get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) {
- c.mu.Lock()
- entry := c.getOrCreateEntryLocked(addr)
- entry.mu.Lock()
- defer entry.mu.Unlock()
- c.mu.Unlock()
-
- switch s := entry.mu.s; s {
- case ready:
- if !time.Now().After(entry.mu.expiration) {
- // Not expired.
- if onResolve != nil {
- onResolve(LinkResolutionResult{LinkAddress: entry.mu.linkAddr, Success: true})
- }
- return entry.mu.linkAddr, nil, nil
- }
-
- entry.changeStateLocked(incomplete, time.Time{})
- fallthrough
- case incomplete:
- if onResolve != nil {
- entry.mu.onResolve = append(entry.mu.onResolve, onResolve)
- }
- if entry.mu.done == nil {
- entry.mu.done = make(chan struct{})
- go c.startAddressResolution(addr, localAddr, entry.mu.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
- }
- return entry.mu.linkAddr, entry.mu.done, &tcpip.ErrWouldBlock{}
- default:
- panic(fmt.Sprintf("invalid cache entry state: %s", s))
- }
-}
-
-func (c *linkAddrCache) startAddressResolution(k tcpip.Address, localAddr tcpip.Address, done <-chan struct{}) {
- for i := 0; ; i++ {
- // Send link request, then wait for the timeout limit and check
- // whether the request succeeded.
- c.linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */)
-
- select {
- case now := <-time.After(c.resolutionTimeout):
- if stop := c.checkLinkRequest(now, k, i); stop {
- return
- }
- case <-done:
- return
- }
- }
-}
-
-// checkLinkRequest checks whether previous attempt to resolve address has
-// succeeded and mark the entry accordingly. Returns true if request can stop,
-// false if another request should be sent.
-func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.Address, attempt int) bool {
- c.mu.Lock()
- defer c.mu.Unlock()
- entry, ok := c.mu.table[k]
- if !ok {
- // Entry was evicted from the cache.
- return true
- }
- entry.mu.Lock()
- defer entry.mu.Unlock()
-
- switch s := entry.mu.s; s {
- case ready:
- // Entry was made ready by resolver.
- case incomplete:
- if attempt+1 < c.resolutionAttempts {
- // No response yet, need to send another ARP request.
- return false
- }
- // Max number of retries reached, delete entry.
- entry.notifyCompletionLocked("" /* linkAddr */)
- delete(c.mu.table, k)
- default:
- panic(fmt.Sprintf("invalid cache entry state: %s", s))
- }
- return true
-}
-
-func (c *linkAddrCache) init(nic *NIC, ageLimit, resolutionTimeout time.Duration, resolutionAttempts int, linkRes LinkAddressResolver) {
- *c = linkAddrCache{
- nic: nic,
- linkRes: linkRes,
- ageLimit: ageLimit,
- resolutionTimeout: resolutionTimeout,
- resolutionAttempts: resolutionAttempts,
- }
-
- c.mu.Lock()
- c.mu.table = make(map[tcpip.Address]*linkAddrEntry, linkAddrCacheSize)
- c.mu.Unlock()
-}
-
-var _ neighborTable = (*linkAddrCache)(nil)
-
-func (*linkAddrCache) neighbors() ([]NeighborEntry, tcpip.Error) {
- return nil, &tcpip.ErrNotSupported{}
-}
-
-func (c *linkAddrCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAddress) {
- c.add(addr, linkAddr)
-}
-
-func (*linkAddrCache) remove(addr tcpip.Address) tcpip.Error {
- return &tcpip.ErrNotSupported{}
-}
-
-func (*linkAddrCache) removeAll() tcpip.Error {
- return &tcpip.ErrNotSupported{}
-}
-
-func (c *linkAddrCache) handleProbe(addr tcpip.Address, linkAddr tcpip.LinkAddress) {
- if len(linkAddr) != 0 {
- // NUD allows probes without a link address but linkAddrCache
- // is a simple neighbor table which does not implement NUD.
- //
- // As per RFC 4861 section 4.3,
- //
- // Source link-layer address
- // The link-layer address for the sender. MUST NOT be
- // included when the source IP address is the
- // unspecified address. Otherwise, on link layers
- // that have addresses this option MUST be included in
- // multicast solicitations and SHOULD be included in
- // unicast solicitations.
- c.add(addr, linkAddr)
- }
-}
-
-func (c *linkAddrCache) handleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) {
- if len(linkAddr) != 0 {
- // NUD allows confirmations without a link address but linkAddrCache
- // is a simple neighbor table which does not implement NUD.
- //
- // As per RFC 4861 section 4.4,
- //
- // Target link-layer address
- // The link-layer address for the target, i.e., the
- // sender of the advertisement. This option MUST be
- // included on link layers that have addresses when
- // responding to multicast solicitations. When
- // responding to a unicast Neighbor Solicitation this
- // option SHOULD be included.
- c.add(addr, linkAddr)
- }
-}
-
-func (c *linkAddrCache) handleUpperLevelConfirmation(tcpip.Address) {}
-
-func (*linkAddrCache) nudConfig() (NUDConfigurations, tcpip.Error) {
- return NUDConfigurations{}, &tcpip.ErrNotSupported{}
-}
-
-func (*linkAddrCache) setNUDConfig(NUDConfigurations) tcpip.Error {
- return &tcpip.ErrNotSupported{}
-}
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
deleted file mode 100644
index 9e7f331c9..000000000
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ /dev/null
@@ -1,291 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package stack
-
-import (
- "fmt"
- "math"
- "sync/atomic"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/tcpip"
-)
-
-type testaddr struct {
- addr tcpip.Address
- linkAddr tcpip.LinkAddress
-}
-
-var testAddrs = func() []testaddr {
- var addrs []testaddr
- for i := 0; i < 4*linkAddrCacheSize; i++ {
- addr := fmt.Sprintf("Addr%06d", i)
- addrs = append(addrs, testaddr{
- addr: tcpip.Address(addr),
- linkAddr: tcpip.LinkAddress("Link" + addr),
- })
- }
- return addrs
-}()
-
-type testLinkAddressResolver struct {
- cache *linkAddrCache
- delay time.Duration
- onLinkAddressRequest func()
-}
-
-func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress) tcpip.Error {
- // TODO(gvisor.dev/issue/5141): Use a fake clock.
- time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) })
- if f := r.onLinkAddressRequest; f != nil {
- f()
- }
- return nil
-}
-
-func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) {
- for _, ta := range testAddrs {
- if ta.addr == addr {
- r.cache.add(ta.addr, ta.linkAddr)
- break
- }
- }
-}
-
-func (*testLinkAddressResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
- if addr == "broadcast" {
- return "mac_broadcast", true
- }
- return "", false
-}
-
-func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
- return 1
-}
-
-func getBlocking(c *linkAddrCache, addr tcpip.Address) (tcpip.LinkAddress, tcpip.Error) {
- var attemptedResolution bool
- for {
- got, ch, err := c.get(addr, "", nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); ok {
- if attemptedResolution {
- return got, &tcpip.ErrTimeout{}
- }
- attemptedResolution = true
- <-ch
- continue
- }
- return got, err
- }
-}
-
-func newEmptyNIC() *NIC {
- n := &NIC{}
- n.linkResQueue.init(n)
- return n
-}
-
-func TestCacheOverflow(t *testing.T) {
- var c linkAddrCache
- c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, nil)
- for i := len(testAddrs) - 1; i >= 0; i-- {
- e := testAddrs[i]
- c.add(e.addr, e.linkAddr)
- got, _, err := c.get(e.addr, "", nil)
- if err != nil {
- t.Errorf("insert %d, c.get(%s, '', nil): %s", i, e.addr, err)
- }
- if got != e.linkAddr {
- t.Errorf("insert %d, got c.get(%s, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr)
- }
- }
- // Expect to find at least half of the most recent entries.
- for i := 0; i < linkAddrCacheSize/2; i++ {
- e := testAddrs[i]
- got, _, err := c.get(e.addr, "", nil)
- if err != nil {
- t.Errorf("check %d, c.get(%s, '', nil): %s", i, e.addr, err)
- }
- if got != e.linkAddr {
- t.Errorf("check %d, got c.get(%s, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr)
- }
- }
- // The earliest entries should no longer be in the cache.
- c.mu.Lock()
- defer c.mu.Unlock()
- for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- {
- e := testAddrs[i]
- if entry, ok := c.mu.table[e.addr]; ok {
- t.Errorf("unexpected entry at c.mu.table[%s]: %#v", e.addr, entry)
- }
- }
-}
-
-func TestCacheConcurrent(t *testing.T) {
- var c linkAddrCache
- linkRes := &testLinkAddressResolver{cache: &c}
- c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, linkRes)
-
- var wg sync.WaitGroup
- for r := 0; r < 16; r++ {
- wg.Add(1)
- go func() {
- for _, e := range testAddrs {
- c.add(e.addr, e.linkAddr)
- }
- wg.Done()
- }()
- }
- wg.Wait()
-
- // All goroutines add in the same order and add more values than
- // can fit in the cache, so our eviction strategy requires that
- // the last entry be present and the first be missing.
- e := testAddrs[len(testAddrs)-1]
- got, _, err := c.get(e.addr, "", nil)
- if err != nil {
- t.Errorf("c.get(%s, '', nil): %s", e.addr, err)
- }
- if got != e.linkAddr {
- t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr)
- }
-
- e = testAddrs[0]
- c.mu.Lock()
- defer c.mu.Unlock()
- if entry, ok := c.mu.table[e.addr]; ok {
- t.Errorf("unexpected entry at c.mu.table[%s]: %#v", e.addr, entry)
- }
-}
-
-func TestCacheAgeLimit(t *testing.T) {
- var c linkAddrCache
- linkRes := &testLinkAddressResolver{cache: &c}
- c.init(newEmptyNIC(), 1*time.Millisecond, 1*time.Second, 3, linkRes)
-
- e := testAddrs[0]
- c.add(e.addr, e.linkAddr)
- time.Sleep(50 * time.Millisecond)
- _, _, err := c.get(e.addr, "", nil)
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got c.get(%s, '', nil) = %s, want = ErrWouldBlock", e.addr, err)
- }
-}
-
-func TestCacheReplace(t *testing.T) {
- var c linkAddrCache
- c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, nil)
- e := testAddrs[0]
- l2 := e.linkAddr + "2"
- c.add(e.addr, e.linkAddr)
- got, _, err := c.get(e.addr, "", nil)
- if err != nil {
- t.Errorf("c.get(%s, '', nil): %s", e.addr, err)
- }
- if got != e.linkAddr {
- t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr)
- }
-
- c.add(e.addr, l2)
- got, _, err = c.get(e.addr, "", nil)
- if err != nil {
- t.Errorf("c.get(%s, '', nil): %s", e.addr, err)
- }
- if got != l2 {
- t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, l2)
- }
-}
-
-func TestCacheResolution(t *testing.T) {
- // There is a race condition causing this test to fail when the executor
- // takes longer than the resolution timeout to call linkAddrCache.get. This
- // is especially common when this test is run with gotsan.
- //
- // Using a large resolution timeout decreases the probability of experiencing
- // this race condition and does not affect how long this test takes to run.
- var c linkAddrCache
- linkRes := &testLinkAddressResolver{cache: &c}
- c.init(newEmptyNIC(), 1<<63-1, math.MaxInt64, 1, linkRes)
- for i, ta := range testAddrs {
- got, err := getBlocking(&c, ta.addr)
- if err != nil {
- t.Errorf("check %d, getBlocking(_, %s): %s", i, ta.addr, err)
- }
- if got != ta.linkAddr {
- t.Errorf("check %d, got getBlocking(_, %s) = %s, want = %s", i, ta.addr, got, ta.linkAddr)
- }
- }
-
- // Check that after resolved, address stays in the cache and never returns WouldBlock.
- for i := 0; i < 10; i++ {
- e := testAddrs[len(testAddrs)-1]
- got, _, err := c.get(e.addr, "", nil)
- if err != nil {
- t.Errorf("c.get(%s, '', nil): %s", e.addr, err)
- }
- if got != e.linkAddr {
- t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr)
- }
- }
-}
-
-func TestCacheResolutionFailed(t *testing.T) {
- var c linkAddrCache
- linkRes := &testLinkAddressResolver{cache: &c}
- c.init(newEmptyNIC(), 1<<63-1, 10*time.Millisecond, 5, linkRes)
-
- var requestCount uint32
- linkRes.onLinkAddressRequest = func() {
- atomic.AddUint32(&requestCount, 1)
- }
-
- // First, sanity check that resolution is working...
- e := testAddrs[0]
- got, err := getBlocking(&c, e.addr)
- if err != nil {
- t.Errorf("getBlocking(_, %s): %s", e.addr, err)
- }
- if got != e.linkAddr {
- t.Errorf("got getBlocking(_, %s) = %s, want = %s", e.addr, got, e.linkAddr)
- }
-
- before := atomic.LoadUint32(&requestCount)
-
- e.addr += "2"
- a, err := getBlocking(&c, e.addr)
- if _, ok := err.(*tcpip.ErrTimeout); !ok {
- t.Errorf("got getBlocking(_, %s) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{})
- }
-
- if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want {
- t.Errorf("got link address request count = %d, want = %d", got, want)
- }
-}
-
-func TestCacheResolutionTimeout(t *testing.T) {
- resolverDelay := 500 * time.Millisecond
- expiration := resolverDelay / 10
- var c linkAddrCache
- linkRes := &testLinkAddressResolver{cache: &c, delay: resolverDelay}
- c.init(newEmptyNIC(), expiration, 1*time.Millisecond, 3, linkRes)
-
- e := testAddrs[0]
- a, err := getBlocking(&c, e.addr)
- if _, ok := err.(*tcpip.ErrTimeout); !ok {
- t.Errorf("got getBlocking(_, %s) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{})
- }
-}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 0238605af..3b6ba9509 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -424,7 +424,7 @@ func TestDADResolve(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPDisp: &ndpDisp,
- NDPConfigs: ipv6.NDPConfigurations{
+ DADConfigs: stack.DADConfigurations{
RetransmitTimer: test.retransTimer,
DupAddrDetectTransmits: test.dupAddrDetectTransmits,
},
@@ -642,14 +642,14 @@ func TestDADFail(t *testing.T) {
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent, 1),
}
- ndpConfigs := ipv6.DefaultNDPConfigurations()
- ndpConfigs.RetransmitTimer = time.Second * 2
+ dadConfigs := stack.DefaultDADConfigurations()
+ dadConfigs.RetransmitTimer = time.Second * 2
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPDisp: &ndpDisp,
- NDPConfigs: ndpConfigs,
+ DADConfigs: dadConfigs,
})},
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -677,7 +677,7 @@ func TestDADFail(t *testing.T) {
// Wait for DAD to fail and make sure the address did
// not get resolved.
select {
- case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second):
// If we don't get a failure event after the
// expected resolution time + extra 1s buffer,
// something is wrong.
@@ -748,7 +748,7 @@ func TestDADStop(t *testing.T) {
dadC: make(chan ndpDADEvent, 1),
}
- ndpConfigs := ipv6.NDPConfigurations{
+ dadConfigs := stack.DADConfigurations{
RetransmitTimer: time.Second,
DupAddrDetectTransmits: 2,
}
@@ -757,7 +757,7 @@ func TestDADStop(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPDisp: &ndpDisp,
- NDPConfigs: ndpConfigs,
+ DADConfigs: dadConfigs,
})},
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -777,12 +777,12 @@ func TestDADStop(t *testing.T) {
// Wait for DAD to fail (since the address was removed during DAD).
select {
- case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second):
// If we don't get a failure event after the expected resolution
// time + extra 1s buffer, something is wrong.
t.Fatal("timed out waiting for DAD failure")
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" {
+ if diff := checkDADEvent(e, nicID, addr1, false, &tcpip.ErrAborted{}); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
@@ -865,16 +865,15 @@ func TestSetNDPConfigurations(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err)
}
- // Update the NDP configurations on NIC(1) to use DAD.
- configs := ipv6.NDPConfigurations{
- DupAddrDetectTransmits: test.dupAddrDetectTransmits,
- RetransmitTimer: test.retransmitTimer,
- }
+ // Update the configurations on NIC(1) to use DAD.
if ipv6Ep, err := s.GetNetworkEndpoint(nicID1, header.IPv6ProtocolNumber); err != nil {
t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, header.IPv6ProtocolNumber, err)
} else {
- ndpEP := ipv6Ep.(ipv6.NDPEndpoint)
- ndpEP.SetNDPConfigurations(configs)
+ dad := ipv6Ep.(stack.DuplicateAddressDetector)
+ dad.SetDADConfigurations(stack.DADConfigurations{
+ DupAddrDetectTransmits: test.dupAddrDetectTransmits,
+ RetransmitTimer: test.retransmitTimer,
+ })
}
// Created after updating NIC(1)'s NDP configurations
@@ -1903,9 +1902,11 @@ func TestAutoGenTempAddr(t *testing.T) {
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ DADConfigs: stack.DADConfigurations{
+ DupAddrDetectTransmits: test.dupAddrTransmits,
+ RetransmitTimer: test.retransmitTimer,
+ },
NDPConfigs: ipv6.NDPConfigurations{
- DupAddrDetectTransmits: test.dupAddrTransmits,
- RetransmitTimer: test.retransmitTimer,
HandleRAs: true,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
@@ -2202,9 +2203,11 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ DADConfigs: stack.DADConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ },
NDPConfigs: ipv6.NDPConfigurations{
- DupAddrDetectTransmits: dadTransmits,
- RetransmitTimer: retransmitTimer,
HandleRAs: true,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
@@ -2635,16 +2638,15 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
}
e := channel.New(0, 1280, linkAddr1)
- ndpConfigs := ipv6.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- AutoGenTempGlobalAddresses: test.tempAddrs,
- AutoGenAddressConflictRetries: 1,
- }
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: test.tempAddrs,
+ AutoGenAddressConflictRetries: 1,
+ },
+ NDPDisp: &ndpDisp,
OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: test.nicNameFromID,
},
@@ -2718,13 +2720,14 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
// Enable DAD.
ndpDisp.dadC = make(chan ndpDADEvent, 2)
- ndpConfigs.DupAddrDetectTransmits = dupAddrTransmits
- ndpConfigs.RetransmitTimer = retransmitTimer
if ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber); err != nil {
t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
} else {
- ndpEP := ipv6Ep.(ipv6.NDPEndpoint)
- ndpEP.SetNDPConfigurations(ndpConfigs)
+ ndpEP := ipv6Ep.(stack.DuplicateAddressDetector)
+ ndpEP.SetDADConfigurations(stack.DADConfigurations{
+ DupAddrDetectTransmits: dupAddrTransmits,
+ RetransmitTimer: retransmitTimer,
+ })
}
// Do SLAAC for prefix.
@@ -2769,7 +2772,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
// stack.Stack will have a default route through the router (llAddr3) installed
// and a static link-address (linkAddr3) added to the link address cache for the
// router.
-func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useNeighborCache bool) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) {
+func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) {
t.Helper()
ndpDisp := &ndpDispatcher{
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
@@ -2785,7 +2788,6 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN
NDPDisp: ndpDisp,
})},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- UseNeighborCache: useNeighborCache,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -2868,126 +2870,108 @@ func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullA
// TestAutoGenAddrDeprecateFromPI tests deprecating a SLAAC address when
// receiving a PI with 0 preferred lifetime.
func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
- stacks := []struct {
- name string
- useNeighborCache bool
- }{
- {
- name: "linkAddrCache",
- useNeighborCache: false,
- },
- {
- name: "neighborCache",
- useNeighborCache: true,
- },
- }
-
- for _, stackTyp := range stacks {
- t.Run(stackTyp.name, func(t *testing.T) {
- const nicID = 1
+ const nicID = 1
- prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
- prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
- ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache)
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
- expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
- t.Helper()
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil {
- t.Fatal(err)
- }
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil {
+ t.Fatal(err)
+ }
- if got := addrForNewConnection(t, s); got != addr.Address {
- t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
- }
- }
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
- // Receive PI for prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
- expectAutoGenAddrEvent(addr1, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should have %s in the list of addresses", addr1)
- }
- expectPrimaryAddr(addr1)
+ // Receive PI for prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ expectPrimaryAddr(addr1)
- // Deprecate addr for prefix1 immedaitely.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
- expectAutoGenAddrEvent(addr1, deprecatedAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should have %s in the list of addresses", addr1)
- }
- // addr should still be the primary endpoint as there are no other addresses.
- expectPrimaryAddr(addr1)
+ // Deprecate addr for prefix1 immedaitely.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, deprecatedAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ // addr should still be the primary endpoint as there are no other addresses.
+ expectPrimaryAddr(addr1)
- // Refresh lifetimes of addr generated from prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr1)
+ // Refresh lifetimes of addr generated from prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
- // Receive PI for prefix2.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
- expectAutoGenAddrEvent(addr2, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- expectPrimaryAddr(addr2)
+ // Receive PI for prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
- // Deprecate addr for prefix2 immedaitely.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
- expectAutoGenAddrEvent(addr2, deprecatedAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- // addr1 should be the primary endpoint now since addr2 is deprecated but
- // addr1 is not.
- expectPrimaryAddr(addr1)
- // addr2 is deprecated but if explicitly requested, it should be used.
- fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID}
- if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address)
- }
+ // Deprecate addr for prefix2 immedaitely.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr2, deprecatedAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr1 should be the primary endpoint now since addr2 is deprecated but
+ // addr1 is not.
+ expectPrimaryAddr(addr1)
+ // addr2 is deprecated but if explicitly requested, it should be used.
+ fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID}
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address)
+ }
- // Another PI w/ 0 preferred lifetime should not result in a deprecation
- // event.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr1)
- if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address)
- }
+ // Another PI w/ 0 preferred lifetime should not result in a deprecation
+ // event.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address)
+ }
- // Refresh lifetimes of addr generated from prefix2.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr2)
- })
+ // Refresh lifetimes of addr generated from prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
}
+ expectPrimaryAddr(addr2)
}
// TestAutoGenAddrJobDeprecation tests that an address is properly deprecated
@@ -2997,232 +2981,214 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
const newMinVL = 2
newMinVLDuration := newMinVL * time.Second
- stacks := []struct {
- name string
- useNeighborCache bool
- }{
- {
- name: "linkAddrCache",
- useNeighborCache: false,
- },
- {
- name: "neighborCache",
- useNeighborCache: true,
- },
- }
-
- for _, stackTyp := range stacks {
- t.Run(stackTyp.name, func(t *testing.T) {
- saved := ipv6.MinPrefixInformationValidLifetimeForUpdate
- defer func() {
- ipv6.MinPrefixInformationValidLifetimeForUpdate = saved
- }()
- ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
+ saved := ipv6.MinPrefixInformationValidLifetimeForUpdate
+ defer func() {
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = saved
+ }()
+ ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
- prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
- prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
- ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache)
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
- expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
- t.Helper()
+ expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
+ t.Helper()
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(timeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
+ case <-time.After(timeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ }
- expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
- t.Helper()
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil {
- t.Fatal(err)
- }
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil {
+ t.Fatal(err)
+ }
- if got := addrForNewConnection(t, s); got != addr.Address {
- t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
- }
- }
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
- // Receive PI for prefix2.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
- expectAutoGenAddrEvent(addr2, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- expectPrimaryAddr(addr2)
+ // Receive PI for prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
- // Receive a PI for prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90))
- expectAutoGenAddrEvent(addr1, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- expectPrimaryAddr(addr1)
+ // Receive a PI for prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr1)
- // Refresh lifetime for addr of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr1)
+ // Refresh lifetime for addr of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
- // Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should not have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- // addr2 should be the primary endpoint now since addr1 is deprecated but
- // addr2 is not.
- expectPrimaryAddr(addr2)
- // addr1 is deprecated but if explicitly requested, it should be used.
- fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID}
- if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
- }
+ // Wait for addr of prefix1 to be deprecated.
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr2 should be the primary endpoint now since addr1 is deprecated but
+ // addr2 is not.
+ expectPrimaryAddr(addr2)
+ // addr1 is deprecated but if explicitly requested, it should be used.
+ fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID}
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
+ }
- // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make
- // sure we do not get a deprecation event again.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr2)
- if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
- }
+ // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make
+ // sure we do not get a deprecation event again.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr2)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
+ }
- // Refresh lifetimes for addr of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- // addr1 is the primary endpoint again since it is non-deprecated now.
- expectPrimaryAddr(addr1)
+ // Refresh lifetimes for addr of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ // addr1 is the primary endpoint again since it is non-deprecated now.
+ expectPrimaryAddr(addr1)
- // Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should not have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- // addr2 should be the primary endpoint now since it is not deprecated.
- expectPrimaryAddr(addr2)
- if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
- }
+ // Wait for addr of prefix1 to be deprecated.
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr2 should be the primary endpoint now since it is not deprecated.
+ expectPrimaryAddr(addr2)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
+ }
- // Wait for addr of prefix1 to be invalidated.
- expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout)
- if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should not have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- expectPrimaryAddr(addr2)
+ // Wait for addr of prefix1 to be invalidated.
+ expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout)
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
- // Refresh both lifetimes for addr of prefix2 to the same value.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
+ // Refresh both lifetimes for addr of prefix2 to the same value.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
- // Wait for a deprecation then invalidation events, or just an invalidation
- // event. We need to cover both cases but cannot deterministically hit both
- // cases because the deprecation and invalidation handlers could be handled in
- // either deprecation then invalidation, or invalidation then deprecation
- // (which should be cancelled by the invalidation handler).
+ // Wait for a deprecation then invalidation events, or just an invalidation
+ // event. We need to cover both cases but cannot deterministically hit both
+ // cases because the deprecation and invalidation handlers could be handled in
+ // either deprecation then invalidation, or invalidation then deprecation
+ // (which should be cancelled by the invalidation handler).
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" {
+ // If we get a deprecation event first, we should get an invalidation
+ // event almost immediately after.
select {
case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" {
- // If we get a deprecation event first, we should get an invalidation
- // event almost immediately after.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
- } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" {
- // If we get an invalidation event first, we should not get a deprecation
- // event after.
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- case <-time.After(defaultAsyncNegativeEventTimeout):
- }
- } else {
- t.Fatalf("got unexpected auto-generated event")
+ if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
- if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should not have %s in the list of addresses", addr1)
- }
- if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should not have %s in the list of addresses", addr2)
- }
- // Should not have any primary endpoints.
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
- t.Fatal(err)
- }
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
- defer close(ch)
- ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
+ } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" {
+ // If we get an invalidation event first, we should not get a deprecation
+ // event after.
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ case <-time.After(defaultAsyncNegativeEventTimeout):
}
- defer ep.Close()
- ep.SocketOptions().SetV6Only(true)
+ } else {
+ t.Fatalf("got unexpected auto-generated event")
+ }
+ case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should not have %s in the list of addresses", addr2)
+ }
+ // Should not have any primary endpoints.
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
+ }
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
+ }
+ defer ep.Close()
+ ep.SocketOptions().SetV6Only(true)
- {
- err := ep.Connect(dstAddr)
- if _, ok := err.(*tcpip.ErrNoRoute); !ok {
- t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, &tcpip.ErrNoRoute{})
- }
- }
- })
+ {
+ err := ep.Connect(dstAddr)
+ if _, ok := err.(*tcpip.ErrNoRoute); !ok {
+ t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, &tcpip.ErrNoRoute{})
+ }
}
}
@@ -3543,126 +3509,108 @@ func TestAutoGenAddrRemoval(t *testing.T) {
func TestAutoGenAddrAfterRemoval(t *testing.T) {
const nicID = 1
- stacks := []struct {
- name string
- useNeighborCache bool
- }{
- {
- name: "linkAddrCache",
- useNeighborCache: false,
- },
- {
- name: "neighborCache",
- useNeighborCache: true,
- },
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
}
- for _, stackTyp := range stacks {
- t.Run(stackTyp.name, func(t *testing.T) {
- prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
- prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
- ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache)
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil {
+ t.Fatal(err)
+ }
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
- expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
- t.Helper()
+ // Receive a PI to auto-generate addr1 with a large valid and preferred
+ // lifetime.
+ const largeLifetimeSeconds = 999
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ expectPrimaryAddr(addr1)
- if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil {
- t.Fatal(err)
- }
+ // Add addr2 as a static address.
+ protoAddr2 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr2,
+ }
+ if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil {
+ t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err)
+ }
+ // addr2 should be more preferred now since it is at the front of the primary
+ // list.
+ expectPrimaryAddr(addr2)
- if got := addrForNewConnection(t, s); got != addr.Address {
- t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
- }
- }
+ // Get a route using addr2 to increment its reference count then remove it
+ // to leave it in the permanentExpired state.
+ r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err)
+ }
+ defer r.Release()
+ if err := s.RemoveAddress(nicID, addr2.Address); err != nil {
+ t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err)
+ }
+ // addr1 should be preferred again since addr2 is in the expired state.
+ expectPrimaryAddr(addr1)
- // Receive a PI to auto-generate addr1 with a large valid and preferred
- // lifetime.
- const largeLifetimeSeconds = 999
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
- expectAutoGenAddrEvent(addr1, newAddr)
- expectPrimaryAddr(addr1)
-
- // Add addr2 as a static address.
- protoAddr2 := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: addr2,
- }
- if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err)
- }
- // addr2 should be more preferred now since it is at the front of the primary
- // list.
- expectPrimaryAddr(addr2)
-
- // Get a route using addr2 to increment its reference count then remove it
- // to leave it in the permanentExpired state.
- r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false)
- if err != nil {
- t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err)
- }
- defer r.Release()
- if err := s.RemoveAddress(nicID, addr2.Address); err != nil {
- t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err)
- }
- // addr1 should be preferred again since addr2 is in the expired state.
- expectPrimaryAddr(addr1)
-
- // Receive a PI to auto-generate addr2 as valid and preferred.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
- expectAutoGenAddrEvent(addr2, newAddr)
- // addr2 should be more preferred now that it is closer to the front of the
- // primary list and not deprecated.
- expectPrimaryAddr(addr2)
-
- // Removing the address should result in an invalidation event immediately.
- // It should still be in the permanentExpired state because r is still held.
- //
- // We remove addr2 here to make sure addr2 was marked as a SLAAC address
- // (it was previously marked as a static address).
- if err := s.RemoveAddress(1, addr2.Address); err != nil {
- t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
- }
- expectAutoGenAddrEvent(addr2, invalidatedAddr)
- // addr1 should be more preferred since addr2 is in the expired state.
- expectPrimaryAddr(addr1)
-
- // Receive a PI to auto-generate addr2 as valid and deprecated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0))
- expectAutoGenAddrEvent(addr2, newAddr)
- // addr1 should still be more preferred since addr2 is deprecated, even though
- // it is closer to the front of the primary list.
- expectPrimaryAddr(addr1)
-
- // Receive a PI to refresh addr2's preferred lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto gen addr event")
- default:
- }
- // addr2 should be more preferred now that it is not deprecated.
- expectPrimaryAddr(addr2)
+ // Receive a PI to auto-generate addr2 as valid and preferred.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ // addr2 should be more preferred now that it is closer to the front of the
+ // primary list and not deprecated.
+ expectPrimaryAddr(addr2)
- if err := s.RemoveAddress(1, addr2.Address); err != nil {
- t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
- }
- expectAutoGenAddrEvent(addr2, invalidatedAddr)
- expectPrimaryAddr(addr1)
- })
+ // Removing the address should result in an invalidation event immediately.
+ // It should still be in the permanentExpired state because r is still held.
+ //
+ // We remove addr2 here to make sure addr2 was marked as a SLAAC address
+ // (it was previously marked as a static address).
+ if err := s.RemoveAddress(1, addr2.Address); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
+ }
+ expectAutoGenAddrEvent(addr2, invalidatedAddr)
+ // addr1 should be more preferred since addr2 is in the expired state.
+ expectPrimaryAddr(addr1)
+
+ // Receive a PI to auto-generate addr2 as valid and deprecated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ // addr1 should still be more preferred since addr2 is deprecated, even though
+ // it is closer to the front of the primary list.
+ expectPrimaryAddr(addr1)
+
+ // Receive a PI to refresh addr2's preferred lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto gen addr event")
+ default:
+ }
+ // addr2 should be more preferred now that it is not deprecated.
+ expectPrimaryAddr(addr2)
+
+ if err := s.RemoveAddress(1, addr2.Address); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
}
+ expectAutoGenAddrEvent(addr2, invalidatedAddr)
+ expectPrimaryAddr(addr1)
}
// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that
@@ -3885,12 +3833,12 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
}
}
- expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool) {
+ expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool, err tcpip.Error) {
t.Helper()
select {
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" {
+ if diff := checkDADEvent(e, nicID, addr, resolved, err); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
default:
@@ -3923,8 +3871,6 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
{
name: "Global address",
ndpConfigs: ipv6.NDPConfigurations{
- DupAddrDetectTransmits: dadTransmits,
- RetransmitTimer: retransmitTimer,
HandleRAs: true,
AutoGenGlobalAddresses: true,
},
@@ -3939,11 +3885,8 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
},
},
{
- name: "LinkLocal address",
- ndpConfigs: ipv6.NDPConfigurations{
- DupAddrDetectTransmits: dadTransmits,
- RetransmitTimer: retransmitTimer,
- },
+ name: "LinkLocal address",
+ ndpConfigs: ipv6.NDPConfigurations{},
autoGenLinkLocal: true,
prepareFn: func(*testing.T, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix {
return nil
@@ -3955,8 +3898,6 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
{
name: "Temporary address",
ndpConfigs: ipv6.NDPConfigurations{
- DupAddrDetectTransmits: dadTransmits,
- RetransmitTimer: retransmitTimer,
HandleRAs: true,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
@@ -4008,8 +3949,12 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
AutoGenLinkLocal: addrType.autoGenLinkLocal,
- NDPConfigs: ndpConfigs,
- NDPDisp: &ndpDisp,
+ DADConfigs: stack.DADConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ },
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{
NICNameFromID: func(_ tcpip.NICID, nicName string) string {
return nicName
@@ -4039,7 +3984,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
// Simulate a DAD conflict.
rxNDPSolicit(e, addr.Address)
expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr)
- expectDADEvent(t, &ndpDisp, addr.Address, false)
+ expectDADEvent(t, &ndpDisp, addr.Address, false, nil)
// Attempting to add the address manually should not fail if the
// address's state was cleaned up when DAD failed.
@@ -4049,7 +3994,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
if err := s.RemoveAddress(nicID, addr.Address); err != nil {
t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err)
}
- expectDADEvent(t, &ndpDisp, addr.Address, false)
+ expectDADEvent(t, &ndpDisp, addr.Address, false, &tcpip.ErrAborted{})
}
// Should not have any new addresses assigned to the NIC.
@@ -4103,8 +4048,6 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
{
name: "Global address",
ndpConfigs: ipv6.NDPConfigurations{
- DupAddrDetectTransmits: dadTransmits,
- RetransmitTimer: retransmitTimer,
HandleRAs: true,
AutoGenGlobalAddresses: true,
AutoGenAddressConflictRetries: maxRetries,
@@ -4119,8 +4062,6 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
{
name: "LinkLocal address",
ndpConfigs: ipv6.NDPConfigurations{
- DupAddrDetectTransmits: dadTransmits,
- RetransmitTimer: retransmitTimer,
AutoGenAddressConflictRetries: maxRetries,
},
autoGenLinkLocal: true,
@@ -4145,6 +4086,10 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
AutoGenLinkLocal: addrType.autoGenLinkLocal,
NDPConfigs: addrType.ndpConfigs,
NDPDisp: &ndpDisp,
+ DADConfigs: stack.DADConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ },
})},
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -4226,9 +4171,11 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ DADConfigs: stack.DADConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ },
NDPConfigs: ipv6.NDPConfigurations{
- DupAddrDetectTransmits: dadTransmits,
- RetransmitTimer: retransmitTimer,
HandleRAs: true,
AutoGenGlobalAddresses: true,
AutoGenAddressConflictRetries: maxRetries,
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
index 7e3132058..533287c4c 100644
--- a/pkg/tcpip/stack/neighbor_cache.go
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -25,9 +25,13 @@ const neighborCacheSize = 512 // max entries per interface
// NeighborStats holds metrics for the neighbor table.
type NeighborStats struct {
- // FailedEntryLookups counts the number of lookups performed on an entry in
- // Failed state.
+ // FailedEntryLookups is deprecated; UnreachableEntryLookups should be used
+ // instead.
FailedEntryLookups *tcpip.StatCounter
+
+ // UnreachableEntryLookups counts the number of lookups performed on an
+ // entry in Unreachable state.
+ UnreachableEntryLookups *tcpip.StatCounter
}
// neighborCache maps IP addresses to link addresses. It uses the Least
@@ -43,21 +47,22 @@ type NeighborStats struct {
// Their state is always Static. The amount of static entries stored in the
// cache is unbounded.
type neighborCache struct {
- nic *NIC
+ nic *nic
state *NUDState
linkRes LinkAddressResolver
- // mu protects the fields below.
- mu sync.RWMutex
+ mu struct {
+ sync.RWMutex
- cache map[tcpip.Address]*neighborEntry
- dynamic struct {
- lru neighborEntryList
+ cache map[tcpip.Address]*neighborEntry
+ dynamic struct {
+ lru neighborEntryList
- // count tracks the amount of dynamic entries in the cache. This is
- // needed since static entries do not count towards the LRU cache
- // eviction strategy.
- count uint16
+ // count tracks the amount of dynamic entries in the cache. This is
+ // needed since static entries do not count towards the LRU cache
+ // eviction strategy.
+ count uint16
+ }
}
}
@@ -74,11 +79,11 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address) *neighborEntr
n.mu.Lock()
defer n.mu.Unlock()
- if entry, ok := n.cache[remoteAddr]; ok {
+ if entry, ok := n.mu.cache[remoteAddr]; ok {
entry.mu.RLock()
- if entry.neigh.State != Static {
- n.dynamic.lru.Remove(entry)
- n.dynamic.lru.PushFront(entry)
+ if entry.mu.neigh.State != Static {
+ n.mu.dynamic.lru.Remove(entry)
+ n.mu.dynamic.lru.PushFront(entry)
}
entry.mu.RUnlock()
return entry
@@ -87,20 +92,20 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address) *neighborEntr
// The entry that needs to be created must be dynamic since all static
// entries are directly added to the cache via addStaticEntry.
entry := newNeighborEntry(n, remoteAddr, n.state)
- if n.dynamic.count == neighborCacheSize {
- e := n.dynamic.lru.Back()
+ if n.mu.dynamic.count == neighborCacheSize {
+ e := n.mu.dynamic.lru.Back()
e.mu.Lock()
- delete(n.cache, e.neigh.Addr)
- n.dynamic.lru.Remove(e)
- n.dynamic.count--
+ delete(n.mu.cache, e.mu.neigh.Addr)
+ n.mu.dynamic.lru.Remove(e)
+ n.mu.dynamic.count--
e.removeLocked()
e.mu.Unlock()
}
- n.cache[remoteAddr] = entry
- n.dynamic.lru.PushFront(entry)
- n.dynamic.count++
+ n.mu.cache[remoteAddr] = entry
+ n.mu.dynamic.lru.PushFront(entry)
+ n.mu.dynamic.count++
return entry
}
@@ -128,7 +133,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, onResolve fun
entry.mu.Lock()
defer entry.mu.Unlock()
- switch s := entry.neigh.State; s {
+ switch s := entry.mu.neigh.State; s {
case Stale:
entry.handlePacketQueuedLocked(localAddr)
fallthrough
@@ -139,19 +144,19 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, onResolve fun
// a node continues sending packets to that neighbor using the cached
// link-layer address."
if onResolve != nil {
- onResolve(LinkResolutionResult{LinkAddress: entry.neigh.LinkAddr, Success: true})
+ onResolve(LinkResolutionResult{LinkAddress: entry.mu.neigh.LinkAddr, Success: true})
}
- return entry.neigh, nil, nil
- case Unknown, Incomplete, Failed:
+ return entry.mu.neigh, nil, nil
+ case Unknown, Incomplete, Unreachable:
if onResolve != nil {
- entry.onResolve = append(entry.onResolve, onResolve)
+ entry.mu.onResolve = append(entry.mu.onResolve, onResolve)
}
- if entry.done == nil {
+ if entry.mu.done == nil {
// Address resolution needs to be initiated.
- entry.done = make(chan struct{})
+ entry.mu.done = make(chan struct{})
}
entry.handlePacketQueuedLocked(localAddr)
- return entry.neigh, entry.done, &tcpip.ErrWouldBlock{}
+ return entry.mu.neigh, entry.mu.done, &tcpip.ErrWouldBlock{}
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", s))
}
@@ -162,10 +167,10 @@ func (n *neighborCache) entries() []NeighborEntry {
n.mu.RLock()
defer n.mu.RUnlock()
- entries := make([]NeighborEntry, 0, len(n.cache))
- for _, entry := range n.cache {
+ entries := make([]NeighborEntry, 0, len(n.mu.cache))
+ for _, entry := range n.mu.cache {
entry.mu.RLock()
- entries = append(entries, entry.neigh)
+ entries = append(entries, entry.mu.neigh)
entry.mu.RUnlock()
}
return entries
@@ -181,19 +186,19 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd
n.mu.Lock()
defer n.mu.Unlock()
- if entry, ok := n.cache[addr]; ok {
+ if entry, ok := n.mu.cache[addr]; ok {
entry.mu.Lock()
- if entry.neigh.State != Static {
+ if entry.mu.neigh.State != Static {
// Dynamic entry found with the same address.
- n.dynamic.lru.Remove(entry)
- n.dynamic.count--
- } else if entry.neigh.LinkAddr == linkAddr {
+ n.mu.dynamic.lru.Remove(entry)
+ n.mu.dynamic.count--
+ } else if entry.mu.neigh.LinkAddr == linkAddr {
// Static entry found with the same address and link address.
entry.mu.Unlock()
return
} else {
// Static entry found with the same address but different link address.
- entry.neigh.LinkAddr = linkAddr
+ entry.mu.neigh.LinkAddr = linkAddr
entry.dispatchChangeEventLocked()
entry.mu.Unlock()
return
@@ -203,7 +208,12 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd
entry.mu.Unlock()
}
- n.cache[addr] = newStaticNeighborEntry(n, addr, linkAddr, n.state)
+ entry := newStaticNeighborEntry(n, addr, linkAddr, n.state)
+ n.mu.cache[addr] = entry
+
+ entry.mu.Lock()
+ defer entry.mu.Unlock()
+ entry.dispatchAddEventLocked()
}
// removeEntry removes a dynamic or static entry by address from the neighbor
@@ -212,7 +222,7 @@ func (n *neighborCache) removeEntry(addr tcpip.Address) bool {
n.mu.Lock()
defer n.mu.Unlock()
- entry, ok := n.cache[addr]
+ entry, ok := n.mu.cache[addr]
if !ok {
return false
}
@@ -220,13 +230,13 @@ func (n *neighborCache) removeEntry(addr tcpip.Address) bool {
entry.mu.Lock()
defer entry.mu.Unlock()
- if entry.neigh.State != Static {
- n.dynamic.lru.Remove(entry)
- n.dynamic.count--
+ if entry.mu.neigh.State != Static {
+ n.mu.dynamic.lru.Remove(entry)
+ n.mu.dynamic.count--
}
entry.removeLocked()
- delete(n.cache, entry.neigh.Addr)
+ delete(n.mu.cache, entry.mu.neigh.Addr)
return true
}
@@ -235,15 +245,15 @@ func (n *neighborCache) clear() {
n.mu.Lock()
defer n.mu.Unlock()
- for _, entry := range n.cache {
+ for _, entry := range n.mu.cache {
entry.mu.Lock()
entry.removeLocked()
entry.mu.Unlock()
}
- n.dynamic.lru = neighborEntryList{}
- n.cache = make(map[tcpip.Address]*neighborEntry)
- n.dynamic.count = 0
+ n.mu.dynamic.lru = neighborEntryList{}
+ n.mu.cache = make(map[tcpip.Address]*neighborEntry)
+ n.mu.dynamic.count = 0
}
// config returns the NUD configuration.
@@ -260,30 +270,6 @@ func (n *neighborCache) setConfig(config NUDConfigurations) {
n.state.SetConfig(config)
}
-var _ neighborTable = (*neighborCache)(nil)
-
-func (n *neighborCache) neighbors() ([]NeighborEntry, tcpip.Error) {
- return n.entries(), nil
-}
-
-func (n *neighborCache) get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) {
- entry, ch, err := n.entry(addr, localAddr, onResolve)
- return entry.LinkAddr, ch, err
-}
-
-func (n *neighborCache) remove(addr tcpip.Address) tcpip.Error {
- if !n.removeEntry(addr) {
- return &tcpip.ErrBadAddress{}
- }
-
- return nil
-}
-
-func (n *neighborCache) removeAll() tcpip.Error {
- n.clear()
- return nil
-}
-
// handleProbe handles a neighbor probe as defined by RFC 4861 section 7.2.3.
//
// Validation of the probe is expected to be handled by the caller.
@@ -300,7 +286,7 @@ func (n *neighborCache) handleProbe(remoteAddr tcpip.Address, remoteLinkAddr tcp
// Validation of the confirmation is expected to be handled by the caller.
func (n *neighborCache) handleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) {
n.mu.RLock()
- entry, ok := n.cache[addr]
+ entry, ok := n.mu.cache[addr]
n.mu.RUnlock()
if ok {
entry.mu.Lock()
@@ -316,7 +302,7 @@ func (n *neighborCache) handleConfirmation(addr tcpip.Address, linkAddr tcpip.Li
// some protocol that operates at a layer above the IP/link layer.
func (n *neighborCache) handleUpperLevelConfirmation(addr tcpip.Address) {
n.mu.RLock()
- entry, ok := n.cache[addr]
+ entry, ok := n.mu.cache[addr]
n.mu.RUnlock()
if ok {
entry.mu.Lock()
@@ -325,11 +311,13 @@ func (n *neighborCache) handleUpperLevelConfirmation(addr tcpip.Address) {
}
}
-func (n *neighborCache) nudConfig() (NUDConfigurations, tcpip.Error) {
- return n.config(), nil
-}
-
-func (n *neighborCache) setNUDConfig(c NUDConfigurations) tcpip.Error {
- n.setConfig(c)
- return nil
+func (n *neighborCache) init(nic *nic, r LinkAddressResolver) {
+ *n = neighborCache{
+ nic: nic,
+ state: NewNUDState(nic.stack.nudConfigs, nic.stack.randomGenerator),
+ linkRes: r,
+ }
+ n.mu.Lock()
+ n.mu.cache = make(map[tcpip.Address]*neighborEntry, neighborCacheSize)
+ n.mu.Unlock()
}
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index b489b5e08..909912662 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -84,19 +84,16 @@ func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, cl
entries: newTestEntryStore(),
delay: typicalLatency,
}
- linkRes.neigh = &neighborCache{
- nic: &NIC{
- stack: &Stack{
- clock: clock,
- nudDisp: nudDisp,
- },
- id: 1,
- stats: makeNICStats(),
+ linkRes.neigh.init(&nic{
+ stack: &Stack{
+ clock: clock,
+ nudDisp: nudDisp,
+ nudConfigs: config,
+ randomGenerator: rng,
},
- state: NewNUDState(config, rng),
- linkRes: linkRes,
- cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
- }
+ id: 1,
+ stats: makeNICStats(),
+ }, linkRes)
return linkRes
}
@@ -190,7 +187,7 @@ func (s *testEntryStore) set(i int, linkAddr tcpip.LinkAddress) {
// neighbor probe.
type testNeighborResolver struct {
clock tcpip.Clock
- neigh *neighborCache
+ neigh neighborCache
entries *testEntryStore
delay time.Duration
onLinkAddressRequest func()
@@ -1613,12 +1610,11 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
}
}
- // Verify the entry is in Failed state.
wantEntries := []NeighborEntry{
{
Addr: entry.Addr,
LinkAddr: "",
- State: Failed,
+ State: Unreachable,
},
}
if diff := cmp.Diff(linkRes.neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" {
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index b05f96d4f..03fef52ee 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -38,7 +38,8 @@ type NeighborEntry struct {
}
// NeighborState defines the state of a NeighborEntry within the Neighbor
-// Unreachability Detection state machine, as per RFC 4861 section 7.3.2.
+// Unreachability Detection state machine, as per RFC 4861 section 7.3.2 and
+// RFC 7048.
type NeighborState uint8
const (
@@ -61,15 +62,33 @@ const (
Delay
// Probe means a reachability confirmation is actively being sought by
// periodically retransmitting reachability probes until a reachability
- // confirmation is received, or until the max amount of probes has been sent.
+ // confirmation is received, or until the maximum number of probes has been
+ // sent.
Probe
// Static describes entries that have been explicitly added by the user. They
// do not expire and are not deleted until explicitly removed.
Static
- // Failed means recent attempts of reachability have returned inconclusive.
+ // Failed is deprecated and should no longer be used.
+ //
+ // TODO(gvisor.dev/issue/4667): Remove this once all references to Failed
+ // are removed from Fuchsia.
Failed
+ // Unreachable means reachability confirmation failed; the maximum number of
+ // reachability probes has been sent and no replies have been received.
+ //
+ // TODO(gvisor.dev/issue/5472): Add the following sentence when we implement
+ // RFC 7048: "Packets continue to be sent to the neighbor while
+ // re-attempting to resolve the address."
+ Unreachable
)
+type timer struct {
+ // done indicates to the timer that the timer was stopped.
+ done *bool
+
+ timer tcpip.Timer
+}
+
// neighborEntry implements a neighbor entry's individual node behavior, as per
// RFC 4861 section 7.3.3. Neighbor Unreachability Detection operates in
// parallel with the sending of packets to a neighbor, necessitating the
@@ -82,20 +101,22 @@ type neighborEntry struct {
// nudState points to the Neighbor Unreachability Detection configuration.
nudState *NUDState
- // mu protects the fields below.
- mu sync.RWMutex
+ mu struct {
+ sync.RWMutex
+
+ neigh NeighborEntry
- neigh NeighborEntry
+ // done is closed when address resolution is complete. It is nil iff s is
+ // incomplete and resolution is not yet in progress.
+ done chan struct{}
- // done is closed when address resolution is complete. It is nil iff s is
- // incomplete and resolution is not yet in progress.
- done chan struct{}
+ // onResolve is called with the result of address resolution.
+ onResolve []func(LinkResolutionResult)
- // onResolve is called with the result of address resolution.
- onResolve []func(LinkResolutionResult)
+ isRouter bool
- isRouter bool
- job *tcpip.Job
+ timer timer
+ }
}
// newNeighborEntry creates a neighbor cache entry starting at the default
@@ -103,14 +124,18 @@ type neighborEntry struct {
// `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created
// neighborEntry.
func newNeighborEntry(cache *neighborCache, remoteAddr tcpip.Address, nudState *NUDState) *neighborEntry {
- return &neighborEntry{
+ n := &neighborEntry{
cache: cache,
nudState: nudState,
- neigh: NeighborEntry{
- Addr: remoteAddr,
- State: Unknown,
- },
}
+ n.mu.Lock()
+ n.mu.neigh = NeighborEntry{
+ Addr: remoteAddr,
+ State: Unknown,
+ }
+ n.mu.Unlock()
+ return n
+
}
// newStaticNeighborEntry creates a neighbor cache entry starting at the
@@ -123,14 +148,14 @@ func newStaticNeighborEntry(cache *neighborCache, addr tcpip.Address, linkAddr t
State: Static,
UpdatedAtNanos: cache.nic.stack.clock.NowNanoseconds(),
}
- if nudDisp := cache.nic.stack.nudDisp; nudDisp != nil {
- nudDisp.OnNeighborAdded(cache.nic.id, entry)
- }
- return &neighborEntry{
+ n := &neighborEntry{
cache: cache,
nudState: state,
- neigh: entry,
}
+ n.mu.Lock()
+ n.mu.neigh = entry
+ n.mu.Unlock()
+ return n
}
// notifyCompletionLocked notifies those waiting for address resolution, with
@@ -138,14 +163,14 @@ func newStaticNeighborEntry(cache *neighborCache, addr tcpip.Address, linkAddr t
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) notifyCompletionLocked(succeeded bool) {
- res := LinkResolutionResult{LinkAddress: e.neigh.LinkAddr, Success: succeeded}
- for _, callback := range e.onResolve {
+ res := LinkResolutionResult{LinkAddress: e.mu.neigh.LinkAddr, Success: succeeded}
+ for _, callback := range e.mu.onResolve {
callback(res)
}
- e.onResolve = nil
- if ch := e.done; ch != nil {
+ e.mu.onResolve = nil
+ if ch := e.mu.done; ch != nil {
close(ch)
- e.done = nil
+ e.mu.done = nil
// Dequeue the pending packets in a new goroutine to not hold up the current
// goroutine as writing packets may be a costly operation.
//
@@ -153,7 +178,7 @@ func (e *neighborEntry) notifyCompletionLocked(succeeded bool) {
// is resolved (which ends up obtaining the entry's lock) while holding the
// link resolution queue's lock. Dequeuing packets in a new goroutine avoids
// a lock ordering violation.
- go e.cache.nic.linkResQueue.dequeue(ch, e.neigh.LinkAddr, succeeded)
+ go e.cache.nic.linkResQueue.dequeue(ch, e.mu.neigh.LinkAddr, succeeded)
}
}
@@ -163,7 +188,7 @@ func (e *neighborEntry) notifyCompletionLocked(succeeded bool) {
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchAddEventLocked() {
if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil {
- nudDisp.OnNeighborAdded(e.cache.nic.id, e.neigh)
+ nudDisp.OnNeighborAdded(e.cache.nic.id, e.mu.neigh)
}
}
@@ -173,7 +198,7 @@ func (e *neighborEntry) dispatchAddEventLocked() {
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchChangeEventLocked() {
if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil {
- nudDisp.OnNeighborChanged(e.cache.nic.id, e.neigh)
+ nudDisp.OnNeighborChanged(e.cache.nic.id, e.mu.neigh)
}
}
@@ -183,17 +208,20 @@ func (e *neighborEntry) dispatchChangeEventLocked() {
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) dispatchRemoveEventLocked() {
if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil {
- nudDisp.OnNeighborRemoved(e.cache.nic.id, e.neigh)
+ nudDisp.OnNeighborRemoved(e.cache.nic.id, e.mu.neigh)
}
}
-// cancelJobLocked cancels the currently scheduled action, if there is one.
+// cancelTimerLocked cancels the currently scheduled action, if there is one.
// Entries in Unknown, Stale, or Static state do not have a scheduled action.
//
// Precondition: e.mu MUST be locked.
-func (e *neighborEntry) cancelJobLocked() {
- if job := e.job; job != nil {
- job.Cancel()
+func (e *neighborEntry) cancelTimerLocked() {
+ if e.mu.timer.timer != nil {
+ e.mu.timer.timer.Stop()
+ *e.mu.timer.done = true
+
+ e.mu.timer = timer{}
}
}
@@ -201,9 +229,9 @@ func (e *neighborEntry) cancelJobLocked() {
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) removeLocked() {
- e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds()
+ e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds()
e.dispatchRemoveEventLocked()
- e.cancelJobLocked()
+ e.cancelTimerLocked()
e.notifyCompletionLocked(false /* succeeded */)
}
@@ -213,61 +241,98 @@ func (e *neighborEntry) removeLocked() {
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) setStateLocked(next NeighborState) {
- e.cancelJobLocked()
+ e.cancelTimerLocked()
- prev := e.neigh.State
- e.neigh.State = next
- e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds()
+ prev := e.mu.neigh.State
+ e.mu.neigh.State = next
+ e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds()
config := e.nudState.Config()
switch next {
case Incomplete:
- panic(fmt.Sprintf("should never transition to Incomplete with setStateLocked; neigh = %#v, prev state = %s", e.neigh, prev))
+ panic(fmt.Sprintf("should never transition to Incomplete with setStateLocked; neigh = %#v, prev state = %s", e.mu.neigh, prev))
case Reachable:
- e.job = e.cache.nic.stack.newJob(&e.mu, func() {
- e.setStateLocked(Stale)
- e.dispatchChangeEventLocked()
- })
- e.job.Schedule(e.nudState.ReachableTime())
+ // Protected by e.mu.
+ done := false
+
+ e.mu.timer = timer{
+ done: &done,
+ timer: e.cache.nic.stack.Clock().AfterFunc(e.nudState.ReachableTime(), func() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if done {
+ // The timer was stopped because the entry changed state.
+ return
+ }
+
+ e.setStateLocked(Stale)
+ e.dispatchChangeEventLocked()
+ }),
+ }
case Delay:
- e.job = e.cache.nic.stack.newJob(&e.mu, func() {
- e.setStateLocked(Probe)
- e.dispatchChangeEventLocked()
- })
- e.job.Schedule(config.DelayFirstProbeTime)
+ // Protected by e.mu.
+ done := false
+
+ e.mu.timer = timer{
+ done: &done,
+ timer: e.cache.nic.stack.Clock().AfterFunc(config.DelayFirstProbeTime, func() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if done {
+ // The timer was stopped because the entry changed state.
+ return
+ }
- case Probe:
- var retryCounter uint32
- var sendUnicastProbe func()
-
- sendUnicastProbe = func() {
- if retryCounter == config.MaxUnicastProbes {
- e.dispatchRemoveEventLocked()
- e.setStateLocked(Failed)
- return
- }
+ e.setStateLocked(Probe)
+ e.dispatchChangeEventLocked()
+ }),
+ }
- if err := e.cache.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr); err != nil {
- e.dispatchRemoveEventLocked()
- e.setStateLocked(Failed)
- return
- }
+ case Probe:
+ // Protected by e.mu.
+ done := false
- retryCounter++
- e.job = e.cache.nic.stack.newJob(&e.mu, sendUnicastProbe)
- e.job.Schedule(config.RetransmitTimer)
- }
+ remaining := config.MaxUnicastProbes
+ addr := e.mu.neigh.Addr
+ linkAddr := e.mu.neigh.LinkAddr
// Send a probe in another gorountine to free this thread of execution
- // for finishing the state transition. This is necessary to avoid
- // deadlock where sending and processing probes are done synchronously,
- // such as loopback and integration tests.
- e.job = e.cache.nic.stack.newJob(&e.mu, sendUnicastProbe)
- e.job.Schedule(immediateDuration)
+ // for finishing the state transition. This is necessary to escape the
+ // currently held lock so we can send the probe message without holding
+ // a shared lock.
+ e.mu.timer = timer{
+ done: &done,
+ timer: e.cache.nic.stack.Clock().AfterFunc(0, func() {
+ var err tcpip.Error
+ timedoutResolution := remaining == 0
+ if !timedoutResolution {
+ err = e.cache.linkRes.LinkAddressRequest(addr, "" /* localAddr */, linkAddr)
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if done {
+ // The timer was stopped because the entry changed state.
+ return
+ }
- case Failed:
+ if timedoutResolution || err != nil {
+ e.setStateLocked(Unreachable)
+ e.dispatchChangeEventLocked()
+ return
+ }
+
+ remaining--
+ e.mu.timer.timer.Reset(config.RetransmitTimer)
+ }),
+ }
+
+ case Unreachable:
e.notifyCompletionLocked(false /* succeeded */)
case Unknown, Stale, Static:
@@ -285,76 +350,66 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
- switch e.neigh.State {
- case Failed:
- e.cache.nic.stats.Neighbor.FailedEntryLookups.Increment()
+ switch e.mu.neigh.State {
+ case Unknown, Unreachable:
+ prev := e.mu.neigh.State
+ e.mu.neigh.State = Incomplete
+ e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds()
+
+ switch prev {
+ case Unknown:
+ e.dispatchAddEventLocked()
+ case Unreachable:
+ e.dispatchChangeEventLocked()
+ e.cache.nic.stats.Neighbor.UnreachableEntryLookups.Increment()
+ }
- fallthrough
- case Unknown:
- e.neigh.State = Incomplete
- e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds()
+ config := e.nudState.Config()
- e.dispatchAddEventLocked()
+ // Protected by e.mu.
+ done := false
- config := e.nudState.Config()
+ remaining := config.MaxMulticastProbes
+ addr := e.mu.neigh.Addr
- var retryCounter uint32
- var sendMulticastProbe func()
-
- sendMulticastProbe = func() {
- if retryCounter == config.MaxMulticastProbes {
- // "If no Neighbor Advertisement is received after
- // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed.
- // The sender MUST return ICMP destination unreachable indications with
- // code 3 (Address Unreachable) for each packet queued awaiting address
- // resolution." - RFC 4861 section 7.2.2
- //
- // There is no need to send an ICMP destination unreachable indication
- // since the failure to resolve the address is expected to only occur
- // on this node. Thus, redirecting traffic is currently not supported.
- //
- // "If the error occurs on a node other than the node originating the
- // packet, an ICMP error message is generated. If the error occurs on
- // the originating node, an implementation is not required to actually
- // create and send an ICMP error packet to the source, as long as the
- // upper-layer sender is notified through an appropriate mechanism
- // (e.g. return value from a procedure call). Note, however, that an
- // implementation may find it convenient in some cases to return errors
- // to the sender by taking the offending packet, generating an ICMP
- // error message, and then delivering it (locally) through the generic
- // error-handling routines." - RFC 4861 section 2.1
- e.dispatchRemoveEventLocked()
- e.setStateLocked(Failed)
- return
- }
+ // Send a probe in another gorountine to free this thread of execution
+ // for finishing the state transition. This is necessary to escape the
+ // currently held lock so we can send the probe message without holding
+ // a shared lock.
+ e.mu.timer = timer{
+ done: &done,
+ timer: e.cache.nic.stack.Clock().AfterFunc(0, func() {
+ var err tcpip.Error
+ timedoutResolution := remaining == 0
+ if !timedoutResolution {
+ // As per RFC 4861 section 7.2.2:
+ //
+ // If the source address of the packet prompting the solicitation is
+ // the same as one of the addresses assigned to the outgoing interface,
+ // that address SHOULD be placed in the IP Source Address of the
+ // outgoing solicitation.
+ //
+ err = e.cache.linkRes.LinkAddressRequest(addr, localAddr, "" /* linkAddr */)
+ }
- // As per RFC 4861 section 7.2.2:
- //
- // If the source address of the packet prompting the solicitation is the
- // same as one of the addresses assigned to the outgoing interface, that
- // address SHOULD be placed in the IP Source Address of the outgoing
- // solicitation.
- //
- if err := e.cache.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, ""); err != nil {
- // There is no need to log the error here; the NUD implementation may
- // assume a working link. A valid link should be the responsibility of
- // the NIC/stack.LinkEndpoint.
- e.dispatchRemoveEventLocked()
- e.setStateLocked(Failed)
- return
- }
+ e.mu.Lock()
+ defer e.mu.Unlock()
- retryCounter++
- e.job = e.cache.nic.stack.newJob(&e.mu, sendMulticastProbe)
- e.job.Schedule(config.RetransmitTimer)
- }
+ if done {
+ // The timer was stopped because the entry changed state.
+ return
+ }
- // Send a probe in another gorountine to free this thread of execution
- // for finishing the state transition. This is necessary to avoid
- // deadlock where sending and processing probes are done synchronously,
- // such as loopback and integration tests.
- e.job = e.cache.nic.stack.newJob(&e.mu, sendMulticastProbe)
- e.job.Schedule(immediateDuration)
+ if timedoutResolution || err != nil {
+ e.setStateLocked(Unreachable)
+ e.dispatchChangeEventLocked()
+ return
+ }
+
+ remaining--
+ e.mu.timer.timer.Reset(config.RetransmitTimer)
+ }),
+ }
case Stale:
e.setStateLocked(Delay)
@@ -363,7 +418,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
case Incomplete, Reachable, Delay, Probe, Static:
// Do nothing
default:
- panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
+ panic(fmt.Sprintf("Invalid cache entry state: %s", e.mu.neigh.State))
}
}
@@ -378,9 +433,9 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) {
// not exist, or not bound to the NIC as per RFC 4861 section 7.2.3. These
// checks MUST be done by the NetworkEndpoint.
- switch e.neigh.State {
- case Unknown, Failed:
- e.neigh.LinkAddr = remoteLinkAddr
+ switch e.mu.neigh.State {
+ case Unknown:
+ e.mu.neigh.LinkAddr = remoteLinkAddr
e.setStateLocked(Stale)
e.dispatchAddEventLocked()
@@ -390,29 +445,36 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) {
// cached address should be replaced by the received address, and the
// entry's reachability state MUST be set to STALE."
// - RFC 4861 section 7.2.3
- e.neigh.LinkAddr = remoteLinkAddr
+ e.mu.neigh.LinkAddr = remoteLinkAddr
e.setStateLocked(Stale)
e.notifyCompletionLocked(true /* succeeded */)
e.dispatchChangeEventLocked()
case Reachable, Delay, Probe:
- if e.neigh.LinkAddr != remoteLinkAddr {
- e.neigh.LinkAddr = remoteLinkAddr
+ if e.mu.neigh.LinkAddr != remoteLinkAddr {
+ e.mu.neigh.LinkAddr = remoteLinkAddr
e.setStateLocked(Stale)
e.dispatchChangeEventLocked()
}
case Stale:
- if e.neigh.LinkAddr != remoteLinkAddr {
- e.neigh.LinkAddr = remoteLinkAddr
+ if e.mu.neigh.LinkAddr != remoteLinkAddr {
+ e.mu.neigh.LinkAddr = remoteLinkAddr
e.dispatchChangeEventLocked()
}
+ case Unreachable:
+ // TODO(gvisor.dev/issue/5472): Do not change the entry if the link
+ // address is the same, as per RFC 7048.
+ e.mu.neigh.LinkAddr = remoteLinkAddr
+ e.setStateLocked(Stale)
+ e.dispatchChangeEventLocked()
+
case Static:
// Do nothing
default:
- panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
+ panic(fmt.Sprintf("Invalid cache entry state: %s", e.mu.neigh.State))
}
}
@@ -430,7 +492,7 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) {
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) {
- switch e.neigh.State {
+ switch e.mu.neigh.State {
case Incomplete:
if len(linkAddr) == 0 {
// "If the link layer has addresses and no Target Link-Layer Address
@@ -439,35 +501,35 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
break
}
- e.neigh.LinkAddr = linkAddr
+ e.mu.neigh.LinkAddr = linkAddr
if flags.Solicited {
e.setStateLocked(Reachable)
} else {
e.setStateLocked(Stale)
}
e.dispatchChangeEventLocked()
- e.isRouter = flags.IsRouter
+ e.mu.isRouter = flags.IsRouter
e.notifyCompletionLocked(true /* succeeded */)
// "Note that the Override flag is ignored if the entry is in the
// INCOMPLETE state." - RFC 4861 section 7.2.5
case Reachable, Stale, Delay, Probe:
- isLinkAddrDifferent := len(linkAddr) != 0 && e.neigh.LinkAddr != linkAddr
+ isLinkAddrDifferent := len(linkAddr) != 0 && e.mu.neigh.LinkAddr != linkAddr
if isLinkAddrDifferent {
if !flags.Override {
- if e.neigh.State == Reachable {
+ if e.mu.neigh.State == Reachable {
e.setStateLocked(Stale)
e.dispatchChangeEventLocked()
}
break
}
- e.neigh.LinkAddr = linkAddr
+ e.mu.neigh.LinkAddr = linkAddr
if !flags.Solicited {
- if e.neigh.State != Stale {
+ if e.mu.neigh.State != Stale {
e.setStateLocked(Stale)
e.dispatchChangeEventLocked()
} else {
@@ -479,7 +541,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
}
if flags.Solicited && (flags.Override || !isLinkAddrDifferent) {
- wasReachable := e.neigh.State == Reachable
+ wasReachable := e.mu.neigh.State == Reachable
// Set state to Reachable again to refresh timers.
e.setStateLocked(Reachable)
e.notifyCompletionLocked(true /* succeeded */)
@@ -488,7 +550,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
}
}
- if e.isRouter && !flags.IsRouter && header.IsV6UnicastAddress(e.neigh.Addr) {
+ if e.mu.isRouter && !flags.IsRouter && header.IsV6UnicastAddress(e.mu.neigh.Addr) {
// "In those cases where the IsRouter flag changes from TRUE to FALSE as
// a result of this update, the node MUST remove that router from the
// Default Router List and update the Destination Cache entries for all
@@ -505,16 +567,16 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
}
if ndpEP, ok := ep.(NDPEndpoint); ok {
- ndpEP.InvalidateDefaultRouter(e.neigh.Addr)
+ ndpEP.InvalidateDefaultRouter(e.mu.neigh.Addr)
}
}
- e.isRouter = flags.IsRouter
+ e.mu.isRouter = flags.IsRouter
- case Unknown, Failed, Static:
+ case Unknown, Unreachable, Static:
// Do nothing
default:
- panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
+ panic(fmt.Sprintf("Invalid cache entry state: %s", e.mu.neigh.State))
}
}
@@ -523,19 +585,19 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) handleUpperLevelConfirmationLocked() {
- switch e.neigh.State {
+ switch e.mu.neigh.State {
case Reachable, Stale, Delay, Probe:
- wasReachable := e.neigh.State == Reachable
+ wasReachable := e.mu.neigh.State == Reachable
// Set state to Reachable again to refresh timers.
e.setStateLocked(Reachable)
if !wasReachable {
e.dispatchChangeEventLocked()
}
- case Unknown, Incomplete, Failed, Static:
+ case Unknown, Incomplete, Unreachable, Static:
// Do nothing
default:
- panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
+ panic(fmt.Sprintf("Invalid cache entry state: %s", e.mu.neigh.State))
}
}
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index 57cfbdb8b..47a9e2448 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -70,38 +70,39 @@ func eventDiffOptsWithSort() []cmp.Option {
}
// The following unit tests exercise every state transition and verify its
-// behavior with RFC 4681.
+// behavior with RFC 4681 and RFC 7048.
//
-// | From | To | Cause | Update | Action | Event |
-// | ========== | ========== | ========================================== | ======== | ===========| ======= |
-// | Unknown | Unknown | Confirmation w/ unknown address | | | Added |
-// | Unknown | Incomplete | Packet queued to unknown address | | Send probe | Added |
-// | Unknown | Stale | Probe | | | Added |
-// | Incomplete | Incomplete | Retransmit timer expired | | Send probe | Changed |
-// | Incomplete | Reachable | Solicited confirmation | LinkAddr | Notify | Changed |
-// | Incomplete | Stale | Unsolicited confirmation | LinkAddr | Notify | Changed |
-// | Incomplete | Stale | Probe | LinkAddr | Notify | Changed |
-// | Incomplete | Failed | Max probes sent without reply | | Notify | Removed |
-// | Reachable | Reachable | Confirmation w/ different isRouter flag | IsRouter | | |
-// | Reachable | Stale | Reachable timer expired | | | Changed |
-// | Reachable | Stale | Probe or confirmation w/ different address | | | Changed |
-// | Stale | Reachable | Solicited override confirmation | LinkAddr | | Changed |
-// | Stale | Reachable | Solicited confirmation w/o address | | Notify | Changed |
-// | Stale | Stale | Override confirmation | LinkAddr | | Changed |
-// | Stale | Stale | Probe w/ different address | LinkAddr | | Changed |
-// | Stale | Delay | Packet sent | | | Changed |
-// | Delay | Reachable | Upper-layer confirmation | | | Changed |
-// | Delay | Reachable | Solicited override confirmation | LinkAddr | | Changed |
-// | Delay | Reachable | Solicited confirmation w/o address | | Notify | Changed |
-// | Delay | Stale | Probe or confirmation w/ different address | | | Changed |
-// | Delay | Probe | Delay timer expired | | Send probe | Changed |
-// | Probe | Reachable | Solicited override confirmation | LinkAddr | | Changed |
-// | Probe | Reachable | Solicited confirmation w/ same address | | Notify | Changed |
-// | Probe | Reachable | Solicited confirmation w/o address | | Notify | Changed |
-// | Probe | Stale | Probe or confirmation w/ different address | | | Changed |
-// | Probe | Probe | Retransmit timer expired | | | Changed |
-// | Probe | Failed | Max probes sent without reply | | Notify | Removed |
-// | Failed | Incomplete | Packet queued | | Send probe | Added |
+// | From | To | Cause | Update | Action | Event |
+// | =========== | =========== | ========================================== | ======== | ===========| ======= |
+// | Unknown | Unknown | Confirmation w/ unknown address | | | Added |
+// | Unknown | Incomplete | Packet queued to unknown address | | Send probe | Added |
+// | Unknown | Stale | Probe | | | Added |
+// | Incomplete | Incomplete | Retransmit timer expired | | Send probe | Changed |
+// | Incomplete | Reachable | Solicited confirmation | LinkAddr | Notify | Changed |
+// | Incomplete | Stale | Unsolicited confirmation | LinkAddr | Notify | Changed |
+// | Incomplete | Stale | Probe | LinkAddr | Notify | Changed |
+// | Incomplete | Unreachable | Max probes sent without reply | | Notify | Changed |
+// | Reachable | Reachable | Confirmation w/ different isRouter flag | IsRouter | | |
+// | Reachable | Stale | Reachable timer expired | | | Changed |
+// | Reachable | Stale | Probe or confirmation w/ different address | | | Changed |
+// | Stale | Reachable | Solicited override confirmation | LinkAddr | | Changed |
+// | Stale | Reachable | Solicited confirmation w/o address | | Notify | Changed |
+// | Stale | Stale | Override confirmation | LinkAddr | | Changed |
+// | Stale | Stale | Probe w/ different address | LinkAddr | | Changed |
+// | Stale | Delay | Packet sent | | | Changed |
+// | Delay | Reachable | Upper-layer confirmation | | | Changed |
+// | Delay | Reachable | Solicited override confirmation | LinkAddr | | Changed |
+// | Delay | Reachable | Solicited confirmation w/o address | | Notify | Changed |
+// | Delay | Stale | Probe or confirmation w/ different address | | | Changed |
+// | Delay | Probe | Delay timer expired | | Send probe | Changed |
+// | Probe | Reachable | Solicited override confirmation | LinkAddr | | Changed |
+// | Probe | Reachable | Solicited confirmation w/ same address | | Notify | Changed |
+// | Probe | Reachable | Solicited confirmation w/o address | | Notify | Changed |
+// | Probe | Stale | Probe or confirmation w/ different address | | | Changed |
+// | Probe | Probe | Retransmit timer expired | | | Changed |
+// | Probe | Unreachable | Max probes sent without reply | | Notify | Changed |
+// | Unreachable | Incomplete | Packet queued | | Send probe | Changed |
+// | Unreachable | Stale | Probe w/ different address | LinkAddr | | Changed |
type testEntryEventType uint8
@@ -220,13 +221,15 @@ func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe
func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *faketime.ManualClock) {
clock := faketime.NewManualClock()
disp := testNUDDispatcher{}
- nic := NIC{
+ nic := nic{
LinkEndpoint: nil, // entryTestLinkResolver doesn't use a LinkEndpoint
id: entryTestNICID,
stack: &Stack{
- clock: clock,
- nudDisp: &disp,
+ clock: clock,
+ nudDisp: &disp,
+ nudConfigs: c,
+ randomGenerator: rand.New(rand.NewSource(time.Now().UnixNano())),
},
stats: makeNICStats(),
}
@@ -235,23 +238,18 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e
header.IPv6ProtocolNumber: netEP,
}
- rng := rand.New(rand.NewSource(time.Now().UnixNano()))
- nudState := NewNUDState(c, rng)
var linkRes entryTestLinkResolver
// Stub out the neighbor cache to verify deletion from the cache.
- neigh := &neighborCache{
- nic: &nic,
- state: nudState,
- linkRes: &linkRes,
- cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
- }
- l := linkResolver{
- resolver: &linkRes,
- neighborTable: neigh,
- }
- entry := newNeighborEntry(neigh, entryTestAddr1 /* remoteAddr */, nudState)
- neigh.cache[entryTestAddr1] = entry
- nic.linkAddrResolvers = map[tcpip.NetworkProtocolNumber]linkResolver{
+ l := &linkResolver{
+ resolver: &linkRes,
+ }
+ l.neigh.init(&nic, &linkRes)
+
+ entry := newNeighborEntry(&l.neigh, entryTestAddr1 /* remoteAddr */, l.neigh.state)
+ l.neigh.mu.Lock()
+ l.neigh.mu.cache[entryTestAddr1] = entry
+ l.neigh.mu.Unlock()
+ nic.linkAddrResolvers = map[tcpip.NetworkProtocolNumber]*linkResolver{
header.IPv6ProtocolNumber: l,
}
@@ -265,8 +263,8 @@ func TestEntryInitiallyUnknown(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- if e.neigh.State != Unknown {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Unknown)
+ if e.mu.neigh.State != Unknown {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unknown)
}
e.mu.Unlock()
@@ -298,8 +296,8 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) {
Override: false,
IsRouter: false,
})
- if e.neigh.State != Unknown {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Unknown)
+ if e.mu.neigh.State != Unknown {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unknown)
}
e.mu.Unlock()
@@ -327,8 +325,8 @@ func TestEntryUnknownToIncomplete(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if e.neigh.State != Incomplete {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
+ if e.mu.neigh.State != Incomplete {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete)
}
e.mu.Unlock()
@@ -374,8 +372,8 @@ func TestEntryUnknownToStale(t *testing.T) {
e.mu.Lock()
e.handleProbeLocked(entryTestLinkAddr1)
- if e.neigh.State != Stale {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ if e.mu.neigh.State != Stale {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
}
e.mu.Unlock()
@@ -413,10 +411,10 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if e.neigh.State != Incomplete {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
+ if e.mu.neigh.State != Incomplete {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete)
}
- updatedAtNanos := e.neigh.UpdatedAtNanos
+ updatedAtNanos := e.mu.neigh.UpdatedAtNanos
e.mu.Unlock()
clock.Advance(c.RetransmitTimer)
@@ -443,15 +441,15 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
}
e.mu.Lock()
- if got, want := e.neigh.UpdatedAtNanos, updatedAtNanos; got != want {
- t.Errorf("got e.neigh.UpdatedAt = %q, want = %q", got, want)
+ if got, want := e.mu.neigh.UpdatedAtNanos, updatedAtNanos; got != want {
+ t.Errorf("got e.mu.neigh.UpdatedAt = %q, want = %q", got, want)
}
e.mu.Unlock()
clock.Advance(c.RetransmitTimer)
// UpdatedAt should change after failing address resolution. Timing out after
- // sending the last probe transitions the entry to Failed.
+ // sending the last probe transitions the entry to Unreachable.
{
wantProbes := []entryTestProbeInfo{
{
@@ -481,12 +479,12 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
},
},
{
- EventType: entryTestRemoved,
+ EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
Addr: entryTestAddr1,
LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ State: Unreachable,
},
},
}
@@ -497,8 +495,8 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
nudDisp.mu.Unlock()
e.mu.Lock()
- if got, notWant := e.neigh.UpdatedAtNanos, updatedAtNanos; got == notWant {
- t.Errorf("expected e.neigh.UpdatedAt to change, got = %q", got)
+ if got, notWant := e.mu.neigh.UpdatedAtNanos, updatedAtNanos; got == notWant {
+ t.Errorf("expected e.mu.neigh.UpdatedAt to change, got = %q", got)
}
e.mu.Unlock()
}
@@ -509,8 +507,8 @@ func TestEntryIncompleteToReachable(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if e.neigh.State != Incomplete {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
+ if e.mu.neigh.State != Incomplete {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete)
}
e.mu.Unlock()
@@ -535,8 +533,8 @@ func TestEntryIncompleteToReachable(t *testing.T) {
Override: false,
IsRouter: false,
})
- if e.neigh.State != Reachable {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ if e.mu.neigh.State != Reachable {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
}
e.mu.Unlock()
@@ -573,8 +571,8 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if e.neigh.State != Incomplete {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
+ if e.mu.neigh.State != Incomplete {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete)
}
e.mu.Unlock()
@@ -599,11 +597,11 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
Override: false,
IsRouter: true,
})
- if e.neigh.State != Reachable {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ if e.mu.neigh.State != Reachable {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
}
- if !e.isRouter {
- t.Errorf("got e.isRouter = %t, want = true", e.isRouter)
+ if !e.mu.isRouter {
+ t.Errorf("got e.mu.isRouter = %t, want = true", e.mu.isRouter)
}
e.mu.Unlock()
@@ -640,8 +638,8 @@ func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if e.neigh.State != Incomplete {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
+ if e.mu.neigh.State != Incomplete {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete)
}
e.mu.Unlock()
@@ -666,8 +664,8 @@ func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) {
Override: false,
IsRouter: false,
})
- if e.neigh.State != Stale {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ if e.mu.neigh.State != Stale {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
}
e.mu.Unlock()
@@ -704,8 +702,8 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if e.neigh.State != Incomplete {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
+ if e.mu.neigh.State != Incomplete {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete)
}
e.mu.Unlock()
@@ -726,8 +724,8 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) {
e.mu.Lock()
e.handleProbeLocked(entryTestLinkAddr1)
- if e.neigh.State != Stale {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ if e.mu.neigh.State != Stale {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
}
e.mu.Unlock()
@@ -758,15 +756,15 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) {
nudDisp.mu.Unlock()
}
-func TestEntryIncompleteToFailed(t *testing.T) {
+func TestEntryIncompleteToUnreachable(t *testing.T) {
c := DefaultNUDConfigurations()
c.MaxMulticastProbes = 3
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if e.neigh.State != Incomplete {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
+ if e.mu.neigh.State != Incomplete {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete)
}
e.mu.Unlock()
@@ -810,12 +808,12 @@ func TestEntryIncompleteToFailed(t *testing.T) {
},
},
{
- EventType: entryTestRemoved,
+ EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
Addr: entryTestAddr1,
LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ State: Unreachable,
},
},
}
@@ -826,8 +824,8 @@ func TestEntryIncompleteToFailed(t *testing.T) {
nudDisp.mu.Unlock()
e.mu.Lock()
- if e.neigh.State != Failed {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed)
+ if e.mu.neigh.State != Unreachable {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unreachable)
}
e.mu.Unlock()
}
@@ -870,11 +868,11 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
Override: false,
IsRouter: true,
})
- if e.neigh.State != Reachable {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ if e.mu.neigh.State != Reachable {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
}
- if got, want := e.isRouter, true; got != want {
- t.Errorf("got e.isRouter = %t, want = %t", got, want)
+ if got, want := e.mu.isRouter, true; got != want {
+ t.Errorf("got e.mu.isRouter = %t, want = %t", got, want)
}
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
@@ -882,11 +880,11 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
Override: false,
IsRouter: false,
})
- if got, want := e.isRouter, false; got != want {
- t.Errorf("got e.isRouter = %t, want = %t", got, want)
+ if got, want := e.mu.isRouter, false; got != want {
+ t.Errorf("got e.mu.isRouter = %t, want = %t", got, want)
}
- if ipv6EP.invalidatedRtr != e.neigh.Addr {
- t.Errorf("got ipv6EP.invalidatedRtr = %s, want = %s", ipv6EP.invalidatedRtr, e.neigh.Addr)
+ if ipv6EP.invalidatedRtr != e.mu.neigh.Addr {
+ t.Errorf("got ipv6EP.invalidatedRtr = %s, want = %s", ipv6EP.invalidatedRtr, e.mu.neigh.Addr)
}
e.mu.Unlock()
@@ -917,8 +915,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
nudDisp.mu.Unlock()
e.mu.Lock()
- if e.neigh.State != Reachable {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ if e.mu.neigh.State != Reachable {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
}
e.mu.Unlock()
}
@@ -952,15 +950,15 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) {
Override: false,
IsRouter: false,
})
- if e.neigh.State != Reachable {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ if e.mu.neigh.State != Reachable {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
}
e.handleProbeLocked(entryTestLinkAddr1)
- if e.neigh.State != Reachable {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ if e.mu.neigh.State != Reachable {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
}
- if e.neigh.LinkAddr != entryTestLinkAddr1 {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1)
+ if e.mu.neigh.LinkAddr != entryTestLinkAddr1 {
+ t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1)
}
e.mu.Unlock()
@@ -1025,8 +1023,8 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
Override: false,
IsRouter: false,
})
- if e.neigh.State != Reachable {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ if e.mu.neigh.State != Reachable {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
}
e.mu.Unlock()
@@ -1068,8 +1066,8 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
nudDisp.mu.Unlock()
e.mu.Lock()
- if e.neigh.State != Stale {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ if e.mu.neigh.State != Stale {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
}
e.mu.Unlock()
}
@@ -1103,12 +1101,12 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
Override: false,
IsRouter: false,
})
- if e.neigh.State != Reachable {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ if e.mu.neigh.State != Reachable {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
}
e.handleProbeLocked(entryTestLinkAddr2)
- if e.neigh.State != Stale {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ if e.mu.neigh.State != Stale {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
}
e.mu.Unlock()
@@ -1177,16 +1175,16 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T)
Override: false,
IsRouter: false,
})
- if e.neigh.State != Reachable {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ if e.mu.neigh.State != Reachable {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
}
e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- if e.neigh.State != Stale {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ if e.mu.neigh.State != Stale {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
}
e.mu.Unlock()
@@ -1255,16 +1253,16 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t
Override: false,
IsRouter: false,
})
- if e.neigh.State != Reachable {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ if e.mu.neigh.State != Reachable {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
}
e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
Solicited: false,
Override: true,
IsRouter: false,
})
- if e.neigh.State != Stale {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ if e.mu.neigh.State != Stale {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
}
e.mu.Unlock()
@@ -1333,15 +1331,15 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) {
Override: false,
IsRouter: false,
})
- if e.neigh.State != Stale {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ if e.mu.neigh.State != Stale {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
}
e.handleProbeLocked(entryTestLinkAddr1)
- if e.neigh.State != Stale {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ if e.mu.neigh.State != Stale {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
}
- if e.neigh.LinkAddr != entryTestLinkAddr1 {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1)
+ if e.mu.neigh.LinkAddr != entryTestLinkAddr1 {
+ t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1)
}
e.mu.Unlock()
@@ -1401,19 +1399,19 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
Override: false,
IsRouter: false,
})
- if e.neigh.State != Stale {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ if e.mu.neigh.State != Stale {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
}
e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
Solicited: true,
Override: true,
IsRouter: false,
})
- if e.neigh.State != Reachable {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ if e.mu.neigh.State != Reachable {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
}
- if e.neigh.LinkAddr != entryTestLinkAddr2 {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2)
+ if e.mu.neigh.LinkAddr != entryTestLinkAddr2 {
+ t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2)
}
e.mu.Unlock()
@@ -1482,19 +1480,19 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
Override: false,
IsRouter: false,
})
- if e.neigh.State != Stale {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ if e.mu.neigh.State != Stale {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
}
e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{
Solicited: true,
Override: false,
IsRouter: false,
})
- if e.neigh.State != Reachable {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ if e.mu.neigh.State != Reachable {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
}
- if e.neigh.LinkAddr != entryTestLinkAddr1 {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1)
+ if e.mu.neigh.LinkAddr != entryTestLinkAddr1 {
+ t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1)
}
e.mu.Unlock()
@@ -1563,19 +1561,19 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
Override: false,
IsRouter: false,
})
- if e.neigh.State != Stale {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ if e.mu.neigh.State != Stale {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
}
e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
Solicited: false,
Override: true,
IsRouter: false,
})
- if e.neigh.State != Stale {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ if e.mu.neigh.State != Stale {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
}
- if e.neigh.LinkAddr != entryTestLinkAddr2 {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2)
+ if e.mu.neigh.LinkAddr != entryTestLinkAddr2 {
+ t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2)
}
e.mu.Unlock()
@@ -1644,15 +1642,15 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
Override: false,
IsRouter: false,
})
- if e.neigh.State != Stale {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ if e.mu.neigh.State != Stale {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
}
e.handleProbeLocked(entryTestLinkAddr2)
- if e.neigh.State != Stale {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ if e.mu.neigh.State != Stale {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
}
- if e.neigh.LinkAddr != entryTestLinkAddr2 {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2)
+ if e.mu.neigh.LinkAddr != entryTestLinkAddr2 {
+ t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2)
}
e.mu.Unlock()
@@ -1721,12 +1719,12 @@ func TestEntryStaleToDelay(t *testing.T) {
Override: false,
IsRouter: false,
})
- if e.neigh.State != Stale {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ if e.mu.neigh.State != Stale {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
}
e.handlePacketQueuedLocked(entryTestAddr2)
- if e.neigh.State != Delay {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ if e.mu.neigh.State != Delay {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
}
e.mu.Unlock()
@@ -1801,12 +1799,12 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
IsRouter: false,
})
e.handlePacketQueuedLocked(entryTestAddr2)
- if e.neigh.State != Delay {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ if e.mu.neigh.State != Delay {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay)
}
e.handleUpperLevelConfirmationLocked()
- if e.neigh.State != Reachable {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ if e.mu.neigh.State != Reachable {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
}
e.mu.Unlock()
@@ -1901,19 +1899,19 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
IsRouter: false,
})
e.handlePacketQueuedLocked(entryTestAddr2)
- if e.neigh.State != Delay {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ if e.mu.neigh.State != Delay {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay)
}
e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
Solicited: true,
Override: true,
IsRouter: false,
})
- if e.neigh.State != Reachable {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ if e.mu.neigh.State != Reachable {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
}
- if e.neigh.LinkAddr != entryTestLinkAddr2 {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2)
+ if e.mu.neigh.LinkAddr != entryTestLinkAddr2 {
+ t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2)
}
e.mu.Unlock()
@@ -2008,19 +2006,19 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
IsRouter: false,
})
e.handlePacketQueuedLocked(entryTestAddr2)
- if e.neigh.State != Delay {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ if e.mu.neigh.State != Delay {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay)
}
e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{
Solicited: true,
Override: false,
IsRouter: false,
})
- if e.neigh.State != Reachable {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ if e.mu.neigh.State != Reachable {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
}
- if e.neigh.LinkAddr != entryTestLinkAddr1 {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1)
+ if e.mu.neigh.LinkAddr != entryTestLinkAddr1 {
+ t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1)
}
e.mu.Unlock()
@@ -2109,19 +2107,19 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
IsRouter: false,
})
e.handlePacketQueuedLocked(entryTestAddr2)
- if e.neigh.State != Delay {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ if e.mu.neigh.State != Delay {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay)
}
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: true,
IsRouter: false,
})
- if e.neigh.State != Delay {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ if e.mu.neigh.State != Delay {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay)
}
- if e.neigh.LinkAddr != entryTestLinkAddr1 {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1)
+ if e.mu.neigh.LinkAddr != entryTestLinkAddr1 {
+ t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1)
}
e.mu.Unlock()
@@ -2191,12 +2189,12 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
IsRouter: false,
})
e.handlePacketQueuedLocked(entryTestAddr2)
- if e.neigh.State != Delay {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ if e.mu.neigh.State != Delay {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay)
}
e.handleProbeLocked(entryTestLinkAddr2)
- if e.neigh.State != Stale {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ if e.mu.neigh.State != Stale {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
}
e.mu.Unlock()
@@ -2275,16 +2273,16 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
IsRouter: false,
})
e.handlePacketQueuedLocked(entryTestAddr2)
- if e.neigh.State != Delay {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ if e.mu.neigh.State != Delay {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay)
}
e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
Solicited: false,
Override: true,
IsRouter: false,
})
- if e.neigh.State != Stale {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ if e.mu.neigh.State != Stale {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
}
e.mu.Unlock()
@@ -2366,8 +2364,8 @@ func TestEntryDelayToProbe(t *testing.T) {
IsRouter: false,
})
e.handlePacketQueuedLocked(entryTestAddr2)
- if e.neigh.State != Delay {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ if e.mu.neigh.State != Delay {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay)
}
e.mu.Unlock()
@@ -2432,8 +2430,8 @@ func TestEntryDelayToProbe(t *testing.T) {
nudDisp.mu.Unlock()
e.mu.Lock()
- if e.neigh.State != Probe {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
+ if e.mu.neigh.State != Probe {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe)
}
e.mu.Unlock()
}
@@ -2490,12 +2488,12 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
}
e.mu.Lock()
- if e.neigh.State != Probe {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
+ if e.mu.neigh.State != Probe {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe)
}
e.handleProbeLocked(entryTestLinkAddr2)
- if e.neigh.State != Stale {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ if e.mu.neigh.State != Stale {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
}
e.mu.Unlock()
@@ -2605,16 +2603,16 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
}
e.mu.Lock()
- if e.neigh.State != Probe {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
+ if e.mu.neigh.State != Probe {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe)
}
e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
Solicited: false,
Override: true,
IsRouter: false,
})
- if e.neigh.State != Stale {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ if e.mu.neigh.State != Stale {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
}
e.mu.Unlock()
@@ -2725,19 +2723,19 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
}
e.mu.Lock()
- if e.neigh.State != Probe {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
+ if e.mu.neigh.State != Probe {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe)
}
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: true,
IsRouter: false,
})
- if e.neigh.State != Probe {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
+ if e.mu.neigh.State != Probe {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe)
}
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ if got, want := e.mu.neigh.LinkAddr, entryTestLinkAddr1; got != want {
+ t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", got, want)
}
e.mu.Unlock()
@@ -2821,19 +2819,19 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
}
e.mu.Lock()
- if e.neigh.State != Probe {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
+ if e.mu.neigh.State != Probe {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe)
}
e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
Solicited: true,
Override: true,
IsRouter: false,
})
- if e.neigh.State != Reachable {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ if e.mu.neigh.State != Reachable {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
}
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ if got, want := e.mu.neigh.LinkAddr, entryTestLinkAddr2; got != want {
+ t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", got, want)
}
e.mu.Unlock()
@@ -2949,19 +2947,19 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
}
e.mu.Lock()
- if e.neigh.State != Probe {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
+ if e.mu.neigh.State != Probe {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe)
}
e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
Solicited: true,
Override: true,
IsRouter: false,
})
- if e.neigh.State != Reachable {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ if e.mu.neigh.State != Reachable {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
}
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ if got, want := e.mu.neigh.LinkAddr, entryTestLinkAddr2; got != want {
+ t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", got, want)
}
e.mu.Unlock()
@@ -3086,16 +3084,16 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
}
e.mu.Lock()
- if e.neigh.State != Probe {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
+ if e.mu.neigh.State != Probe {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe)
}
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: true,
Override: false,
IsRouter: false,
})
- if e.neigh.State != Reachable {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ if e.mu.neigh.State != Reachable {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
}
e.mu.Unlock()
@@ -3220,16 +3218,16 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
}
e.mu.Lock()
- if e.neigh.State != Probe {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
+ if e.mu.neigh.State != Probe {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe)
}
e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{
Solicited: true,
Override: false,
IsRouter: false,
})
- if e.neigh.State != Reachable {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ if e.mu.neigh.State != Reachable {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable)
}
e.mu.Unlock()
@@ -3297,7 +3295,7 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
nudDisp.mu.Unlock()
}
-func TestEntryProbeToFailed(t *testing.T) {
+func TestEntryProbeToUnreachable(t *testing.T) {
c := DefaultNUDConfigurations()
c.MaxMulticastProbes = 3
c.MaxUnicastProbes = 3
@@ -3352,17 +3350,17 @@ func TestEntryProbeToFailed(t *testing.T) {
}
e.mu.Lock()
- if e.neigh.State != Probe {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
+ if e.mu.neigh.State != Probe {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe)
}
e.mu.Unlock()
}
- // Wait for the last probe to expire, causing a transition to Failed.
+ // Wait for the last probe to expire, causing a transition to Unreachable.
clock.Advance(c.RetransmitTimer)
e.mu.Lock()
- if e.neigh.State != Failed {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed)
+ if e.mu.neigh.State != Unreachable {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unreachable)
}
e.mu.Unlock()
@@ -3404,12 +3402,12 @@ func TestEntryProbeToFailed(t *testing.T) {
},
},
{
- EventType: entryTestRemoved,
+ EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
Addr: entryTestAddr1,
LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ State: Unreachable,
},
},
}
@@ -3420,7 +3418,7 @@ func TestEntryProbeToFailed(t *testing.T) {
nudDisp.mu.Unlock()
}
-func TestEntryFailedToIncomplete(t *testing.T) {
+func TestEntryUnreachableToIncomplete(t *testing.T) {
c := DefaultNUDConfigurations()
c.MaxMulticastProbes = 3
e, nudDisp, linkRes, clock := entryTestSetup(c)
@@ -3429,8 +3427,8 @@ func TestEntryFailedToIncomplete(t *testing.T) {
// their expected state.
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if e.neigh.State != Incomplete {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
+ if e.mu.neigh.State != Incomplete {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete)
}
e.mu.Unlock()
@@ -3464,15 +3462,15 @@ func TestEntryFailedToIncomplete(t *testing.T) {
}
e.mu.Lock()
- if e.neigh.State != Failed {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed)
+ if e.mu.neigh.State != Unreachable {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unreachable)
}
e.mu.Unlock()
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if e.neigh.State != Incomplete {
- t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
+ if e.mu.neigh.State != Incomplete {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete)
}
e.mu.Unlock()
@@ -3487,7 +3485,16 @@ func TestEntryFailedToIncomplete(t *testing.T) {
},
},
{
- EventType: entryTestRemoved,
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Unreachable,
+ },
+ },
+ {
+ EventType: entryTestChanged,
NICID: entryTestNICID,
Entry: NeighborEntry{
Addr: entryTestAddr1,
@@ -3495,6 +3502,72 @@ func TestEntryFailedToIncomplete(t *testing.T) {
State: Incomplete,
},
},
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryUnreachableToStale(t *testing.T) {
+ wantProbes := []entryTestProbeInfo{
+ // The Incomplete-to-Incomplete state transition is tested here by
+ // verifying that 3 reachability probes were sent.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = uint32(len(wantProbes))
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ // TODO(gvisor.dev/issue/4872): Use helper functions to start entry tests in
+ // their expected state.
+ e.mu.Lock()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.mu.neigh.State != Incomplete {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete)
+ }
+ e.mu.Unlock()
+
+ waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes)
+ clock.Advance(waitFor)
+
+ linkRes.mu.Lock()
+ diff := cmp.Diff(wantProbes, linkRes.probes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if e.mu.neigh.State != Unreachable {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unreachable)
+ }
+ e.mu.Unlock()
+
+ e.mu.Lock()
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if e.mu.neigh.State != Stale {
+ t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
+ wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
@@ -3504,6 +3577,24 @@ func TestEntryFailedToIncomplete(t *testing.T) {
State: Incomplete,
},
},
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Unreachable,
+ },
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ },
}
nudDisp.mu.Lock()
if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
diff --git a/pkg/tcpip/stack/neighborstate_string.go b/pkg/tcpip/stack/neighborstate_string.go
index aa7311ec6..765df4d7a 100644
--- a/pkg/tcpip/stack/neighborstate_string.go
+++ b/pkg/tcpip/stack/neighborstate_string.go
@@ -1,4 +1,4 @@
-// Copyright 2020 The gVisor Authors.
+// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -30,11 +30,12 @@ func _() {
_ = x[Probe-5]
_ = x[Static-6]
_ = x[Failed-7]
+ _ = x[Unreachable-8]
}
-const _NeighborState_name = "UnknownIncompleteReachableStaleDelayProbeStaticFailed"
+const _NeighborState_name = "UnknownIncompleteReachableStaleDelayProbeStaticFailedUnreachable"
-var _NeighborState_index = [...]uint8{0, 7, 17, 26, 31, 36, 41, 47, 53}
+var _NeighborState_index = [...]uint8{0, 7, 17, 26, 31, 36, 41, 47, 53, 64}
func (i NeighborState) String() string {
if i >= NeighborState(len(_NeighborState_index)-1) {
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 41a489047..f66db16a7 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -24,40 +24,26 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
)
-type neighborTable interface {
- neighbors() ([]NeighborEntry, tcpip.Error)
- addStaticEntry(tcpip.Address, tcpip.LinkAddress)
- get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error)
- remove(tcpip.Address) tcpip.Error
- removeAll() tcpip.Error
-
- handleProbe(tcpip.Address, tcpip.LinkAddress)
- handleConfirmation(tcpip.Address, tcpip.LinkAddress, ReachabilityConfirmationFlags)
- handleUpperLevelConfirmation(tcpip.Address)
-
- nudConfig() (NUDConfigurations, tcpip.Error)
- setNUDConfig(NUDConfigurations) tcpip.Error
-}
-
-var _ NetworkInterface = (*NIC)(nil)
-
type linkResolver struct {
resolver LinkAddressResolver
- neighborTable neighborTable
+ neigh neighborCache
}
func (l *linkResolver) getNeighborLinkAddress(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) {
- return l.neighborTable.get(addr, localAddr, onResolve)
+ entry, ch, err := l.neigh.entry(addr, localAddr, onResolve)
+ return entry.LinkAddr, ch, err
}
func (l *linkResolver) confirmReachable(addr tcpip.Address) {
- l.neighborTable.handleUpperLevelConfirmation(addr)
+ l.neigh.handleUpperLevelConfirmation(addr)
}
-// NIC represents a "network interface card" to which the networking stack is
+var _ NetworkInterface = (*nic)(nil)
+
+// nic represents a "network interface card" to which the networking stack is
// attached.
-type NIC struct {
+type nic struct {
LinkEndpoint
stack *Stack
@@ -69,8 +55,9 @@ type NIC struct {
// The network endpoints themselves may be modified by calling the interface's
// methods, but the map reference and entries must be constant.
- networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint
- linkAddrResolvers map[tcpip.NetworkProtocolNumber]linkResolver
+ networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint
+ linkAddrResolvers map[tcpip.NetworkProtocolNumber]*linkResolver
+ duplicateAddressDetectors map[tcpip.NetworkProtocolNumber]DuplicateAddressDetector
// enabled is set to 1 when the NIC is enabled and 0 when it is disabled.
//
@@ -147,7 +134,7 @@ func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) {
}
// newNIC returns a new NIC using the default NDP configurations from stack.
-func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC {
+func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *nic {
// TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For
// example, make sure that the link address it provides is a valid
// unicast ethernet address.
@@ -156,16 +143,17 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
// observe an MTU of at least 1280 bytes. Ensure that this requirement
// of IPv6 is supported on this endpoint's LinkEndpoint.
- nic := &NIC{
+ nic := &nic{
LinkEndpoint: ep,
- stack: stack,
- id: id,
- name: name,
- context: ctx,
- stats: makeNICStats(),
- networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
- linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]linkResolver),
+ stack: stack,
+ id: id,
+ name: name,
+ context: ctx,
+ stats: makeNICStats(),
+ networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
+ linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]*linkResolver),
+ duplicateAddressDetectors: make(map[tcpip.NetworkProtocolNumber]DuplicateAddressDetector),
}
nic.linkResQueue.init(nic)
nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList)
@@ -185,26 +173,15 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
if resolutionRequired {
if r, ok := netEP.(LinkAddressResolver); ok {
- l := linkResolver{
- resolver: r,
- }
-
- if stack.useNeighborCache {
- l.neighborTable = &neighborCache{
- nic: nic,
- state: NewNUDState(stack.nudConfigs, stack.randomGenerator),
- linkRes: r,
-
- cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
- }
- } else {
- cache := new(linkAddrCache)
- cache.init(nic, ageLimit, resolutionTimeout, resolutionAttempts, r)
- l.neighborTable = cache
- }
+ l := &linkResolver{resolver: r}
+ l.neigh.init(nic, r)
nic.linkAddrResolvers[r.LinkAddressProtocol()] = l
}
}
+
+ if d, ok := netEP.(DuplicateAddressDetector); ok {
+ nic.duplicateAddressDetectors[d.DuplicateAddressProtocol()] = d
+ }
}
nic.LinkEndpoint.Attach(nic)
@@ -212,19 +189,19 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
return nic
}
-func (n *NIC) getNetworkEndpoint(proto tcpip.NetworkProtocolNumber) NetworkEndpoint {
+func (n *nic) getNetworkEndpoint(proto tcpip.NetworkProtocolNumber) NetworkEndpoint {
return n.networkEndpoints[proto]
}
// Enabled implements NetworkInterface.
-func (n *NIC) Enabled() bool {
+func (n *nic) Enabled() bool {
return atomic.LoadUint32(&n.enabled) == 1
}
// setEnabled sets the enabled status for the NIC.
//
// Returns true if the enabled status was updated.
-func (n *NIC) setEnabled(v bool) bool {
+func (n *nic) setEnabled(v bool) bool {
if v {
return atomic.SwapUint32(&n.enabled, 1) == 0
}
@@ -234,7 +211,7 @@ func (n *NIC) setEnabled(v bool) bool {
// disable disables n.
//
// It undoes the work done by enable.
-func (n *NIC) disable() {
+func (n *nic) disable() {
n.mu.Lock()
n.disableLocked()
n.mu.Unlock()
@@ -245,7 +222,7 @@ func (n *NIC) disable() {
// It undoes the work done by enable.
//
// n MUST be locked.
-func (n *NIC) disableLocked() {
+func (n *nic) disableLocked() {
if !n.Enabled() {
return
}
@@ -283,7 +260,7 @@ func (n *NIC) disableLocked() {
// address (ff02::1), start DAD for permanent addresses, and start soliciting
// routers if the stack is not operating as a router. If the stack is also
// configured to auto-generate a link-local address, one will be generated.
-func (n *NIC) enable() tcpip.Error {
+func (n *nic) enable() tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
@@ -303,7 +280,7 @@ func (n *NIC) enable() tcpip.Error {
// remove detaches NIC from the link endpoint and releases network endpoint
// resources. This guarantees no packets between this NIC and the network
// stack.
-func (n *NIC) remove() tcpip.Error {
+func (n *nic) remove() tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
@@ -319,14 +296,14 @@ func (n *NIC) remove() tcpip.Error {
}
// setPromiscuousMode enables or disables promiscuous mode.
-func (n *NIC) setPromiscuousMode(enable bool) {
+func (n *nic) setPromiscuousMode(enable bool) {
n.mu.Lock()
n.mu.promiscuous = enable
n.mu.Unlock()
}
// Promiscuous implements NetworkInterface.
-func (n *NIC) Promiscuous() bool {
+func (n *nic) Promiscuous() bool {
n.mu.RLock()
rv := n.mu.promiscuous
n.mu.RUnlock()
@@ -334,17 +311,17 @@ func (n *NIC) Promiscuous() bool {
}
// IsLoopback implements NetworkInterface.
-func (n *NIC) IsLoopback() bool {
+func (n *nic) IsLoopback() bool {
return n.LinkEndpoint.Capabilities()&CapabilityLoopback != 0
}
// WritePacket implements NetworkLinkEndpoint.
-func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
+func (n *nic) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
_, err := n.enqueuePacketBuffer(r, gso, protocol, pkt)
return err
}
-func (n *NIC) writePacketBuffer(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) {
+func (n *nic) writePacketBuffer(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) {
switch pkt := pkt.(type) {
case *PacketBuffer:
if err := n.writePacket(r, gso, protocol, pkt); err != nil {
@@ -358,7 +335,7 @@ func (n *NIC) writePacketBuffer(r RouteInfo, gso *GSO, protocol tcpip.NetworkPro
}
}
-func (n *NIC) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) {
+func (n *nic) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) {
routeInfo, _, err := r.resolvedFields(nil)
switch err.(type) {
case nil:
@@ -388,14 +365,14 @@ func (n *NIC) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProt
}
// WritePacketToRemote implements NetworkInterface.
-func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
+func (n *nic) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
var r RouteInfo
r.NetProto = protocol
r.RemoteLinkAddress = remoteLinkAddr
return n.writePacket(r, gso, protocol, pkt)
}
-func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
+func (n *nic) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
// WritePacket takes ownership of pkt, calculate numBytes first.
numBytes := pkt.Size()
@@ -412,11 +389,11 @@ func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolN
}
// WritePackets implements NetworkLinkEndpoint.
-func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+func (n *nic) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
return n.enqueuePacketBuffer(r, gso, protocol, &pkts)
}
-func (n *NIC) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, tcpip.Error) {
+func (n *nic) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, tcpip.Error) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
pkt.EgressRoute = r
pkt.GSOOptions = gso
@@ -435,15 +412,22 @@ func (n *NIC) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocol
}
// setSpoofing enables or disables address spoofing.
-func (n *NIC) setSpoofing(enable bool) {
+func (n *nic) setSpoofing(enable bool) {
n.mu.Lock()
n.mu.spoofing = enable
n.mu.Unlock()
}
+// Spoofing implements NetworkInterface.
+func (n *nic) Spoofing() bool {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+ return n.mu.spoofing
+}
+
// primaryAddress returns an address that can be used to communicate with
// remoteAddr.
-func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint {
+func (n *nic) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint {
ep, ok := n.networkEndpoints[protocol]
if !ok {
return nil
@@ -473,11 +457,11 @@ const (
promiscuous
)
-func (n *NIC) getAddress(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) AssignableAddressEndpoint {
+func (n *nic) getAddress(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) AssignableAddressEndpoint {
return n.getAddressOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous)
}
-func (n *NIC) hasAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+func (n *nic) hasAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
ep := n.getAddressOrCreateTempInner(protocol, addr, false, NeverPrimaryEndpoint)
if ep != nil {
ep.DecRef()
@@ -488,7 +472,7 @@ func (n *NIC) hasAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Addres
}
// findEndpoint finds the endpoint, if any, with the given address.
-func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) AssignableAddressEndpoint {
+func (n *nic) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) AssignableAddressEndpoint {
return n.getAddressOrCreateTemp(protocol, address, peb, spoofing)
}
@@ -501,7 +485,7 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A
//
// If the address is the IPv4 broadcast address for an endpoint's network, that
// endpoint will be returned.
-func (n *NIC) getAddressOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getAddressBehaviour) AssignableAddressEndpoint {
+func (n *nic) getAddressOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getAddressBehaviour) AssignableAddressEndpoint {
n.mu.RLock()
var spoofingOrPromiscuous bool
switch tempRef {
@@ -516,7 +500,7 @@ func (n *NIC) getAddressOrCreateTemp(protocol tcpip.NetworkProtocolNumber, addre
// getAddressOrCreateTempInner is like getAddressEpOrCreateTemp except a boolean
// is passed to indicate whether or not we should generate temporary endpoints.
-func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, createTemp bool, peb PrimaryEndpointBehavior) AssignableAddressEndpoint {
+func (n *nic) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, createTemp bool, peb PrimaryEndpointBehavior) AssignableAddressEndpoint {
ep, ok := n.networkEndpoints[protocol]
if !ok {
return nil
@@ -532,7 +516,7 @@ func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber,
// addAddress adds a new address to n, so that it starts accepting packets
// targeted at the given address (and network protocol).
-func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error {
+func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error {
ep, ok := n.networkEndpoints[protocolAddress.Protocol]
if !ok {
return &tcpip.ErrUnknownProtocol{}
@@ -553,7 +537,7 @@ func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo
// allPermanentAddresses returns all permanent addresses associated with
// this NIC.
-func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress {
+func (n *nic) allPermanentAddresses() []tcpip.ProtocolAddress {
var addrs []tcpip.ProtocolAddress
for p, ep := range n.networkEndpoints {
addressableEndpoint, ok := ep.(AddressableEndpoint)
@@ -569,7 +553,7 @@ func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress {
}
// primaryAddresses returns the primary addresses associated with this NIC.
-func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress {
+func (n *nic) primaryAddresses() []tcpip.ProtocolAddress {
var addrs []tcpip.ProtocolAddress
for p, ep := range n.networkEndpoints {
addressableEndpoint, ok := ep.(AddressableEndpoint)
@@ -589,7 +573,7 @@ func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress {
// primaryAddress will return the first non-deprecated address if such an
// address exists. If no non-deprecated address exists, the first deprecated
// address will be returned.
-func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix {
+func (n *nic) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix {
ep, ok := n.networkEndpoints[proto]
if !ok {
return tcpip.AddressWithPrefix{}
@@ -604,7 +588,7 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit
}
// removeAddress removes an address from n.
-func (n *NIC) removeAddress(addr tcpip.Address) tcpip.Error {
+func (n *nic) removeAddress(addr tcpip.Address) tcpip.Error {
for _, ep := range n.networkEndpoints {
addressableEndpoint, ok := ep.(AddressableEndpoint)
if !ok {
@@ -622,7 +606,7 @@ func (n *NIC) removeAddress(addr tcpip.Address) tcpip.Error {
return &tcpip.ErrBadLocalAddress{}
}
-func (n *NIC) getLinkAddress(addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(LinkResolutionResult)) tcpip.Error {
+func (n *nic) getLinkAddress(addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(LinkResolutionResult)) tcpip.Error {
linkRes, ok := n.linkAddrResolvers[protocol]
if !ok {
return &tcpip.ErrNotSupported{}
@@ -637,34 +621,38 @@ func (n *NIC) getLinkAddress(addr, localAddr tcpip.Address, protocol tcpip.Netwo
return err
}
-func (n *NIC) neighbors(protocol tcpip.NetworkProtocolNumber) ([]NeighborEntry, tcpip.Error) {
+func (n *nic) neighbors(protocol tcpip.NetworkProtocolNumber) ([]NeighborEntry, tcpip.Error) {
if linkRes, ok := n.linkAddrResolvers[protocol]; ok {
- return linkRes.neighborTable.neighbors()
+ return linkRes.neigh.entries(), nil
}
return nil, &tcpip.ErrNotSupported{}
}
-func (n *NIC) addStaticNeighbor(addr tcpip.Address, protocol tcpip.NetworkProtocolNumber, linkAddress tcpip.LinkAddress) tcpip.Error {
+func (n *nic) addStaticNeighbor(addr tcpip.Address, protocol tcpip.NetworkProtocolNumber, linkAddress tcpip.LinkAddress) tcpip.Error {
if linkRes, ok := n.linkAddrResolvers[protocol]; ok {
- linkRes.neighborTable.addStaticEntry(addr, linkAddress)
+ linkRes.neigh.addStaticEntry(addr, linkAddress)
return nil
}
return &tcpip.ErrNotSupported{}
}
-func (n *NIC) removeNeighbor(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error {
+func (n *nic) removeNeighbor(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error {
if linkRes, ok := n.linkAddrResolvers[protocol]; ok {
- return linkRes.neighborTable.remove(addr)
+ if !linkRes.neigh.removeEntry(addr) {
+ return &tcpip.ErrBadAddress{}
+ }
+ return nil
}
return &tcpip.ErrNotSupported{}
}
-func (n *NIC) clearNeighbors(protocol tcpip.NetworkProtocolNumber) tcpip.Error {
+func (n *nic) clearNeighbors(protocol tcpip.NetworkProtocolNumber) tcpip.Error {
if linkRes, ok := n.linkAddrResolvers[protocol]; ok {
- return linkRes.neighborTable.removeAll()
+ linkRes.neigh.clear()
+ return nil
}
return &tcpip.ErrNotSupported{}
@@ -672,7 +660,7 @@ func (n *NIC) clearNeighbors(protocol tcpip.NetworkProtocolNumber) tcpip.Error {
// joinGroup adds a new endpoint for the given multicast address, if none
// exists yet. Otherwise it just increments its count.
-func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error {
+func (n *nic) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error {
// TODO(b/143102137): When implementing MLD, make sure MLD packets are
// not sent unless a valid link-local address is available for use on n
// as an MLD packet's source address must be a link-local address as
@@ -693,7 +681,7 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
// leaveGroup decrements the count for the given multicast address, and when it
// reaches zero removes the endpoint for this address.
-func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error {
+func (n *nic) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error {
ep, ok := n.networkEndpoints[protocol]
if !ok {
return &tcpip.ErrNotSupported{}
@@ -708,7 +696,7 @@ func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Addres
}
// isInGroup returns true if n has joined the multicast group addr.
-func (n *NIC) isInGroup(addr tcpip.Address) bool {
+func (n *nic) isInGroup(addr tcpip.Address) bool {
for _, ep := range n.networkEndpoints {
gep, ok := ep.(GroupAddressableEndpoint)
if !ok {
@@ -729,7 +717,7 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool {
// Note that the ownership of the slice backing vv is retained by the caller.
// This rule applies only to the slice itself, not to the items of the slice;
// the ownership of the items is not retained by the caller.
-func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
n.mu.RLock()
enabled := n.Enabled()
// If the NIC is not yet enabled, don't receive any packets.
@@ -777,41 +765,11 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
anyEPs.forEach(deliverPacketEPs)
}
- // Parse headers.
- netProto := n.stack.NetworkProtocolInstance(protocol)
- transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt)
- if !ok {
- // The packet is too small to contain a network header.
- n.stack.stats.MalformedRcvdPackets.Increment()
- return
- }
- if hasTransportHdr {
- pkt.TransportProtocolNumber = transProtoNum
- // Parse the transport header if present.
- if state, ok := n.stack.transportProtocols[transProtoNum]; ok {
- state.proto.Parse(pkt)
- }
- }
-
- if n.stack.handleLocal && !n.IsLoopback() {
- src, _ := netProto.ParseAddresses(pkt.NetworkHeader().View())
- if r := n.getAddress(protocol, src); r != nil {
- r.DecRef()
-
- // The source address is one of our own, so we never should have gotten a
- // packet like this unless handleLocal is false. Loopback also calls this
- // function even though the packets didn't come from the physical interface
- // so don't drop those.
- n.stack.stats.IP.InvalidSourceAddressesReceived.Increment()
- return
- }
- }
-
networkEndpoint.HandlePacket(pkt)
}
// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket.
-func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+func (n *nic) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
n.mu.RLock()
// We do not deliver to protocol specific packet endpoints as on Linux
// only ETH_P_ALL endpoints get outbound packets.
@@ -831,7 +789,7 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
-func (n *NIC) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition {
+func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
@@ -912,7 +870,7 @@ func (n *NIC) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt
}
// DeliverTransportError implements TransportDispatcher.
-func (n *NIC) DeliverTransportError(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr TransportError, pkt *PacketBuffer) {
+func (n *nic) DeliverTransportError(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr TransportError, pkt *PacketBuffer) {
state, ok := n.stack.transportProtocols[trans]
if !ok {
return
@@ -940,19 +898,19 @@ func (n *NIC) DeliverTransportError(local, remote tcpip.Address, net tcpip.Netwo
}
// ID implements NetworkInterface.
-func (n *NIC) ID() tcpip.NICID {
+func (n *nic) ID() tcpip.NICID {
return n.id
}
// Name implements NetworkInterface.
-func (n *NIC) Name() string {
+func (n *nic) Name() string {
return n.name
}
// nudConfigs gets the NUD configurations for n.
-func (n *NIC) nudConfigs(protocol tcpip.NetworkProtocolNumber) (NUDConfigurations, tcpip.Error) {
+func (n *nic) nudConfigs(protocol tcpip.NetworkProtocolNumber) (NUDConfigurations, tcpip.Error) {
if linkRes, ok := n.linkAddrResolvers[protocol]; ok {
- return linkRes.neighborTable.nudConfig()
+ return linkRes.neigh.config(), nil
}
return NUDConfigurations{}, &tcpip.ErrNotSupported{}
@@ -962,16 +920,17 @@ func (n *NIC) nudConfigs(protocol tcpip.NetworkProtocolNumber) (NUDConfiguration
//
// Note, if c contains invalid NUD configuration values, it will be fixed to
// use default values for the erroneous values.
-func (n *NIC) setNUDConfigs(protocol tcpip.NetworkProtocolNumber, c NUDConfigurations) tcpip.Error {
+func (n *nic) setNUDConfigs(protocol tcpip.NetworkProtocolNumber, c NUDConfigurations) tcpip.Error {
if linkRes, ok := n.linkAddrResolvers[protocol]; ok {
c.resetInvalidFields()
- return linkRes.neighborTable.setNUDConfig(c)
+ linkRes.neigh.setConfig(c)
+ return nil
}
return &tcpip.ErrNotSupported{}
}
-func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error {
+func (n *nic) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
@@ -984,7 +943,7 @@ func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa
return nil
}
-func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) {
+func (n *nic) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) {
n.mu.Lock()
defer n.mu.Unlock()
@@ -998,7 +957,7 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep
// isValidForOutgoing returns true if the endpoint can be used to send out a
// packet. It requires the endpoint to not be marked expired (i.e., its address
// has been removed) unless the NIC is in spoofing mode, or temporary.
-func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool {
+func (n *nic) isValidForOutgoing(ep AssignableAddressEndpoint) bool {
n.mu.RLock()
spoofing := n.mu.spoofing
n.mu.RUnlock()
@@ -1006,9 +965,9 @@ func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool {
}
// HandleNeighborProbe implements NetworkInterface.
-func (n *NIC) HandleNeighborProbe(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error {
+func (n *nic) HandleNeighborProbe(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error {
if l, ok := n.linkAddrResolvers[protocol]; ok {
- l.neighborTable.handleProbe(addr, linkAddr)
+ l.neigh.handleProbe(addr, linkAddr)
return nil
}
@@ -1016,11 +975,34 @@ func (n *NIC) HandleNeighborProbe(protocol tcpip.NetworkProtocolNumber, addr tcp
}
// HandleNeighborConfirmation implements NetworkInterface.
-func (n *NIC) HandleNeighborConfirmation(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) tcpip.Error {
+func (n *nic) HandleNeighborConfirmation(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) tcpip.Error {
if l, ok := n.linkAddrResolvers[protocol]; ok {
- l.neighborTable.handleConfirmation(addr, linkAddr, flags)
+ l.neigh.handleConfirmation(addr, linkAddr, flags)
return nil
}
return &tcpip.ErrNotSupported{}
}
+
+// CheckLocalAddress implements NetworkInterface.
+func (n *nic) CheckLocalAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ if n.Spoofing() {
+ return true
+ }
+
+ if addressEndpoint := n.getAddressOrCreateTempInner(protocol, addr, false /* createTemp */, NeverPrimaryEndpoint); addressEndpoint != nil {
+ addressEndpoint.DecRef()
+ return true
+ }
+
+ return false
+}
+
+func (n *nic) checkDuplicateAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, h DADCompletionHandler) (DADCheckAddressDisposition, tcpip.Error) {
+ d, ok := n.duplicateAddressDetectors[protocol]
+ if !ok {
+ return 0, &tcpip.ErrNotSupported{}
+ }
+
+ return d.CheckDuplicateAddress(addr, h), nil
+}
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 9992d6eb4..c0f956e53 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -170,7 +170,7 @@ func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bo
func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
// When the NIC is disabled, the only field that matters is the stats field.
// This test is limited to stats counter checks.
- nic := NIC{
+ nic := nic{
stats: makeNICStats(),
}
diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go
index e9acef6a2..e1253f310 100644
--- a/pkg/tcpip/stack/nud_test.go
+++ b/pkg/tcpip/stack/nud_test.go
@@ -101,7 +101,6 @@ func TestNUDFunctions(t *testing.T) {
clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NUDConfigs: stack.DefaultNUDConfigurations(),
- UseNeighborCache: true,
NetworkProtocols: test.netProtoFactory,
Clock: clock,
})
@@ -206,7 +205,6 @@ func TestDefaultNUDConfigurations(t *testing.T) {
// address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
NUDConfigs: stack.DefaultNUDConfigurations(),
- UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -261,7 +259,6 @@ func TestNUDConfigurationsBaseReachableTime(t *testing.T) {
// providing link address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
NUDConfigs: c,
- UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -318,7 +315,6 @@ func TestNUDConfigurationsMinRandomFactor(t *testing.T) {
// providing link address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
NUDConfigs: c,
- UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -398,7 +394,6 @@ func TestNUDConfigurationsMaxRandomFactor(t *testing.T) {
// providing link address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
NUDConfigs: c,
- UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -460,7 +455,6 @@ func TestNUDConfigurationsRetransmitTimer(t *testing.T) {
// providing link address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
NUDConfigs: c,
- UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -512,7 +506,6 @@ func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) {
// providing link address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
NUDConfigs: c,
- UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -564,7 +557,6 @@ func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) {
// providing link address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
NUDConfigs: c,
- UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -616,7 +608,6 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) {
// providing link address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
NUDConfigs: c,
- UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go
index 1c651e216..dc139ebb2 100644
--- a/pkg/tcpip/stack/pending_packets.go
+++ b/pkg/tcpip/stack/pending_packets.go
@@ -55,7 +55,7 @@ type pendingPacket struct {
//
// Once link resolution completes successfully, the packets will be written.
type packetsPendingLinkResolution struct {
- nic *NIC
+ nic *nic
mu struct {
sync.Mutex
@@ -82,7 +82,7 @@ func (f *packetsPendingLinkResolution) incrementOutgoingPacketErrors(proto tcpip
}
}
-func (f *packetsPendingLinkResolution) init(nic *NIC) {
+func (f *packetsPendingLinkResolution) init(nic *nic) {
f.mu.Lock()
defer f.mu.Unlock()
f.nic = nic
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index d589f798d..43e9e4beb 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -16,6 +16,7 @@ package stack
import (
"fmt"
+ "time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -514,8 +515,19 @@ type NetworkInterface interface {
Enabled() bool
// Promiscuous returns true if the interface is in promiscuous mode.
+ //
+ // When in promiscuous mode, the interface should accept all packets.
Promiscuous() bool
+ // Spoofing returns true if the interface is in spoofing mode.
+ //
+ // When in spoofing mode, the interface should consider all addresses as
+ // assigned to it.
+ Spoofing() bool
+
+ // CheckLocalAddress returns true if the address exists on the interface.
+ CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool
+
// WritePacketToRemote writes the packet to the given remote link address.
WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error
@@ -840,7 +852,97 @@ type InjectableLinkEndpoint interface {
InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error
}
-// A LinkAddressResolver handles link address resolution for a network protocol.
+// DADResult is the result of a duplicate address detection process.
+type DADResult struct {
+ // Resolved is true when DAD completed without detecting a duplicate address
+ // on the link.
+ //
+ // Ignored when Err is non-nil.
+ Resolved bool
+
+ // Err is an error encountered while performing DAD.
+ Err tcpip.Error
+}
+
+// DADCompletionHandler is a handler for DAD completion.
+type DADCompletionHandler func(DADResult)
+
+// DADCheckAddressDisposition enumerates the possible return values from
+// DAD.CheckDuplicateAddress.
+type DADCheckAddressDisposition int
+
+const (
+ _ DADCheckAddressDisposition = iota
+
+ // DADDisabled indicates that DAD is disabled.
+ DADDisabled
+
+ // DADStarting indicates that DAD is starting for an address.
+ DADStarting
+
+ // DADAlreadyRunning indicates that DAD was already started for an address.
+ DADAlreadyRunning
+)
+
+const (
+ // defaultDupAddrDetectTransmits is the default number of NDP Neighbor
+ // Solicitation messages to send when doing Duplicate Address Detection
+ // for a tentative address.
+ //
+ // Default = 1 (from RFC 4862 section 5.1)
+ defaultDupAddrDetectTransmits = 1
+)
+
+// DADConfigurations holds configurations for duplicate address detection.
+type DADConfigurations struct {
+ // The number of Neighbor Solicitation messages to send when doing
+ // Duplicate Address Detection for a tentative address.
+ //
+ // Note, a value of zero effectively disables DAD.
+ DupAddrDetectTransmits uint8
+
+ // The amount of time to wait between sending Neighbor Solicitation
+ // messages.
+ //
+ // Must be greater than or equal to 1ms.
+ RetransmitTimer time.Duration
+}
+
+// DefaultDADConfigurations returns the default DAD configurations.
+func DefaultDADConfigurations() DADConfigurations {
+ return DADConfigurations{
+ DupAddrDetectTransmits: defaultDupAddrDetectTransmits,
+ RetransmitTimer: defaultRetransmitTimer,
+ }
+}
+
+// Validate modifies the configuration with valid values. If invalid values are
+// present in the configurations, the corresponding default values are used
+// instead.
+func (c *DADConfigurations) Validate() {
+ if c.RetransmitTimer < minimumRetransmitTimer {
+ c.RetransmitTimer = defaultRetransmitTimer
+ }
+}
+
+// DuplicateAddressDetector handles checking if an address is already assigned
+// to some neighboring node on the link.
+type DuplicateAddressDetector interface {
+ // CheckDuplicateAddress checks if an address is assigned to a neighbor.
+ //
+ // If DAD is already being performed for the address, the handler will be
+ // called with the result of the original DAD request.
+ CheckDuplicateAddress(tcpip.Address, DADCompletionHandler) DADCheckAddressDisposition
+
+ // SetDADConfiguations sets the configurations for DAD.
+ SetDADConfigurations(c DADConfigurations)
+
+ // DuplicateAddressProtocol returns the network protocol the receiver can
+ // perform duplicate address detection for.
+ DuplicateAddressProtocol() tcpip.NetworkProtocolNumber
+}
+
+// LinkAddressResolver handles link address resolution for a network protocol.
type LinkAddressResolver interface {
// LinkAddressRequest sends a request for the link address of the target
// address. The request is broadcasted on the local network if a remote link
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index bab55ce49..e946f9fe3 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -35,7 +35,7 @@ type Route struct {
// localAddressNIC is the interface the address is associated with.
// TODO(gvisor.dev/issue/4548): Remove this field once we can query the
// address's assigned status without the NIC.
- localAddressNIC *NIC
+ localAddressNIC *nic
mu struct {
sync.RWMutex
@@ -49,11 +49,11 @@ type Route struct {
}
// outgoingNIC is the interface this route uses to write packets.
- outgoingNIC *NIC
+ outgoingNIC *nic
// linkRes is set if link address resolution is enabled for this protocol on
// the route's NIC.
- linkRes linkResolver
+ linkRes *linkResolver
}
type routeInfo struct {
@@ -108,7 +108,7 @@ func (r *Route) fieldsLocked() RouteInfo {
// ownership of the provided local address.
//
// Returns an empty route if validation fails.
-func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) *Route {
+func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *nic, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) *Route {
if len(localAddr) == 0 {
localAddr = addressEndpoint.AddressWithPrefix().Address
}
@@ -140,7 +140,7 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp
// makeRoute initializes a new route. It takes ownership of the provided
// AssignableAddressEndpoint.
-func makeRoute(netProto tcpip.NetworkProtocolNumber, gateway, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) *Route {
+func makeRoute(netProto tcpip.NetworkProtocolNumber, gateway, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *nic, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) *Route {
if localAddressNIC.stack != outgoingNIC.stack {
panic(fmt.Sprintf("cannot create a route with NICs from different stacks"))
}
@@ -184,7 +184,7 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, gateway, localAddr, remoteA
return r
}
- if r.linkRes.resolver == nil {
+ if r.linkRes == nil {
return r
}
@@ -206,7 +206,7 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, gateway, localAddr, remoteA
return r
}
-func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) *Route {
+func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *nic, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) *Route {
r := &Route{
routeInfo: routeInfo{
NetProto: netProto,
@@ -230,7 +230,7 @@ func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr
// provided AssignableAddressEndpoint.
//
// A local route is a route to a destination that is local to the stack.
-func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) *Route {
+func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *nic, localAddressEndpoint AssignableAddressEndpoint) *Route {
loop := PacketLoop
// TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the
// link endpoint level. We can remove this check once loopback interfaces
@@ -400,7 +400,7 @@ func (r *Route) IsResolutionRequired() bool {
}
func (r *Route) isResolutionRequiredRLocked() bool {
- return len(r.mu.remoteLinkAddress) == 0 && r.linkRes.resolver != nil && r.isValidForOutgoingRLocked() && !r.local()
+ return len(r.mu.remoteLinkAddress) == 0 && r.linkRes != nil && r.isValidForOutgoingRLocked() && !r.local()
}
func (r *Route) isValidForOutgoing() bool {
@@ -528,7 +528,7 @@ func (r *Route) IsOutboundBroadcast() bool {
// "Reachable" is defined as having full-duplex communication between the
// local and remote ends of the route.
func (r *Route) ConfirmReachable() {
- if r.linkRes.resolver != nil {
+ if r.linkRes != nil {
r.linkRes.confirmReachable(r.nextHop())
}
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index a51d758d0..674c9a1ff 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -395,7 +395,7 @@ type Stack struct {
}
mu sync.RWMutex
- nics map[tcpip.NICID]*NIC
+ nics map[tcpip.NICID]*nic
// cleanupEndpointsMu protects cleanupEndpoints.
cleanupEndpointsMu sync.Mutex
@@ -434,12 +434,6 @@ type Stack struct {
// nudConfigs is the default NUD configurations used by interfaces.
nudConfigs NUDConfigurations
- // useNeighborCache indicates whether ARP and NDP packets should be handled
- // by the NIC's neighborCache instead of linkAddrCache.
- //
- // TODO(gvisor.dev/issue/4658): Remove this field.
- useNeighborCache bool
-
// nudDisp is the NUD event dispatcher that is used to send the netstack
// integrator NUD related events.
nudDisp NUDDispatcher
@@ -516,17 +510,6 @@ type Options struct {
// NUDConfigs is the default NUD configurations used by interfaces.
NUDConfigs NUDConfigurations
- // UseNeighborCache is unused.
- //
- // TODO(gvisor.dev/issue/4658): Remove this field.
- UseNeighborCache bool
-
- // UseLinkAddrCache indicates that the legacy link address cache should be
- // used for link resolution.
- //
- // TODO(gvisor.dev/issue/4658): Remove this field.
- UseLinkAddrCache bool
-
// NUDDisp is the NUD event dispatcher that an integrator can provide to
// receive NUD related events.
NUDDisp NUDDispatcher
@@ -656,7 +639,7 @@ func New(opts Options) *Stack {
s := &Stack{
transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
- nics: make(map[tcpip.NICID]*NIC),
+ nics: make(map[tcpip.NICID]*nic),
cleanupEndpoints: make(map[TransportEndpoint]struct{}),
PortManager: ports.NewPortManager(),
clock: clock,
@@ -666,7 +649,6 @@ func New(opts Options) *Stack {
icmpRateLimiter: NewICMPRateLimiter(),
seed: generateRandUint32(),
nudConfigs: opts.NUDConfigs,
- useNeighborCache: !opts.UseLinkAddrCache,
uniqueIDGenerator: opts.UniqueID,
nudDisp: opts.NUDDisp,
randomGenerator: mathrand.New(randSrc),
@@ -1233,7 +1215,7 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol
return nic.primaryAddress(protocol), true
}
-func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) AssignableAddressEndpoint {
+func (s *Stack) getAddressEP(nic *nic, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) AssignableAddressEndpoint {
if len(localAddr) == 0 {
return nic.primaryEndpoint(netProto, remoteAddr)
}
@@ -1244,13 +1226,13 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP
// from the specified NIC.
//
// Precondition: s.mu must be read locked.
-func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route {
+func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *nic, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route {
localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint)
if localAddressEndpoint == nil {
return nil
}
- var outgoingNIC *NIC
+ var outgoingNIC *nic
// Prefer a local route to the same interface as the local address.
if localAddressNIC.hasAddress(netProto, remoteAddr) {
outgoingNIC = localAddressNIC
@@ -1319,6 +1301,11 @@ func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr,
return nil
}
+// HandleLocal returns true if non-loopback interfaces are allowed to loop packets.
+func (s *Stack) HandleLocal() bool {
+ return s.handleLocal
+}
+
// FindRoute creates a route to the given destination address, leaving through
// the given NIC and local address (if provided).
//
@@ -1479,6 +1466,17 @@ func (s *Stack) CheckNetworkProtocol(protocol tcpip.NetworkProtocolNumber) bool
return ok
}
+// CheckDuplicateAddress performs duplicate address detection for the address on
+// the specified interface.
+func (s *Stack) CheckDuplicateAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, h DADCompletionHandler) (DADCheckAddressDisposition, tcpip.Error) {
+ nic, ok := s.nics[nicID]
+ if !ok {
+ return 0, &tcpip.ErrUnknownNICID{}
+ }
+
+ return nic.checkDuplicateAddress(protocol, addr, h)
+}
+
// CheckLocalAddress determines if the given local address exists, and if it
// does, returns the id of the NIC it's bound to. Returns 0 if the address
// does not exist.
@@ -1493,20 +1491,16 @@ func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProto
return 0
}
- addressEndpoint := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint)
- if addressEndpoint == nil {
- return 0
+ if nic.CheckLocalAddress(protocol, addr) {
+ return nic.id
}
- addressEndpoint.DecRef()
-
- return nic.id
+ return 0
}
// Go through all the NICs.
for _, nic := range s.nics {
- if addressEndpoint := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint); addressEndpoint != nil {
- addressEndpoint.DecRef()
+ if nic.CheckLocalAddress(protocol, addr) {
return nic.id
}
}
@@ -2062,22 +2056,6 @@ func generateRandInt64() int64 {
return v
}
-// FindNetworkEndpoint returns the network endpoint for the given address.
-func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, tcpip.Error) {
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- for _, nic := range s.nics {
- addressEndpoint := nic.getAddressOrCreateTempInner(netProto, address, false /* createTemp */, NeverPrimaryEndpoint)
- if addressEndpoint == nil {
- continue
- }
- addressEndpoint.DecRef()
- return nic.getNetworkEndpoint(netProto), nil
- }
- return nil, &tcpip.ErrBadAddress{}
-}
-
// FindNICNameFromID returns the name of the NIC for the given NICID.
func (s *Stack) FindNICNameFromID(id tcpip.NICID) string {
s.mu.RLock()
@@ -2103,13 +2081,6 @@ const (
// ParsedOK indicates that a packet was successfully parsed.
ParsedOK ParseResult = iota
- // UnknownNetworkProtocol indicates that the network protocol is unknown.
- UnknownNetworkProtocol
-
- // NetworkLayerParseError indicates that the network packet was not
- // successfully parsed.
- NetworkLayerParseError
-
// UnknownTransportProtocol indicates that the transport protocol is unknown.
UnknownTransportProtocol
@@ -2118,31 +2089,19 @@ const (
TransportLayerParseError
)
-// ParsePacketBuffer parses the provided packet buffer.
-func (s *Stack) ParsePacketBuffer(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) ParseResult {
- netProto, ok := s.networkProtocols[protocol]
- if !ok {
- return UnknownNetworkProtocol
- }
-
- transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt)
- if !ok {
- return NetworkLayerParseError
- }
- if !hasTransportHdr {
- return ParsedOK
- }
-
+// ParsePacketBufferTransport parses the provided packet buffer's transport
+// header.
+func (s *Stack) ParsePacketBufferTransport(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) ParseResult {
// TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
// fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a
// full explanation.
- if transProtoNum == header.ICMPv4ProtocolNumber || transProtoNum == header.ICMPv6ProtocolNumber {
+ if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber {
return ParsedOK
}
- pkt.TransportProtocolNumber = transProtoNum
+ pkt.TransportProtocolNumber = protocol
// Parse the transport header if present.
- state, ok := s.transportProtocols[transProtoNum]
+ state, ok := s.transportProtocols[protocol]
if !ok {
return UnknownTransportProtocol
}
@@ -2164,7 +2123,7 @@ func (s *Stack) networkProtocolNumbers() []tcpip.NetworkProtocolNumber {
return protos
}
-func isSubnetBroadcastOnNIC(nic *NIC, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+func isSubnetBroadcastOnNIC(nic *nic, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
addressEndpoint := nic.getAddressOrCreateTempInner(protocol, addr, false /* createTemp */, NeverPrimaryEndpoint)
if addressEndpoint == nil {
return false
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index b641a4aaa..92a0cb401 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -119,6 +119,10 @@ func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
}
func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
+ if _, _, ok := f.proto.Parse(pkt); !ok {
+ return
+ }
+
// Increment the received packet count in the protocol descriptor.
netHdr := pkt.NetworkHeader().View()
@@ -2569,16 +2573,16 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent),
}
- ndpConfigs := ipv6.DefaultNDPConfigurations()
+ dadConfigs := stack.DefaultDADConfigurations()
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ndpConfigs,
AutoGenLinkLocal: true,
NDPDisp: &ndpDisp,
+ DADConfigs: dadConfigs,
})},
}
- e := channel.New(int(ndpConfigs.DupAddrDetectTransmits), 1280, linkAddr1)
+ e := channel.New(int(dadConfigs.DupAddrDetectTransmits), 1280, linkAddr1)
s := stack.New(opts)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -2594,7 +2598,7 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
// Wait for DAD to resolve.
select {
- case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second):
// We should get a resolution event after 1s (default time to
// resolve as per default NDP configurations). Waiting for that
// resolution time + an extra 1s without a resolution event
@@ -3231,7 +3235,7 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
}
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
+ DADConfigs: stack.DADConfigurations{
DupAddrDetectTransmits: dadTransmits,
RetransmitTimer: retransmitTimer,
},
@@ -4320,7 +4324,6 @@ func TestClearNeighborCacheOnNICDisable(t *testing.T) {
clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
- UseNeighborCache: true,
Clock: clock,
})
e := channel.New(0, 0, "")
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 7d8d0851e..292e51d20 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -183,7 +183,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet
}
// handleError delivers an error to the transport endpoint identified by id.
-func (epsByNIC *endpointsByNIC) handleError(n *NIC, id TransportEndpointID, transErr TransportError, pkt *PacketBuffer) {
+func (epsByNIC *endpointsByNIC) handleError(n *nic, id TransportEndpointID, transErr TransportError, pkt *PacketBuffer) {
epsByNIC.mu.RLock()
defer epsByNIC.mu.RUnlock()
@@ -599,7 +599,7 @@ func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumb
// endpoint.
//
// Returns true if the error was delivered.
-func (d *transportDemuxer) deliverError(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr TransportError, pkt *PacketBuffer, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverError(n *nic, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr TransportError, pkt *PacketBuffer, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{net, trans}]
if !ok {
return false
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index 21a8dd291..b56706357 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
type inputIfNameMatcher struct {
@@ -334,3 +335,312 @@ func TestIPTablesStatsForInput(t *testing.T) {
})
}
}
+
+var _ stack.LinkEndpoint = (*channelEndpointWithoutWritePacket)(nil)
+
+// channelEndpointWithoutWritePacket is a channel endpoint that does not support
+// stack.LinkEndpoint.WritePacket.
+type channelEndpointWithoutWritePacket struct {
+ *channel.Endpoint
+
+ t *testing.T
+}
+
+func (c *channelEndpointWithoutWritePacket) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
+ c.t.Error("unexpectedly called WritePacket; all writes should go through WritePackets")
+ return &tcpip.ErrNotSupported{}
+}
+
+var _ stack.Matcher = (*udpSourcePortMatcher)(nil)
+
+type udpSourcePortMatcher struct {
+ port uint16
+}
+
+func (*udpSourcePortMatcher) Name() string {
+ return "udpSourcePortMatcher"
+}
+
+func (m *udpSourcePortMatcher) Match(_ stack.Hook, pkt *stack.PacketBuffer, _, _ string) (matches, hotdrop bool) {
+ udp := header.UDP(pkt.TransportHeader().View())
+ if len(udp) < header.UDPMinimumSize {
+ // Drop immediately as the packet is invalid.
+ return false, true
+ }
+
+ return udp.SourcePort() == m.port, false
+}
+
+func TestIPTableWritePackets(t *testing.T) {
+ const (
+ nicID = 1
+
+ dropLocalPort = localPort - 1
+ acceptPackets = 2
+ dropPackets = 3
+ )
+
+ udpHdr := func(hdr buffer.View, srcAddr, dstAddr tcpip.Address, srcPort, dstPort uint16) {
+ u := header.UDP(hdr)
+ u.Encode(&header.UDPFields{
+ SrcPort: srcPort,
+ DstPort: dstPort,
+ Length: header.UDPMinimumSize,
+ })
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, srcAddr, dstAddr, header.UDPMinimumSize)
+ sum = header.Checksum(hdr, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+ }
+
+ tests := []struct {
+ name string
+ setupFilter func(*testing.T, *stack.Stack)
+ genPacket func(*stack.Route) stack.PacketBufferList
+ proto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ expectSent uint64
+ expectOutputDropped uint64
+ }{
+ {
+ name: "IPv4 Accept",
+ setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
+ genPacket: func(r *stack.Route) stack.PacketBufferList {
+ var pkts stack.PacketBufferList
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort)
+ pkts.PushFront(pkt)
+
+ return pkts
+ },
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: dstAddrV4,
+ expectSent: 1,
+ expectOutputDropped: 0,
+ },
+ {
+ name: "IPv4 Drop Other Port",
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+
+ table := stack.Table{
+ Rules: []stack.Rule{
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
+ },
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
+ },
+ {
+ Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}},
+ Target: &stack.DropTarget{NetworkProtocol: header.IPv4ProtocolNumber},
+ },
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
+ },
+ {
+ Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber},
+ },
+ },
+ BuiltinChains: [stack.NumHooks]int{
+ stack.Prerouting: stack.HookUnset,
+ stack.Input: 0,
+ stack.Forward: 1,
+ stack.Output: 2,
+ stack.Postrouting: stack.HookUnset,
+ },
+ Underflows: [stack.NumHooks]int{
+ stack.Prerouting: stack.HookUnset,
+ stack.Input: 0,
+ stack.Forward: 1,
+ stack.Output: 2,
+ stack.Postrouting: stack.HookUnset,
+ },
+ }
+
+ if err := s.IPTables().ReplaceTable(stack.FilterID, table, false /* ipv4 */); err != nil {
+ t.Fatalf("RelaceTable(%d, _, false): %s", stack.FilterID, err)
+ }
+ },
+ genPacket: func(r *stack.Route) stack.PacketBufferList {
+ var pkts stack.PacketBufferList
+
+ for i := 0; i < acceptPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort)
+ pkts.PushFront(pkt)
+ }
+ for i := 0; i < dropPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, dropLocalPort, remotePort)
+ pkts.PushFront(pkt)
+ }
+
+ return pkts
+ },
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: dstAddrV4,
+ expectSent: acceptPackets,
+ expectOutputDropped: dropPackets,
+ },
+ {
+ name: "IPv6 Accept",
+ setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
+ genPacket: func(r *stack.Route) stack.PacketBufferList {
+ var pkts stack.PacketBufferList
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort)
+ pkts.PushFront(pkt)
+
+ return pkts
+ },
+ proto: header.IPv6ProtocolNumber,
+ remoteAddr: dstAddrV6,
+ expectSent: 1,
+ expectOutputDropped: 0,
+ },
+ {
+ name: "IPv6 Drop Other Port",
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+
+ table := stack.Table{
+ Rules: []stack.Rule{
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
+ },
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
+ },
+ {
+ Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}},
+ Target: &stack.DropTarget{NetworkProtocol: header.IPv6ProtocolNumber},
+ },
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
+ },
+ {
+ Target: &stack.ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber},
+ },
+ },
+ BuiltinChains: [stack.NumHooks]int{
+ stack.Prerouting: stack.HookUnset,
+ stack.Input: 0,
+ stack.Forward: 1,
+ stack.Output: 2,
+ stack.Postrouting: stack.HookUnset,
+ },
+ Underflows: [stack.NumHooks]int{
+ stack.Prerouting: stack.HookUnset,
+ stack.Input: 0,
+ stack.Forward: 1,
+ stack.Output: 2,
+ stack.Postrouting: stack.HookUnset,
+ },
+ }
+
+ if err := s.IPTables().ReplaceTable(stack.FilterID, table, true /* ipv6 */); err != nil {
+ t.Fatalf("RelaceTable(%d, _, true): %s", stack.FilterID, err)
+ }
+ },
+ genPacket: func(r *stack.Route) stack.PacketBufferList {
+ var pkts stack.PacketBufferList
+
+ for i := 0; i < acceptPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort)
+ pkts.PushFront(pkt)
+ }
+ for i := 0; i < dropPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, dropLocalPort, remotePort)
+ pkts.PushFront(pkt)
+ }
+
+ return pkts
+ },
+ proto: header.IPv6ProtocolNumber,
+ remoteAddr: dstAddrV6,
+ expectSent: acceptPackets,
+ expectOutputDropped: dropPackets,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+ e := channelEndpointWithoutWritePacket{
+ Endpoint: channel.New(4, header.IPv6MinimumMTU, linkAddr),
+ t: t,
+ }
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, srcAddrV6); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, srcAddrV6, err)
+ }
+ if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, srcAddrV4); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, srcAddrV4, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ test.setupFilter(t, s)
+
+ r, err := s.FindRoute(nicID, "", test.remoteAddr, test.proto, false)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, '', %s, %d, false): %s", nicID, test.remoteAddr, test.proto, err)
+ }
+ defer r.Release()
+
+ pkts := test.genPacket(r)
+ pktsLen := pkts.Len()
+ if n, err := r.WritePackets(nil /* gso */, pkts, stack.NetworkHeaderParams{
+ Protocol: header.UDPProtocolNumber,
+ TTL: 64,
+ }); err != nil {
+ t.Fatalf("WritePackets(...): %s", err)
+ } else if n != pktsLen {
+ t.Fatalf("got WritePackets(...) = %d, want = %d", n, pktsLen)
+ }
+
+ if got := s.Stats().IP.PacketsSent.Value(); got != test.expectSent {
+ t.Errorf("got PacketSent = %d, want = %d", got, test.expectSent)
+ }
+ if got := s.Stats().IP.IPTablesOutputDropped.Value(); got != test.expectOutputDropped {
+ t.Errorf("got IPTablesOutputDropped = %d, want = %d", got, test.expectOutputDropped)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index f2301a9e6..824f81a42 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -487,30 +487,25 @@ func TestGetLinkAddress(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- for _, useNeighborCache := range []bool{true, false} {
- t.Run(fmt.Sprintf("UseNeighborCache=%t", useNeighborCache), func(t *testing.T) {
- stackOpts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
- UseNeighborCache: useNeighborCache,
- }
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ }
- host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
+ host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
- ch := make(chan stack.LinkResolutionResult, 1)
- err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(r stack.LinkResolutionResult) {
- ch <- r
- })
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, &tcpip.ErrWouldBlock{})
- }
- wantRes := stack.LinkResolutionResult{Success: test.expectedOk}
- if test.expectedOk {
- wantRes.LinkAddress = linkAddr2
- }
- if diff := cmp.Diff(wantRes, <-ch); diff != "" {
- t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff)
- }
- })
+ ch := make(chan stack.LinkResolutionResult, 1)
+ err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(r stack.LinkResolutionResult) {
+ ch <- r
+ })
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, &tcpip.ErrWouldBlock{})
+ }
+ wantRes := stack.LinkResolutionResult{Success: test.expectedOk}
+ if test.expectedOk {
+ wantRes.LinkAddress = linkAddr2
+ }
+ if diff := cmp.Diff(wantRes, <-ch); diff != "" {
+ t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff)
}
})
}
@@ -587,66 +582,61 @@ func TestRouteResolvedFields(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- for _, useNeighborCache := range []bool{true, false} {
- t.Run(fmt.Sprintf("UseNeighborCache=%t", useNeighborCache), func(t *testing.T) {
- stackOpts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
- UseNeighborCache: useNeighborCache,
- }
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ }
- host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
- r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err)
- }
- defer r.Release()
-
- var wantRouteInfo stack.RouteInfo
- wantRouteInfo.LocalLinkAddress = linkAddr1
- wantRouteInfo.LocalAddress = test.localAddr
- wantRouteInfo.RemoteAddress = test.remoteAddr
- wantRouteInfo.NetProto = test.netProto
- wantRouteInfo.Loop = stack.PacketOut
- wantRouteInfo.RemoteLinkAddress = test.expectedLinkAddr
-
- ch := make(chan stack.ResolvedFieldsResult, 1)
-
- if !test.immediatelyResolvable {
- wantUnresolvedRouteInfo := wantRouteInfo
- wantUnresolvedRouteInfo.RemoteLinkAddress = ""
-
- err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) {
- ch <- r
- })
- if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
- t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{})
- }
- if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: test.expectedSuccess}, <-ch, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" {
- t.Errorf("route resolve result mismatch (-want +got):\n%s", diff)
- }
+ host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
+ r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err)
+ }
+ defer r.Release()
- if !test.expectedSuccess {
- return
- }
+ var wantRouteInfo stack.RouteInfo
+ wantRouteInfo.LocalLinkAddress = linkAddr1
+ wantRouteInfo.LocalAddress = test.localAddr
+ wantRouteInfo.RemoteAddress = test.remoteAddr
+ wantRouteInfo.NetProto = test.netProto
+ wantRouteInfo.Loop = stack.PacketOut
+ wantRouteInfo.RemoteLinkAddress = test.expectedLinkAddr
- // At this point the neighbor table should be populated so the route
- // should be immediately resolvable.
- }
+ ch := make(chan stack.ResolvedFieldsResult, 1)
- if err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) {
- ch <- r
- }); err != nil {
- t.Errorf("r.ResolvedFields(_): %s", err)
- }
- select {
- case routeResolveRes := <-ch:
- if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: true}, routeResolveRes, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" {
- t.Errorf("route resolve result from resolved route mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected route to be immediately resolvable")
- }
+ if !test.immediatelyResolvable {
+ wantUnresolvedRouteInfo := wantRouteInfo
+ wantUnresolvedRouteInfo.RemoteLinkAddress = ""
+
+ err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) {
+ ch <- r
})
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{})
+ }
+ if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: test.expectedSuccess}, <-ch, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" {
+ t.Errorf("route resolve result mismatch (-want +got):\n%s", diff)
+ }
+
+ if !test.expectedSuccess {
+ return
+ }
+
+ // At this point the neighbor table should be populated so the route
+ // should be immediately resolvable.
+ }
+
+ if err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) {
+ ch <- r
+ }); err != nil {
+ t.Errorf("r.ResolvedFields(_): %s", err)
+ }
+ select {
+ case routeResolveRes := <-ch:
+ if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: true}, routeResolveRes, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" {
+ t.Errorf("route resolve result from resolved route mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected route to be immediately resolvable")
}
})
}
@@ -1065,7 +1055,6 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
Clock: clock,
- UseNeighborCache: true,
}
host1StackOpts := stackOpts
host1StackOpts.NUDDisp = &nudDisp
@@ -1210,3 +1199,148 @@ func TestTCPConfirmNeighborReachability(t *testing.T) {
})
}
}
+
+func TestDAD(t *testing.T) {
+ const (
+ host1NICID = 1
+ host2NICID = 4
+ )
+
+ dadConfigs := stack.DADConfigurations{
+ DupAddrDetectTransmits: 1,
+ RetransmitTimer: time.Second,
+ }
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ dadNetProto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ expectedResolved bool
+ }{
+ {
+ name: "IPv4 own address",
+ netProto: ipv4.ProtocolNumber,
+ dadNetProto: arp.ProtocolNumber,
+ remoteAddr: ipv4Addr1.AddressWithPrefix.Address,
+ expectedResolved: true,
+ },
+ {
+ name: "IPv6 own address",
+ netProto: ipv6.ProtocolNumber,
+ dadNetProto: ipv6.ProtocolNumber,
+ remoteAddr: ipv6Addr1.AddressWithPrefix.Address,
+ expectedResolved: true,
+ },
+ {
+ name: "IPv4 duplicate address",
+ netProto: ipv4.ProtocolNumber,
+ dadNetProto: arp.ProtocolNumber,
+ remoteAddr: ipv4Addr2.AddressWithPrefix.Address,
+ expectedResolved: false,
+ },
+ {
+ name: "IPv6 duplicate address",
+ netProto: ipv6.ProtocolNumber,
+ dadNetProto: ipv6.ProtocolNumber,
+ remoteAddr: ipv6Addr2.AddressWithPrefix.Address,
+ expectedResolved: false,
+ },
+ {
+ name: "IPv4 no duplicate address",
+ netProto: ipv4.ProtocolNumber,
+ dadNetProto: arp.ProtocolNumber,
+ remoteAddr: ipv4Addr3.AddressWithPrefix.Address,
+ expectedResolved: true,
+ },
+ {
+ name: "IPv6 no duplicate address",
+ netProto: ipv6.ProtocolNumber,
+ dadNetProto: ipv6.ProtocolNumber,
+ remoteAddr: ipv6Addr3.AddressWithPrefix.Address,
+ expectedResolved: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ stackOpts := stack.Options{
+ Clock: clock,
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ arp.NewProtocol,
+ ipv4.NewProtocol,
+ ipv6.NewProtocol,
+ },
+ }
+
+ host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
+
+ // DAD should be disabled by default.
+ if res, err := host1Stack.CheckDuplicateAddress(host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) {
+ t.Errorf("unexpectedly called DAD completion handler when DAD was supposed to be disabled")
+ }); err != nil {
+ t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", host1NICID, test.netProto, test.remoteAddr, err)
+ } else if res != stack.DADDisabled {
+ t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", host1NICID, test.netProto, test.remoteAddr, res, stack.DADDisabled)
+ }
+
+ // Enable DAD then attempt to check if an address is duplicated.
+ netEP, err := host1Stack.GetNetworkEndpoint(host1NICID, test.dadNetProto)
+ if err != nil {
+ t.Fatalf("host1Stack.GetNetworkEndpoint(%d, %d): %s", host1NICID, test.dadNetProto, err)
+ }
+ dad, ok := netEP.(stack.DuplicateAddressDetector)
+ if !ok {
+ t.Fatalf("expected %T to implement stack.DuplicateAddressDetector", netEP)
+ }
+ dad.SetDADConfigurations(dadConfigs)
+ ch := make(chan stack.DADResult, 3)
+ if res, err := host1Stack.CheckDuplicateAddress(host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) {
+ ch <- r
+ }); err != nil {
+ t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", host1NICID, test.netProto, test.remoteAddr, err)
+ } else if res != stack.DADStarting {
+ t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", host1NICID, test.netProto, test.remoteAddr, res, stack.DADStarting)
+ }
+
+ expectResults := 1
+ if test.expectedResolved {
+ const delta = time.Nanosecond
+ clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer - delta)
+ select {
+ case r := <-ch:
+ t.Fatalf("unexpectedly got DAD result before the DAD timeout; r = %#v", r)
+ default:
+ }
+
+ // If we expect the resolve to succeed try requesting DAD again on the
+ // same address. The handler for the new request should be called once
+ // the original DAD request completes.
+ expectResults = 2
+ if res, err := host1Stack.CheckDuplicateAddress(host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) {
+ ch <- r
+ }); err != nil {
+ t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", host1NICID, test.netProto, test.remoteAddr, err)
+ } else if res != stack.DADAlreadyRunning {
+ t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", host1NICID, test.netProto, test.remoteAddr, res, stack.DADAlreadyRunning)
+ }
+
+ clock.Advance(delta)
+ }
+
+ for i := 0; i < expectResults; i++ {
+ if diff := cmp.Diff(stack.DADResult{Resolved: test.expectedResolved}, <-ch); diff != "" {
+ t.Errorf("(i=%d) DAD result mismatch (-want +got):\n%s", i, diff)
+ }
+ }
+
+ // Should have no more results.
+ select {
+ case r := <-ch:
+ t.Errorf("unexpectedly got an extra DAD result; r = %#v", r)
+ default:
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 34a631b53..461b1a9d7 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -1301,7 +1301,8 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
// e.mu is expected to be hold upon entering this section.
if e.snd != nil {
e.snd.resendTimer.cleanup()
- e.snd.rc.probeTimer.cleanup()
+ e.snd.probeTimer.cleanup()
+ e.snd.reorderTimer.cleanup()
}
if closeTimer != nil {
@@ -1396,7 +1397,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
},
},
{
- w: &e.snd.rc.probeWaker,
+ w: &e.snd.probeWaker,
f: e.snd.probeTimerExpired,
},
{
@@ -1475,6 +1476,10 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
return nil
},
},
+ {
+ w: &e.snd.reorderWaker,
+ f: e.snd.rc.reorderTimerExpired,
+ },
}
// Initialize the sleeper based on the wakers in funcs.
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 4e5a6089f..8c5be0586 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1698,11 +1698,7 @@ func (e *endpoint) OnCorkOptionSet(v bool) {
}
func (e *endpoint) getSendBufferSize() int {
- sndBufSize, err := e.ops.GetSendBufferSize()
- if err != nil {
- panic(fmt.Sprintf("e.ops.GetSendBufferSize() = %s", err))
- }
- return int(sndBufSize)
+ return int(e.ops.GetSendBufferSize())
}
// SetSockOptInt sets a socket option.
diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go
index e862f159e..9959b60b8 100644
--- a/pkg/tcpip/transport/tcp/rack.go
+++ b/pkg/tcpip/transport/tcp/rack.go
@@ -17,7 +17,6 @@ package tcp
import (
"time"
- "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
)
@@ -50,7 +49,8 @@ type rackControl struct {
// dsackSeen indicates if the connection has seen a DSACK.
dsackSeen bool
- // endSequence is the ending TCP sequence number of rackControl.seg.
+ // endSequence is the ending TCP sequence number of the most recent
+ // acknowledged segment.
endSequence seqnum.Value
// exitedRecovery indicates if the connection is exiting loss recovery.
@@ -90,13 +90,10 @@ type rackControl struct {
// rttSeq is the SND.NXT when rtt is updated.
rttSeq seqnum.Value
- // xmitTime is the latest transmission timestamp of rackControl.seg.
+ // xmitTime is the latest transmission timestamp of the most recent
+ // acknowledged segment.
xmitTime time.Time `state:".(unixTime)"`
- // probeTimer and probeWaker are used to schedule PTO for RACK TLP algorithm.
- probeTimer timer `state:"nosave"`
- probeWaker sleep.Waker `state:"nosave"`
-
// tlpRxtOut indicates whether there is an unacknowledged
// TLP retransmission.
tlpRxtOut bool
@@ -114,7 +111,6 @@ func (rc *rackControl) init(snd *sender, iss seqnum.Value) {
rc.fack = iss
rc.reoWndIncr = 1
rc.snd = snd
- rc.probeTimer.init(&rc.probeWaker)
}
// update will update the RACK related fields when an ACK has been received.
@@ -223,13 +219,13 @@ func (s *sender) schedulePTO() {
s.resendTimer.disable()
}
- s.rc.probeTimer.enable(pto)
+ s.probeTimer.enable(pto)
}
// probeTimerExpired is the same as TLP_send_probe() as defined in
// https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5.2.
func (s *sender) probeTimerExpired() tcpip.Error {
- if !s.rc.probeTimer.checkExpiration() {
+ if !s.probeTimer.checkExpiration() {
return nil
}
@@ -386,3 +382,102 @@ func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) {
func (rc *rackControl) exitRecovery() {
rc.exitedRecovery = true
}
+
+// detectLoss marks the segment as lost if the reordering window has elapsed
+// and the ACK is not received. It will also arm the reorder timer.
+// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 Step 5.
+func (rc *rackControl) detectLoss(rcvTime time.Time) int {
+ var timeout time.Duration
+ numLost := 0
+ for seg := rc.snd.writeList.Front(); seg != nil && seg.xmitCount != 0; seg = seg.Next() {
+ if rc.snd.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ continue
+ }
+
+ if seg.lost && seg.xmitCount == 1 {
+ numLost++
+ continue
+ }
+
+ endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
+ if seg.xmitTime.Before(rc.xmitTime) || (seg.xmitTime.Equal(rc.xmitTime) && rc.endSequence.LessThan(endSeq)) {
+ timeRemaining := seg.xmitTime.Sub(rcvTime) + rc.rtt + rc.reoWnd
+ if timeRemaining <= 0 {
+ seg.lost = true
+ numLost++
+ } else if timeRemaining > timeout {
+ timeout = timeRemaining
+ }
+ }
+ }
+
+ if timeout != 0 && !rc.snd.reorderTimer.enabled() {
+ rc.snd.reorderTimer.enable(timeout)
+ }
+ return numLost
+}
+
+// reorderTimerExpired will retransmit the segments which have not been acked
+// before the reorder timer expired.
+func (rc *rackControl) reorderTimerExpired() tcpip.Error {
+ // Check if the timer actually expired or if it's a spurious wake due
+ // to a previously orphaned runtime timer.
+ if !rc.snd.reorderTimer.checkExpiration() {
+ return nil
+ }
+
+ numLost := rc.detectLoss(time.Now())
+ if numLost == 0 {
+ return nil
+ }
+
+ fastRetransmit := false
+ if !rc.snd.fr.active {
+ rc.snd.cc.HandleLossDetected()
+ rc.snd.enterRecovery()
+ fastRetransmit = true
+ }
+
+ rc.DoRecovery(nil, fastRetransmit)
+ return nil
+}
+
+// DoRecovery implements lossRecovery.DoRecovery.
+func (rc *rackControl) DoRecovery(_ *segment, fastRetransmit bool) {
+ snd := rc.snd
+ if fastRetransmit {
+ snd.resendSegment()
+ }
+
+ var dataSent bool
+ // Iterate the writeList and retransmit the segments which are marked
+ // as lost by RACK.
+ for seg := snd.writeList.Front(); seg != nil && seg.xmitCount > 0; seg = seg.Next() {
+ if seg == snd.writeNext {
+ break
+ }
+
+ if !seg.lost {
+ continue
+ }
+
+ // Reset seg.lost as it is already SACKed.
+ if snd.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ seg.lost = false
+ continue
+ }
+
+ // Check the congestion window after entering recovery.
+ if snd.outstanding >= snd.sndCwnd {
+ break
+ }
+
+ snd.outstanding++
+ dataSent = true
+ snd.sendSegment(seg)
+ }
+
+ // Rearm the RTO.
+ snd.resendTimer.enable(snd.rto)
+ snd.postXmit(dataSent)
+}
diff --git a/pkg/tcpip/transport/tcp/rack_state.go b/pkg/tcpip/transport/tcp/rack_state.go
index 76cad0831..c9dc7e773 100644
--- a/pkg/tcpip/transport/tcp/rack_state.go
+++ b/pkg/tcpip/transport/tcp/rack_state.go
@@ -27,8 +27,3 @@ func (rc *rackControl) saveXmitTime() unixTime {
func (rc *rackControl) loadXmitTime(unix unixTime) {
rc.xmitTime = time.Unix(unix.second, unix.nano)
}
-
-// afterLoad is invoked by stateify.
-func (rc *rackControl) afterLoad() {
- rc.probeTimer.init(&rc.probeWaker)
-}
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 7cca4def5..f27eef6a9 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -83,6 +83,9 @@ type segment struct {
// dataMemSize is the memory used by data initially.
dataMemSize int
+
+ // lost indicates if the segment is marked as lost by RACK.
+ lost bool
}
func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 463a259b7..d6365b93d 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -191,6 +191,15 @@ type sender struct {
// rc has the fields needed for implementing RACK loss detection
// algorithm.
rc rackControl
+
+ // reorderTimer is the timer used to retransmit the segments after RACK
+ // detects them as lost.
+ reorderTimer timer `state:"nosave"`
+ reorderWaker sleep.Waker `state:"nosave"`
+
+ // probeTimer and probeWaker are used to schedule PTO for RACK TLP algorithm.
+ probeTimer timer `state:"nosave"`
+ probeWaker sleep.Waker `state:"nosave"`
}
// rtt is a synchronization wrapper used to appease stateify. See the comment
@@ -267,7 +276,6 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
}
s.cc = s.initCongestionControl(ep.cc)
-
s.lr = s.initLossRecovery()
s.rc.init(s, iss)
@@ -278,6 +286,8 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
}
s.resendTimer.init(&s.resendWaker)
+ s.reorderTimer.init(&s.reorderWaker)
+ s.probeTimer.init(&s.probeWaker)
s.updateMaxPayloadSize(int(ep.route.MTU()), 0)
@@ -1126,6 +1136,15 @@ func (s *sender) SetPipe() {
func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) {
// We're not in fast recovery yet.
+ // If RACK is enabled and there is no reordering we should honor the
+ // three duplicate ACK rule to enter recovery.
+ // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-4
+ if s.ep.sackPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 {
+ if s.rc.reorderSeen {
+ return false
+ }
+ }
+
if !s.isDupAck(seg) {
s.dupAckCount = 0
return false
@@ -1320,7 +1339,9 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// unacknowledged and also never retransmitted sequence below
// RACK.fack, then the corresponding packet has been
// reordered and RACK.reord is set to TRUE.
- s.walkSACK(rcvdSeg)
+ if s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 {
+ s.walkSACK(rcvdSeg)
+ }
s.SetPipe()
}
@@ -1339,7 +1360,9 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
}
// See if TLP based recovery was successful.
- s.detectTLPRecovery(ack, rcvdSeg)
+ if s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 {
+ s.detectTLPRecovery(ack, rcvdSeg)
+ }
// Stash away the current window size.
s.sndWnd = rcvdSeg.window
@@ -1421,7 +1444,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
}
// Update the RACK fields if SACK is enabled.
- if s.ep.sackPermitted && !seg.acked {
+ if s.ep.sackPermitted && !seg.acked && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 {
s.rc.update(seg, rcvdSeg)
s.rc.detectReorder(seg)
}
@@ -1455,7 +1478,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// Update RACK when we are exiting fast or RTO
// recovery as described in the RFC
// draft-ietf-tcpm-rack-08 Section-7.2 Step 4.
- s.rc.exitRecovery()
+ if s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 {
+ s.rc.exitRecovery()
+ }
+ s.reorderTimer.disable()
}
}
@@ -1475,19 +1501,36 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// Reset firstRetransmittedSegXmitTime to the zero value.
s.firstRetransmittedSegXmitTime = time.Time{}
s.resendTimer.disable()
- s.rc.probeTimer.disable()
+ s.probeTimer.disable()
}
}
- // Update RACK reorder window.
- // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
- // * Upon receiving an ACK:
- // * Step 4: Update RACK reordering window
- s.rc.updateRACKReorderWindow(rcvdSeg)
+ if s.ep.sackPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 {
+ // Update RACK reorder window.
+ // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
+ // * Upon receiving an ACK:
+ // * Step 4: Update RACK reordering window
+ s.rc.updateRACKReorderWindow(rcvdSeg)
+
+ // After the reorder window is calculated, detect any loss by checking
+ // if the time elapsed after the segments are sent is greater than the
+ // reorder window.
+ if numLost := s.rc.detectLoss(rcvdSeg.rcvdTime); numLost > 0 && !s.fr.active {
+ // If any segment is marked as lost by
+ // RACK, enter recovery and retransmit
+ // the lost segments.
+ s.cc.HandleLossDetected()
+ s.enterRecovery()
+ }
+
+ if s.fr.active {
+ s.rc.DoRecovery(nil, true)
+ }
+ }
// Now that we've popped all acknowledged data from the retransmit
// queue, retransmit if needed.
- if s.fr.active {
+ if s.fr.active && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 {
s.lr.DoRecovery(rcvdSeg, fastRetransmit)
// When SACK is enabled data sending is governed by steps in
// RFC 6675 Section 5 recovery steps A-C.
@@ -1515,6 +1558,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error {
}
seg.xmitTime = time.Now()
seg.xmitCount++
+ seg.lost = false
err := s.sendSegmentFromView(seg.data, seg.flags, seg.sequenceNumber)
// Every time a packet containing data is sent (including a
diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go
index 8b20c3455..ba41cff6d 100644
--- a/pkg/tcpip/transport/tcp/snd_state.go
+++ b/pkg/tcpip/transport/tcp/snd_state.go
@@ -47,6 +47,8 @@ func (s *sender) loadRttMeasureTime(unix unixTime) {
// afterLoad is invoked by stateify.
func (s *sender) afterLoad() {
s.resendTimer.init(&s.resendWaker)
+ s.reorderTimer.init(&s.reorderWaker)
+ s.probeTimer.init(&s.probeWaker)
}
// saveFirstRetransmittedSegXmitTime is invoked by stateify.
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
index a6a26b705..6da981d80 100644
--- a/pkg/tcpip/transport/tcp/tcp_rack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
+ "gvisor.dev/gvisor/pkg/test/testutil"
)
const (
@@ -34,6 +35,14 @@ const (
mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload
)
+func setStackRACKPermitted(t *testing.T, c *context.Context) {
+ t.Helper()
+ opt := tcpip.TCPRecovery(tcpip.TCPRACKLossDetection)
+ if err := c.Stack().SetTransportProtocolOption(header.TCPProtocolNumber, &opt); err != nil {
+ t.Fatalf("c.s.SetTransportProtocolOption(%d, &%v(%v)): %s", header.TCPProtocolNumber, opt, opt, err)
+ }
+}
+
// TestRACKUpdate tests the RACK related fields are updated when an ACK is
// received on a SACK enabled connection.
func TestRACKUpdate(t *testing.T) {
@@ -60,6 +69,7 @@ func TestRACKUpdate(t *testing.T) {
close(probeDone)
})
setStackSACKPermitted(t, c, true)
+ setStackRACKPermitted(t, c)
createConnectedWithSACKAndTS(c)
data := make([]byte, maxPayload)
@@ -90,6 +100,8 @@ func TestRACKDetectReorder(t *testing.T) {
c := context.New(t, uint32(mtu))
defer c.Cleanup()
+ t.Skipf("Skipping this test as reorder detection does not consider DSACK.")
+
var n int
const ackNumToVerify = 2
probeDone := make(chan struct{})
@@ -116,6 +128,7 @@ func TestRACKDetectReorder(t *testing.T) {
close(probeDone)
})
setStackSACKPermitted(t, c, true)
+ setStackRACKPermitted(t, c)
createConnectedWithSACKAndTS(c)
data := make([]byte, ackNumToVerify*maxPayload)
for i := range data {
@@ -148,6 +161,7 @@ func TestRACKDetectReorder(t *testing.T) {
func sendAndReceive(t *testing.T, c *context.Context, numPackets int) []byte {
setStackSACKPermitted(t, c, true)
+ setStackRACKPermitted(t, c)
createConnectedWithSACKAndTS(c)
data := make([]byte, numPackets*maxPayload)
@@ -580,7 +594,6 @@ func TestRACKCheckReorderWindow(t *testing.T) {
c.SendAck(seq, bytesRead)
// Missing [2-6] packets and SACK #7 packet.
- seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1)
start := c.IRS.Add(1 + seqnum.Size(6*maxPayload))
end := start.Add(seqnum.Size(maxPayload))
c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
@@ -596,3 +609,46 @@ func TestRACKCheckReorderWindow(t *testing.T) {
t.Fatalf("unexpected values for RACK variables: %v", err)
}
}
+
+func TestRACKWithDuplicateACK(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ const numPackets = 4
+ data := sendAndReceive(t, c, numPackets)
+
+ // Send three duplicate ACKs to trigger fast recovery.
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ start := c.IRS.Add(1 + seqnum.Size(maxPayload))
+ end := start.Add(seqnum.Size(maxPayload))
+ for i := 0; i < 3; i++ {
+ c.SendAckWithSACK(seq, maxPayload, []header.SACKBlock{{start, end}})
+ end = end.Add(seqnum.Size(maxPayload))
+ }
+
+ // Receive the retransmitted packet.
+ c.ReceiveAndCheckPacketWithOptions(data, maxPayload, maxPayload, tsOptionSize)
+
+ metricPollFn := func() error {
+ tcpStats := c.Stack().Stats().TCP
+ stats := []struct {
+ stat *tcpip.StatCounter
+ name string
+ want uint64
+ }{
+ {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1},
+ {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
+ {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0},
+ }
+ for _, s := range stats {
+ if got, want := s.stat.Value(), s.want; got != want {
+ return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
+ }
+ }
+ return nil
+ }
+
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index cd3c4a027..0128c1f7e 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -4393,12 +4393,7 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
t.Helper()
- s, err := ep.SocketOptions().GetSendBufferSize()
- if err != nil {
- t.Fatalf("GetSendBufferSize failed: %s", err)
- }
-
- if int(s) != v {
+ if s := ep.SocketOptions().GetSendBufferSize(); int(s) != v {
t.Fatalf("got send buffer size = %d, want = %d", s, v)
}
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index afd8f4d39..807df2bb5 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -938,11 +938,6 @@ func (e *endpoint) Disconnect() tcpip.Error {
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
- if addr.Port == 0 {
- // We don't support connecting to port zero.
- return &tcpip.ErrInvalidEndpointState{}
- }
-
e.mu.Lock()
defer e.mu.Unlock()
@@ -1188,7 +1183,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- if e.EndpointState() != StateConnected {
+ if e.EndpointState() != StateConnected || e.dstPort == 0 {
return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
}
diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go
index eacd73531..2a8c916d5 100644
--- a/runsc/boot/filter/config.go
+++ b/runsc/boot/filter/config.go
@@ -100,6 +100,15 @@ var allowedSyscalls = seccomp.SyscallRules{
seccomp.MatchAny{},
},
},
+ // getcpu is used by some versions of the Go runtime and by the hostcpu
+ // package on arm64.
+ unix.SYS_GETCPU: []seccomp.Rule{
+ {
+ seccomp.MatchAny{},
+ seccomp.EqualTo(0),
+ seccomp.EqualTo(0),
+ },
+ },
syscall.SYS_GETPID: {},
unix.SYS_GETRANDOM: {},
syscall.SYS_GETSOCKOPT: []seccomp.Rule{
diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go
index d50bbcd9f..129478505 100644
--- a/runsc/container/container_test.go
+++ b/runsc/container/container_test.go
@@ -777,6 +777,28 @@ func TestExec(t *testing.T) {
}
})
}
+
+ // Test for exec failure with an non-existent file.
+ t.Run("nonexist", func(t *testing.T) {
+ // b/179114837 found by Syzkaller that causes nil pointer panic when
+ // trying to dec-ref an unix socket FD.
+ fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer syscall.Close(fds[0])
+
+ _, err = cont.executeSync(&control.ExecArgs{
+ Argv: []string{"/nonexist"},
+ FilePayload: urpc.FilePayload{
+ Files: []*os.File{os.NewFile(uintptr(fds[1]), "sock")},
+ },
+ })
+ want := "failed to load /nonexist"
+ if err == nil || !strings.Contains(err.Error(), want) {
+ t.Errorf("executeSync: want err containing %q; got err = %q", want, err)
+ }
+ })
})
}
}
diff --git a/runsc/fsgofer/filter/config.go b/runsc/fsgofer/filter/config.go
index 39b8a0b1e..f92e2f80e 100644
--- a/runsc/fsgofer/filter/config.go
+++ b/runsc/fsgofer/filter/config.go
@@ -107,6 +107,15 @@ var allowedSyscalls = seccomp.SyscallRules{
seccomp.MatchAny{},
},
},
+ // getcpu is used by some versions of the Go runtime and by the hostcpu
+ // package on arm64.
+ unix.SYS_GETCPU: []seccomp.Rule{
+ {
+ seccomp.MatchAny{},
+ seccomp.EqualTo(0),
+ seccomp.EqualTo(0),
+ },
+ },
syscall.SYS_GETDENTS64: {},
syscall.SYS_GETPID: {},
unix.SYS_GETRANDOM: {},
diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go
index aaffabfd0..49cd74887 100644
--- a/test/e2e/integration_test.go
+++ b/test/e2e/integration_test.go
@@ -430,11 +430,44 @@ func TestTmpMount(t *testing.T) {
}
}
+// TestSyntheticDirs checks that submounts can be created inside a readonly
+// mount even if the target path does not exist.
+func TestSyntheticDirs(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ opts := dockerutil.RunOpts{
+ Image: "basic/alpine",
+ // Make the root read-only to force use of synthetic dirs
+ // inside the root gofer mount.
+ ReadOnly: true,
+ Mounts: []mount.Mount{
+ // Mount inside read-only gofer-backed root.
+ {
+ Type: mount.TypeTmpfs,
+ Target: "/foo/bar/baz",
+ },
+ // Mount inside sysfs, which always uses synthetic dirs
+ // for submounts.
+ {
+ Type: mount.TypeTmpfs,
+ Target: "/sys/foo/bar/baz",
+ },
+ },
+ }
+ // Make sure the directories exist.
+ if _, err := d.Run(ctx, opts, "ls", "/foo/bar/baz", "/sys/foo/bar/baz"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+}
+
// TestHostOverlayfsCopyUp tests that the --overlayfs-stale-read option causes
// runsc to hide the incoherence of FDs opened before and after overlayfs
// copy-up on the host.
func TestHostOverlayfsCopyUp(t *testing.T) {
- runIntegrationTest(t, nil, "sh", "-c", "gcc -O2 -o test_copy_up test_copy_up.c && ./test_copy_up")
+ runIntegrationTest(t, nil, "./test_copy_up")
}
// TestHostOverlayfsRewindDir tests that rewinddir() "causes the directory
@@ -449,14 +482,14 @@ func TestHostOverlayfsCopyUp(t *testing.T) {
// automated tests yield newly-added files from readdir() even if the fsgofer
// does not explicitly rewinddir(), but overlayfs does not.
func TestHostOverlayfsRewindDir(t *testing.T) {
- runIntegrationTest(t, nil, "sh", "-c", "gcc -O2 -o test_rewinddir test_rewinddir.c && ./test_rewinddir")
+ runIntegrationTest(t, nil, "./test_rewinddir")
}
// Basic test for linkat(2). Syscall tests requires CAP_DAC_READ_SEARCH and it
// cannot use tricks like userns as root. For this reason, run a basic link test
// to ensure some coverage.
func TestLink(t *testing.T) {
- runIntegrationTest(t, nil, "sh", "-c", "gcc -O2 -o link_test link_test.c && ./link_test")
+ runIntegrationTest(t, nil, "./link_test")
}
// This test ensures we can run ping without errors.
@@ -487,6 +520,20 @@ func TestPing6Loopback(t *testing.T) {
runIntegrationTest(t, []string{"NET_ADMIN"}, "./ping6.sh")
}
+// This test checks that the owner of the sticky directory can delete files
+// inside it belonging to other users. It also checks that the owner of a file
+// can always delete its file when the file is inside a sticky directory owned
+// by another user.
+func TestStickyDir(t *testing.T) {
+ if vfs2Used, err := dockerutil.UsingVFS2(); err != nil {
+ t.Fatalf("failed to read config for runtime %s: %v", dockerutil.Runtime(), err)
+ } else if !vfs2Used {
+ t.Skip("sticky bit test fails on VFS1.")
+ }
+
+ runIntegrationTest(t, nil, "./test_sticky")
+}
+
func runIntegrationTest(t *testing.T, capAdd []string, args ...string) {
ctx := context.Background()
d := dockerutil.MakeContainer(ctx, t)
diff --git a/test/packetimpact/tests/tcp_rack_test.go b/test/packetimpact/tests/tcp_rack_test.go
index fb2a4cc90..ef902c54d 100644
--- a/test/packetimpact/tests/tcp_rack_test.go
+++ b/test/packetimpact/tests/tcp_rack_test.go
@@ -168,9 +168,10 @@ func TestRACKTLPLost(t *testing.T) {
closeSACKConnection(t, dut, conn, acceptFd, listenFd)
}
-// TestRACKTLPWithSACK tests TLP by acknowledging out of order packets.
+// TestRACKWithSACK tests that RACK marks the packets as lost after receiving
+// the ACK for retransmitted packets.
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-8.1
-func TestRACKTLPWithSACK(t *testing.T) {
+func TestRACKWithSACK(t *testing.T) {
dut, conn, acceptFd, listenFd := createSACKConnection(t)
seqNum1 := *conn.RemoteSeqNum(t)
@@ -180,8 +181,9 @@ func TestRACKTLPWithSACK(t *testing.T) {
// We are not sending ACK for these packets.
const numPkts = 3
- lastSent := sendAndReceive(t, dut, conn, numPkts, acceptFd, false /* sendACK */)
+ sendAndReceive(t, dut, conn, numPkts, acceptFd, false /* sendACK */)
+ time.Sleep(simulatedRTT)
// SACK for #2 packet.
sackBlock := make([]byte, 40)
start := seqNum1.Add(seqnum.Size(payloadSize))
@@ -194,32 +196,25 @@ func TestRACKTLPWithSACK(t *testing.T) {
}}, sackBlock[sbOff:])
conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]})
- // RACK marks #1 packet as lost and retransmits it.
- if _, err := conn.Expect(t, testbench.TCP{SeqNum: testbench.Uint32(uint32(seqNum1))}, time.Second); err != nil {
+ rtt, _ := getRTTAndRTO(t, dut, acceptFd)
+ timeout := 2 * rtt
+ // RACK marks #1 packet as lost after RTT+reorderWindow(RTT/4) and
+ // retransmits it.
+ if _, err := conn.Expect(t, testbench.TCP{SeqNum: testbench.Uint32(uint32(seqNum1))}, timeout); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
+ time.Sleep(simulatedRTT)
// ACK for #1 packet.
conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(end))})
- // Probe Timeout (PTO) should be two times RTT. TLP will trigger for #3
- // packet. RACK adds an additional timeout of 200ms if the number of
- // outstanding packets is equal to 1.
- rtt, rto := getRTTAndRTO(t, dut, acceptFd)
- pto := rtt*2 + (200 * time.Millisecond)
- if rto < pto {
- pto = rto
- }
- // We expect the 3rd packet (the last unacknowledged packet) to be
- // retransmitted.
- tlpProbe := testbench.Uint32(uint32(seqNum1) + uint32((numPkts-1)*payloadSize))
- if _, err := conn.Expect(t, testbench.TCP{SeqNum: tlpProbe}, time.Second); err != nil {
+ // RACK considers transmission times of the packets to mark them lost.
+ // As the 3rd packet was sent before the retransmitted 1st packet, RACK
+ // marks it as lost and retransmits it..
+ expectedSeqNum := testbench.Uint32(uint32(seqNum1) + uint32((numPkts-1)*payloadSize))
+ if _, err := conn.Expect(t, testbench.TCP{SeqNum: expectedSeqNum}, timeout); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
- diff := time.Now().Sub(lastSent)
- if diff < pto {
- t.Fatalf("expected payload was received before the probe timeout, got: %v, want: %v", diff, pto)
- }
closeSACKConnection(t, dut, conn, acceptFd, listenFd)
}
diff --git a/test/packetimpact/tests/udp_send_recv_dgram_test.go b/test/packetimpact/tests/udp_send_recv_dgram_test.go
index 6e45cb143..894d156cf 100644
--- a/test/packetimpact/tests/udp_send_recv_dgram_test.go
+++ b/test/packetimpact/tests/udp_send_recv_dgram_test.go
@@ -32,6 +32,7 @@ import (
func init() {
testbench.Initialize(flag.CommandLine)
+ testbench.RPCTimeout = 500 * time.Millisecond
}
type udpConn interface {
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index e43f30ba3..d6658898d 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -993,3 +993,7 @@ syscall_test(
syscall_test(
test = "//test/syscalls/linux:proc_net_udp_test",
)
+
+syscall_test(
+ test = "//test/syscalls/linux:processes_test",
+)
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 80e2837f8..42fc363a2 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -2346,6 +2346,7 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ "@com_google_absl//absl/time",
gtest,
"//test/util:test_util",
],
@@ -2360,6 +2361,7 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ "@com_google_absl//absl/time",
gtest,
"//test/util:test_util",
],
@@ -2678,6 +2680,7 @@ cc_binary(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ "@com_google_absl//absl/time",
gtest,
"//test/util:test_main",
"//test/util:test_util",
@@ -4160,6 +4163,18 @@ cc_binary(
)
cc_binary(
+ name = "processes_test",
+ testonly = 1,
+ srcs = ["processes.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
name = "xattr_test",
testonly = 1,
srcs = [
diff --git a/test/syscalls/linux/exec_binary.cc b/test/syscalls/linux/exec_binary.cc
index 3797fd4c8..b0fb120c6 100644
--- a/test/syscalls/linux/exec_binary.cc
+++ b/test/syscalls/linux/exec_binary.cc
@@ -951,6 +951,34 @@ TEST(ElfTest, PIEOutOfOrderSegments) {
EXPECT_EQ(execve_errno, ENOEXEC);
}
+TEST(ElfTest, PIEOverflow) {
+ ElfBinary<64> elf = StandardElf();
+
+ elf.header.e_type = ET_DYN;
+
+ // Choose vaddr of the first segment so that the end address overflows if the
+ // segment is mapped with a non-zero offset.
+ elf.phdrs[1].p_vaddr = 0xfffffffffffff000UL - elf.phdrs[1].p_memsz;
+
+ elf.UpdateOffsets();
+
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf));
+
+ pid_t child;
+ int execve_errno;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno));
+ if (IsRunningOnGvisor()) {
+ ASSERT_EQ(execve_errno, EINVAL);
+ } else {
+ ASSERT_EQ(execve_errno, 0);
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0),
+ SyscallSucceedsWithValue(child));
+ EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGSEGV) << status;
+ }
+}
+
// Standard dynamically linked binary with an ELF interpreter.
TEST(ElfTest, ELFInterpreter) {
ElfBinary<64> interpreter = StandardElf();
diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc
index f8fbea79e..46f41de50 100644
--- a/test/syscalls/linux/open_create.cc
+++ b/test/syscalls/linux/open_create.cc
@@ -46,8 +46,10 @@ TEST(CreateTest, ExistingFile) {
TEST(CreateTest, CreateAtFile) {
auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
auto dirfd = ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_DIRECTORY, 0666));
- EXPECT_THAT(openat(dirfd.get(), "CreateAtFile", O_RDWR | O_CREAT, 0666),
+ int fd;
+ EXPECT_THAT(fd = openat(dirfd.get(), "CreateAtFile", O_RDWR | O_CREAT, 0666),
SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
}
TEST(CreateTest, HonorsUmask_NoRandomSave) {
diff --git a/test/syscalls/linux/processes.cc b/test/syscalls/linux/processes.cc
new file mode 100644
index 000000000..412582515
--- /dev/null
+++ b/test/syscalls/linux/processes.cc
@@ -0,0 +1,90 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdint.h>
+#include <sys/syscall.h>
+#include <unistd.h>
+
+#include "test/util/capability_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+int testSetPGIDOfZombie(void* arg) {
+ int p[2];
+
+ TEST_PCHECK(pipe(p) == 0);
+
+ pid_t pid = fork();
+ if (pid == 0) {
+ pid = fork();
+ // Create a second child to repeat one of syzkaller reproducers.
+ if (pid == 0) {
+ pid = getpid();
+ TEST_PCHECK(setpgid(pid, 0) == 0);
+ TEST_PCHECK(write(p[1], &pid, sizeof(pid)) == sizeof(pid));
+ _exit(0);
+ }
+ TEST_PCHECK(pid > 0);
+ _exit(0);
+ }
+ close(p[1]);
+ TEST_PCHECK(pid > 0);
+
+ // Get PID of the second child.
+ pid_t cpid;
+ TEST_PCHECK(read(p[0], &cpid, sizeof(cpid)) == sizeof(cpid));
+
+ // Wait when both child processes will die.
+ int c;
+ TEST_PCHECK(read(p[0], &c, sizeof(c)) == 0);
+
+ // Wait the second child process to collect its zombie.
+ int status;
+ TEST_PCHECK(RetryEINTR(waitpid)(cpid, &status, 0) == cpid);
+
+ // Set the child's group.
+ TEST_PCHECK(setpgid(pid, pid) == 0);
+
+ TEST_PCHECK(RetryEINTR(waitpid)(-pid, &status, 0) == pid);
+
+ TEST_PCHECK(status == 0);
+ _exit(0);
+}
+
+TEST(Processes, SetPGIDOfZombie) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ // Fork a test process in a new PID namespace, because it needs to manipulate
+ // with reparanted processes.
+ struct clone_arg {
+ // Reserve some space for clone() to locate arguments and retcode in this
+ // place.
+ char stack[128] __attribute__((aligned(16)));
+ char stack_ptr[0];
+ } ca;
+ pid_t pid;
+ ASSERT_THAT(pid = clone(testSetPGIDOfZombie, ca.stack_ptr,
+ CLONE_NEWPID | SIGCHLD, &ca),
+ SyscallSucceeds());
+
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(pid, &status, 0),
+ SyscallSucceedsWithValue(pid));
+ EXPECT_EQ(status, 0);
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/rename.cc b/test/syscalls/linux/rename.cc
index 5458f54ad..22c8c19cf 100644
--- a/test/syscalls/linux/rename.cc
+++ b/test/syscalls/linux/rename.cc
@@ -391,6 +391,39 @@ TEST(RenameTest, FileWithOpenFd) {
EXPECT_EQ(absl::string_view(buf, sizeof(buf) - 1), kContents);
}
+// Tests that calling rename with file path ending with . or .. causes EBUSY.
+TEST(RenameTest, PathEndingWithDots) {
+ TempPath root_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ TempPath dir1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root_dir.path()));
+ TempPath dir2 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root_dir.path()));
+
+ // Try to move dir1 into dir2 but mess up the paths.
+ auto dir1Dot = JoinPath(dir1.path(), ".");
+ auto dir2Dot = JoinPath(dir2.path(), ".");
+ auto dir1DotDot = JoinPath(dir1.path(), "..");
+ auto dir2DotDot = JoinPath(dir2.path(), "..");
+ ASSERT_THAT(rename(dir1.path().c_str(), dir2Dot.c_str()),
+ SyscallFailsWithErrno(EBUSY));
+ ASSERT_THAT(rename(dir1.path().c_str(), dir2DotDot.c_str()),
+ SyscallFailsWithErrno(EBUSY));
+ ASSERT_THAT(rename(dir1Dot.c_str(), dir2.path().c_str()),
+ SyscallFailsWithErrno(EBUSY));
+ ASSERT_THAT(rename(dir1DotDot.c_str(), dir2.path().c_str()),
+ SyscallFailsWithErrno(EBUSY));
+}
+
+// Calling rename with file path ending with . or .. causes EBUSY in sysfs.
+TEST(RenameTest, SysfsPathEndingWithDots) {
+ // If a non-root user tries to rename inside /sys then we get EPERM.
+ SKIP_IF(geteuid() != 0);
+ ASSERT_THAT(rename("/sys/devices/system/cpu/online", "/sys/."),
+ SyscallFailsWithErrno(EBUSY));
+ ASSERT_THAT(rename("/sys/devices/system/cpu/online", "/sys/.."),
+ SyscallFailsWithErrno(EBUSY));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc
index 3924e0001..93b3a94f1 100644
--- a/test/syscalls/linux/sendfile.cc
+++ b/test/syscalls/linux/sendfile.cc
@@ -551,6 +551,29 @@ TEST(SendFileTest, SendPipeEOF) {
SyscallSucceedsWithValue(0));
}
+TEST(SendFileTest, SendToFullPipeReturnsEAGAIN) {
+ // Create and open an empty input file.
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor in_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+
+ // Set up the output pipe.
+ int fds[2];
+ ASSERT_THAT(pipe2(fds, O_NONBLOCK), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ int pipe_size = -1;
+ ASSERT_THAT(pipe_size = fcntl(wfd.get(), F_GETPIPE_SZ), SyscallSucceeds());
+ int data_size = pipe_size * 8;
+ ASSERT_THAT(ftruncate(in_fd.get(), data_size), SyscallSucceeds());
+
+ ASSERT_THAT(sendfile(wfd.get(), in_fd.get(), 0, data_size),
+ SyscallSucceeds());
+ EXPECT_THAT(sendfile(wfd.get(), in_fd.get(), 0, data_size),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
TEST(SendFileTest, SendPipeBlocks) {
// Create temp file.
constexpr char kData[] =
diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc
index de0b8bb11..f70047a09 100644
--- a/test/syscalls/linux/socket_generic.cc
+++ b/test/syscalls/linux/socket_generic.cc
@@ -379,7 +379,7 @@ TEST_P(AllSocketPairTest, RcvBufSucceeds) {
EXPECT_GT(size, 0);
}
-TEST_P(AllSocketPairTest, SndBufSucceeds) {
+TEST_P(AllSocketPairTest, GetSndBufSucceeds) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
int size = 0;
socklen_t size_size = sizeof(size);
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index a11147085..344a5a22c 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -526,7 +526,7 @@ void TestListenWhileConnect(const TestParam& param,
stopListen(listen_fd);
for (auto& client : clients) {
- const int kTimeout = 10000;
+ constexpr int kTimeout = 10000;
struct pollfd pfd = {
.fd = client.get(),
.events = POLLIN,
@@ -942,7 +942,7 @@ void setupTimeWaitClose(const TestAddress* listener,
// shutdown to trigger TIME_WAIT.
ASSERT_THAT(shutdown(active_closefd.get(), SHUT_WR), SyscallSucceeds());
{
- const int kTimeout = 10000;
+ constexpr int kTimeout = 10000;
struct pollfd pfd = {
.fd = passive_closefd.get(),
.events = POLLIN,
@@ -1186,7 +1186,7 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) {
ASSERT_EQ(addrlen, listener.addr_len);
// Wait for accept_fd to process the RST.
- const int kTimeout = 10000;
+ constexpr int kTimeout = 10000;
struct pollfd pfd = {
.fd = accept_fd.get(),
.events = POLLIN,
diff --git a/test/syscalls/linux/socket_unix_dgram.cc b/test/syscalls/linux/socket_unix_dgram.cc
index af0df4fb4..5b0844493 100644
--- a/test/syscalls/linux/socket_unix_dgram.cc
+++ b/test/syscalls/linux/socket_unix_dgram.cc
@@ -18,6 +18,8 @@
#include <sys/un.h>
#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
#include "test/util/test_util.h"
@@ -39,6 +41,39 @@ TEST_P(DgramUnixSocketPairTest, WriteOneSideClosed) {
SyscallFailsWithErrno(ECONNREFUSED));
}
+TEST_P(DgramUnixSocketPairTest, IncreasedSocketSendBufUnblocksWrites) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ int sock = sockets->first_fd();
+ int buf_size = 0;
+ socklen_t buf_size_len = sizeof(buf_size);
+ ASSERT_THAT(getsockopt(sock, SOL_SOCKET, SO_SNDBUF, &buf_size, &buf_size_len),
+ SyscallSucceeds());
+ int opts;
+ ASSERT_THAT(opts = fcntl(sock, F_GETFL), SyscallSucceeds());
+ opts |= O_NONBLOCK;
+ ASSERT_THAT(fcntl(sock, F_SETFL, opts), SyscallSucceeds());
+
+ std::vector<char> buf(buf_size / 4);
+ // Write till the socket buffer is full.
+ while (RetryEINTR(send)(sock, buf.data(), buf.size(), 0) != -1) {
+ // Sleep to give linux a chance to move data from the send buffer to the
+ // receive buffer.
+ absl::SleepFor(absl::Milliseconds(10)); // 10ms.
+ }
+ // The last error should have been EWOULDBLOCK.
+ ASSERT_EQ(errno, EWOULDBLOCK);
+
+ // Now increase the socket send buffer.
+ buf_size = buf_size * 2;
+ ASSERT_THAT(
+ setsockopt(sock, SOL_SOCKET, SO_SNDBUF, &buf_size, sizeof(buf_size)),
+ SyscallSucceeds());
+
+ // The send should succeed again.
+ ASSERT_THAT(RetryEINTR(send)(sock, buf.data(), buf.size(), 0),
+ SyscallSucceeds());
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/socket_unix_seqpacket.cc b/test/syscalls/linux/socket_unix_seqpacket.cc
index 6d03df4d9..eb373373d 100644
--- a/test/syscalls/linux/socket_unix_seqpacket.cc
+++ b/test/syscalls/linux/socket_unix_seqpacket.cc
@@ -18,6 +18,8 @@
#include <sys/un.h>
#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
#include "test/util/test_util.h"
@@ -61,6 +63,39 @@ TEST_P(SeqpacketUnixSocketPairTest, Sendto) {
SyscallSucceedsWithValue(3));
}
+TEST_P(SeqpacketUnixSocketPairTest, IncreasedSocketSendBufUnblocksWrites) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ int sock = sockets->first_fd();
+ int buf_size = 0;
+ socklen_t buf_size_len = sizeof(buf_size);
+ ASSERT_THAT(getsockopt(sock, SOL_SOCKET, SO_SNDBUF, &buf_size, &buf_size_len),
+ SyscallSucceeds());
+ int opts;
+ ASSERT_THAT(opts = fcntl(sock, F_GETFL), SyscallSucceeds());
+ opts |= O_NONBLOCK;
+ ASSERT_THAT(fcntl(sock, F_SETFL, opts), SyscallSucceeds());
+
+ std::vector<char> buf(buf_size / 4);
+ // Write till the socket buffer is full.
+ while (RetryEINTR(send)(sock, buf.data(), buf.size(), 0) != -1) {
+ // Sleep to give linux a chance to move data from the send buffer to the
+ // receive buffer.
+ absl::SleepFor(absl::Milliseconds(10)); // 10ms.
+ }
+ // The last error should have been EWOULDBLOCK.
+ ASSERT_EQ(errno, EWOULDBLOCK);
+
+ // Now increase the socket send buffer.
+ buf_size = buf_size * 2;
+ ASSERT_THAT(
+ setsockopt(sock, SOL_SOCKET, SO_SNDBUF, &buf_size, sizeof(buf_size)),
+ SyscallSucceeds());
+
+ // The send should succeed again.
+ ASSERT_THAT(RetryEINTR(send)(sock, buf.data(), buf.size(), 0),
+ SyscallSucceeds());
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/socket_unix_stream.cc b/test/syscalls/linux/socket_unix_stream.cc
index ad9c4bf37..3ff810914 100644
--- a/test/syscalls/linux/socket_unix_stream.cc
+++ b/test/syscalls/linux/socket_unix_stream.cc
@@ -17,6 +17,8 @@
#include <sys/un.h>
#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
#include "test/util/test_util.h"
@@ -134,6 +136,84 @@ TEST_P(StreamUnixSocketPairTest, GetSocketAcceptConn) {
EXPECT_EQ(got, 0);
}
+TEST_P(StreamUnixSocketPairTest, SetSocketSendBuf) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ auto s = sockets->first_fd();
+ int max = 0;
+ int min = 0;
+ {
+ // Discover maxmimum buffer size by setting to a really large value.
+ constexpr int kRcvBufSz = INT_MAX;
+ ASSERT_THAT(
+ setsockopt(s, SOL_SOCKET, SO_SNDBUF, &kRcvBufSz, sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ max = 0;
+ socklen_t max_len = sizeof(max);
+ ASSERT_THAT(getsockopt(s, SOL_SOCKET, SO_SNDBUF, &max, &max_len),
+ SyscallSucceeds());
+ }
+
+ {
+ // Discover minimum buffer size by setting it to zero.
+ constexpr int kRcvBufSz = 0;
+ ASSERT_THAT(
+ setsockopt(s, SOL_SOCKET, SO_SNDBUF, &kRcvBufSz, sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s, SOL_SOCKET, SO_SNDBUF, &min, &min_len),
+ SyscallSucceeds());
+ }
+
+ int quarter_sz = min + (max - min) / 4;
+ ASSERT_THAT(
+ setsockopt(s, SOL_SOCKET, SO_SNDBUF, &quarter_sz, sizeof(quarter_sz)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s, SOL_SOCKET, SO_SNDBUF, &val, &val_len),
+ SyscallSucceeds());
+
+ // Linux doubles the value set by SO_SNDBUF/SO_SNDBUF.
+ quarter_sz *= 2;
+ ASSERT_EQ(quarter_sz, val);
+}
+
+TEST_P(StreamUnixSocketPairTest, IncreasedSocketSendBufUnblocksWrites) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ int sock = sockets->first_fd();
+ int buf_size = 0;
+ socklen_t buf_size_len = sizeof(buf_size);
+ ASSERT_THAT(getsockopt(sock, SOL_SOCKET, SO_SNDBUF, &buf_size, &buf_size_len),
+ SyscallSucceeds());
+ int opts;
+ ASSERT_THAT(opts = fcntl(sock, F_GETFL), SyscallSucceeds());
+ opts |= O_NONBLOCK;
+ ASSERT_THAT(fcntl(sock, F_SETFL, opts), SyscallSucceeds());
+
+ std::vector<char> buf(buf_size / 4);
+ // Write till the socket buffer is full.
+ while (RetryEINTR(send)(sock, buf.data(), buf.size(), 0) != -1) {
+ // Sleep to give linux a chance to move data from the send buffer to the
+ // receive buffer.
+ absl::SleepFor(absl::Milliseconds(10)); // 10ms.
+ }
+ // The last error should have been EWOULDBLOCK.
+ ASSERT_EQ(errno, EWOULDBLOCK);
+
+ // Now increase the socket send buffer.
+ buf_size = buf_size * 2;
+ ASSERT_THAT(
+ setsockopt(sock, SOL_SOCKET, SO_SNDBUF, &buf_size, sizeof(buf_size)),
+ SyscallSucceeds());
+
+ // The send should succeed again.
+ ASSERT_THAT(RetryEINTR(send)(sock, buf.data(), buf.size(), 0),
+ SyscallSucceeds());
+}
+
INSTANTIATE_TEST_SUITE_P(
AllUnixDomainSockets, StreamUnixSocketPairTest,
::testing::ValuesIn(IncludeReversals(VecCat<SocketPairKind>(
diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc
index 650f12350..50f589708 100644
--- a/test/syscalls/linux/udp_socket.cc
+++ b/test/syscalls/linux/udp_socket.cc
@@ -2061,11 +2061,92 @@ TEST_P(UdpSocketTest, SendToZeroPort) {
SyscallSucceedsWithValue(sizeof(buf)));
}
+TEST_P(UdpSocketTest, ConnectToZeroPortUnbound) {
+ struct sockaddr_storage addr = InetLoopbackAddr();
+ SetPort(&addr, 0);
+ ASSERT_THAT(
+ connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen_),
+ SyscallSucceeds());
+}
+
+TEST_P(UdpSocketTest, ConnectToZeroPortBound) {
+ struct sockaddr_storage addr = InetLoopbackAddr();
+ ASSERT_NO_ERRNO(
+ BindSocket(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr)));
+
+ SetPort(&addr, 0);
+ ASSERT_THAT(
+ connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen_),
+ SyscallSucceeds());
+ socklen_t len = sizeof(sockaddr_storage);
+ ASSERT_THAT(
+ getsockname(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr), &len),
+ SyscallSucceeds());
+ ASSERT_EQ(len, addrlen_);
+}
+
+TEST_P(UdpSocketTest, ConnectToZeroPortConnected) {
+ struct sockaddr_storage addr = InetLoopbackAddr();
+ ASSERT_NO_ERRNO(
+ BindSocket(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr)));
+
+ // Connect to an address with non-zero port should succeed.
+ ASSERT_THAT(
+ connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen_),
+ SyscallSucceeds());
+ sockaddr_storage peername;
+ socklen_t peerlen = sizeof(peername);
+ ASSERT_THAT(
+ getpeername(sock_.get(), reinterpret_cast<struct sockaddr*>(&peername),
+ &peerlen),
+ SyscallSucceeds());
+ ASSERT_EQ(peerlen, addrlen_);
+ ASSERT_EQ(memcmp(&peername, &addr, addrlen_), 0);
+
+ // However connect() to an address with port 0 will make the following
+ // getpeername() fail.
+ SetPort(&addr, 0);
+ ASSERT_THAT(
+ connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen_),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ getpeername(sock_.get(), reinterpret_cast<struct sockaddr*>(&peername),
+ &peerlen),
+ SyscallFailsWithErrno(ENOTCONN));
+}
+
INSTANTIATE_TEST_SUITE_P(AllInetTests, UdpSocketTest,
::testing::Values(AddressFamily::kIpv4,
AddressFamily::kIpv6,
AddressFamily::kDualStack));
+TEST(UdpInet6SocketTest, ConnectInet4Sockaddr) {
+ // glibc getaddrinfo expects the invariant expressed by this test to be held.
+ const sockaddr_in connect_sockaddr = {
+ .sin_family = AF_INET, .sin_addr = {.s_addr = htonl(INADDR_LOOPBACK)}};
+ auto sock_ =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP));
+ ASSERT_THAT(
+ connect(sock_.get(),
+ reinterpret_cast<const struct sockaddr*>(&connect_sockaddr),
+ sizeof(sockaddr_in)),
+ SyscallSucceeds());
+ socklen_t len;
+ sockaddr_storage sockname;
+ ASSERT_THAT(getsockname(sock_.get(),
+ reinterpret_cast<struct sockaddr*>(&sockname), &len),
+ SyscallSucceeds());
+ ASSERT_EQ(sockname.ss_family, AF_INET6);
+ ASSERT_EQ(len, sizeof(sockaddr_in6));
+ auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&sockname);
+ char addr_buf[INET6_ADDRSTRLEN];
+ const char* addr;
+ ASSERT_NE(addr = inet_ntop(sockname.ss_family, &sockname, addr_buf,
+ sizeof(addr_buf)),
+ nullptr);
+ ASSERT_TRUE(IN6_IS_ADDR_V4MAPPED(sin6->sin6_addr.s6_addr)) << addr;
+}
+
} // namespace
} // namespace testing
diff --git a/tools/bazel_gazelle_noise.patch b/tools/bazel_gazelle_noise.patch
deleted file mode 100644
index e35f38933..000000000
--- a/tools/bazel_gazelle_noise.patch
+++ /dev/null
@@ -1,24 +0,0 @@
-diff -r -u2 a/language/go/resolve.go b/language/go/resolve.go
---- a/language/go/resolve.go 2020-10-02 14:22:18.000000000 -0700
-+++ b/language/go/resolve.go 2020-11-17 19:40:59.770648029 -0800
-@@ -20,5 +20,4 @@
- "fmt"
- "go/build"
-- "log"
- "path"
- "regexp"
-@@ -80,5 +79,5 @@
- resolve = ResolveGo
- }
-- deps, errs := imports.Map(func(imp string) (string, error) {
-+ deps, _ := imports.Map(func(imp string) (string, error) {
- l, err := resolve(c, ix, rc, imp, from)
- if err == skipImportError {
-@@ -95,7 +94,4 @@
- return l.String(), nil
- })
-- for _, err := range errs {
-- log.Print(err)
-- }
- if !deps.IsEmpty() {
- if r.Kind() == "go_proto_library" {
diff --git a/tools/go_marshal/README.md b/tools/go_marshal/README.md
index d8045c295..eddba0c21 100644
--- a/tools/go_marshal/README.md
+++ b/tools/go_marshal/README.md
@@ -98,6 +98,18 @@ for embedded structs that are not aligned.
Because of this, it's generally best to avoid using `marshal:"unaligned"` and
insert explicit padding fields instead.
+## Working with dynamically sized structs
+
+While `go_marshal` seamlessly supports statically sized structs (which most ABI
+structs are), it can also used for other uses cases where marshalling is
+required. There is some provision to partially support dynamically sized structs
+that may not be ABI structs. A user can define a dynamic struct and define
+`SizeBytes()`, `MarshalBytes(dst)` and `UnmarshalBytes(src)` for it. Then user
+can then add a comment above the struct like `// +marshal dynamic` while will
+make `go_marshal` autogenerate the remaining methods required to complete the
+`Marshallable` interface. This feature is currently only available for structs
+and can not be used alongside the Slice API.
+
## Modifying the `go_marshal` Tool
The following are some guidelines for modifying the `go_marshal` tool:
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go
index fa642c88a..abd6f69ea 100644
--- a/tools/go_marshal/gomarshal/generator.go
+++ b/tools/go_marshal/gomarshal/generator.go
@@ -211,9 +211,10 @@ type sliceAPI struct {
// marshallableType carries information about a type marked with the '+marshal'
// directive.
type marshallableType struct {
- spec *ast.TypeSpec
- slice *sliceAPI
- recv string
+ spec *ast.TypeSpec
+ slice *sliceAPI
+ recv string
+ dynamic bool
}
func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.TypeSpec) *marshallableType {
@@ -248,6 +249,9 @@ func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.Ty
}
continue
+ } else if tag == "dynamic" {
+ mt.dynamic = true
+ continue
}
unhandledTags = append(unhandledTags, tag)
@@ -379,23 +383,38 @@ func (g *Generator) generateOne(t *marshallableType, fset *token.FileSet) *inter
i := newInterfaceGenerator(t.spec, t.recv, fset)
switch ty := t.spec.Type.(type) {
case *ast.StructType:
+ if t.dynamic {
+ // Don't validate because this type is dynamically sized and probably
+ // contains some funky slices which the validation does not allow.
+ i.emitMarshallableForStruct(ty, t.dynamic)
+ if t.slice != nil {
+ abortAt(fset.Position(t.slice.comment.Slash), "Slice API is not supported for dynamic types because it assumes that each slice element is statically sized.")
+ }
+ break
+ }
i.validateStruct(t.spec, ty)
- i.emitMarshallableForStruct(ty)
+ i.emitMarshallableForStruct(ty, t.dynamic)
if t.slice != nil {
i.emitMarshallableSliceForStruct(ty, t.slice)
}
case *ast.Ident:
i.validatePrimitiveNewtype(ty)
+ if t.dynamic {
+ abortAt(fset.Position(t.slice.comment.Slash), "Primitive type marked as '+marshal dynamic', but primitive types can not be dynamic.")
+ }
i.emitMarshallableForPrimitiveNewtype(ty)
if t.slice != nil {
i.emitMarshallableSliceForPrimitiveNewtype(ty, t.slice)
}
case *ast.ArrayType:
i.validateArrayNewtype(t.spec.Name, ty)
+ if t.dynamic {
+ abortAt(fset.Position(t.slice.comment.Slash), "Marking array types as `dynamic` is currently not supported.")
+ }
// After validate, we can safely call arrayLen.
i.emitMarshallableForArrayNewtype(t.spec.Name, ty, ty.Elt.(*ast.Ident))
if t.slice != nil {
- abortAt(fset.Position(t.slice.comment.Slash), fmt.Sprintf("Array type marked as '+marshal slice:...', but this is not supported. Perhaps fold one of the dimensions?"))
+ abortAt(fset.Position(t.slice.comment.Slash), "Array type marked as '+marshal slice:...', but this is not supported. Perhaps fold one of the dimensions?")
}
default:
// This should've been filtered out by collectMarshallabeTypes.
@@ -408,7 +427,7 @@ func (g *Generator) generateOne(t *marshallableType, fset *token.FileSet) *inter
// implementations type t.
func (g *Generator) generateOneTestSuite(t *marshallableType) *testGenerator {
i := newTestGenerator(t.spec, t.recv)
- i.emitTests(t.slice)
+ i.emitTests(t.slice, t.dynamic)
return i
}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go
index 5f6306b8f..f98e41ed7 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces_struct.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go
@@ -69,7 +69,11 @@ func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType
})
}
-func (g *interfaceGenerator) isStructPacked(st *ast.StructType) bool {
+func (g *interfaceGenerator) isStructPacked(st *ast.StructType, isDynamic bool) bool {
+ if isDynamic {
+ // Dynamic types are not packed because a slice header might be present.
+ return false
+ }
packed := true
forEachStructField(st, func(f *ast.Field) {
if f.Tag != nil {
@@ -85,165 +89,17 @@ func (g *interfaceGenerator) isStructPacked(st *ast.StructType) bool {
return packed
}
-func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
- thisPacked := g.isStructPacked(st)
-
- g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
- g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
- g.inIndent(func() {
- primitiveSize := 0
- var dynamicSizeTerms []string
-
- forEachStructField(st, fieldDispatcher{
- primitive: func(_, t *ast.Ident) {
- if size, dynamic := g.scalarSize(t); !dynamic {
- primitiveSize += size
- } else {
- g.recordUsedMarshallable(t.Name)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name))
- }
- },
- selector: func(_, tX, tSel *ast.Ident) {
- tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name)
- g.recordUsedImport(tX.Name)
- g.recordUsedMarshallable(tName)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName))
- },
- array: func(_ *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
- lenExpr := g.arrayLenExpr(a)
- if size, dynamic := g.scalarSize(t); !dynamic {
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%d*%s", size, lenExpr))
- } else {
- g.recordUsedMarshallable(t.Name)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%s", t.Name, lenExpr))
- }
- },
- }.dispatch)
- g.emit("return %d", primitiveSize)
- if len(dynamicSizeTerms) > 0 {
- g.incIndent()
- }
- {
- for _, d := range dynamicSizeTerms {
- g.emitNoIndent(" +\n")
- g.emit(d)
- }
- }
- if len(dynamicSizeTerms) > 0 {
- g.decIndent()
- }
- })
- g.emit("\n}\n\n")
+func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType, isDynamic bool) {
+ thisPacked := g.isStructPacked(st, isDynamic)
- g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
- g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- forEachStructField(st, fieldDispatcher{
- primitive: func(n, t *ast.Ident) {
- if n.Name == "_" {
- g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("dst", len)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can reference here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name)
- }
- return
- }
- g.marshalScalar(g.fieldAccessor(n), t.Name, "dst")
- },
- selector: func(n, tX, tSel *ast.Ident) {
- if n.Name == "_" {
- g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", tX.Name, tSel.Name)
- g.emit("dst = dst[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name)
- return
- }
- g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
- },
- array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
- lenExpr := g.arrayLenExpr(a)
- if n.Name == "_" {
- g.emit("// Padding: dst[:sizeof(%s)*%s] ~= [%s]%s{0}\n", t.Name, lenExpr, lenExpr, t.Name)
- if size, dynamic := g.scalarSize(t); !dynamic {
- g.emit("dst = dst[%d*(%s):]\n", size, lenExpr)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can reference here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("dst = dst[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr)
- }
- return
- }
-
- g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
- g.inIndent(func() {
- g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst")
- })
- g.emit("}\n")
- },
- }.dispatch)
- })
- g.emit("}\n\n")
-
- g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
- g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- forEachStructField(st, fieldDispatcher{
- primitive: func(n, t *ast.Ident) {
- if n.Name == "_" {
- g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("src", len)
- } else {
- // We don't have an instance of the dynamic type we can
- // reference here (since the version in this struct is
- // anonymous). Use a typed nil pointer to call
- // SizeBytes() instead.
- g.shiftDynamic("src", fmt.Sprintf("(*%s)(nil)", t.Name))
- g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name))
- }
- return
- }
- g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src")
- },
- selector: func(n, tX, tSel *ast.Ident) {
- if n.Name == "_" {
- g.emit("// Padding: %s ~= src[:sizeof(%s.%s)]\n", g.fieldAccessor(n), tX.Name, tSel.Name)
- g.emit("src = src[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name)
- g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s.%s)(nil)", tX.Name, tSel.Name))
- return
- }
- g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
- },
- array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
- lenExpr := g.arrayLenExpr(a)
- if n.Name == "_" {
- g.emit("// Padding: ~ copy([%s]%s(%s), src[:sizeof(%s)*%s])\n", lenExpr, t.Name, g.fieldAccessor(n), t.Name, lenExpr)
- if size, dynamic := g.scalarSize(t); !dynamic {
- g.emit("src = src[%d*(%s):]\n", size, lenExpr)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can referece here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("src = src[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr)
- }
- return
- }
-
- g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
- g.inIndent(func() {
- g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src")
- })
- g.emit("}\n")
- },
- }.dispatch)
- })
- g.emit("}\n\n")
+ // Dynamic types are supposed to manually implement SizeBytes, MarshalBytes
+ // and UnmarshalBytes. The rest of the methos are autogenerated and depend on
+ // the implementation of these three.
+ if !isDynamic {
+ g.emitSizeBytesForStruct(st)
+ g.emitMarshalBytesForStruct(st)
+ g.emitUnmarshalBytesForStruct(st)
+ }
g.emit("// Packed implements marshal.Marshallable.Packed.\n")
g.emit("//go:nosplit\n")
@@ -428,8 +284,171 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
g.emit("}\n\n")
}
+func (g *interfaceGenerator) emitSizeBytesForStruct(st *ast.StructType) {
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ primitiveSize := 0
+ var dynamicSizeTerms []string
+
+ forEachStructField(st, fieldDispatcher{
+ primitive: func(_, t *ast.Ident) {
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ primitiveSize += size
+ } else {
+ g.recordUsedMarshallable(t.Name)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name))
+ }
+ },
+ selector: func(_, tX, tSel *ast.Ident) {
+ tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name)
+ g.recordUsedImport(tX.Name)
+ g.recordUsedMarshallable(tName)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName))
+ },
+ array: func(_ *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
+ lenExpr := g.arrayLenExpr(a)
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%d*%s", size, lenExpr))
+ } else {
+ g.recordUsedMarshallable(t.Name)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%s", t.Name, lenExpr))
+ }
+ },
+ }.dispatch)
+ g.emit("return %d", primitiveSize)
+ if len(dynamicSizeTerms) > 0 {
+ g.incIndent()
+ }
+ {
+ for _, d := range dynamicSizeTerms {
+ g.emitNoIndent(" +\n")
+ g.emit(d)
+ }
+ }
+ if len(dynamicSizeTerms) > 0 {
+ g.decIndent()
+ }
+ })
+ g.emit("\n}\n\n")
+}
+
+func (g *interfaceGenerator) emitMarshalBytesForStruct(st *ast.StructType) {
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ forEachStructField(st, fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("dst", len)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can reference here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name)
+ }
+ return
+ }
+ g.marshalScalar(g.fieldAccessor(n), t.Name, "dst")
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", tX.Name, tSel.Name)
+ g.emit("dst = dst[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name)
+ return
+ }
+ g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
+ },
+ array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
+ lenExpr := g.arrayLenExpr(a)
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)*%s] ~= [%s]%s{0}\n", t.Name, lenExpr, lenExpr, t.Name)
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ g.emit("dst = dst[%d*(%s):]\n", size, lenExpr)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can reference here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("dst = dst[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr)
+ }
+ return
+ }
+
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
+ g.inIndent(func() {
+ g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst")
+ })
+ g.emit("}\n")
+ },
+ }.dispatch)
+ })
+ g.emit("}\n\n")
+}
+
+func (g *interfaceGenerator) emitUnmarshalBytesForStruct(st *ast.StructType) {
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ forEachStructField(st, fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("src", len)
+ } else {
+ // We don't have an instance of the dynamic type we can
+ // reference here (since the version in this struct is
+ // anonymous). Use a typed nil pointer to call
+ // SizeBytes() instead.
+ g.shiftDynamic("src", fmt.Sprintf("(*%s)(nil)", t.Name))
+ g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name))
+ }
+ return
+ }
+ g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src")
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: %s ~= src[:sizeof(%s.%s)]\n", g.fieldAccessor(n), tX.Name, tSel.Name)
+ g.emit("src = src[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name)
+ g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s.%s)(nil)", tX.Name, tSel.Name))
+ return
+ }
+ g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
+ },
+ array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
+ lenExpr := g.arrayLenExpr(a)
+ if n.Name == "_" {
+ g.emit("// Padding: ~ copy([%s]%s(%s), src[:sizeof(%s)*%s])\n", lenExpr, t.Name, g.fieldAccessor(n), t.Name, lenExpr)
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ g.emit("src = src[%d*(%s):]\n", size, lenExpr)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can referece here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("src = src[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr)
+ }
+ return
+ }
+
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
+ g.inIndent(func() {
+ g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src")
+ })
+ g.emit("}\n")
+ },
+ }.dispatch)
+ })
+ g.emit("}\n\n")
+}
+
func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, slice *sliceAPI) {
- thisPacked := g.isStructPacked(st)
+ thisPacked := g.isStructPacked(st, false /* isDynamic */)
if slice.inner {
abortAt(g.f.Position(slice.comment.Slash), fmt.Sprintf("The ':inner' argument to '+marshal slice:%s:inner' is only applicable to newtypes on primitives. Remove it from this struct declaration.", slice.ident))
diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go
index 6cf00843f..ca3e15c16 100644
--- a/tools/go_marshal/gomarshal/generator_tests.go
+++ b/tools/go_marshal/gomarshal/generator_tests.go
@@ -216,12 +216,16 @@ func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() {
})
}
-func (g *testGenerator) emitTests(slice *sliceAPI) {
+func (g *testGenerator) emitTests(slice *sliceAPI, isDynamic bool) {
g.emitTestNonZeroSize()
g.emitTestSuspectAlignment()
- g.emitTestMarshalUnmarshalPreservesData()
- g.emitTestWriteToUnmarshalPreservesData()
- g.emitTestSizeBytesOnTypedNilPtr()
+ if !isDynamic {
+ // Do not test these for dynamic structs because they violate some
+ // assumptions that these tests make.
+ g.emitTestMarshalUnmarshalPreservesData()
+ g.emitTestWriteToUnmarshalPreservesData()
+ g.emitTestSizeBytesOnTypedNilPtr()
+ }
if slice != nil {
g.emitTestMarshalUnmarshalSlicePreservesData(slice)
diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD
index 4b27773c2..cb2d4e6e3 100644
--- a/tools/go_marshal/test/BUILD
+++ b/tools/go_marshal/test/BUILD
@@ -26,7 +26,10 @@ go_library(
srcs = ["test.go"],
marshal = True,
visibility = ["//tools/go_marshal/test:__subpackages__"],
- deps = ["//tools/go_marshal/test/external"],
+ deps = [
+ "//pkg/marshal/primitive",
+ "//tools/go_marshal/test/external",
+ ],
)
go_test(
@@ -36,6 +39,7 @@ go_test(
deps = [
":test",
"//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/syserror",
"//pkg/usermem",
"//tools/go_marshal/analysis",
diff --git a/tools/go_marshal/test/marshal_test.go b/tools/go_marshal/test/marshal_test.go
index a00f9a684..b0091dc64 100644
--- a/tools/go_marshal/test/marshal_test.go
+++ b/tools/go_marshal/test/marshal_test.go
@@ -28,6 +28,7 @@ import (
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/tools/go_marshal/analysis"
@@ -513,3 +514,21 @@ func TestLimitedSliceMarshalling(t *testing.T) {
})
}
}
+
+func TestDynamicType(t *testing.T) {
+ t12 := test.Type12Dynamic{
+ X: 32,
+ Y: []primitive.Int64{5, 6, 7},
+ }
+
+ var m marshal.Marshallable
+ m = &t12 // Ensure that all methods were generated.
+ b := make([]byte, m.SizeBytes())
+ m.MarshalBytes(b)
+
+ var res test.Type12Dynamic
+ res.UnmarshalBytes(b)
+ if !reflect.DeepEqual(t12, res) {
+ t.Errorf("dynamic type is not same after marshalling and unmarshalling: before = %+v, after = %+v", t12, res)
+ }
+}
diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go
index e7e3ed74a..b8eb989d9 100644
--- a/tools/go_marshal/test/test.go
+++ b/tools/go_marshal/test/test.go
@@ -16,6 +16,8 @@
package test
import (
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
+
// We're intentionally using a package name alias here even though it's not
// necessary to test the code generator's ability to handle package aliases.
ex "gvisor.dev/gvisor/tools/go_marshal/test/external"
@@ -198,3 +200,36 @@ type Type11 struct {
ex.External
y int64
}
+
+// Type12Dynamic is a dynamically sized struct which depends on the autogenerator
+// to generate some Marshallable methods for it.
+//
+// +marshal dynamic
+type Type12Dynamic struct {
+ X primitive.Int64
+ Y []primitive.Int64
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (t *Type12Dynamic) SizeBytes() int {
+ return (len(t.Y) * 8) + t.X.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (t *Type12Dynamic) MarshalBytes(dst []byte) {
+ t.X.MarshalBytes(dst)
+ dst = dst[t.X.SizeBytes():]
+ for i, x := range t.Y {
+ x.MarshalBytes(dst[i*8 : (i+1)*8])
+ }
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (t *Type12Dynamic) UnmarshalBytes(src []byte) {
+ t.X.UnmarshalBytes(src)
+ for i := t.X.SizeBytes(); i < len(src); i += 8 {
+ var x primitive.Int64
+ x.UnmarshalBytes(src[i:])
+ t.Y = append(t.Y, x)
+ }
+}