diff options
147 files changed, 5644 insertions, 4515 deletions
diff --git a/.bazelignore b/.bazelignore new file mode 100644 index 000000000..511b10433 --- /dev/null +++ b/.bazelignore @@ -0,0 +1 @@ +bazel-gvisor diff --git a/.devcontainer.json b/.devcontainer.json new file mode 100644 index 000000000..6f7fe4bf8 --- /dev/null +++ b/.devcontainer.json @@ -0,0 +1,9 @@ +{ + "dockerFile": "images/default/Dockerfile", + "overrideCommand": true, + "mounts": ["source=/var/run/docker.sock,target=/var/run/docker-host.sock,type=bind"], + "runArgs": ["--cap-add=SYS_PTRACE", "--security-opt", "seccomp=unconfined"], + "extensions": [ + "bazelbuild.vscode-bazel" + ] +} diff --git a/.vscode/tasks.json b/.vscode/tasks.json new file mode 100644 index 000000000..42a018434 --- /dev/null +++ b/.vscode/tasks.json @@ -0,0 +1,31 @@ +{ + "version": "2.0.0", + "tasks": [ + { + "label": "Build", + "type": "shell", + "command": "bazel build //...", + "group": { + "kind": "build", + "isDefault": true + }, + "presentation": { + "reveal": "always", + "panel": "new" + } + }, + { + "label": "Test", + "type": "shell", + "command": "bazel test //...", + "group": { + "kind": "test", + "isDefault": true + }, + "presentation": { + "reveal": "always", + "panel": "new" + } + } + ] +} @@ -57,9 +57,6 @@ http_archive( # This is actually a no-op with the hacky patch above, but should # slightly future proof this mechanism. "//tools:bazel_gazelle_generate.patch", - # False positive output complaining about Go logrus versions spam the - # logs. Strip this message in this case. Does not affect control flow. - "//tools:bazel_gazelle_noise.patch", ], sha256 = "222e49f034ca7a1d1231422cdb67066b885819885c356673cb1f72f748a3c9d4", urls = [ @@ -377,8 +374,8 @@ go_repository( name = "org_golang_google_grpc", build_file_proto_mode = "disable", importpath = "google.golang.org/grpc", - sum = "h1:cb+I9RwgcErlwAuOVnGhJ2d3YrcdwGXw+RPArsTWot4=", - version = "v1.36.0-dev.0.20210122012134-2c42474aca0c", + sum = "h1:iGG0ZwQMaxJT/qsL0nnzZCg+4aiWHuQy3MytzLieAjo=", + version = "v1.36.0-dev.0.20210208035533-9280052d3665", ) go_repository( @@ -517,8 +514,8 @@ go_repository( go_repository( name = "com_github_konsorten_go_windows_terminal_sequences", importpath = "github.com/konsorten/go-windows-terminal-sequences", - sum = "h1:CE8S1cTafDpPvMhIxNJKvHsGVBgn1xWYf1NbHQhywc8=", - version = "v1.0.3", + sum = "h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=", + version = "v1.0.2", ) go_repository( @@ -637,8 +634,8 @@ go_repository( go_repository( name = "com_github_containerd_continuity", importpath = "github.com/containerd/continuity", - sum = "h1:6ejg6Lkk8dskcM7wQ28gONkukbQkM4qpj4RnYbpFzrI=", - version = "v0.0.0-20201208142359-180525291bb7", + sum = "h1:6JKvHHt396/qabvMhnhUZvWaHZzfVfldxE60TK8YLhg=", + version = "v0.0.0-20210208174643-50096c924a4e", ) go_repository( diff --git a/g3doc/user_guide/FAQ.md b/g3doc/user_guide/FAQ.md index 8e5721ad1..26c836ddf 100644 --- a/g3doc/user_guide/FAQ.md +++ b/g3doc/user_guide/FAQ.md @@ -107,11 +107,11 @@ kubeadm to create your cluster please check if Docker is also installed on that system. Kubeadm prefers using Docker if both Docker and containerd are installed. -Please recreate your cluster and set the `--cni-socket` option on kubeadm +Please recreate your cluster and set the `--cri-socket` option on kubeadm commands. For example: ```bash -kubeadm init --cni-socket=/var/run/containerd/containerd.sock ... +kubeadm init --cri-socket=/var/run/containerd/containerd.sock ... ``` To fix an existing cluster edit the `/var/lib/kubelet/kubeadm-flags.env` file @@ -1,18 +1,17 @@ module gvisor.dev/gvisor -go 1.15 - -replace github.com/Sirupsen/logrus => github.com/sirupsen/logrus v1.6.0 +go 1.16 require ( cloud.google.com/go v0.75.0 // indirect + github.com/BurntSushi/toml v0.3.1 // indirect github.com/Microsoft/go-winio v0.4.16 // indirect github.com/Microsoft/hcsshim v0.8.14 // indirect github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422 // indirect github.com/containerd/cgroups v0.0.0-20201119153540-4cbc285b3327 // indirect github.com/containerd/console v1.0.1 // indirect github.com/containerd/containerd v1.3.9 // indirect - github.com/containerd/continuity v0.0.0-20201208142359-180525291bb7 // indirect + github.com/containerd/continuity v0.0.0-20210208174643-50096c924a4e // indirect github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00 // indirect github.com/containerd/go-runc v0.0.0-20200220073739-7016d3ce2328 // indirect github.com/containerd/ttrpc v1.0.2 // indirect @@ -23,6 +22,10 @@ require ( github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c // indirect github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 // indirect github.com/gogo/googleapis v1.4.0 // indirect + github.com/gogo/protobuf v1.3.1 // indirect + github.com/golang/mock v1.4.4 // indirect + github.com/google/btree v1.0.0 // indirect + github.com/google/go-cmp v0.5.4 // indirect github.com/google/go-github/v32 v32.1.0 // indirect github.com/google/pprof v0.0.0-20210115211752-39141e76b647 // indirect github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8 // indirect @@ -32,15 +35,25 @@ require ( github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9 // indirect github.com/opencontainers/image-spec v1.0.1 // indirect github.com/opencontainers/runc v0.1.1 // indirect + github.com/opencontainers/runtime-spec v1.0.2 // indirect github.com/pborman/uuid v1.2.0 // indirect + github.com/sirupsen/logrus v1.7.0 // indirect github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 // indirect github.com/vishvananda/netlink v1.0.1-0.20190930145447-2ec5bdc52b86 // indirect github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f // indirect github.com/xeipuuv/gojsonschema v1.2.0 // indirect go.uber.org/multierr v1.6.0 // indirect - google.golang.org/grpc v1.36.0-dev.0.20210122012134-2c42474aca0c // indirect + golang.org/x/net v0.0.0-20201224014010-6772e930b67b // indirect + golang.org/x/oauth2 v0.0.0-20201208152858-08078c50e5b5 // indirect + golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4 // indirect + golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect + golang.org/x/tools v0.1.0 // indirect + google.golang.org/api v0.36.0 // indirect + google.golang.org/grpc v1.36.0-dev.0.20210208035533-9280052d3665 // indirect google.golang.org/protobuf v1.25.1-0.20201020201750-d3470999428b // indirect + gopkg.in/yaml.v2 v2.2.8 // indirect honnef.co/go/tools v0.1.1 // indirect + k8s.io/api v0.16.13 // indirect k8s.io/apimachinery v0.16.14-rc.0 // indirect k8s.io/client-go v0.16.13 // indirect ) @@ -77,8 +77,8 @@ github.com/containerd/containerd v1.3.2/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMX github.com/containerd/containerd v1.3.9 h1:K2U/F4jGAMBqeUssfgJRbFuomLcS2Fxo1vR3UM/Mbh8= github.com/containerd/containerd v1.3.9/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMXFTttgp+kVtyUA= github.com/containerd/continuity v0.0.0-20190426062206-aaeac12a7ffc/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y= -github.com/containerd/continuity v0.0.0-20201208142359-180525291bb7 h1:6ejg6Lkk8dskcM7wQ28gONkukbQkM4qpj4RnYbpFzrI= -github.com/containerd/continuity v0.0.0-20201208142359-180525291bb7/go.mod h1:kR3BEg7bDFaEddKm54WSmrol1fKWDU1nKYkgrcgZT7Y= +github.com/containerd/continuity v0.0.0-20210208174643-50096c924a4e h1:6JKvHHt396/qabvMhnhUZvWaHZzfVfldxE60TK8YLhg= +github.com/containerd/continuity v0.0.0-20210208174643-50096c924a4e/go.mod h1:EXlVlkqNba9rJe3j7w3Xa924itAMLgZH4UD/Q4PExuQ= github.com/containerd/fifo v0.0.0-20190226154929-a9fb20d87448/go.mod h1:ODA38xgv3Kuk8dQz2ZQXpnv/UZZUHUCL7pnLehbXgQI= github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00 h1:lsjC5ENBl+Zgf38+B0ymougXFp0BaubeIVETltYZTQw= github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00/go.mod h1:jPQ2IAeZRCYxpS/Cm1495vGFww6ecHmMk1YJH2Q5ln0= @@ -236,7 +236,6 @@ github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQL github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1 h1:zc0R6cOw98cMengLA0fvU55mqbnN7sd/tBMLzSejp+M= @@ -286,7 +285,6 @@ github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= -github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.7.0 h1:ShrD1U9pZB12TX0cVy0DtePoCH97K8EtX+mg7ZARUtM= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/spf13/afero v1.2.2/go.mod h1:9ZxEEn6pIJ8Rxe320qSDBk6AsU0r9pR7Q4OcevTdifk= @@ -612,8 +610,8 @@ google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM google.golang.org/grpc v1.31.1/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= google.golang.org/grpc v1.34.0/go.mod h1:WotjhfgOW/POjDeRt8vscBtXq+2VjORFy659qA51WJ8= -google.golang.org/grpc v1.36.0-dev.0.20210122012134-2c42474aca0c h1:cb+I9RwgcErlwAuOVnGhJ2d3YrcdwGXw+RPArsTWot4= -google.golang.org/grpc v1.36.0-dev.0.20210122012134-2c42474aca0c/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= +google.golang.org/grpc v1.36.0-dev.0.20210208035533-9280052d3665 h1:iGG0ZwQMaxJT/qsL0nnzZCg+4aiWHuQy3MytzLieAjo= +google.golang.org/grpc v1.36.0-dev.0.20210208035533-9280052d3665/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/images/basic/integrationtest/Dockerfile.x86_64 b/images/basic/integrationtest/Dockerfile.x86_64 index e80e17527..b9fed05cb 100644 --- a/images/basic/integrationtest/Dockerfile.x86_64 +++ b/images/basic/integrationtest/Dockerfile.x86_64 @@ -5,3 +5,9 @@ COPY . . RUN chmod +x *.sh RUN apt-get update && apt-get install -y gcc iputils-ping iproute2 + +# Compilation Steps. +RUN gcc -O2 -o test_copy_up test_copy_up.c +RUN gcc -O2 -o test_rewinddir test_rewinddir.c +RUN gcc -O2 -o link_test link_test.c +RUN gcc -O2 -o test_sticky test_sticky.c diff --git a/images/basic/integrationtest/test_sticky.c b/images/basic/integrationtest/test_sticky.c new file mode 100644 index 000000000..58dcf91d3 --- /dev/null +++ b/images/basic/integrationtest/test_sticky.c @@ -0,0 +1,96 @@ +#include <err.h> +#include <errno.h> +#include <fcntl.h> +#include <stdlib.h> +#include <sys/stat.h> +#include <sys/types.h> +#include <sys/wait.h> +#include <unistd.h> + +void createFile(const char* path) { + int fd = open(path, O_WRONLY | O_CREAT, 0777); + if (fd < 0) { + err(1, "open(%s)", path); + exit(1); + } else { + close(fd); + } +} + +void waitAndCheckStatus(pid_t child) { + int status; + if (waitpid(child, &status, 0) == -1) { + err(1, "waitpid() failed"); + exit(1); + } + + if (WIFEXITED(status)) { + int es = WEXITSTATUS(status); + if (es) { + err(1, "child exit status %d", es); + exit(1); + } + } else { + err(1, "child did not exit normally"); + exit(1); + } +} + +void deleteFile(uid_t user, const char* path) { + pid_t child = fork(); + if (child == 0) { + if (setuid(user)) { + err(1, "setuid(%d)", user); + exit(1); + } + + if (unlink(path)) { + err(1, "unlink(%s)", path); + exit(1); + } + exit(0); + } + waitAndCheckStatus(child); +} + +int main(int argc, char** argv) { + const char kUser1Dir[] = "/user1dir"; + const char kUser2File[] = "/user1dir/user2file"; + const char kUser2File2[] = "/user1dir/user2file2"; + + const uid_t user1 = 6666; + const uid_t user2 = 6667; + + if (mkdir(kUser1Dir, 0755) != 0) { + err(1, "mkdir(%s)", kUser1Dir); + exit(1); + } + // Enable sticky bit for user1dir. + if (chmod(kUser1Dir, 01777) != 0) { + err(1, "chmod(%s)", kUser1Dir); + exit(1); + } + createFile(kUser2File); + createFile(kUser2File2); + + if (chown(kUser1Dir, user1, getegid())) { + err(1, "chown(%s)", kUser1Dir); + exit(1); + } + if (chown(kUser2File, user2, getegid())) { + err(1, "chown(%s)", kUser2File); + exit(1); + } + if (chown(kUser2File2, user2, getegid())) { + err(1, "chown(%s)", kUser2File2); + exit(1); + } + + // User1 should be able to delete any file inside user1dir, even files of + // other users due to the sticky bit. + deleteFile(user1, kUser2File); + + // User2 should naturally be able to delete its own file even if the file is + // inside a sticky dir owned by someone else. + deleteFile(user2, kUser2File2); +} diff --git a/images/syzkaller/Dockerfile b/images/syzkaller/Dockerfile index df6680f40..9a85ae345 100644 --- a/images/syzkaller/Dockerfile +++ b/images/syzkaller/Dockerfile @@ -1,5 +1,7 @@ FROM gcr.io/syzkaller/env +# This image is mostly for investigating syzkaller crashes, so let's install +# developer tools. RUN apt update && apt install -y git vim strace gdb procps WORKDIR /syzkaller/gopath/src/github.com/google/syzkaller diff --git a/images/syzkaller/README.md b/images/syzkaller/README.md index 1eac474f3..47e309422 100644 --- a/images/syzkaller/README.md +++ b/images/syzkaller/README.md @@ -5,21 +5,54 @@ syzkaller is an unsupervised coverage-guided kernel fuzzer. # How to run syzkaller. -* Build the syzkaller docker image `make load-syzkaller` -* Build runsc and place it in /tmp/syzkaller. `make RUNTIME_DIR=/tmp/syzkaller - refresh` -* Copy the syzkaller config in /tmp/syzkaller `cp - images/syzkaller/default-gvisor-config.cfg /tmp/syzkaller/syzkaller.cfg` -* Run syzkaller `docker run --privileged -it --rm -v - /tmp/syzkaller:/tmp/syzkaller gvisor.dev/images/syzkaller:latest` +First, we need to load a syzkaller docker image: + +```bash +make load-syzkaller +``` + +or we can rebuild it to use an up-to-date version of the master branch: + +```bash +make rebuild-syzkaller +``` + +Then we need to create a directory with all artifacts that we will need to run a +syzkaller. Then we will bind-mount this directory to a docker container. + +We need to build runsc and place it on the artifact directory: + +```bash +make RUNTIME_DIR=/tmp/syzkaller refresh +``` + +The next step is to create a syzkaller config. We can copy the default one and +customize it: + +```bash +cp images/syzkaller/default-gvisor-config.cfg /tmp/syzkaller/syzkaller.cfg +``` + +Now we can start syzkaller in a docker container: + +```bash +docker run --privileged -it --rm \ + -v /tmp/syzkaller:/tmp/syzkaller \ + gvisor.dev/images/syzkaller:latest +``` + +All logs will be in /tmp/syzkaller/workdir. # How to run a syz repro. -* Repeate all steps except the last one from the previous section. +We need to repeat all preparation steps from the previous section and save a +syzkaller repro in /tmp/syzkaller/repro. -* Save a syzkaller repro in /tmp/syzkaller/repro +Now we can run syz-repro to reproduce a crash: -* Run syz-repro `docker run --privileged -it --rm -v +```bash +docker run --privileged -it --rm -v /tmp/syzkaller:/tmp/syzkaller --entrypoint="" gvisor.dev/images/syzkaller:latest ./bin/syz-repro -config - /tmp/syzkaller/syzkaller.cfg /tmp/syzkaller/repro` + /tmp/syzkaller/syzkaller.cfg /tmp/syzkaller/repro +``` @@ -147,6 +147,11 @@ analyzers: - pkg/sentry/fs/fs.go # Intentional. - pkg/sentry/fs/gofer/inode.go # Intentional. - pkg/refs/refcounter_test.go # Intentional. + ST1019: + generated: + exclude: + # package ".../kubeapi/core/v1/v1" is being imported more than once + - generated.gen.pb.go ST1021: internal: suppress: diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index 8fa61d6f7..ecaeb11ac 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -80,7 +80,6 @@ go_library( "//pkg/bits", "//pkg/marshal", "//pkg/marshal/primitive", - "//pkg/usermem", ], ) diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index b521144d9..378f1baf3 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -15,11 +15,8 @@ package linux import ( - "io" - "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" - "gvisor.dev/gvisor/pkg/usermem" ) // This file contains structures required to support netfilter, specifically @@ -129,8 +126,8 @@ type IPTEntry struct { const SizeOfIPTEntry = 112 // KernelIPTEntry is identical to IPTEntry, but includes the Elems field. -// KernelIPTEntry itself is not Marshallable but it implements some methods of -// marshal.Marshallable that help in other implementations of Marshallable. +// +// +marshal dynamic type KernelIPTEntry struct { Entry IPTEntry @@ -158,6 +155,8 @@ func (ke *KernelIPTEntry) UnmarshalBytes(src []byte) { ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():]) } +var _ marshal.Marshallable = (*KernelIPTEntry)(nil) + // IPTIP contains information for matching a packet's IP header. // It corresponds to struct ipt_ip in // include/uapi/linux/netfilter_ipv4/ip_tables.h. @@ -411,8 +410,9 @@ type IPTGetEntries struct { const SizeOfIPTGetEntries = 40 // KernelIPTGetEntries is identical to IPTGetEntries, but includes the -// Entrytable field. This has been manually made marshal.Marshallable since it -// is dynamically sized. +// Entrytable field. +// +// +marshal dynamic type KernelIPTGetEntries struct { IPTGetEntries Entrytable []KernelIPTEntry @@ -447,65 +447,6 @@ func (ke *KernelIPTGetEntries) UnmarshalBytes(src []byte) { } } -// Packed implements marshal.Marshallable.Packed. -func (ke *KernelIPTGetEntries) Packed() bool { - // KernelIPTGetEntries isn't packed because the ke.Entrytable contains an - // indirection to the actual data we want to marshal (the slice data - // pointer), and the memory for KernelIPTGetEntries contains the slice - // header which we don't want to marshal. - return false -} - -// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. -func (ke *KernelIPTGetEntries) MarshalUnsafe(dst []byte) { - // Fall back to safe Marshal because the type in not packed. - ke.MarshalBytes(dst) -} - -// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. -func (ke *KernelIPTGetEntries) UnmarshalUnsafe(src []byte) { - // Fall back to safe Unmarshal because the type in not packed. - ke.UnmarshalBytes(src) -} - -// CopyIn implements marshal.Marshallable.CopyIn. -func (ke *KernelIPTGetEntries) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) { - buf := cc.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. - length, err := cc.CopyInBytes(addr, buf) // escapes: okay. - // Unmarshal unconditionally. If we had a short copy-in, this results in a - // partially unmarshalled struct. - ke.UnmarshalBytes(buf) // escapes: fallback. - return length, err -} - -// CopyOut implements marshal.Marshallable.CopyOut. -func (ke *KernelIPTGetEntries) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) { - // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall - // back to MarshalBytes. - return cc.CopyOutBytes(addr, ke.marshalAll(cc)) -} - -// CopyOutN implements marshal.Marshallable.CopyOutN. -func (ke *KernelIPTGetEntries) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) { - // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall - // back to MarshalBytes. - return cc.CopyOutBytes(addr, ke.marshalAll(cc)[:limit]) -} - -func (ke *KernelIPTGetEntries) marshalAll(cc marshal.CopyContext) []byte { - buf := cc.CopyScratchBuffer(ke.SizeBytes()) - ke.MarshalBytes(buf) - return buf -} - -// WriteTo implements io.WriterTo.WriteTo. -func (ke *KernelIPTGetEntries) WriteTo(w io.Writer) (int64, error) { - buf := make([]byte, ke.SizeBytes()) - ke.MarshalBytes(buf) - length, err := w.Write(buf) - return int64(length), err -} - var _ marshal.Marshallable = (*KernelIPTGetEntries)(nil) // IPTReplace is the argument for the IPT_SO_SET_REPLACE sockopt. It diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go index bcb57642e..b953e62dc 100644 --- a/pkg/abi/linux/netfilter_ipv6.go +++ b/pkg/abi/linux/netfilter_ipv6.go @@ -15,11 +15,8 @@ package linux import ( - "io" - "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" - "gvisor.dev/gvisor/pkg/usermem" ) // This file contains structures required to support IPv6 netfilter and @@ -70,8 +67,9 @@ type IP6TReplace struct { const SizeOfIP6TReplace = 96 // KernelIP6TGetEntries is identical to IP6TGetEntries, but includes the -// Entrytable field. This has been manually made marshal.Marshallable since it -// is dynamically sized. +// Entrytable field. +// +// +marshal dynamic type KernelIP6TGetEntries struct { IPTGetEntries Entrytable []KernelIP6TEntry @@ -106,65 +104,6 @@ func (ke *KernelIP6TGetEntries) UnmarshalBytes(src []byte) { } } -// Packed implements marshal.Marshallable.Packed. -func (ke *KernelIP6TGetEntries) Packed() bool { - // KernelIP6TGetEntries isn't packed because the ke.Entrytable contains - // an indirection to the actual data we want to marshal (the slice data - // pointer), and the memory for KernelIP6TGetEntries contains the slice - // header which we don't want to marshal. - return false -} - -// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. -func (ke *KernelIP6TGetEntries) MarshalUnsafe(dst []byte) { - // Fall back to safe Marshal because the type in not packed. - ke.MarshalBytes(dst) -} - -// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. -func (ke *KernelIP6TGetEntries) UnmarshalUnsafe(src []byte) { - // Fall back to safe Unmarshal because the type in not packed. - ke.UnmarshalBytes(src) -} - -// CopyIn implements marshal.Marshallable.CopyIn. -func (ke *KernelIP6TGetEntries) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) { - buf := cc.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. - length, err := cc.CopyInBytes(addr, buf) // escapes: okay. - // Unmarshal unconditionally. If we had a short copy-in, this results - // in a partially unmarshalled struct. - ke.UnmarshalBytes(buf) // escapes: fallback. - return length, err -} - -// CopyOut implements marshal.Marshallable.CopyOut. -func (ke *KernelIP6TGetEntries) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) { - // Type KernelIP6TGetEntries doesn't have a packed layout in memory, - // fall back to MarshalBytes. - return cc.CopyOutBytes(addr, ke.marshalAll(cc)) -} - -// CopyOutN implements marshal.Marshallable.CopyOutN. -func (ke *KernelIP6TGetEntries) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) { - // Type KernelIP6TGetEntries doesn't have a packed layout in memory, fall - // back to MarshalBytes. - return cc.CopyOutBytes(addr, ke.marshalAll(cc)[:limit]) -} - -func (ke *KernelIP6TGetEntries) marshalAll(cc marshal.CopyContext) []byte { - buf := cc.CopyScratchBuffer(ke.SizeBytes()) - ke.MarshalBytes(buf) - return buf -} - -// WriteTo implements io.WriterTo.WriteTo. -func (ke *KernelIP6TGetEntries) WriteTo(w io.Writer) (int64, error) { - buf := make([]byte, ke.SizeBytes()) - ke.MarshalBytes(buf) - length, err := w.Write(buf) - return int64(length), err -} - var _ marshal.Marshallable = (*KernelIP6TGetEntries)(nil) // IP6TEntry is an iptables rule. It corresponds to struct ip6t_entry in diff --git a/pkg/sentry/fs/host/file.go b/pkg/sentry/fs/host/file.go index 86d1a87f0..fd4e057d8 100644 --- a/pkg/sentry/fs/host/file.go +++ b/pkg/sentry/fs/host/file.go @@ -180,16 +180,9 @@ func (f *fileOperations) Readdir(ctx context.Context, file *fs.File, serializer // IterateDir implements fs.DirIterator.IterateDir. func (f *fileOperations) IterateDir(ctx context.Context, d *fs.Dirent, dirCtx *fs.DirCtx, offset int) (int, error) { - if f.dirinfo == nil { - f.dirinfo = new(dirInfo) - f.dirinfo.buf = make([]byte, usermem.PageSize) - } - entries, err := f.iops.readdirAll(f.dirinfo) - if err != nil { - return offset, err - } - count, err := fs.GenericReaddir(dirCtx, fs.NewSortedDentryMap(entries)) - return offset + count, err + // We only support non-directory file descriptors that have been + // imported, so just claim that this isn't a directory, even if it is. + return offset, syscall.ENOTDIR } // Write implements fs.FileOperations.Write. diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go index 2c14aa6d9..df4b265fa 100644 --- a/pkg/sentry/fs/host/inode.go +++ b/pkg/sentry/fs/host/inode.go @@ -411,10 +411,3 @@ func (i *inodeOperations) DropLink() {} // NotifyStatusChange implements fs.InodeOperations.NotifyStatusChange. func (i *inodeOperations) NotifyStatusChange(ctx context.Context) {} - -// readdirAll returns all of the directory entries in i. -func (i *inodeOperations) readdirAll(d *dirInfo) (map[string]fs.DentAttr, error) { - // We only support non-directory file descriptors that have been - // imported, so just claim that this isn't a directory, even if it is. - return nil, syscall.ENOTDIR -} diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 07b4fb70f..2b58fc52c 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -16,6 +16,7 @@ package host import ( "fmt" + "sync/atomic" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -206,7 +207,7 @@ func (c *ConnectedEndpoint) Send(ctx context.Context, data [][]byte, controlMess // only as much of the message as fits in the send buffer. truncate := c.stype == linux.SOCK_STREAM - n, totalLen, err := fdWriteVec(c.file.FD(), data, c.sndbuf, truncate) + n, totalLen, err := fdWriteVec(c.file.FD(), data, c.SendMaxQueueSize(), truncate) if n < totalLen && err == nil { // The host only returns a short write if it would otherwise // block (and only for stream sockets). @@ -282,7 +283,7 @@ func (c *ConnectedEndpoint) Recv(ctx context.Context, data [][]byte, creds bool, // N.B. Unix sockets don't have a receive buffer, the send buffer // serves both purposes. - rl, ml, cl, cTrunc, err := fdReadVec(c.file.FD(), data, []byte(cm), peek, c.sndbuf) + rl, ml, cl, cTrunc, err := fdReadVec(c.file.FD(), data, []byte(cm), peek, c.RecvMaxQueueSize()) if rl > 0 && err != nil { // We got some data, so all we need to do on error is return // the data that we got. Short reads are fine, no need to @@ -363,14 +364,14 @@ func (c *ConnectedEndpoint) RecvQueuedSize() int64 { // SendMaxQueueSize implements transport.Receiver.SendMaxQueueSize. func (c *ConnectedEndpoint) SendMaxQueueSize() int64 { - return int64(c.sndbuf) + return atomic.LoadInt64(&c.sndbuf) } // RecvMaxQueueSize implements transport.Receiver.RecvMaxQueueSize. func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 { // N.B. Unix sockets don't use the receive buffer. We'll claim it is // the same size as the send buffer. - return int64(c.sndbuf) + return atomic.LoadInt64(&c.sndbuf) } // Release implements transport.ConnectedEndpoint.Release and transport.Receiver.Release. @@ -381,4 +382,11 @@ func (c *ConnectedEndpoint) Release(ctx context.Context) { // CloseUnread implements transport.ConnectedEndpoint.CloseUnread. func (c *ConnectedEndpoint) CloseUnread() {} +// SetSendBufferSize implements transport.ConnectedEndpoint.SetSendBufferSize. +func (c *ConnectedEndpoint) SetSendBufferSize(v int64) (newSz int64) { + // gVisor does not permit setting of SO_SNDBUF for host backed unix domain + // sockets. + return atomic.LoadInt64(&c.sndbuf) +} + // LINT.ThenChange(../../fsimpl/host/socket.go) diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 98f7bc52f..094d993a8 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -1216,7 +1216,13 @@ func (d *dentry) checkXattrPermissions(creds *auth.Credentials, name string, ats } func (d *dentry) mayDelete(creds *auth.Credentials, child *dentry) error { - return vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&child.uid))) + return vfs.CheckDeleteSticky( + creds, + linux.FileMode(atomic.LoadUint32(&d.mode)), + auth.KUID(atomic.LoadUint32(&d.uid)), + auth.KUID(atomic.LoadUint32(&child.uid)), + auth.KGID(atomic.LoadUint32(&child.gid)), + ) } func dentryUIDFromP9UID(uid p9.UID) uint32 { diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go index 72aa535f8..6763f5b0c 100644 --- a/pkg/sentry/fsimpl/host/socket.go +++ b/pkg/sentry/fsimpl/host/socket.go @@ -16,6 +16,7 @@ package host import ( "fmt" + "sync/atomic" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -111,7 +112,7 @@ func (c *ConnectedEndpoint) init() *syserr.Error { } c.stype = linux.SockType(stype) - c.sndbuf = int64(sndbuf) + atomic.StoreInt64(&c.sndbuf, int64(sndbuf)) return nil } @@ -150,7 +151,7 @@ func (c *ConnectedEndpoint) Send(ctx context.Context, data [][]byte, controlMess // only as much of the message as fits in the send buffer. truncate := c.stype == linux.SOCK_STREAM - n, totalLen, err := fdWriteVec(c.fd, data, c.sndbuf, truncate) + n, totalLen, err := fdWriteVec(c.fd, data, c.SendMaxQueueSize(), truncate) if n < totalLen && err == nil { // The host only returns a short write if it would otherwise // block (and only for stream sockets). @@ -226,7 +227,7 @@ func (c *ConnectedEndpoint) Recv(ctx context.Context, data [][]byte, creds bool, // N.B. Unix sockets don't have a receive buffer, the send buffer // serves both purposes. - rl, ml, cl, cTrunc, err := fdReadVec(c.fd, data, []byte(cm), peek, c.sndbuf) + rl, ml, cl, cTrunc, err := fdReadVec(c.fd, data, []byte(cm), peek, c.RecvMaxQueueSize()) if rl > 0 && err != nil { // We got some data, so all we need to do on error is return // the data that we got. Short reads are fine, no need to @@ -300,14 +301,14 @@ func (c *ConnectedEndpoint) RecvQueuedSize() int64 { // SendMaxQueueSize implements transport.Receiver.SendMaxQueueSize. func (c *ConnectedEndpoint) SendMaxQueueSize() int64 { - return int64(c.sndbuf) + return atomic.LoadInt64(&c.sndbuf) } // RecvMaxQueueSize implements transport.Receiver.RecvMaxQueueSize. func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 { // N.B. Unix sockets don't use the receive buffer. We'll claim it is // the same size as the send buffer. - return int64(c.sndbuf) + return atomic.LoadInt64(&c.sndbuf) } func (c *ConnectedEndpoint) destroyLocked() { @@ -327,6 +328,13 @@ func (c *ConnectedEndpoint) Release(ctx context.Context) { // CloseUnread implements transport.ConnectedEndpoint.CloseUnread. func (c *ConnectedEndpoint) CloseUnread() {} +// SetSendBufferSize implements transport.ConnectedEndpoint.SetSendBufferSize. +func (c *ConnectedEndpoint) SetSendBufferSize(v int64) (newSz int64) { + // gVisor does not permit setting of SO_SNDBUF for host backed unix domain + // sockets. + return atomic.LoadInt64(&c.sndbuf) +} + // SCMConnectedEndpoint represents an endpoint backed by a host fd that was // passed through a gofer Unix socket. It resembles ConnectedEndpoint, with the // following differences: diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index a7a553619..d6dd6bc41 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -668,6 +668,12 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // Can we create the dst dentry? var dst *Dentry pc := rp.Component() + if pc == "." || pc == ".." { + if noReplace { + return syserror.EEXIST + } + return syserror.EBUSY + } switch err := checkCreateLocked(ctx, rp.Credentials(), pc, dstDir); err { case nil: // Ok, continue with rename as replacement. diff --git a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go index 463d77d79..11694c392 100644 --- a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go +++ b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go @@ -42,19 +42,16 @@ type syntheticDirectory struct { var _ Inode = (*syntheticDirectory)(nil) func newSyntheticDirectory(ctx context.Context, creds *auth.Credentials, perm linux.FileMode) Inode { - inode := &syntheticDirectory{} - inode.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, 0 /* ino */, perm) - return inode -} - -func (dir *syntheticDirectory) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { if perm&^linux.PermissionsMask != 0 { panic(fmt.Sprintf("perm contains non-permission bits: %#o", perm)) } - dir.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.S_IFDIR|perm) + dir := &syntheticDirectory{} + dir.InitRefs() + dir.InodeAttrs.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, 0 /* ino */, linux.S_IFDIR|perm) dir.OrderedChildren.Init(OrderedChildrenOptions{ Writable: true, }) + return dir } // Open implements Inode.Open. diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index e46f593c7..b36031291 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -1068,7 +1068,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if err != nil { return err } - if err := vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&oldParent.mode)), auth.KUID(atomic.LoadUint32(&renamed.uid))); err != nil { + if err := oldParent.mayDelete(creds, renamed); err != nil { return err } if renamed.isDir() { @@ -1317,7 +1317,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error if !child.isDir() { return syserror.ENOTDIR } - if err := vfs.CheckDeleteSticky(rp.Credentials(), linux.FileMode(atomic.LoadUint32(&parent.mode)), auth.KUID(atomic.LoadUint32(&child.uid))); err != nil { + if err := parent.mayDelete(rp.Credentials(), child); err != nil { return err } child.dirMu.Lock() @@ -1584,7 +1584,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error if child.isDir() { return syserror.EISDIR } - if err := vfs.CheckDeleteSticky(rp.Credentials(), linux.FileMode(parentMode), auth.KUID(atomic.LoadUint32(&child.uid))); err != nil { + if err := parent.mayDelete(rp.Credentials(), child); err != nil { return err } if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil { diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go index 082fa6504..acd3684c6 100644 --- a/pkg/sentry/fsimpl/overlay/overlay.go +++ b/pkg/sentry/fsimpl/overlay/overlay.go @@ -760,6 +760,16 @@ func (d *dentry) updateAfterSetStatLocked(opts *vfs.SetStatOptions) { } } +func (d *dentry) mayDelete(creds *auth.Credentials, child *dentry) error { + return vfs.CheckDeleteSticky( + creds, + linux.FileMode(atomic.LoadUint32(&d.mode)), + auth.KUID(atomic.LoadUint32(&d.uid)), + auth.KUID(atomic.LoadUint32(&child.uid)), + auth.KGID(atomic.LoadUint32(&child.gid)), + ) +} + // fileDescription is embedded by overlay implementations of // vfs.FileDescriptionImpl. // diff --git a/pkg/sentry/fsimpl/tmpfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go index e90669cf0..417ac2eff 100644 --- a/pkg/sentry/fsimpl/tmpfs/directory.go +++ b/pkg/sentry/fsimpl/tmpfs/directory.go @@ -84,7 +84,13 @@ func (dir *directory) removeChildLocked(child *dentry) { } func (dir *directory) mayDelete(creds *auth.Credentials, child *dentry) error { - return vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&dir.inode.mode)), auth.KUID(atomic.LoadUint32(&child.inode.uid))) + return vfs.CheckDeleteSticky( + creds, + linux.FileMode(atomic.LoadUint32(&dir.inode.mode)), + auth.KUID(atomic.LoadUint32(&dir.inode.uid)), + auth.KUID(atomic.LoadUint32(&child.inode.uid)), + auth.KGID(atomic.LoadUint32(&child.inode.gid)), + ) } // +stateify savable diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index 6255a7c84..82a743ff3 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -656,6 +656,9 @@ func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, // Write to that memory as usual. seg, gap = rw.file.data.Insert(gap, gapMR, fr.Start), fsutil.FileRangeGapIterator{} + + default: + panic("unreachable") } } exitLoop: diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index c551acd99..2c8668fc4 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -247,11 +247,15 @@ func (p *Pipe) writeLocked(count int64, f func(safemem.BlockSeq) (uint64, error) return 0, syscall.EPIPE } - // POSIX requires that a write smaller than atomicIOBytes (PIPE_BUF) be - // atomic, but requires no atomicity for writes larger than this. avail := p.max - p.size + if avail == 0 { + return 0, syserror.ErrWouldBlock + } short := false if count > avail { + // POSIX requires that a write smaller than atomicIOBytes + // (PIPE_BUF) be atomic, but requires no atomicity for writes + // larger than this. if count <= atomicIOBytes { return 0, syserror.ErrWouldBlock } diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index 16986244c..f7765fa3a 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -415,6 +415,12 @@ func (tg *ThreadGroup) anyNonExitingTaskLocked() *Task { func (t *Task) reparentLocked(parent *Task) { oldParent := t.parent t.parent = parent + if oldParent != nil { + delete(oldParent.children, t) + } + if parent != nil { + parent.children[t] = struct{}{} + } // If a thread group leader's parent changes, reset the thread group's // termination signal to SIGCHLD and re-check exit notification. (Compare // kernel/exit.c:reparent_leader().) diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index 98af2cc38..cd9fa4031 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -517,12 +517,14 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, in start, ok = start.AddLength(uint64(offset)) if !ok { - panic(fmt.Sprintf("Start %#x + offset %#x overflows?", start, offset)) + ctx.Infof(fmt.Sprintf("Start %#x + offset %#x overflows?", start, offset)) + return loadedELF{}, syserror.EINVAL } end, ok = end.AddLength(uint64(offset)) if !ok { - panic(fmt.Sprintf("End %#x + offset %#x overflows?", end, offset)) + ctx.Infof(fmt.Sprintf("End %#x + offset %#x overflows?", end, offset)) + return loadedELF{}, syserror.EINVAL } info.entry, ok = info.entry.AddLength(uint64(offset)) diff --git a/pkg/sentry/platform/ptrace/filters.go b/pkg/sentry/platform/ptrace/filters.go index b0970e356..20fc62acb 100644 --- a/pkg/sentry/platform/ptrace/filters.go +++ b/pkg/sentry/platform/ptrace/filters.go @@ -17,14 +17,12 @@ package ptrace import ( "syscall" - "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/seccomp" ) // SyscallFilters returns syscalls made exclusively by the ptrace platform. func (*PTrace) SyscallFilters() seccomp.SyscallRules { return seccomp.SyscallRules{ - unix.SYS_GETCPU: {}, syscall.SYS_PTRACE: {}, syscall.SYS_TGKILL: {}, syscall.SYS_WAIT4: {}, diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go index f82c7c224..dc03ccb47 100644 --- a/pkg/sentry/socket/hostinet/socket_vfs2.go +++ b/pkg/sentry/socket/hostinet/socket_vfs2.go @@ -80,8 +80,7 @@ func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol in // Release implements vfs.FileDescriptionImpl.Release. func (s *socketVFS2) Release(ctx context.Context) { - t := kernel.TaskFromContext(ctx) - t.Kernel().DeleteSocketVFS2(&s.vfsfd) + kernel.KernelFromContext(ctx).DeleteSocketVFS2(&s.vfsfd) s.socketOpsCommon.Release(ctx) } diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 69693f263..cee8120ab 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -855,10 +855,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - size, err := ep.SocketOptions().GetSendBufferSize() - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } + size := ep.SocketOptions().GetSendBufferSize() if size > math.MaxInt32 { size = math.MaxInt32 @@ -1647,13 +1644,6 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam return syserr.ErrInvalidArgument } - family, _, _ := s.Type() - // TODO(gvisor.dev/issue/5132): We currently do not support - // setting this option for unix sockets. - if family == linux.AF_UNIX { - return nil - } - v := usermem.ByteOrder.Uint32(optVal) ep.SocketOptions().SetSendBufferSize(int64(v), true) return nil diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index 24922c400..fc29f8f13 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -79,8 +79,7 @@ func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, qu // Release implements vfs.FileDescriptionImpl.Release. func (s *SocketVFS2) Release(ctx context.Context) { - t := kernel.TaskFromContext(ctx) - t.Kernel().DeleteSocketVFS2(&s.vfsfd) + kernel.KernelFromContext(ctx).DeleteSocketVFS2(&s.vfsfd) s.socketOpsCommon.Release(ctx) } diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index cce0acc33..acf2ab8e7 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -51,6 +51,7 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/sockfs", + "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/kernel/time", "//pkg/sentry/socket", diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 3ebbd28b0..0d11bb251 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -32,6 +32,7 @@ go_library( "connectioned.go", "connectioned_state.go", "connectionless.go", + "connectionless_state.go", "queue.go", "queue_refs.go", "transport_message_list.go", @@ -45,6 +46,7 @@ go_library( "//pkg/log", "//pkg/refs", "//pkg/refsvfs2", + "//pkg/sentry/inet", "//pkg/sync", "//pkg/syserr", "//pkg/tcpip", diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index fc5b823b0..809c95429 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -128,7 +128,9 @@ func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProv idGenerator: uid, stype: stype, } - ep.ops.InitHandler(ep, nil, nil) + + ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */) + ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits) return ep } @@ -137,9 +139,9 @@ func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (E a := newConnectioned(ctx, stype, uid) b := newConnectioned(ctx, stype, uid) - q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit} + q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: defaultBufferSize} q1.InitRefs() - q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: initialLimit} + q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: defaultBufferSize} q2.InitRefs() if stype == linux.SOCK_STREAM { @@ -173,7 +175,8 @@ func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider idGenerator: uid, stype: stype, } - ep.ops.InitHandler(ep, nil, nil) + ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits) + ep.ops.SetSendBufferSize(connected.SendMaxQueueSize(), false /* notify */) return ep } @@ -296,16 +299,18 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn idGenerator: e.idGenerator, stype: e.stype, } - ne.ops.InitHandler(ne, nil, nil) + ne.ops.InitHandler(ne, &stackHandler{}, getSendBufferLimits) + ne.ops.SetSendBufferSize(defaultBufferSize, false /* notify */) - readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit} + readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: defaultBufferSize} readQueue.InitRefs() ne.connected = &connectedEndpoint{ endpoint: ce, writeQueue: readQueue, } - writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: initialLimit} + // Make sure the accepted endpoint inherits this listening socket's SO_SNDBUF. + writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: e.ops.GetSendBufferSize()} writeQueue.InitRefs() if e.stype == linux.SOCK_STREAM { ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}} @@ -357,6 +362,9 @@ func (e *connectionedEndpoint) Connect(ctx context.Context, server BoundEndpoint returnConnect := func(r Receiver, ce ConnectedEndpoint) { e.receiver = r e.connected = ce + // Make sure the newly created connected endpoint's write queue is updated + // to reflect this endpoint's send buffer size. + e.connected.SetSendBufferSize(e.ops.GetSendBufferSize()) } return server.BidirectionalConnect(ctx, e, returnConnect) @@ -495,3 +503,11 @@ func (e *connectionedEndpoint) State() uint32 { } return linux.SS_UNCONNECTED } + +// OnSetSendBufferSize implements tcpip.SocketOptionsHandler.OnSetSendBufferSize. +func (e *connectionedEndpoint) OnSetSendBufferSize(v int64) (newSz int64) { + if e.Connected() { + return e.baseEndpoint.connected.SetSendBufferSize(v) + } + return v +} diff --git a/pkg/sentry/socket/unix/transport/connectioned_state.go b/pkg/sentry/socket/unix/transport/connectioned_state.go index 7e02a5db8..590b0bd01 100644 --- a/pkg/sentry/socket/unix/transport/connectioned_state.go +++ b/pkg/sentry/socket/unix/transport/connectioned_state.go @@ -51,3 +51,8 @@ func (e *connectionedEndpoint) loadAcceptedChan(acceptedSlice []*connectionedEnd } } } + +// afterLoad is invoked by stateify. +func (e *connectionedEndpoint) afterLoad() { + e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits) +} diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 20fa8b874..0be78480c 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -41,10 +41,11 @@ var ( // NewConnectionless creates a new unbound dgram endpoint. func NewConnectionless(ctx context.Context) Endpoint { ep := &connectionlessEndpoint{baseEndpoint{Queue: &waiter.Queue{}}} - q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit} + q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: defaultBufferSize} q.InitRefs() ep.receiver = &queueReceiver{readQueue: &q} - ep.ops.InitHandler(ep, nil, nil) + ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */) + ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits) return ep } @@ -217,3 +218,11 @@ func (e *connectionlessEndpoint) State() uint32 { return linux.SS_DISCONNECTING } } + +// OnSetSendBufferSize implements tcpip.SocketOptionsHandler.OnSetSendBufferSize. +func (e *connectionlessEndpoint) OnSetSendBufferSize(v int64) (newSz int64) { + if e.Connected() { + return e.baseEndpoint.connected.SetSendBufferSize(v) + } + return v +} diff --git a/pkg/sentry/socket/unix/transport/connectionless_state.go b/pkg/sentry/socket/unix/transport/connectionless_state.go new file mode 100644 index 000000000..2ef337ec8 --- /dev/null +++ b/pkg/sentry/socket/unix/transport/connectionless_state.go @@ -0,0 +1,20 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transport + +// afterLoad is invoked by stateify. +func (e *connectionlessEndpoint) afterLoad() { + e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits) +} diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go index 342def28f..698a9a82c 100644 --- a/pkg/sentry/socket/unix/transport/queue.go +++ b/pkg/sentry/socket/unix/transport/queue.go @@ -237,9 +237,18 @@ func (q *queue) QueuedSize() int64 { // MaxQueueSize returns the maximum number of bytes storable in the queue. func (q *queue) MaxQueueSize() int64 { + q.mu.Lock() + defer q.mu.Unlock() return q.limit } +// SetMaxQueueSize sets the maximum number of bytes storable in the queue. +func (q *queue) SetMaxQueueSize(v int64) { + q.mu.Lock() + defer q.mu.Unlock() + q.limit = v +} + // CloseUnread sets flag to indicate that the peer is closed (not shutdown) // with unread data. So if read on this queue shall return ECONNRESET error. func (q *queue) CloseUnread() { diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 70227bbd2..ceada54a8 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -26,8 +26,16 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// initialLimit is the starting limit for the socket buffers. -const initialLimit = 16 * 1024 +const ( + // The minimum size of the send/receive buffers. + minimumBufferSize = 4 << 10 // 4 KiB (match default in linux) + + // The default size of the send/receive buffers. + defaultBufferSize = 208 << 10 // 208 KiB (default in linux for net.core.wmem_default) + + // The maximum permitted size for the send/receive buffers. + maxBufferSize = 4 << 20 // 4 MiB 4 MiB (default in linux for net.core.wmem_max) +) // A RightsControlMessage is a control message containing FDs. // @@ -627,6 +635,10 @@ type ConnectedEndpoint interface { // CloseUnread sets the fact that this end is closed with unread data to // the peer socket. CloseUnread() + + // SetSendBufferSize is called when the endpoint's send buffer size is + // changed. + SetSendBufferSize(v int64) (newSz int64) } // +stateify savable @@ -722,6 +734,14 @@ func (e *connectedEndpoint) CloseUnread() { e.writeQueue.CloseUnread() } +// SetSendBufferSize implements ConnectedEndpoint.SetSendBufferSize. +// SetSendBufferSize sets the send buffer size for the write queue to the +// specified value. +func (e *connectedEndpoint) SetSendBufferSize(v int64) (newSz int64) { + e.writeQueue.SetMaxQueueSize(v) + return v +} + // baseEndpoint is an embeddable unix endpoint base used in both the connected and connectionless // unix domain socket Endpoint implementations. // @@ -849,27 +869,6 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { return nil } -// IsUnixSocket implements tcpip.SocketOptionsHandler.IsUnixSocket. -func (e *baseEndpoint) IsUnixSocket() bool { - return true -} - -// GetSendBufferSize implements tcpip.SocketOptionsHandler.GetSendBufferSize. -func (e *baseEndpoint) GetSendBufferSize() (int64, tcpip.Error) { - e.Lock() - defer e.Unlock() - - if !e.Connected() { - return -1, &tcpip.ErrNotConnected{} - } - - v := e.connected.SendMaxQueueSize() - if v < 0 { - return -1, &tcpip.ErrQueueSizeNotSupported{} - } - return v, nil -} - func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: @@ -987,3 +986,35 @@ func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { func (*baseEndpoint) Release(context.Context) { // Binding a baseEndpoint doesn't take a reference. } + +// stackHandler is just a stub implementation of tcpip.StackHandler to provide +// when initializing socketoptions. +type stackHandler struct { +} + +// Option implements tcpip.StackHandler. +func (h *stackHandler) Option(option interface{}) tcpip.Error { + panic("unimplemented") +} + +// TransportProtocolOption implements tcpip.StackHandler. +func (h *stackHandler) TransportProtocolOption(proto tcpip.TransportProtocolNumber, option tcpip.GettableTransportProtocolOption) tcpip.Error { + panic("unimplemented") +} + +// getSendBufferLimits implements tcpip.GetSendBufferLimits. +// +// AF_UNIX sockets buffer sizes are not tied to the networking stack/namespace +// in linux but are bound by net.core.(wmem|rmem)_(max|default). +// +// In gVisor net.core sysctls today are not exposed or if exposed are currently +// tied to the networking stack in use. This makes it complicated for AF_UNIX +// when we are in a new namespace w/ no networking stack. As a result for now we +// define default/max values here in the unix socket implementation itself. +func getSendBufferLimits(tcpip.StackHandler) tcpip.SendBufferSizeOption { + return tcpip.SendBufferSizeOption{ + Min: minimumBufferSize, + Default: defaultBufferSize, + Max: maxBufferSize, + } +} diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index a7d4d7f1f..9c037cbae 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -95,8 +95,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3 // DecRef implements RefCounter.DecRef. func (s *SocketVFS2) DecRef(ctx context.Context) { s.socketVFS2Refs.DecRef(func() { - t := kernel.TaskFromContext(ctx) - t.Kernel().DeleteSocketVFS2(&s.vfsfd) + kernel.KernelFromContext(ctx).DeleteSocketVFS2(&s.vfsfd) s.ep.Close(ctx) if s.abstractNamespace != nil { s.abstractNamespace.Remove(s.abstractName, s) diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index a2e441448..4188502dc 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -62,7 +62,6 @@ go_library( deps = [ "//pkg/abi", "//pkg/abi/linux", - "//pkg/binary", "//pkg/bpf", "//pkg/context", "//pkg/log", diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index fe45225c1..686392cc8 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -18,7 +18,6 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -457,7 +456,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, e.ToError() } - vLen := int32(binary.Size(v)) + vLen := int32(v.SizeBytes()) if _, err := primitive.CopyInt32Out(t, optLenAddr, vLen); err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD index 9ee766552..2e59bd5b1 100644 --- a/pkg/sentry/syscalls/linux/vfs2/BUILD +++ b/pkg/sentry/syscalls/linux/vfs2/BUILD @@ -39,7 +39,6 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/bits", "//pkg/context", "//pkg/fspath", diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index f5795b4a8..7636ca453 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -18,7 +18,6 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -460,7 +459,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, e.ToError() } - vLen := int32(binary.Size(v)) + vLen := int32(v.SizeBytes()) if _, err := primitive.CopyInt32Out(t, optLenAddr, vLen); err != nil { return 0, nil, err } diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index eb7d2fd3b..d2050b3f7 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -238,6 +238,8 @@ func (s *StaticData) Generate(ctx context.Context, buf *bytes.Buffer) error { // WritableDynamicBytesSource extends DynamicBytesSource to allow writes to the // underlying source. +// +// TODO(b/179825241): Make utility for integer-based writable files. type WritableDynamicBytesSource interface { DynamicBytesSource diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go index d48520d58..db6146fd2 100644 --- a/pkg/sentry/vfs/permissions.go +++ b/pkg/sentry/vfs/permissions.go @@ -243,11 +243,13 @@ func CheckSetStat(ctx context.Context, creds *auth.Credentials, opts *SetStatOpt // the given file mode, and if so, checks whether creds has permission to // remove a file owned by childKUID from a directory with the given mode. // CheckDeleteSticky is consistent with fs/linux.h:check_sticky(). -func CheckDeleteSticky(creds *auth.Credentials, parentMode linux.FileMode, childKUID auth.KUID) error { +func CheckDeleteSticky(creds *auth.Credentials, parentMode linux.FileMode, parentKUID auth.KUID, childKUID auth.KUID, childKGID auth.KGID) error { if parentMode&linux.ModeSticky == 0 { return nil } - if CanActAsOwner(creds, childKUID) { + if creds.EffectiveKUID == childKUID || + creds.EffectiveKUID == parentKUID || + HasCapabilityOnFile(creds, linux.CAP_FOWNER, childKUID, childKGID) { return nil } return syserror.EPERM diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go index bbe84f220..21fb87757 100644 --- a/pkg/tcpip/link/pipe/pipe.go +++ b/pkg/tcpip/link/pipe/pipe.go @@ -46,6 +46,10 @@ type Endpoint struct { } func (e *Endpoint) deliverPackets(r stack.RouteInfo, proto tcpip.NetworkProtocolNumber, pkts stack.PacketBufferList) { + if !e.linked.IsAttached() { + return + } + // Note that the local address from the perspective of this endpoint is the // remote address from the perspective of the other end of the pipe // (e.linked). Similarly, the remote address from the perspective of this @@ -54,44 +58,26 @@ func (e *Endpoint) deliverPackets(r stack.RouteInfo, proto tcpip.NetworkProtocol // Deliver the packet in a new goroutine to escape this goroutine's stack and // avoid a deadlock when a packet triggers a response which leads the stack to // try and take a lock it already holds. - // - // As of writing, a deadlock may occur when performing link resolution as the - // neighbor table will send a solicitation while holding a lock and the - // response advertisement will be sent in the same stack that sent the - // solictation. When the response is received, the stack attempts to take the - // same lock it already took before sending the solicitation, leading to a - // deadlock. Basically, we attempt to lock the same lock twice in the same - // call stack. - // - // TODO(gvisor.dev/issue/5289): don't use a new goroutine once we support send - // and receive queues. - go func() { - for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), - })) - } - }() + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + })) + } } // WritePacket implements stack.LinkEndpoint. func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - if e.linked.IsAttached() { - var pkts stack.PacketBufferList - pkts.PushBack(pkt) - e.deliverPackets(r, proto, pkts) - } - + var pkts stack.PacketBufferList + pkts.PushBack(pkt) + e.deliverPackets(r, proto, pkts) return nil } // WritePackets implements stack.LinkEndpoint. func (e *Endpoint) WritePackets(r stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - if e.linked.IsAttached() { - e.deliverPackets(r, proto, pkts) - } - - return pkts.Len(), nil + n := pkts.Len() + e.deliverPackets(r, proto, pkts) + return n, nil } // Attach implements stack.LinkEndpoint. diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index 0caa65251..fa8814bac 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -16,7 +16,6 @@ go_test( "//pkg/tcpip/checker", "//pkg/tcpip/faketime", "//pkg/tcpip/header", - "//pkg/tcpip/header/parse", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index 933845269..d59d678b2 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -10,10 +10,12 @@ go_library( ], visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/header/parse", + "//pkg/tcpip/network/internal/ip", "//pkg/tcpip/stack", ], ) @@ -44,7 +46,7 @@ go_test( library = ":arp", deps = [ "//pkg/tcpip", - "//pkg/tcpip/network/testutil", + "//pkg/tcpip/network/internal/testutil", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 0d7fadc31..3fcdea119 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -22,10 +22,12 @@ import ( "reflect" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header/parse" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -34,6 +36,7 @@ const ( ProtocolNumber = header.ARPProtocolNumber ) +var _ stack.DuplicateAddressDetector = (*endpoint)(nil) var _ stack.LinkAddressResolver = (*endpoint)(nil) // ARP endpoints need to implement stack.NetworkEndpoint because the stack @@ -52,6 +55,35 @@ type endpoint struct { nic stack.NetworkInterface stats sharedStats + + mu struct { + sync.Mutex + + dad ip.DAD + } +} + +// CheckDuplicateAddress implements stack.DuplicateAddressDetector. +func (e *endpoint) CheckDuplicateAddress(addr tcpip.Address, h stack.DADCompletionHandler) stack.DADCheckAddressDisposition { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.dad.CheckDuplicateAddressLocked(addr, h) +} + +// SetDADConfigurations implements stack.DuplicateAddressDetector. +func (e *endpoint) SetDADConfigurations(c stack.DADConfigurations) { + e.mu.Lock() + defer e.mu.Unlock() + e.mu.dad.SetConfigsLocked(c) +} + +// DuplicateAddressProtocol implements stack.DuplicateAddressDetector. +func (*endpoint) DuplicateAddressProtocol() tcpip.NetworkProtocolNumber { + return header.IPv4ProtocolNumber +} + +func (e *endpoint) SendDADMessage(addr tcpip.Address) tcpip.Error { + return e.sendARPRequest(header.IPv4Any, addr, header.EthernetBroadcastAddress) } func (e *endpoint) Enable() tcpip.Error { @@ -129,6 +161,11 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } + if _, _, ok := e.protocol.Parse(pkt); !ok { + stats.malformedPacketsReceived.Increment() + return + } + h := header.ARP(pkt.NetworkHeader().View()) if !h.IsValid() { stats.malformedPacketsReceived.Increment() @@ -140,7 +177,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { stats.requestsReceived.Increment() localAddr := tcpip.Address(h.ProtocolAddressTarget()) - if e.protocol.stack.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { + if !e.nic.CheckLocalAddress(header.IPv4ProtocolNumber, localAddr) { stats.requestsReceivedUnknownTargetAddress.Increment() return // we have no useful answer, ignore the request } @@ -194,6 +231,10 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { addr := tcpip.Address(h.ProtocolAddressSender()) linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) + e.mu.Lock() + e.mu.dad.StopLocked(addr, false /* aborted */) + e.mu.Unlock() + // The solicited, override, and isRouter flags are not available for ARP; // they are only available for IPv6 Neighbor Advertisements. switch err := e.nic.HandleNeighborConfirmation(header.IPv4ProtocolNumber, addr, linkAddr, stack.ReachabilityConfirmationFlags{ @@ -222,9 +263,9 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats { var _ stack.NetworkProtocol = (*protocol)(nil) -// protocol implements stack.NetworkProtocol and stack.LinkAddressResolver. type protocol struct { - stack *stack.Stack + stack *stack.Stack + options Options } func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber } @@ -241,6 +282,14 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran nic: nic, } + e.mu.Lock() + e.mu.dad.Init(&e.mu, p.options.DADConfigs, ip.DADOptions{ + Clock: p.stack.Clock(), + Protocol: e, + NICID: nic.ID(), + }) + e.mu.Unlock() + tcpip.InitStatCounters(reflect.ValueOf(&e.stats.localStats).Elem()) stackStats := p.stack.Stats() @@ -276,13 +325,17 @@ func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot } localAddr = addr.Address - } else if e.protocol.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 { + } else if !e.nic.CheckLocalAddress(header.IPv4ProtocolNumber, localAddr) { stats.outgoingRequestBadLocalAddressErrors.Increment() return &tcpip.ErrBadLocalAddress{} } + return e.sendARPRequest(localAddr, targetAddr, remoteLinkAddr) +} + +func (e *endpoint) sendARPRequest(localAddr, targetAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize, + ReserveHeaderBytes: int(e.MaxHeaderLength()), }) h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) pkt.NetworkProtocolNumber = ProtocolNumber @@ -297,6 +350,8 @@ func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot if n := copy(h.ProtocolAddressTarget(), targetAddr); n != header.IPv4AddressSize { panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) } + + stats := e.stats.arp if err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { stats.outgoingRequestsDropped.Increment() return err @@ -337,9 +392,24 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu return 0, false, parse.ARP(pkt) } +// Options holds options to configure a protocol. +type Options struct { + // DADConfigs is the default DAD configurations used by ARP endpoints. + DADConfigs stack.DADConfigurations +} + +// NewProtocolWithOptions returns an ARP network protocol factory that +// will return an ARP network protocol with the provided options. +func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { + return func(s *stack.Stack) stack.NetworkProtocol { + return &protocol{ + stack: s, + options: opts, + } + } +} + // NewProtocol returns an ARP network protocol. func NewProtocol(s *stack.Stack) stack.NetworkProtocol { - return &protocol{ - stack: s, - } + return NewProtocolWithOptions(Options{})(s) } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 24357e15d..018d6a578 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -17,7 +17,6 @@ package arp_test import ( "context" "fmt" - "strconv" "testing" "time" @@ -155,7 +154,7 @@ type testContext struct { nudDisp *arpDispatcher } -func newTestContext(t *testing.T, useNeighborCache bool) *testContext { +func newTestContext(t *testing.T) *testContext { c := stack.DefaultNUDConfigurations() // Transition from Reachable to Stale almost immediately to test if receiving // probes refreshes positive reachability. @@ -173,7 +172,6 @@ func newTestContext(t *testing.T, useNeighborCache bool) *testContext { TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, NUDConfigs: c, NUDDisp: &d, - UseNeighborCache: useNeighborCache, }) ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) @@ -191,15 +189,6 @@ func newTestContext(t *testing.T, useNeighborCache bool) *testContext { if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { t.Fatalf("AddAddress for ipv4 failed: %v", err) } - if !useNeighborCache { - // The remote address needs to be assigned to the NIC so we can receive and - // verify outgoing ARP packets. The neighbor cache isn't concerned with - // this; the tests that use linkAddrCache expect the ARP responses to be - // received by the same NIC. - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, remoteAddr); err != nil { - t.Fatalf("AddAddress for ipv4 failed: %v", err) - } - } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, @@ -217,86 +206,8 @@ func (c *testContext) cleanup() { c.linkEP.Close() } -func TestDirectRequest(t *testing.T) { - c := newTestContext(t, false /* useNeighborCache */) - defer c.cleanup() - - const senderMAC = "\x01\x02\x03\x04\x05\x06" - const senderIPv4 = "\x0a\x00\x00\x02" - - v := make(buffer.View, header.ARPSize) - h := header.ARP(v) - h.SetIPv4OverEthernet() - h.SetOp(header.ARPRequest) - copy(h.HardwareAddressSender(), senderMAC) - copy(h.ProtocolAddressSender(), senderIPv4) - - inject := func(addr tcpip.Address) { - copy(h.ProtocolAddressTarget(), addr) - c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: v.ToVectorisedView(), - })) - } - - for i, address := range []tcpip.Address{stackAddr, remoteAddr} { - t.Run(strconv.Itoa(i), func(t *testing.T) { - expectedPacketsReceived := c.s.Stats().ARP.PacketsReceived.Value() + 1 - expectedRequestsReceived := c.s.Stats().ARP.RequestsReceived.Value() + 1 - expectedRepliesSent := c.s.Stats().ARP.OutgoingRepliesSent.Value() + 1 - - inject(address) - pi, _ := c.linkEP.ReadContext(context.Background()) - if pi.Proto != arp.ProtocolNumber { - t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) - } - rep := header.ARP(pi.Pkt.NetworkHeader().View()) - if !rep.IsValid() { - t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep) - } - if got := rep.Op(); got != header.ARPReply { - t.Fatalf("got Op = %d, want = %d", got, header.ARPReply) - } - if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { - t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) - } - if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want { - t.Errorf("got ProtocolAddressSender = %s, want = %s", got, want) - } - if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress(h.HardwareAddressSender()); got != want { - t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want) - } - if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want { - t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, want) - } - - if got := c.s.Stats().ARP.PacketsReceived.Value(); got != expectedPacketsReceived { - t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, expectedPacketsReceived) - } - if got := c.s.Stats().ARP.RequestsReceived.Value(); got != expectedRequestsReceived { - t.Errorf("got c.s.Stats().ARP.PacketsReceived.Value() = %d, want = %d", got, expectedRequestsReceived) - } - if got := c.s.Stats().ARP.OutgoingRepliesSent.Value(); got != expectedRepliesSent { - t.Errorf("got c.s.Stats().ARP.OutgoingRepliesSent.Value() = %d, want = %d", got, expectedRepliesSent) - } - }) - } - - inject(unknownAddr) - // Sleep tests are gross, but this will only potentially flake - // if there's a bug. If there is no bug this will reliably - // succeed. - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - if pkt, ok := c.linkEP.ReadContext(ctx); ok { - t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) - } - if got := c.s.Stats().ARP.RequestsReceivedUnknownTargetAddress.Value(); got != 1 { - t.Errorf("got c.s.Stats().ARP.RequestsReceivedUnKnownTargetAddress.Value() = %d, want = 1", got) - } -} - func TestMalformedPacket(t *testing.T) { - c := newTestContext(t, false) + c := newTestContext(t) defer c.cleanup() v := make(buffer.View, header.ARPSize) @@ -315,7 +226,7 @@ func TestMalformedPacket(t *testing.T) { } func TestDisabledEndpoint(t *testing.T) { - c := newTestContext(t, false) + c := newTestContext(t) defer c.cleanup() ep, err := c.s.GetNetworkEndpoint(nicID, header.ARPProtocolNumber) @@ -340,7 +251,7 @@ func TestDisabledEndpoint(t *testing.T) { } func TestDirectReply(t *testing.T) { - c := newTestContext(t, false) + c := newTestContext(t) defer c.cleanup() const senderMAC = "\x01\x02\x03\x04\x05\x06" @@ -370,8 +281,8 @@ func TestDirectReply(t *testing.T) { } } -func TestDirectRequestWithNeighborCache(t *testing.T) { - c := newTestContext(t, true /* useNeighborCache */) +func TestDirectRequest(t *testing.T) { + c := newTestContext(t) defer c.cleanup() tests := []struct { @@ -748,3 +659,53 @@ func TestLinkAddressRequest(t *testing.T) { }) } } + +func TestDADARPRequestPacket(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocolWithOptions(arp.Options{ + DADConfigs: stack.DADConfigurations{ + DupAddrDetectTransmits: 1, + RetransmitTimer: time.Second, + }, + }), ipv4.NewProtocol}, + }) + e := channel.New(1, defaultMTU, stackLinkAddr) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + + if res, err := s.CheckDuplicateAddress(nicID, header.IPv4ProtocolNumber, remoteAddr, func(stack.DADResult) {}); err != nil { + t.Fatalf("s.CheckDuplicateAddress(%d, %d, %s, _): %s", nicID, header.IPv4ProtocolNumber, remoteAddr, err) + } else if res != stack.DADStarting { + t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, header.IPv4ProtocolNumber, remoteAddr, res, stack.DADStarting) + } + + pkt, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatal("expected to send an ARP request") + } + + if pkt.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) + } + + req := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader())) + if !req.IsValid() { + t.Errorf("got req.IsValid() = false, want = true") + } + if got := req.Op(); got != header.ARPRequest { + t.Errorf("got req.Op() = %d, want = %d", got, header.ARPRequest) + } + if got := tcpip.LinkAddress(req.HardwareAddressSender()); got != stackLinkAddr { + t.Errorf("got req.HardwareAddressSender() = %s, want = %s", got, stackLinkAddr) + } + if got := tcpip.Address(req.ProtocolAddressSender()); got != header.IPv4Any { + t.Errorf("got req.ProtocolAddressSender() = %s, want = %s", got, header.IPv4Any) + } + if got, want := tcpip.LinkAddress(req.HardwareAddressTarget()), tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"); got != want { + t.Errorf("got req.HardwareAddressTarget() = %s, want = %s", got, want) + } + if got := tcpip.Address(req.ProtocolAddressTarget()); got != remoteAddr { + t.Errorf("got req.ProtocolAddressTarget() = %s, want = %s", got, remoteAddr) + } +} diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go index 65c708ac4..e867b3c3f 100644 --- a/pkg/tcpip/network/arp/stats_test.go +++ b/pkg/tcpip/network/arp/stats_test.go @@ -19,7 +19,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/testutil" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" ) diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/internal/fragmentation/BUILD index 429af69ee..274f09092 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/internal/fragmentation/BUILD @@ -22,7 +22,10 @@ go_library( "reassembler.go", "reassembler_list.go", ], - visibility = ["//visibility:public"], + visibility = [ + "//pkg/tcpip/network/ipv4:__pkg__", + "//pkg/tcpip/network/ipv6:__pkg__", + ], deps = [ "//pkg/log", "//pkg/sync", @@ -44,7 +47,7 @@ go_test( deps = [ "//pkg/tcpip/buffer", "//pkg/tcpip/faketime", - "//pkg/tcpip/network/testutil", + "//pkg/tcpip/network/internal/testutil", "//pkg/tcpip/stack", "@com_github_google_go_cmp//cmp:go_default_library", ], diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/internal/fragmentation/fragmentation.go index 243738951..243738951 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/internal/fragmentation/fragmentation.go diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go index 905bbc19b..47ea3173e 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go @@ -22,7 +22,7 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/network/testutil" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" ) diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/internal/fragmentation/reassembler.go index 933d63d32..933d63d32 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/internal/fragmentation/reassembler.go diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/internal/fragmentation/reassembler_test.go index 214a93709..214a93709 100644 --- a/pkg/tcpip/network/fragmentation/reassembler_test.go +++ b/pkg/tcpip/network/internal/fragmentation/reassembler_test.go diff --git a/pkg/tcpip/network/ip/BUILD b/pkg/tcpip/network/internal/ip/BUILD index 411bca25d..d21b4c7ef 100644 --- a/pkg/tcpip/network/ip/BUILD +++ b/pkg/tcpip/network/internal/ip/BUILD @@ -5,25 +5,35 @@ package(licenses = ["notice"]) go_library( name = "ip", srcs = [ + "duplicate_address_detection.go", "generic_multicast_protocol.go", "stats.go", ], - visibility = ["//visibility:public"], + visibility = [ + "//pkg/tcpip/network/arp:__pkg__", + "//pkg/tcpip/network/ipv4:__pkg__", + "//pkg/tcpip/network/ipv6:__pkg__", + ], deps = [ "//pkg/sync", "//pkg/tcpip", + "//pkg/tcpip/stack", ], ) go_test( - name = "ip_test", + name = "ip_x_test", size = "small", - srcs = ["generic_multicast_protocol_test.go"], + srcs = [ + "duplicate_address_detection_test.go", + "generic_multicast_protocol_test.go", + ], deps = [ ":ip", "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/faketime", + "//pkg/tcpip/stack", "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go new file mode 100644 index 000000000..6f89a6a16 --- /dev/null +++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go @@ -0,0 +1,172 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ip holds IPv4/IPv6 common utilities. +package ip + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type dadState struct { + done *bool + timer tcpip.Timer + + completionHandlers []stack.DADCompletionHandler +} + +// DADProtocol is a protocol whose core state machine can be represented by DAD. +type DADProtocol interface { + // SendDADMessage attempts to send a DAD probe message. + SendDADMessage(tcpip.Address) tcpip.Error +} + +// DADOptions holds options for DAD. +type DADOptions struct { + Clock tcpip.Clock + Protocol DADProtocol + NICID tcpip.NICID +} + +// DAD performs duplicate address detection for addresses. +type DAD struct { + opts DADOptions + configs stack.DADConfigurations + + protocolMU sync.Locker + addresses map[tcpip.Address]dadState +} + +// Init initializes the DAD state. +// +// Must only be called once for the lifetime of d; Init will panic if it is +// called twice. +// +// The lock will only be taken when timers fire. +func (d *DAD) Init(protocolMU sync.Locker, configs stack.DADConfigurations, opts DADOptions) { + if d.addresses != nil { + panic("attempted to initialize DAD state twice") + } + + *d = DAD{ + opts: opts, + configs: configs, + protocolMU: protocolMU, + addresses: make(map[tcpip.Address]dadState), + } +} + +// CheckDuplicateAddressLocked performs DAD for an address, calling the +// completion handler once DAD resolves. +// +// If DAD is already performing for the provided address, h will be called when +// the currently running process completes. +// +// Precondition: d.protocolMU must be locked. +func (d *DAD) CheckDuplicateAddressLocked(addr tcpip.Address, h stack.DADCompletionHandler) stack.DADCheckAddressDisposition { + if d.configs.DupAddrDetectTransmits == 0 { + return stack.DADDisabled + } + + ret := stack.DADAlreadyRunning + s, ok := d.addresses[addr] + if !ok { + ret = stack.DADStarting + + remaining := d.configs.DupAddrDetectTransmits + + // Protected by d.protocolMU. + done := false + + s = dadState{ + done: &done, + timer: d.opts.Clock.AfterFunc(0, func() { + var err tcpip.Error + dadDone := remaining == 0 + if !dadDone { + err = d.opts.Protocol.SendDADMessage(addr) + } + + d.protocolMU.Lock() + defer d.protocolMU.Unlock() + + if done { + return + } + + s, ok := d.addresses[addr] + if !ok { + panic(fmt.Sprintf("dad: timer fired but missing state for %s on NIC(%d)", addr, d.opts.NICID)) + } + + if !dadDone && err == nil { + remaining-- + s.timer.Reset(d.configs.RetransmitTimer) + return + } + + // At this point we know that either DAD has resolved or we hit an error + // sending the last DAD message. Either way, clear the DAD state. + done = false + s.timer.Stop() + delete(d.addresses, addr) + + r := stack.DADResult{Resolved: dadDone, Err: err} + for _, h := range s.completionHandlers { + h(r) + } + }), + } + } + + s.completionHandlers = append(s.completionHandlers, h) + d.addresses[addr] = s + return ret +} + +// StopLocked stops a currently running DAD process. +// +// Precondition: d.protocolMU must be locked. +func (d *DAD) StopLocked(addr tcpip.Address, aborted bool) { + s, ok := d.addresses[addr] + if !ok { + return + } + + *s.done = true + s.timer.Stop() + delete(d.addresses, addr) + + var err tcpip.Error + if aborted { + err = &tcpip.ErrAborted{} + } + + r := stack.DADResult{Resolved: false, Err: err} + for _, h := range s.completionHandlers { + h(r) + } +} + +// SetConfigsLocked sets the DAD configurations. +// +// Precondition: d.protocolMU must be locked. +func (d *DAD) SetConfigsLocked(c stack.DADConfigurations) { + c.Validate() + d.configs = c +} diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go new file mode 100644 index 000000000..18c357b56 --- /dev/null +++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go @@ -0,0 +1,279 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ip_test + +import ( + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type mockDADProtocol struct { + t *testing.T + + mu struct { + sync.Mutex + + dad ip.DAD + sendCount map[tcpip.Address]int + } +} + +func (m *mockDADProtocol) init(t *testing.T, c stack.DADConfigurations, opts ip.DADOptions) { + m.mu.Lock() + defer m.mu.Unlock() + + m.t = t + opts.Protocol = m + m.mu.dad.Init(&m.mu, c, opts) + m.initLocked() +} + +func (m *mockDADProtocol) initLocked() { + m.mu.sendCount = make(map[tcpip.Address]int) +} + +func (m *mockDADProtocol) SendDADMessage(addr tcpip.Address) tcpip.Error { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.sendCount[addr]++ + return nil +} + +func (m *mockDADProtocol) check(addrs []tcpip.Address) string { + m.mu.Lock() + defer m.mu.Unlock() + + sendCount := make(map[tcpip.Address]int) + for _, a := range addrs { + sendCount[a]++ + } + + diff := cmp.Diff(sendCount, m.mu.sendCount) + m.initLocked() + return diff +} + +func (m *mockDADProtocol) checkDuplicateAddress(addr tcpip.Address, h stack.DADCompletionHandler) stack.DADCheckAddressDisposition { + m.mu.Lock() + defer m.mu.Unlock() + return m.mu.dad.CheckDuplicateAddressLocked(addr, h) +} + +func (m *mockDADProtocol) stop(addr tcpip.Address, aborted bool) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.dad.StopLocked(addr, aborted) +} + +func (m *mockDADProtocol) setConfigs(c stack.DADConfigurations) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.dad.SetConfigsLocked(c) +} + +const ( + addr1 = tcpip.Address("\x01") + addr2 = tcpip.Address("\x02") + addr3 = tcpip.Address("\x03") + addr4 = tcpip.Address("\x04") +) + +type dadResult struct { + Addr tcpip.Address + R stack.DADResult +} + +func handler(ch chan<- dadResult, a tcpip.Address) func(stack.DADResult) { + return func(r stack.DADResult) { + ch <- dadResult{Addr: a, R: r} + } +} + +func TestDADCheckDuplicateAddress(t *testing.T) { + var dad mockDADProtocol + clock := faketime.NewManualClock() + dad.init(t, stack.DADConfigurations{}, ip.DADOptions{ + Clock: clock, + }) + + ch := make(chan dadResult, 2) + + // DAD should initially be disabled. + if res := dad.checkDuplicateAddress(addr1, handler(nil, "")); res != stack.DADDisabled { + t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADDisabled) + } + // Wait for any initially fired timers to complete. + clock.Advance(0) + if diff := dad.check(nil); diff != "" { + t.Errorf("dad check mismatch (-want +got):\n%s", diff) + } + + // Enable and request DAD. + dadConfigs1 := stack.DADConfigurations{ + DupAddrDetectTransmits: 1, + RetransmitTimer: time.Second, + } + dad.setConfigs(dadConfigs1) + if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { + t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) + } + clock.Advance(0) + if diff := dad.check([]tcpip.Address{addr1}); diff != "" { + t.Errorf("dad check mismatch (-want +got):\n%s", diff) + } + // The second request for DAD on the same address should use the original + // request since it has not completed yet. + if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADAlreadyRunning { + t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADAlreadyRunning) + } + clock.Advance(0) + if diff := dad.check(nil); diff != "" { + t.Errorf("dad check mismatch (-want +got):\n%s", diff) + } + + dadConfigs2 := stack.DADConfigurations{ + DupAddrDetectTransmits: 2, + RetransmitTimer: time.Second, + } + dad.setConfigs(dadConfigs2) + // A new address should start a new DAD process. + if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { + t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) + } + clock.Advance(0) + if diff := dad.check([]tcpip.Address{addr2}); diff != "" { + t.Errorf("dad check mismatch (-want +got):\n%s", diff) + } + + // Make sure DAD for addr1 only resolves after the expected timeout. + const delta = time.Nanosecond + dadConfig1Duration := time.Duration(dadConfigs1.DupAddrDetectTransmits) * dadConfigs1.RetransmitTimer + clock.Advance(dadConfig1Duration - delta) + select { + case r := <-ch: + t.Fatalf("unexpectedly got a DAD result before the expected timeout of %s; r = %#v", dadConfig1Duration, r) + default: + } + clock.Advance(delta) + for i := 0; i < 2; i++ { + if diff := cmp.Diff(dadResult{Addr: addr1, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" { + t.Errorf("(i=%d) dad result mismatch (-want +got):\n%s", i, diff) + } + } + + // Make sure DAD for addr2 only resolves after the expected timeout. + dadConfig2Duration := time.Duration(dadConfigs2.DupAddrDetectTransmits) * dadConfigs2.RetransmitTimer + clock.Advance(dadConfig2Duration - dadConfig1Duration - delta) + select { + case r := <-ch: + t.Fatalf("unexpectedly got a DAD result before the expected timeout of %s; r = %#v", dadConfig2Duration, r) + default: + } + clock.Advance(delta) + if diff := cmp.Diff(dadResult{Addr: addr2, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" { + t.Errorf("dad result mismatch (-want +got):\n%s", diff) + } + + // Should be able to restart DAD for addr2 after it resolved. + if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { + t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) + } + clock.Advance(0) + if diff := dad.check([]tcpip.Address{addr2, addr2}); diff != "" { + t.Errorf("dad check mismatch (-want +got):\n%s", diff) + } + clock.Advance(dadConfig2Duration) + if diff := cmp.Diff(dadResult{Addr: addr2, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" { + t.Errorf("dad result mismatch (-want +got):\n%s", diff) + } + + // Should not have anymore results. + select { + case r := <-ch: + t.Fatalf("unexpectedly got an extra DAD result; r = %#v", r) + default: + } +} + +func TestDADStop(t *testing.T) { + var dad mockDADProtocol + clock := faketime.NewManualClock() + dadConfigs := stack.DADConfigurations{ + DupAddrDetectTransmits: 1, + RetransmitTimer: time.Second, + } + dad.init(t, dadConfigs, ip.DADOptions{ + Clock: clock, + }) + + ch := make(chan dadResult, 1) + + if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { + t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) + } + if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { + t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) + } + if res := dad.checkDuplicateAddress(addr3, handler(ch, addr3)); res != stack.DADStarting { + t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) + } + clock.Advance(0) + if diff := dad.check([]tcpip.Address{addr1, addr2, addr3}); diff != "" { + t.Errorf("dad check mismatch (-want +got):\n%s", diff) + } + + dad.stop(addr1, true /* aborted */) + if diff := cmp.Diff(dadResult{Addr: addr1, R: stack.DADResult{Resolved: false, Err: &tcpip.ErrAborted{}}}, <-ch); diff != "" { + t.Errorf("dad result mismatch (-want +got):\n%s", diff) + } + + dad.stop(addr2, false /* aborted */) + if diff := cmp.Diff(dadResult{Addr: addr2, R: stack.DADResult{Resolved: false, Err: nil}}, <-ch); diff != "" { + t.Errorf("dad result mismatch (-want +got):\n%s", diff) + } + + dadResolutionDuration := time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer + clock.Advance(dadResolutionDuration) + if diff := cmp.Diff(dadResult{Addr: addr3, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" { + t.Errorf("dad result mismatch (-want +got):\n%s", diff) + } + + // Should be able to restart DAD for an address we stopped DAD on. + if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { + t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) + } + clock.Advance(0) + if diff := dad.check([]tcpip.Address{addr1}); diff != "" { + t.Errorf("dad check mismatch (-want +got):\n%s", diff) + } + clock.Advance(dadResolutionDuration) + if diff := cmp.Diff(dadResult{Addr: addr1, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" { + t.Errorf("dad result mismatch (-want +got):\n%s", diff) + } + + // Should not have anymore updates. + select { + case r := <-ch: + t.Fatalf("unexpectedly got an extra DAD result; r = %#v", r) + default: + } +} diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go index b9f129728..b9f129728 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go index 60eaea37e..381460c82 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go +++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go @@ -23,17 +23,10 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/faketime" - "gvisor.dev/gvisor/pkg/tcpip/network/ip" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" ) -const ( - addr1 = tcpip.Address("\x01") - addr2 = tcpip.Address("\x02") - addr3 = tcpip.Address("\x03") - addr4 = tcpip.Address("\x04") - - maxUnsolicitedReportDelay = time.Second -) +const maxUnsolicitedReportDelay = time.Second var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil) diff --git a/pkg/tcpip/network/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go index 898f8b356..898f8b356 100644 --- a/pkg/tcpip/network/ip/stats.go +++ b/pkg/tcpip/network/internal/ip/stats.go diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/internal/testutil/BUILD index bd62c4482..1c4f583c7 100644 --- a/pkg/tcpip/network/testutil/BUILD +++ b/pkg/tcpip/network/internal/testutil/BUILD @@ -10,7 +10,7 @@ go_library( ], visibility = [ "//pkg/tcpip/network/arp:__pkg__", - "//pkg/tcpip/network/fragmentation:__pkg__", + "//pkg/tcpip/network/internal/fragmentation:__pkg__", "//pkg/tcpip/network/ipv4:__pkg__", "//pkg/tcpip/network/ipv6:__pkg__", ], diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go index f5fa77b65..f5fa77b65 100644 --- a/pkg/tcpip/network/testutil/testutil.go +++ b/pkg/tcpip/network/internal/testutil/testutil.go diff --git a/pkg/tcpip/network/testutil/testutil_unsafe.go b/pkg/tcpip/network/internal/testutil/testutil_unsafe.go index 5ff764800..5ff764800 100644 --- a/pkg/tcpip/network/testutil/testutil_unsafe.go +++ b/pkg/tcpip/network/internal/testutil/testutil_unsafe.go diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 6a1f11a36..90236ed9e 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -24,7 +24,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -315,6 +314,10 @@ func (*testInterface) Promiscuous() bool { return false } +func (*testInterface) Spoofing() bool { + return false +} + func (t *testInterface) setEnabled(v bool) { t.mu.Lock() defer t.mu.Unlock() @@ -333,6 +336,10 @@ func (*testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tc return nil } +func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool { + return false +} + func TestSourceAddressValidation(t *testing.T) { rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) { totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize @@ -626,9 +633,6 @@ func TestReceive(t *testing.T) { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: view.ToVectorisedView(), }) - if ok := parse.IPv4(pkt); !ok { - t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) - } ep.HandlePacket(pkt) }, }, @@ -664,9 +668,6 @@ func TestReceive(t *testing.T) { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: view.ToVectorisedView(), }) - if _, _, _, _, ok := parse.IPv6(pkt); !ok { - t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) - } ep.HandlePacket(pkt) }, }, @@ -943,9 +944,6 @@ func TestIPv4FragmentationReceive(t *testing.T) { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: frag1.ToVectorisedView(), }) - if _, _, ok := proto.Parse(pkt); !ok { - t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) - } addressableEndpoint, ok := ep.(stack.AddressableEndpoint) if !ok { @@ -967,9 +965,6 @@ func TestIPv4FragmentationReceive(t *testing.T) { pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: frag2.ToVectorisedView(), }) - if _, _, ok := proto.Parse(pkt); !ok { - t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) - } ep.HandlePacket(pkt) if nic.testObject.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) @@ -1234,7 +1229,6 @@ func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: v.ToVectorisedView(), }) - _, _ = pkt.NetworkHeader().Consume(netHdrLen) return pkt } diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 9713c4448..4b21ee79c 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -17,9 +17,9 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/header/parse", - "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/network/hash", - "//pkg/tcpip/network/ip", + "//pkg/tcpip/network/internal/fragmentation", + "//pkg/tcpip/network/internal/ip", "//pkg/tcpip/stack", ], ) @@ -40,8 +40,8 @@ go_test( "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/arp", + "//pkg/tcpip/network/internal/testutil", "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/testutil", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/raw", @@ -59,7 +59,7 @@ go_test( library = ":ipv4", deps = [ "//pkg/tcpip", - "//pkg/tcpip/network/testutil", + "//pkg/tcpip/network/internal/testutil", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 74e70e283..b44304cee 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -120,6 +120,18 @@ func (*icmpv4FragmentationNeededSockError) Kind() stack.TransportErrorKind { return stack.PacketTooBigTransportError } +func (e *endpoint) checkLocalAddress(addr tcpip.Address) bool { + if e.nic.Spoofing() { + return true + } + + if addressEndpoint := e.AcquireAssignedAddress(addr, false, stack.NeverPrimaryEndpoint); addressEndpoint != nil { + addressEndpoint.DecRef() + return true + } + return false +} + // handleControl handles the case when an ICMP error packet contains the headers // of the original packet that caused the ICMP one to be sent. This information // is used to find out which transport endpoint must be notified about the ICMP @@ -139,7 +151,7 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match an address we own. srcAddr := hdr.SourceAddress() - if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, srcAddr) == 0 { + if !e.checkLocalAddress(srcAddr) { return } @@ -201,7 +213,8 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { } else { op = &optionUsageReceive{} } - tmp, optProblem := e.processIPOptions(pkt, opts, op) + var optProblem *header.IPv4OptParameterProblem + newOptions, optProblem = e.processIPOptions(pkt, opts, op) if optProblem != nil { if optProblem.NeedICMP { _ = e.protocol.returnError(&icmpReasonParamProblem{ @@ -212,7 +225,14 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { } return } - newOptions = tmp + copied := copy(opts, newOptions) + if copied != len(newOptions) { + panic(fmt.Sprintf("copied %d bytes of new options, expected %d bytes", copied, len(newOptions))) + } + for i := copied; i < len(opts); i++ { + // Pad with 0 (EOL). RFC 791 page 23 says "The padding is zero". + opts[i] = byte(header.IPv4OptionListEndType) + } } // TODO(b/112892170): Meaningfully handle all ICMP types. @@ -363,6 +383,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { // icmpReason is a marker interface for IPv4 specific ICMP errors. type icmpReason interface { isICMPReason() + isForwarding() bool } // icmpReasonPortUnreachable is an error where the transport protocol has no @@ -370,12 +391,18 @@ type icmpReason interface { type icmpReasonPortUnreachable struct{} func (*icmpReasonPortUnreachable) isICMPReason() {} +func (*icmpReasonPortUnreachable) isForwarding() bool { + return false +} // icmpReasonProtoUnreachable is an error where the transport protocol is // not supported. type icmpReasonProtoUnreachable struct{} func (*icmpReasonProtoUnreachable) isICMPReason() {} +func (*icmpReasonProtoUnreachable) isForwarding() bool { + return false +} // icmpReasonTTLExceeded is an error where a packet's time to live exceeded in // transit to its final destination, as per RFC 792 page 6, Time Exceeded @@ -383,6 +410,15 @@ func (*icmpReasonProtoUnreachable) isICMPReason() {} type icmpReasonTTLExceeded struct{} func (*icmpReasonTTLExceeded) isICMPReason() {} +func (*icmpReasonTTLExceeded) isForwarding() bool { + // If we hit a TTL Exceeded error, then we know we are operating as a router. + // As per RFC 792 page 6, Time Exceeded Message, + // + // If the gateway processing a datagram finds the time to live field + // is zero it must discard the datagram. The gateway may also notify + // the source host via the time exceeded message. + return true +} // icmpReasonReassemblyTimeout is an error where insufficient fragments are // received to complete reassembly of a packet within a configured time after @@ -390,14 +426,21 @@ func (*icmpReasonTTLExceeded) isICMPReason() {} type icmpReasonReassemblyTimeout struct{} func (*icmpReasonReassemblyTimeout) isICMPReason() {} +func (*icmpReasonReassemblyTimeout) isForwarding() bool { + return false +} // icmpReasonParamProblem is an error to use to request a Parameter Problem // message to be sent. type icmpReasonParamProblem struct { - pointer byte + pointer byte + forwarding bool } func (*icmpReasonParamProblem) isICMPReason() {} +func (r *icmpReasonParamProblem) isForwarding() bool { + return r.forwarding +} // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv4 and sends it back to the remote device that sent @@ -436,26 +479,14 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip return nil } - // If we hit a TTL Exceeded error, then we know we are operating as a router. - // As per RFC 792 page 6, Time Exceeded Message, - // - // If the gateway processing a datagram finds the time to live field - // is zero it must discard the datagram. The gateway may also notify - // the source host via the time exceeded message. - // - // ... - // - // Code 0 may be received from a gateway. ... - // - // Note, Code 0 is the TTL exceeded error. - // // If we are operating as a router/gateway, don't use the packet's destination // address as the response's source address as we should not not own the // destination address of a packet we are forwarding. localAddr := origIPHdrDst - if _, ok := reason.(*icmpReasonTTLExceeded); ok { + if reason.isForwarding() { localAddr = "" } + // Even if we were able to receive a packet from some remote, we may not have // a route to it - the remote may be blocked via routing rules. We must always // consult our routing table and find a route to the remote before sending any diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index acc126c3b..12632aceb 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -22,7 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/network/ip" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index b2d626107..250e4846a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -27,8 +27,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header/parse" - "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation" "gvisor.dev/gvisor/pkg/tcpip/network/hash" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/fragmentation" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -130,6 +130,20 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran return e } +func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint { + p.mu.RLock() + defer p.mu.RUnlock() + + for _, e := range p.mu.eps { + if addressEndpoint := e.AcquireAssignedAddress(addr, false /* allowTemp */, stack.NeverPrimaryEndpoint); addressEndpoint != nil { + addressEndpoint.DecRef() + return e + } + } + + return nil +} + func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { p.mu.Lock() defer p.mu.Unlock() @@ -347,15 +361,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { netHeader := header.IPv4(pkt.NetworkHeader().View()) - ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()) - if err == nil { - pkt := pkt.CloneToInbound() - if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - // Since we rewrote the packet but it is being routed back to us, we can - // safely assume the checksum is valid. - pkt.RXTransportChecksumValidated = true - ep.(*endpoint).handlePacket(pkt) - } + if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil { + // Since we rewrote the packet but it is being routed back to us, we + // can safely assume the checksum is valid. + ep.handleLocalPacket(pkt, true /* canSkipRXChecksum */) return nil } } @@ -365,14 +374,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error { if r.Loop&stack.PacketLoop != 0 { - pkt := pkt.CloneToInbound() - if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - // If the packet was generated by the stack (not a raw/packet endpoint - // where a packet may be written with the header included), then we can - // safely assume the checksum is valid. - pkt.RXTransportChecksumValidated = !headerIncluded - e.handlePacket(pkt) - } + // If the packet was generated by the stack (not a raw/packet endpoint + // where a packet may be written with the header included), then we can + // safely assume the checksum is valid. + e.handleLocalPacket(pkt, !headerIncluded /* canSkipRXChecksum */) } if r.Loop&stack.PacketOut == 0 { return nil @@ -450,51 +455,37 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // iptables filtering. All packets that reach here are locally // generated. dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "", outNicName) - if len(dropped) == 0 && len(natPkts) == 0 { - // Fast path: If no packets are to be dropped then we can just invoke the - // faster WritePackets API directly. - n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) - stats.PacketsSent.IncrementBy(uint64(n)) - if err != nil { - stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) - } - return n, err - } stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) + for pkt := range dropped { + pkts.Remove(pkt) + } - // Slow path as we are dropping some packets in the batch degrade to - // emitting one packet at a time. - n := 0 - for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if _, ok := dropped[pkt]; ok { + // The NAT-ed packets may now be destined for us. + locallyDelivered := 0 + for pkt := range natPkts { + ep := e.protocol.findEndpointWithAddress(header.IPv4(pkt.NetworkHeader().View()).DestinationAddress()) + if ep == nil { + // The NAT-ed packet is still destined for some remote node. continue } - if _, ok := natPkts[pkt]; ok { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - pkt := pkt.CloneToInbound() - if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - // Since we rewrote the packet but it is being routed back to us, we - // can safely assume the checksum is valid. - pkt.RXTransportChecksumValidated = true - ep.(*endpoint).handlePacket(pkt) - } - n++ - continue - } - } - if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { - stats.PacketsSent.IncrementBy(uint64(n)) - stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped))) - // Dropped packets aren't errors, so include them in - // the return value. - return n + len(dropped), err - } - n++ + + // Do not send the locally destined packet out the NIC. + pkts.Remove(pkt) + + // Deliver the packet locally. + ep.handleLocalPacket(pkt, true /* canSkipRXChecksum */) + locallyDelivered++ + } - stats.PacketsSent.IncrementBy(uint64(n)) + + // The rest of the packets can be delivered to the NIC as a batch. + pktsLen := pkts.Len() + written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) + stats.PacketsSent.IncrementBy(uint64(written)) + stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written)) + // Dropped packets aren't errors, so include them in the return value. - return n + len(dropped), nil + return locallyDelivered + written + len(dropped), err } // WriteHeaderIncludedPacket implements stack.NetworkEndpoint. @@ -570,17 +561,41 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { return e.protocol.returnError(&icmpReasonTTLExceeded{}, pkt) } + if opts := h.Options(); len(opts) != 0 { + newOpts, optProblem := e.processIPOptions(pkt, opts, &optionUsageForward{}) + if optProblem != nil { + if optProblem.NeedICMP { + _ = e.protocol.returnError(&icmpReasonParamProblem{ + pointer: optProblem.Pointer, + forwarding: true, + }, pkt) + e.protocol.stack.Stats().MalformedRcvdPackets.Increment() + e.stats.ip.MalformedPacketsReceived.Increment() + } + return nil // option problems are not reported locally. + } + copied := copy(opts, newOpts) + if copied != len(newOpts) { + panic(fmt.Sprintf("copied %d bytes of new options, expected %d bytes", copied, len(newOpts))) + } + // Since in forwarding we handle all options, including copying those we + // do not recognise, the options region should remain the same size which + // simplifies processing. As we MAY receive a packet with a lot of padded + // bytes after the "end of options list" byte, make sure we copy + // them as the legal padding value (0). + for i := copied; i < len(opts); i++ { + // Pad with 0 (EOL). RFC 791 page 23 says "The padding is zero". + opts[i] = byte(header.IPv4OptionListEndType) + } + } + dstAddr := h.DestinationAddress() // Check if the destination is owned by the stack. - networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr) - if err == nil { - networkEndpoint.(*endpoint).handlePacket(pkt) + if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { + ep.handlePacket(pkt) return nil } - if _, ok := err.(*tcpip.ErrBadAddress); !ok { - return err - } r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) if err != nil { @@ -619,8 +634,26 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - // Loopback traffic skips the prerouting chain. + if !e.protocol.parse(pkt) { + stats.MalformedPacketsReceived.Increment() + return + } + if !e.nic.IsLoopback() { + if e.protocol.stack.HandleLocal() { + addressEndpoint := e.AcquireAssignedAddress(header.IPv4(pkt.NetworkHeader().View()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) + if addressEndpoint != nil { + addressEndpoint.DecRef() + + // The source address is one of our own, so we never should have gotten + // a packet like this unless HandleLocal is false or our NIC is the + // loopback interface. + stats.InvalidSourceAddressesReceived.Increment() + return + } + } + + // Loopback traffic skips the prerouting chain. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. @@ -632,6 +665,21 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { e.handlePacket(pkt) } +func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum bool) { + stats := e.stats.ip + + stats.PacketsReceived.Increment() + + pkt = pkt.CloneToInbound() + if e.protocol.parse(pkt) { + pkt.RXTransportChecksumValidated = canSkipRXChecksum + e.handlePacket(pkt) + return + } + + stats.MalformedPacketsReceived.Increment() +} + // handlePacket is like HandlePacket except it does not perform the prerouting // iptables hook. func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { @@ -691,8 +739,9 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { } } - // The destination address should be an address we own or a group we joined - // for us to receive the packet. Otherwise, attempt to forward the packet. + // Before we do any processing, note if the packet was received as some + // sort of broadcast. The destination address should be an address we own + // or a group we joined. if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil { subnet := addressEndpoint.AddressWithPrefix().Subnet() addressEndpoint.DecRef() @@ -702,7 +751,6 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { stats.ip.InvalidDestinationAddressesReceived.Increment() return } - _ = e.forwardPacket(pkt) return } @@ -724,6 +772,21 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { stats.ip.MalformedFragmentsReceived.Increment() return } + if opts := h.Options(); len(opts) != 0 { + // If there are options we need to check them before we do assembly + // or we could be assembling errant packets. However we do not change the + // options as that could lead to double processing later. + if _, optProblem := e.processIPOptions(pkt, opts, &optionUsageVerify{}); optProblem != nil { + if optProblem.NeedICMP { + _ = e.protocol.returnError(&icmpReasonParamProblem{ + pointer: optProblem.Pointer, + }, pkt) + e.protocol.stack.Stats().MalformedRcvdPackets.Increment() + e.stats.ip.MalformedPacketsReceived.Increment() + } + return + } + } // The packet is a fragment, let's try to reassemble it. start := h.FragmentOffset() // Drop the fragment if the size of the reassembled payload would exceed the @@ -782,17 +845,10 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { e.handleICMP(pkt) return } - if p == header.IGMPProtocolNumber { - e.mu.Lock() - e.mu.igmp.handleIGMP(pkt) - e.mu.Unlock() - return - } + // ICMP handles options itself but do it here for all remaining destinations. if opts := h.Options(); len(opts) != 0 { - // TODO(gvisor.dev/issue/4586): - // When we add forwarding support we should use the verified options - // rather than just throwing them away. - if _, optProblem := e.processIPOptions(pkt, opts, &optionUsageReceive{}); optProblem != nil { + newOpts, optProblem := e.processIPOptions(pkt, opts, &optionUsageReceive{}) + if optProblem != nil { if optProblem.NeedICMP { _ = e.protocol.returnError(&icmpReasonParamProblem{ pointer: optProblem.Pointer, @@ -802,6 +858,20 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { } return } + copied := copy(opts, newOpts) + if copied != len(newOpts) { + panic(fmt.Sprintf("copied %d bytes of new options, expected %d bytes", copied, len(newOpts))) + } + for i := copied; i < len(opts); i++ { + // Pad with 0 (EOL). RFC 791 page 23 says "The padding is zero". + opts[i] = byte(header.IPv4OptionListEndType) + } + } + if p == header.IGMPProtocolNumber { + e.mu.Lock() + e.mu.igmp.handleIGMP(pkt) + e.mu.Unlock() + return } switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res { @@ -1043,6 +1113,29 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// parse is like Parse but also attempts to parse the transport layer. +// +// Returns true if the network header was successfully parsed. +func (p *protocol) parse(pkt *stack.PacketBuffer) bool { + transProtoNum, hasTransportHdr, ok := p.Parse(pkt) + if !ok { + return false + } + + if hasTransportHdr { + switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err { + case stack.ParsedOK: + case stack.UnknownTransportProtocol, stack.TransportLayerParseError: + // The transport layer will handle unknown protocols and transport layer + // parsing errors. + default: + panic(fmt.Sprintf("unexpected error parsing transport header = %d", err)) + } + } + + return true +} + // Parse implements stack.NetworkProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { if ok := parse.IPv4(pkt); !ok { @@ -1211,23 +1304,49 @@ type optionsUsage interface { actions() optionActions } -// optionUsageReceive implements optionsUsage for received packets. -type optionUsageReceive struct{} +// optionUsageVerify implements optionsUsage for when we just want to check +// fragments. Don't change anything, just check and reject if bad. No +// replacement options are generated. +type optionUsageVerify struct{} // actions implements optionsUsage. -func (*optionUsageReceive) actions() optionActions { +func (*optionUsageVerify) actions() optionActions { return optionActions{ timestamp: optionVerify, recordRoute: optionVerify, + unknown: optionRemove, + } +} + +// optionUsageReceive implements optionsUsage for packets we will pass +// to the transport layer (with the exception of Echo requests). +type optionUsageReceive struct{} + +// actions implements optionsUsage. +func (*optionUsageReceive) actions() optionActions { + return optionActions{ + timestamp: optionProcess, + recordRoute: optionProcess, unknown: optionPass, } } -// TODO(gvisor.dev/issue/4586): Add an entry here for forwarding when it -// is enabled (Process, Process, Pass) and for fragmenting (Process, Process, -// Pass for frag1, but Remove,Remove,Remove for all other frags). +// optionUsageForward implements optionsUsage for packets about to be forwarded. +// All options are passed on regardless of whether we recognise them, however +// we do process the Timestamp and Record Route options. +type optionUsageForward struct{} + +// actions implements optionsUsage. +func (*optionUsageForward) actions() optionActions { + return optionActions{ + timestamp: optionProcess, + recordRoute: optionProcess, + unknown: optionPass, + } +} // optionUsageEcho implements optionsUsage for echo packet processing. +// Only Timestamp and RecordRoute are processed and sent back. type optionUsageEcho struct{} // actions implements optionsUsage. diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index a296bed79..dc4db6e5f 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -34,8 +34,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/arp" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/raw" @@ -111,10 +111,11 @@ func TestExcludeBroadcast(t *testing.T) { func TestForwarding(t *testing.T) { const ( - nicID1 = 1 - nicID2 = 2 - randomSequence = 123 - randomIdent = 42 + nicID1 = 1 + nicID2 = 2 + randomSequence = 123 + randomIdent = 42 + randomTimeOffset = 0x10203040 ) ipv4Addr1 := tcpip.AddressWithPrefix{ @@ -129,14 +130,20 @@ func TestForwarding(t *testing.T) { remoteIPv4Addr2 := tcpip.Address(net.ParseIP("11.0.0.2").To4()) tests := []struct { - name string - TTL uint8 - expectErrorICMP bool + name string + TTL uint8 + expectErrorICMP bool + options header.IPv4Options + forwardedOptions header.IPv4Options + icmpType header.ICMPv4Type + icmpCode header.ICMPv4Code }{ { name: "TTL of zero", TTL: 0, expectErrorICMP: true, + icmpType: header.ICMPv4TimeExceeded, + icmpCode: header.ICMPv4TTLExceeded, }, { name: "TTL of one", @@ -153,14 +160,78 @@ func TestForwarding(t *testing.T) { TTL: math.MaxUint8, expectErrorICMP: false, }, + { + name: "four EOL options", + TTL: 2, + expectErrorICMP: false, + options: header.IPv4Options{0, 0, 0, 0}, + forwardedOptions: header.IPv4Options{0, 0, 0, 0}, + }, + { + name: "TS type 1 full", + TTL: 2, + options: header.IPv4Options{ + 68, 12, 13, 0xF1, + 192, 168, 1, 12, + 1, 2, 3, 4, + }, + expectErrorICMP: true, + icmpType: header.ICMPv4ParamProblem, + icmpCode: header.ICMPv4UnusedCode, + }, + { + name: "TS type 0", + TTL: 2, + options: header.IPv4Options{ + 68, 24, 21, 0x00, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 0, 0, 0, 0, + }, + forwardedOptions: header.IPv4Options{ + 68, 24, 25, 0x00, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock + }, + }, + { + name: "end of options list", + TTL: 2, + options: header.IPv4Options{ + 68, 12, 13, 0x11, + 192, 168, 1, 12, + 1, 2, 3, 4, + 0, 10, 3, 99, // EOL followed by junk + 1, 2, 3, 4, + }, + forwardedOptions: header.IPv4Options{ + 68, 12, 13, 0x21, + 192, 168, 1, 12, + 1, 2, 3, 4, + 0, // End of Options hides following bytes. + 0, 0, 0, // 7 bytes unknown option removed. + 0, 0, 0, 0, + }, + }, } - for _, test := range tests { t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, + Clock: clock, }) + + // Advance the clock by some unimportant amount to make + // it give a more recognisable signature than 00,00,00,00. + clock.Advance(time.Millisecond * randomTimeOffset) + // We expect at most a single packet in response to our ICMP Echo Request. e1 := channel.New(1, ipv4.MaxTotalSize, "") if err := s.CreateNIC(nicID1, e1); err != nil { @@ -195,7 +266,11 @@ func TestForwarding(t *testing.T) { t.Fatalf("SetForwarding(%d, true): %s", header.IPv4ProtocolNumber, err) } - totalLen := uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize) + ipHeaderLength := header.IPv4MinimumSize + len(test.options) + if ipHeaderLength > header.IPv4MaximumHeaderSize { + t.Fatalf("got ipHeaderLength = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) + } + totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize) hdr := buffer.NewPrependable(int(totalLen)) icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) icmp.SetIdent(randomIdent) @@ -204,7 +279,7 @@ func TestForwarding(t *testing.T) { icmp.SetCode(header.ICMPv4UnusedCode) icmp.SetChecksum(0) icmp.SetChecksum(^header.Checksum(icmp, 0)) - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip := header.IPv4(hdr.Prepend(ipHeaderLength)) ip.Encode(&header.IPv4Fields{ TotalLength: totalLen, Protocol: uint8(header.ICMPv4ProtocolNumber), @@ -212,6 +287,14 @@ func TestForwarding(t *testing.T) { SrcAddr: remoteIPv4Addr1, DstAddr: remoteIPv4Addr2, }) + if len(test.options) != 0 { + ip.SetHeaderLength(uint8(ipHeaderLength)) + // Copy options manually. We do not use Encode for options so we can + // verify malformed options with handcrafted payloads. + if want, got := copy(ip.Options(), test.options), len(test.options); want != got { + t.Fatalf("got copy(ip.Options(), test.options) = %d, want = %d", got, want) + } + } ip.SetChecksum(0) ip.SetChecksum(^ip.CalculateChecksum()) requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -222,7 +305,7 @@ func TestForwarding(t *testing.T) { if test.expectErrorICMP { reply, ok := e1.Read() if !ok { - t.Fatal("expected ICMP TTL Exceeded packet through incoming NIC") + t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType) } checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), @@ -231,8 +314,8 @@ func TestForwarding(t *testing.T) { checker.TTL(ipv4.DefaultTTL), checker.ICMPv4( checker.ICMPv4Checksum(), - checker.ICMPv4Type(header.ICMPv4TimeExceeded), - checker.ICMPv4Code(header.ICMPv4TTLExceeded), + checker.ICMPv4Type(test.icmpType), + checker.ICMPv4Code(test.icmpCode), checker.ICMPv4Payload([]byte(hdr.View())), ), ) @@ -250,6 +333,7 @@ func TestForwarding(t *testing.T) { checker.SrcAddr(remoteIPv4Addr1), checker.DstAddr(remoteIPv4Addr2), checker.TTL(test.TTL-1), + checker.IPv4Options(test.forwardedOptions), checker.ICMPv4( checker.ICMPv4Checksum(), checker.ICMPv4Type(header.ICMPv4Echo), @@ -279,6 +363,7 @@ func TestIPv4Sanity(t *testing.T) { // (offset 1). For compatibility we must do the same. Use this constant // to indicate where this happens. pointerOffsetForInvalidLength = 0 + randomTimeOffset = 0x10203040 ) var ( ipv4Addr = tcpip.AddressWithPrefix{ @@ -893,6 +978,10 @@ func TestIPv4Sanity(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, Clock: clock, }) + // Advance the clock by some unimportant amount to make + // it give a more recognisable signature than 00,00,00,00. + clock.Advance(time.Millisecond * randomTimeOffset) + // We expect at most a single packet in response to our ICMP Echo Request. e := channel.New(1, ipv4.MaxTotalSize, "") if err := s.CreateNIC(nicID, e); err != nil { @@ -902,9 +991,6 @@ func TestIPv4Sanity(t *testing.T) { if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) } - // Advance the clock by some unimportant amount to make - // sure it's all set up. - clock.Advance(time.Millisecond * 0x10203040) // Default routes for IPv4 so ICMP can find a route to the remote // node when attempting to send the ICMP Echo Reply. @@ -1739,12 +1825,19 @@ func TestInvalidFragments(t *testing.T) { ip := header.IPv4(hdr.Prepend(pktSize)) ip.Encode(&f.ipv4fields) + if want, got := len(f.payload), copy(ip[header.IPv4MinimumSize:], f.payload); want != got { + t.Fatalf("copied %d bytes, expected %d bytes.", got, want) + } // Encode sets this up correctly. If we want a different value for // testing then we need to overwrite the good value. if f.overrideIHL != 0 { ip.SetHeaderLength(uint8(f.overrideIHL)) + // If we are asked to add options (type not specified) then pad + // with 0 (EOL). RFC 791 page 23 says "The padding is zero". + for i := header.IPv4MinimumSize; i < f.overrideIHL; i++ { + ip[i] = byte(header.IPv4OptionListEndType) + } } - copy(ip[header.IPv4MinimumSize:], f.payload) if f.autoChecksum { ip.SetChecksum(0) diff --git a/pkg/tcpip/network/ipv4/stats.go b/pkg/tcpip/network/ipv4/stats.go index bee72c649..5ae73fbfb 100644 --- a/pkg/tcpip/network/ipv4/stats.go +++ b/pkg/tcpip/network/ipv4/stats.go @@ -16,7 +16,7 @@ package ipv4 import ( "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/ip" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go index fbbc6e69c..a637f9d50 100644 --- a/pkg/tcpip/network/ipv4/stats_test.go +++ b/pkg/tcpip/network/ipv4/stats_test.go @@ -19,7 +19,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/testutil" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" ) diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index 0c5f8d683..bb9a02ed0 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -19,9 +19,9 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/header/parse", - "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/network/hash", - "//pkg/tcpip/network/ip", + "//pkg/tcpip/network/internal/fragmentation", + "//pkg/tcpip/network/internal/ip", "//pkg/tcpip/stack", ], ) @@ -43,7 +43,7 @@ go_test( "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", - "//pkg/tcpip/network/testutil", + "//pkg/tcpip/network/internal/testutil", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index dcfd93bab..2690644d6 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -148,6 +148,18 @@ func (*icmpv6PacketTooBigSockError) Kind() stack.TransportErrorKind { return stack.PacketTooBigTransportError } +func (e *endpoint) checkLocalAddress(addr tcpip.Address) bool { + if e.nic.Spoofing() { + return true + } + + if addressEndpoint := e.AcquireAssignedAddress(addr, false, stack.NeverPrimaryEndpoint); addressEndpoint != nil { + addressEndpoint.DecRef() + return true + } + return false +} + // handleControl handles the case when an ICMP packet contains the headers of // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP @@ -165,8 +177,8 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe // // Drop packet if it doesn't have the basic IPv6 header or if the // original source address doesn't match an address we own. - src := hdr.SourceAddress() - if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 { + srcAddr := hdr.SourceAddress() + if !e.checkLocalAddress(srcAddr) { return } @@ -192,7 +204,7 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe p = fragHdr.TransportProtocol() } - e.dispatcher.DeliverTransportError(src, hdr.DestinationAddress(), ProtocolNumber, p, transErr, pkt) + e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, transErr, pkt) } // getLinkAddrOption searches NDP options for a given link address option using @@ -377,7 +389,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // section 5.4.3. // Is the NS targeting us? - if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 { + if !e.checkLocalAddress(targetAddr) { return } @@ -525,6 +537,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // NDP datagrams are very small and ToView() will not incur allocations. na := header.NDPNeighborAdvert(payload.ToView()) targetAddr := na.TargetAddress() + + e.dad.mu.Lock() + e.dad.mu.dad.StopLocked(targetAddr, false /* aborted */) + e.dad.mu.Unlock() + if e.hasTentativeAddr(targetAddr) { // We just got an NA from a node that owns an address we are performing // DAD on, implying the address is not unique. In this case we let the @@ -854,37 +871,9 @@ func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot return &tcpip.ErrBadLocalAddress{} } - optsSerializer := header.NDPOptionsSerializer{ + return e.sendNDPNS(localAddr, remoteAddr, targetAddr, remoteLinkAddr, header.NDPOptionsSerializer{ header.NDPSourceLinkLayerAddressOption(e.nic.LinkAddress()), - } - neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length() - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.IPv6FixedHeaderSize + neighborSolicitSize, }) - pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber - packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize)) - packet.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(packet.MessageBody()) - ns.SetTargetAddress(targetAddr) - ns.Options().Serialize(optsSerializer) - packet.SetChecksum(header.ICMPv6Checksum(packet, localAddr, remoteAddr, buffer.VectorisedView{})) - - if err := addIPHeader(localAddr, remoteAddr, pkt, stack.NetworkHeaderParams{ - Protocol: header.ICMPv6ProtocolNumber, - TTL: header.NDPHopLimit, - }, header.IPv6ExtHdrSerializer{}); err != nil { - panic(fmt.Sprintf("failed to add IP header: %s", err)) - } - - stat := e.stats.icmp.packetsSent - - if err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { - stat.dropped.Increment() - return err - } - - stat.neighborSolicit.Increment() - return nil } // ResolveStaticAddress implements stack.LinkAddressResolver. diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 92f9ee2c2..69c1e4bea 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -124,6 +124,10 @@ func (*testInterface) Promiscuous() bool { return false } +func (*testInterface) Spoofing() bool { + return false +} + func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt) } @@ -149,185 +153,31 @@ func (t *testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, return nil } -func TestICMPCounts(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - UseNeighborCache: test.useNeighborCache, - }) - if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: nicID, - }}, - ) - } - - netProto := s.NetworkProtocolInstance(ProtocolNumber) - if netProto == nil { - t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) - } - ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{}) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") - } - addr := lladdr0.WithPrefix() - if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) - } else { - ep.DecRef() - } - - var tllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) - - types := []struct { - typ header.ICMPv6Type - size int - extraData []byte - }{ - { - typ: header.ICMPv6DstUnreachable, - size: header.ICMPv6DstUnreachableMinimumSize, - }, - { - typ: header.ICMPv6PacketTooBig, - size: header.ICMPv6PacketTooBigMinimumSize, - }, - { - typ: header.ICMPv6TimeExceeded, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6ParamProblem, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6EchoRequest, - size: header.ICMPv6EchoMinimumSize, - }, - { - typ: header.ICMPv6EchoReply, - size: header.ICMPv6EchoMinimumSize, - }, - { - typ: header.ICMPv6RouterSolicit, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6RouterAdvert, - size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, - }, - { - typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize, - }, - { - typ: header.ICMPv6NeighborAdvert, - size: header.ICMPv6NeighborAdvertMinimumSize, - extraData: tllData[:], - }, - { - typ: header.ICMPv6RedirectMsg, - size: header.ICMPv6MinimumSize, - }, - { - typ: header.ICMPv6MulticastListenerQuery, - size: header.MLDMinimumSize + header.ICMPv6HeaderSize, - }, - { - typ: header.ICMPv6MulticastListenerReport, - size: header.MLDMinimumSize + header.ICMPv6HeaderSize, - }, - { - typ: header.ICMPv6MulticastListenerDone, - size: header.MLDMinimumSize + header.ICMPv6HeaderSize, - }, - { - typ: 255, /* Unrecognized */ - size: 50, - }, - } - - handleIPv6Payload := func(icmp header.ICMPv6) { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.IPv6MinimumSize, - Data: buffer.View(icmp).ToVectorisedView(), - }) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - ep.HandlePacket(pkt) - } - - for _, typ := range types { - icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) - copy(icmp[typ.size:], typ.extraData) - icmp.SetType(typ.typ) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) - handleIPv6Payload(icmp) - } - - // Construct an empty ICMP packet so that - // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. - handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) +func (*testInterface) CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool { + return false +} - icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived - visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { - if got, want := s.Value(), uint64(1); got != want { - t.Errorf("got %s = %d, want = %d", name, got, want) - } - }) - if t.Failed() { - t.Logf("stats:\n%+v", s.Stats()) - } - }) - } +func handleICMPInIPv6(ep stack.NetworkEndpoint, src, dst tcpip.Address, icmp header.ICMPv6) { + ip := buffer.NewView(header.IPv6MinimumSize) + header.IPv6(ip).Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(icmp)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: src, + DstAddr: dst, + }) + vv := ip.ToVectorisedView() + vv.AppendView(buffer.View(icmp)) + ep.HandlePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize, + Data: vv, + })) } -func TestICMPCountsWithNeighborCache(t *testing.T) { +func TestICMPCounts(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_, _) = %s", err) @@ -440,33 +290,17 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { }, } - handleIPv6Payload := func(icmp header.ICMPv6) { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.IPv6MinimumSize, - Data: buffer.View(icmp).ToVectorisedView(), - }) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - ep.HandlePacket(pkt) - } - for _, typ := range types { icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) - handleIPv6Payload(icmp) + handleICMPInIPv6(ep, lladdr1, lladdr0, icmp) } // Construct an empty ICMP packet so that // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. - handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) + handleICMPInIPv6(ep, lladdr1, lladdr0, header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { @@ -777,135 +611,116 @@ func TestICMPChecksumValidationSimple(t *testing.T) { }, } - tests := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, - }, - } + for _, typ := range types { + for _, isRouter := range []bool{false, true} { + name := typ.name + if isRouter { + name += " (Router)" + } + t.Run(name, func(t *testing.T) { + e := channel.New(0, 1280, linkAddr0) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, typ := range types { - for _, isRouter := range []bool{false, true} { - name := typ.name - if isRouter { - name += " (Router)" + // Indicate that resolution for link layer addresses is required to + // send packets over this link. This is needed so the NIC knows to + // allocate a neighbor table. + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + if isRouter { + // Enabling forwarding makes the stack act as a router. + s.SetForwarding(ProtocolNumber, true) + } + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) + } + + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) } - t.Run(name, func(t *testing.T) { - e := channel.New(0, 1280, linkAddr0) - - // Indicate that resolution for link layer addresses is required to - // send packets over this link. This is needed so the NIC knows to - // allocate a neighbor table. - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - UseNeighborCache: test.useNeighborCache, - }) - if isRouter { - // Enabling forwarding makes the stack act as a router. - s.SetForwarding(ProtocolNumber, true) - } - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: nicID, - }}, - ) - } - - handleIPv6Payload := func(checksum bool) { - icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) - copy(icmp[typ.size:], typ.extraData) - icmp.SetType(typ.typ) - if checksum { - icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView())) - } - ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), - }) - e.InjectInbound(ProtocolNumber, pkt) - } - - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - routerOnly := stats.RouterOnlyPacketsDroppedByHost - typStat := typ.statCounter(stats) - - // Initial stat counts should be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := routerOnly.Value(); got != 0 { - t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) - } - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // Without setting checksum, the incoming packet should - // be invalid. - handleIPv6Payload(false) - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - // Router only count should not have increased. - if got := routerOnly.Value(); got != 0 { - t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) - } - // Rx count of type typ.typ should not have increased. - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // When checksum is set, it should be received. - handleIPv6Payload(true) - if got := typStat.Value(); got != 1 { - t.Fatalf("got %s = %d, want = 1", typ.name, got) - } - // Invalid count should not have increased again. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - if !isRouter && typ.routerOnly && test.useNeighborCache { - // Router only count should have increased. - if got := routerOnly.Value(); got != 1 { - t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 1", got) - } - } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: nicID, + }}, + ) + } + + handleIPv6Payload := func(checksum bool) { + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) + if checksum { + icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView())) + } + ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(icmp)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), }) + e.InjectInbound(ProtocolNumber, pkt) } - } - }) + + stats := s.Stats().ICMP.V6.PacketsReceived + invalid := stats.Invalid + routerOnly := stats.RouterOnlyPacketsDroppedByHost + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := routerOnly.Value(); got != 0 { + t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Router only count should not have increased. + if got := routerOnly.Value(); got != 0 { + t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + if !isRouter && typ.routerOnly { + // Router only count should have increased. + if got := routerOnly.Value(); got != 1 { + t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 1", got) + } + } + }) + } } } @@ -1769,7 +1584,6 @@ func TestCallsToNeighborCache(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - UseNeighborCache: true, }) { if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { @@ -1818,19 +1632,7 @@ func TestCallsToNeighborCache(t *testing.T) { icmp := test.createPacket() icmp.SetChecksum(header.ICMPv6Checksum(icmp, test.source, test.destination, buffer.VectorisedView{})) - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.IPv6MinimumSize, - Data: buffer.View(icmp).ToVectorisedView(), - }) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: test.source, - DstAddr: test.destination, - }) - ep.HandlePacket(pkt) + handleICMPInIPv6(ep, test.source, test.destination, icmp) // Confirm the endpoint calls the correct NUDHandler method. if testInterface.probeCount != test.wantProbeCount { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index c2e8c3ea7..c5c3ef882 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -30,8 +30,9 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header/parse" - "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation" "gvisor.dev/gvisor/pkg/tcpip/network/hash" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/fragmentation" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -164,6 +165,7 @@ func getLabel(addr tcpip.Address) uint8 { panic(fmt.Sprintf("should have a label for address = %s", addr)) } +var _ stack.DuplicateAddressDetector = (*endpoint)(nil) var _ stack.LinkAddressResolver = (*endpoint)(nil) var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil) var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) @@ -192,6 +194,23 @@ type endpoint struct { ndp ndpState mld mldState } + + // dad is used to check if an arbitrary address is already assigned to some + // neighbor. + // + // Note: this is different from mu.ndp.dad which is used to perform DAD for + // addresses that are assigned to the interface. Removing an address aborts + // DAD; if we had used the same state, handlers for a removed address would + // not be called with the actual DAD result. + // + // LOCK ORDERING: mu > dad.mu. + dad struct { + mu struct { + sync.Mutex + + dad ip.DAD + } + } } // NICNameFromID is a function that returns a stable name for the specified NIC, @@ -226,6 +245,29 @@ type OpaqueInterfaceIdentifierOptions struct { SecretKey []byte } +// CheckDuplicateAddress implements stack.DuplicateAddressDetector. +func (e *endpoint) CheckDuplicateAddress(addr tcpip.Address, h stack.DADCompletionHandler) stack.DADCheckAddressDisposition { + e.dad.mu.Lock() + defer e.dad.mu.Unlock() + return e.dad.mu.dad.CheckDuplicateAddressLocked(addr, h) +} + +// SetDADConfigurations implements stack.DuplicateAddressDetector. +func (e *endpoint) SetDADConfigurations(c stack.DADConfigurations) { + e.mu.Lock() + defer e.mu.Unlock() + e.dad.mu.Lock() + defer e.dad.mu.Unlock() + + e.mu.ndp.dad.SetConfigsLocked(c) + e.dad.mu.dad.SetConfigsLocked(c) +} + +// DuplicateAddressProtocol implements stack.DuplicateAddressDetector. +func (*endpoint) DuplicateAddressProtocol() tcpip.NetworkProtocolNumber { + return ProtocolNumber +} + // HandleLinkResolutionFailure implements stack.LinkResolvableNetworkEndpoint. func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) { // handleControl expects the entire offending packet to be in the packet @@ -321,7 +363,7 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) tcpip.Error { // If the address is a SLAAC address, do not invalidate its SLAAC prefix as an // attempt will be made to generate a new address for it. - if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */); err != nil { + if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */, true /* dadFailure */); err != nil { return err } @@ -525,7 +567,7 @@ func (e *endpoint) stopDADForPermanentAddressesLocked() { addr := addressEndpoint.AddressWithPrefix().Address if header.IsV6UnicastAddress(addr) { - e.mu.ndp.stopDuplicateAddressDetection(addr) + e.mu.ndp.stopDuplicateAddressDetection(addr, false /* failed */) } return true @@ -648,14 +690,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { netHeader := header.IPv6(pkt.NetworkHeader().View()) - if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - pkt := pkt.CloneToInbound() - if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - // Since we rewrote the packet but it is being routed back to us, we can - // safely assume the checksum is valid. - pkt.RXTransportChecksumValidated = true - ep.(*endpoint).handlePacket(pkt) - } + if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil { + // Since we rewrote the packet but it is being routed back to us, we + // can safely assume the checksum is valid. + ep.handleLocalPacket(pkt, true /* canSkipRXChecksum */) return nil } } @@ -665,14 +703,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) tcpip.Error { if r.Loop&stack.PacketLoop != 0 { - pkt := pkt.CloneToInbound() - if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - // If the packet was generated by the stack (not a raw/packet endpoint - // where a packet may be written with the header included), then we can - // safely assume the checksum is valid. - pkt.RXTransportChecksumValidated = !headerIncluded - e.handlePacket(pkt) - } + // If the packet was generated by the stack (not a raw/packet endpoint + // where a packet may be written with the header included), then we can + // safely assume the checksum is valid. + e.handleLocalPacket(pkt, !headerIncluded /* canSkipRXChecksum */) } if r.Loop&stack.PacketOut == 0 { return nil @@ -750,52 +784,36 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName) - if len(dropped) == 0 && len(natPkts) == 0 { - // Fast path: If no packets are to be dropped then we can just invoke the - // faster WritePackets API directly. - n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) - stats.PacketsSent.IncrementBy(uint64(n)) - if err != nil { - stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) - } - return n, err - } stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) + for pkt := range dropped { + pkts.Remove(pkt) + } - // Slow path as we are dropping some packets in the batch degrade to - // emitting one packet at a time. - n := 0 - for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if _, ok := dropped[pkt]; ok { + // The NAT-ed packets may now be destined for us. + locallyDelivered := 0 + for pkt := range natPkts { + ep := e.protocol.findEndpointWithAddress(header.IPv6(pkt.NetworkHeader().View()).DestinationAddress()) + if ep == nil { + // The NAT-ed packet is still destined for some remote node. continue } - if _, ok := natPkts[pkt]; ok { - netHeader := header.IPv6(pkt.NetworkHeader().View()) - if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - pkt := pkt.CloneToInbound() - if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - // Since we rewrote the packet but it is being routed back to us, we - // can safely assume the checksum is valid. - pkt.RXTransportChecksumValidated = true - ep.(*endpoint).handlePacket(pkt) - } - n++ - continue - } - } - if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { - stats.PacketsSent.IncrementBy(uint64(n)) - stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped))) - // Dropped packets aren't errors, so include them in - // the return value. - return n + len(dropped), err - } - n++ + + // Do not send the locally destined packet out the NIC. + pkts.Remove(pkt) + + // Deliver the packet locally. + ep.handleLocalPacket(pkt, true /* canSkipRXChecksum */) + locallyDelivered++ } - stats.PacketsSent.IncrementBy(uint64(n)) + // The rest of the packets can be delivered to the NIC as a batch. + pktsLen := pkts.Len() + written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) + stats.PacketsSent.IncrementBy(uint64(written)) + stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written)) + // Dropped packets aren't errors, so include them in the return value. - return n + len(dropped), nil + return locallyDelivered + written + len(dropped), err } // WriteHeaderIncludedPacket implements stack.NetworkEndpoint. @@ -852,14 +870,11 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { dstAddr := h.DestinationAddress() // Check if the destination is owned by the stack. - networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr) - if err == nil { - networkEndpoint.(*endpoint).handlePacket(pkt) + + if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { + ep.handlePacket(pkt) return nil } - if _, ok := err.(*tcpip.ErrBadAddress); !ok { - return err - } r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) if err != nil { @@ -896,8 +911,26 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - // Loopback traffic skips the prerouting chain. + if !e.protocol.parse(pkt) { + stats.MalformedPacketsReceived.Increment() + return + } + if !e.nic.IsLoopback() { + if e.protocol.stack.HandleLocal() { + addressEndpoint := e.AcquireAssignedAddress(header.IPv6(pkt.NetworkHeader().View()).SourceAddress(), e.nic.Promiscuous(), stack.CanBePrimaryEndpoint) + if addressEndpoint != nil { + addressEndpoint.DecRef() + + // The source address is one of our own, so we never should have gotten + // a packet like this unless HandleLocal is false or our NIC is the + // loopback interface. + stats.InvalidSourceAddressesReceived.Increment() + return + } + } + + // Loopback traffic skips the prerouting chain. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. @@ -909,6 +942,21 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { e.handlePacket(pkt) } +func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum bool) { + stats := e.stats.ip + + stats.PacketsReceived.Increment() + + pkt = pkt.CloneToInbound() + if e.protocol.parse(pkt) { + pkt.RXTransportChecksumValidated = canSkipRXChecksum + e.handlePacket(pkt) + return + } + + stats.MalformedPacketsReceived.Increment() +} + // handlePacket is like HandlePacket except it does not perform the prerouting // iptables hook. func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { @@ -1384,18 +1432,18 @@ func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { return &tcpip.ErrBadLocalAddress{} } - return e.removePermanentEndpointLocked(addressEndpoint, true) + return e.removePermanentEndpointLocked(addressEndpoint, true /* allowSLAACInvalidation */, false /* dadFailure */) } // removePermanentEndpointLocked is like removePermanentAddressLocked except // it works with a stack.AddressEndpoint. // // Precondition: e.mu must be write locked. -func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool) tcpip.Error { +func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation, dadFailure bool) tcpip.Error { addr := addressEndpoint.AddressWithPrefix() unicast := header.IsV6UnicastAddress(addr.Address) if unicast { - e.mu.ndp.stopDuplicateAddressDetection(addr.Address) + e.mu.ndp.stopDuplicateAddressDetection(addr.Address, dadFailure) // If we are removing an address generated via SLAAC, cleanup // its SLAAC resources and notify the integrator. @@ -1741,6 +1789,13 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran e.mu.addressableEndpointState.Init(e) e.mu.ndp.init(e) e.mu.mld.init(e) + e.dad.mu.Lock() + e.dad.mu.dad.Init(&e.dad.mu, p.options.DADConfigs, ip.DADOptions{ + Clock: p.stack.Clock(), + Protocol: &e.mu.ndp, + NICID: nic.ID(), + }) + e.dad.mu.Unlock() e.mu.Unlock() stackStats := p.stack.Stats() @@ -1754,6 +1809,20 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran return e } +func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint { + p.mu.RLock() + defer p.mu.RUnlock() + + for _, e := range p.mu.eps { + if addressEndpoint := e.AcquireAssignedAddress(addr, false /* allowTemp */, stack.NeverPrimaryEndpoint); addressEndpoint != nil { + addressEndpoint.DecRef() + return e + } + } + + return nil +} + func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { p.mu.Lock() defer p.mu.Unlock() @@ -1798,6 +1867,29 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// parse is like Parse but also attempts to parse the transport layer. +// +// Returns true if the network header was successfully parsed. +func (p *protocol) parse(pkt *stack.PacketBuffer) bool { + transProtoNum, hasTransportHdr, ok := p.Parse(pkt) + if !ok { + return false + } + + if hasTransportHdr { + switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err { + case stack.ParsedOK: + case stack.UnknownTransportProtocol, stack.TransportLayerParseError: + // The transport layer will handle unknown protocols and transport layer + // parsing errors. + default: + panic(fmt.Sprintf("unexpected error parsing transport header = %d", err)) + } + } + + return true +} + // Parse implements stack.NetworkProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { proto, _, fragOffset, fragMore, ok := parse.IPv6(pkt) @@ -1906,6 +1998,9 @@ type Options struct { // MLD holds options for MLD. MLD MLDOptions + + // DADConfigs holds the default DAD configurations used by IPv6 endpoints. + DADConfigs stack.DADConfigurations } // NewProtocolWithOptions returns an IPv6 network protocol. diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 1c6c37c91..7e714b50e 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -31,7 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/testutil" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index 2cc0dfebd..205e36cdd 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -21,7 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/network/ip" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go index f6ffa7133..fe39555e0 100644 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -126,7 +126,7 @@ func TestSendQueuedMLDReports(t *testing.T) { clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ + DADConfigs: stack.DADConfigurations{ DupAddrDetectTransmits: test.dadTransmits, RetransmitTimer: test.retransmitTimer, }, diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index d7dde1767..53c043dcd 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -16,7 +16,6 @@ package ipv6 import ( "fmt" - "log" "math/rand" "time" @@ -24,34 +23,11 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) const ( - // defaultRetransmitTimer is the default amount of time to wait between - // sending reachability probes. - // - // Default taken from RETRANS_TIMER of RFC 4861 section 10. - defaultRetransmitTimer = time.Second - - // minimumRetransmitTimer is the minimum amount of time to wait between - // sending reachability probes. - // - // Note, RFC 4861 does not impose a minimum Retransmit Timer, but we do here - // to make sure the messages are not sent all at once. We also come to this - // value because in the RetransmitTimer field of a Router Advertisement, a - // value of 0 means unspecified, so the smallest valid value is 1. Note, the - // unit of the RetransmitTimer field in the Router Advertisement is - // milliseconds. - minimumRetransmitTimer = time.Millisecond - - // defaultDupAddrDetectTransmits is the default number of NDP Neighbor - // Solicitation messages to send when doing Duplicate Address Detection - // for a tentative address. - // - // Default = 1 (from RFC 4862 section 5.1) - defaultDupAddrDetectTransmits = 1 - // defaultMaxRtrSolicitations is the default number of Router // Solicitation messages to send when an IPv6 endpoint becomes enabled. // @@ -331,18 +307,6 @@ type NDPDispatcher interface { // NDPConfigurations is the NDP configurations for the netstack. type NDPConfigurations struct { - // The number of Neighbor Solicitation messages to send when doing - // Duplicate Address Detection for a tentative address. - // - // Note, a value of zero effectively disables DAD. - DupAddrDetectTransmits uint8 - - // The amount of time to wait between sending Neighbor solicitation - // messages. - // - // Must be greater than or equal to 1ms. - RetransmitTimer time.Duration - // The number of Router Solicitation messages to send when the IPv6 endpoint // becomes enabled. MaxRtrSolicitations uint8 @@ -414,8 +378,6 @@ type NDPConfigurations struct { // default values. func DefaultNDPConfigurations() NDPConfigurations { return NDPConfigurations{ - DupAddrDetectTransmits: defaultDupAddrDetectTransmits, - RetransmitTimer: defaultRetransmitTimer, MaxRtrSolicitations: defaultMaxRtrSolicitations, RtrSolicitationInterval: defaultRtrSolicitationInterval, MaxRtrSolicitationDelay: defaultMaxRtrSolicitationDelay, @@ -433,10 +395,6 @@ func DefaultNDPConfigurations() NDPConfigurations { // validate modifies an NDPConfigurations with valid values. If invalid values // are present in c, the corresponding default values are used instead. func (c *NDPConfigurations) validate() { - if c.RetransmitTimer < minimumRetransmitTimer { - c.RetransmitTimer = defaultRetransmitTimer - } - if c.RtrSolicitationInterval < minimumRtrSolicitationInterval { c.RtrSolicitationInterval = defaultRtrSolicitationInterval } @@ -458,7 +416,14 @@ func (c *NDPConfigurations) validate() { } } -// ndpState is the per-interface NDP state. +type timer struct { + // done indicates to the timer that the timer was stopped. + done *bool + + timer tcpip.Timer +} + +// ndpState is the per-Interface NDP state. type ndpState struct { // Do not allow overwriting this state. _ sync.NoCopy @@ -469,14 +434,17 @@ type ndpState struct { // configs is the per-interface NDP configurations. configs NDPConfigurations - // The DAD state to send the next NS message, or resolve the address. - dad map[tcpip.Address]dadState + // The DAD timers to send the next NS message, or resolve the address. + dad ip.DAD // The default routers discovered through Router Advertisements. defaultRouters map[tcpip.Address]defaultRouterState - // The job used to send the next router solicitation message. - rtrSolicitJob *tcpip.Job + // rtrSolicitTimer is the timer used to send the next router solicitation + // message. + // + // rtrSolicitTimer is the zero value when NDP is not soliciting routers. + rtrSolicitTimer timer // The on-link prefixes discovered through Router Advertisements' Prefix // Information option. @@ -498,19 +466,6 @@ type ndpState struct { temporaryAddressDesyncFactor time.Duration } -// dadState holds the Duplicate Address Detection timer and channel to signal -// to the DAD goroutine that DAD should stop. -type dadState struct { - // The DAD timer to send the next NS message, or resolve the address. - job *tcpip.Job - - // Used to let the DAD timer know that it has been stopped. - // - // Must only be read from or written to while protected by the lock of - // the IPv6 endpoint this dadState is associated with. - done *bool -} - // defaultRouterState holds data associated with a default router discovered by // a Router Advertisement (RA). type defaultRouterState struct { @@ -625,125 +580,47 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.ep.nic.ID())) } - // Should not attempt to perform DAD on an address that is currently in the - // DAD process. - if _, ok := ndp.dad[addr]; ok { - // Should never happen because we should only ever call this function for - // newly created addresses. If we attemped to "add" an address that already - // existed, we would get an error since we attempted to add a duplicate - // address, or its reference count would have been increased without doing - // the work that would have been done for an address that was brand new. - // See endpoint.addAddressLocked. - panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.ep.nic.ID())) - } + ret := ndp.dad.CheckDuplicateAddressLocked(addr, func(r stack.DADResult) { + if addressEndpoint.GetKind() != stack.PermanentTentative { + // The endpoint should still be marked as tentative since we are still + // performing DAD on it. + panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID())) + } - remaining := ndp.configs.DupAddrDetectTransmits - if remaining == 0 { - addressEndpoint.SetKind(stack.Permanent) + if r.Resolved { + addressEndpoint.SetKind(stack.Permanent) + } - // Consider DAD to have resolved even if no DAD messages were actually - // transmitted. if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { - ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil) + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, r.Resolved, r.Err) } - ndp.ep.onAddressAssignedLocked(addr) - return nil - } - - state := dadState{ - job: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { - state, ok := ndp.dad[addr] - if !ok { - panic(fmt.Sprintf("ndpdad: DAD timer fired but missing state for %s on NIC(%d)", addr, ndp.ep.nic.ID())) - } - - if addressEndpoint.GetKind() != stack.PermanentTentative { - // The endpoint should still be marked as tentative since we are still - // performing DAD on it. - panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID())) - } - - dadDone := remaining == 0 - - var err tcpip.Error - if !dadDone { - err = ndp.sendDADPacket(addr, addressEndpoint) - } - - if dadDone { - // DAD has resolved. - addressEndpoint.SetKind(stack.Permanent) - } else if err == nil { - // DAD is not done and we had no errors when sending the last NDP NS, - // schedule the next DAD timer. - remaining-- - state.job.Schedule(ndp.configs.RetransmitTimer) - return - } - - // At this point we know that either DAD is done or we hit an error - // sending the last NDP NS. Either way, clean up addr's DAD state and let - // the integrator know DAD has completed. - delete(ndp.dad, addr) - - if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { - ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err) + if r.Resolved { + if addressEndpoint.ConfigType() == stack.AddressConfigSlaac { + // Reset the generation attempts counter as we are starting the + // generation of a new address for the SLAAC prefix. + ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */) } - if dadDone { - if addressEndpoint.ConfigType() == stack.AddressConfigSlaac { - // Reset the generation attempts counter as we are starting the - // generation of a new address for the SLAAC prefix. - ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */) - } - - ndp.ep.onAddressAssignedLocked(addr) - } - }), - } - - // We initially start a timer to fire immediately because some of the DAD work - // cannot be done while holding the IPv6 endpoint's lock. This is effectively - // the same as starting a goroutine but we use a timer that fires immediately - // so we can reset it for the next DAD iteration. - state.job.Schedule(0) - ndp.dad[addr] = state - - return nil -} - -// sendDADPacket sends a NS message to see if any nodes on ndp's NIC's link owns -// addr. -// -// addr must be a tentative IPv6 address on ndp's IPv6 endpoint. -func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) tcpip.Error { - snmc := header.SolicitedNodeAddr(addr) - - icmp := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize)) - icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.MessageBody()) - ns.SetTargetAddress(addr) - icmp.SetChecksum(header.ICMPv6Checksum(icmp, header.IPv6Any, snmc, buffer.VectorisedView{})) - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()), - Data: buffer.View(icmp).ToVectorisedView(), + ndp.ep.onAddressAssignedLocked(addr) + } }) - sent := ndp.ep.stats.icmp.packetsSent - if err := addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{ - Protocol: header.ICMPv6ProtocolNumber, - TTL: header.NDPHopLimit, - }, nil /* extensionHeaders */); err != nil { - panic(fmt.Sprintf("failed to add IP header: %s", err)) - } + switch ret { + case stack.DADStarting: + case stack.DADAlreadyRunning: + panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.ep.nic.ID())) + case stack.DADDisabled: + addressEndpoint.SetKind(stack.Permanent) - if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* gso */, ProtocolNumber, pkt); err != nil { - sent.dropped.Increment() - return err + // Consider DAD to have resolved even if no DAD messages were actually + // transmitted. + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil) + } + + ndp.ep.onAddressAssignedLocked(addr) } - sent.neighborSolicit.Increment() return nil } @@ -756,20 +633,8 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.Add // of this function to handle such a scenario. // // The IPv6 endpoint that ndp belongs to MUST be locked. -func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { - dad, ok := ndp.dad[addr] - if !ok { - // Not currently performing DAD on addr, just return. - return - } - - dad.job.Cancel() - delete(ndp.dad, addr) - - // Let the integrator know DAD did not resolve. - if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { - ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, false, nil) - } +func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address, failed bool) { + ndp.dad.StopLocked(addr, !failed) } // handleRA handles a Router Advertisement message that arrived on the NIC @@ -1634,7 +1499,7 @@ func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefi if addressEndpoint := state.stableAddr.addressEndpoint; addressEndpoint != nil { // Since we are already invalidating the prefix, do not invalidate the // prefix when removing the address. - if err := ndp.ep.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */); err != nil { + if err := ndp.ep.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */, false /* dadFailure */); err != nil { panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", addressEndpoint.AddressWithPrefix(), err)) } } @@ -1693,7 +1558,7 @@ func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaa func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) { // Since we are already invalidating the address, do not invalidate the // address when removing the address. - if err := ndp.ep.removePermanentEndpointLocked(tempAddrState.addressEndpoint, false /* allowSLAACInvalidation */); err != nil { + if err := ndp.ep.removePermanentEndpointLocked(tempAddrState.addressEndpoint, false /* allowSLAACInvalidation */, false /* dadFailure */); err != nil { panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.addressEndpoint.AddressWithPrefix(), err)) } @@ -1803,7 +1668,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { // // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) startSolicitingRouters() { - if ndp.rtrSolicitJob != nil { + if ndp.rtrSolicitTimer.timer != nil { // We are already soliciting routers. return } @@ -1820,65 +1685,85 @@ func (ndp *ndpState) startSolicitingRouters() { delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay))) } - ndp.rtrSolicitJob = ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { - // As per RFC 4861 section 4.1, the source of the RS is an address assigned - // to the sending interface, or the unspecified address if no address is - // assigned to the sending interface. - localAddr := header.IPv6Any - if addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false); addressEndpoint != nil { - localAddr = addressEndpoint.AddressWithPrefix().Address - addressEndpoint.DecRef() - } + // Protected by ndp.ep.mu. + done := false - // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source - // link-layer address option if the source address of the NDP RS is - // specified. This option MUST NOT be included if the source address is - // unspecified. - // - // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by - // LinkEndpoint.LinkAddress) before reaching this point. - var optsSerializer header.NDPOptionsSerializer - linkAddress := ndp.ep.nic.LinkAddress() - if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(linkAddress) { - optsSerializer = header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(linkAddress), + ndp.rtrSolicitTimer = timer{ + done: &done, + timer: ndp.ep.protocol.stack.Clock().AfterFunc(delay, func() { + // As per RFC 4861 section 4.1: + // + // IP Fields: + // Source Address + // An IP address assigned to the sending interface, or + // the unspecified address if no address is assigned + // to the sending interface. + localAddr := header.IPv6Any + if addressEndpoint := ndp.ep.AcquireOutgoingPrimaryAddress(header.IPv6AllRoutersMulticastAddress, false); addressEndpoint != nil { + localAddr = addressEndpoint.AddressWithPrefix().Address + addressEndpoint.DecRef() } - } - payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length()) - icmpData := header.ICMPv6(buffer.NewView(payloadSize)) - icmpData.SetType(header.ICMPv6RouterSolicit) - rs := header.NDPRouterSolicit(icmpData.MessageBody()) - rs.Options().Serialize(optsSerializer) - icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, localAddr, header.IPv6AllRoutersMulticastAddress, buffer.VectorisedView{})) - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()), - Data: buffer.View(icmpData).ToVectorisedView(), - }) - - sent := ndp.ep.stats.icmp.packetsSent - if err := addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{ - Protocol: header.ICMPv6ProtocolNumber, - TTL: header.NDPHopLimit, - }, nil /* extensionHeaders */); err != nil { - panic(fmt.Sprintf("failed to add IP header: %s", err)) - } - if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { - sent.dropped.Increment() - log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err) - // Don't send any more messages if we had an error. - remaining = 0 - } else { - sent.routerSolicit.Increment() - remaining-- - } - if remaining != 0 { - ndp.rtrSolicitJob.Schedule(ndp.configs.RtrSolicitationInterval) - } - }) + // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source + // link-layer address option if the source address of the NDP RS is + // specified. This option MUST NOT be included if the source address is + // unspecified. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by + // LinkEndpoint.LinkAddress) before reaching this point. + var optsSerializer header.NDPOptionsSerializer + linkAddress := ndp.ep.nic.LinkAddress() + if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(linkAddress) { + optsSerializer = header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(linkAddress), + } + } + payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length()) + icmpData := header.ICMPv6(buffer.NewView(payloadSize)) + icmpData.SetType(header.ICMPv6RouterSolicit) + rs := header.NDPRouterSolicit(icmpData.MessageBody()) + rs.Options().Serialize(optsSerializer) + icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, localAddr, header.IPv6AllRoutersMulticastAddress, buffer.VectorisedView{})) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()), + Data: buffer.View(icmpData).ToVectorisedView(), + }) + + sent := ndp.ep.stats.icmp.packetsSent + if err := addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + }, nil /* extensionHeaders */); err != nil { + panic(fmt.Sprintf("failed to add IP header: %s", err)) + } - ndp.rtrSolicitJob.Schedule(delay) + if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { + sent.dropped.Increment() + // Don't send any more messages if we had an error. + remaining = 0 + } else { + sent.routerSolicit.Increment() + remaining-- + } + + ndp.ep.mu.Lock() + defer ndp.ep.mu.Unlock() + + if done { + // Router solicitation was stopped. + return + } + + if remaining == 0 { + // We are done soliciting routers. + ndp.stopSolicitingRouters() + return + } + + ndp.rtrSolicitTimer.timer.Reset(ndp.configs.RtrSolicitationInterval) + }), + } } // stopSolicitingRouters stops soliciting routers. If routers are not currently @@ -1886,23 +1771,28 @@ func (ndp *ndpState) startSolicitingRouters() { // // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) stopSolicitingRouters() { - if ndp.rtrSolicitJob == nil { + if ndp.rtrSolicitTimer.timer == nil { // Nothing to do. return } - ndp.rtrSolicitJob.Cancel() - ndp.rtrSolicitJob = nil + ndp.rtrSolicitTimer.timer.Stop() + *ndp.rtrSolicitTimer.done = true + ndp.rtrSolicitTimer = timer{} } func (ndp *ndpState) init(ep *endpoint) { - if ndp.dad != nil { + if ndp.defaultRouters != nil { panic("attempted to initialize NDP state twice") } ndp.ep = ep ndp.configs = ep.protocol.options.NDPConfigs - ndp.dad = make(map[tcpip.Address]dadState) + ndp.dad.Init(&ndp.ep.mu, ep.protocol.options.DADConfigs, ip.DADOptions{ + Clock: ep.protocol.stack.Clock(), + Protocol: ndp, + NICID: ep.nic.ID(), + }) ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState) ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState) ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState) @@ -1912,3 +1802,38 @@ func (ndp *ndpState) init(ep *endpoint) { ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor))) } } + +func (ndp *ndpState) SendDADMessage(addr tcpip.Address) tcpip.Error { + snmc := header.SolicitedNodeAddr(addr) + return ndp.ep.sendNDPNS(header.IPv6Any, snmc, addr, header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* opts */) +} + +func (e *endpoint) sendNDPNS(srcAddr, dstAddr, targetAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, opts header.NDPOptionsSerializer) tcpip.Error { + icmp := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize + opts.Length())) + icmp.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) + ns.SetTargetAddress(targetAddr) + ns.Options().Serialize(opts) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, srcAddr, dstAddr, buffer.VectorisedView{})) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(e.MaxHeaderLength()), + Data: buffer.View(icmp).ToVectorisedView(), + }) + + if err := addIPHeader(srcAddr, dstAddr, pkt, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + }, nil /* extensionHeaders */); err != nil { + panic(fmt.Sprintf("failed to add IP header: %s", err)) + } + + sent := e.stats.icmp.packetsSent + err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt) + if err != nil { + sent.dropped.Increment() + } else { + sent.neighborSolicit.Increment() + } + return err +} diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 8edaa9508..ce20af0e3 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -33,13 +34,12 @@ import ( // setupStackAndEndpoint creates a stack with a single NIC with a link-local // address llladdr and an IPv6 endpoint to a remote with link-local address // rlladdr -func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeighborCache bool) (*stack.Stack, stack.NetworkEndpoint) { +func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) { t.Helper() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - UseNeighborCache: useNeighborCache, }) if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { @@ -237,107 +237,6 @@ func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) { Data: hdr.View().ToVectorisedView(), })) - ch := make(chan stack.LinkResolutionResult, 1) - err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, func(r stack.LinkResolutionResult) { - ch <- r - }) - - wantInvalid := uint64(0) - wantSucccess := true - if len(test.expectedLinkAddr) == 0 { - wantInvalid = 1 - wantSucccess = false - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, &tcpip.ErrWouldBlock{}) - } - } else { - if err != nil { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = nil", nicID, lladdr1, lladdr0, ProtocolNumber, err) - } - } - - if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: wantSucccess}, <-ch); diff != "" { - t.Errorf("linkResolutionResult mismatch (-want +got):\n%s", diff) - } - if got := invalid.Value(); got != wantInvalid { - t.Errorf("got invalid = %d, want = %d", got, wantInvalid) - } - }) - } -} - -// TestNeighborSolicitationWithSourceLinkLayerOptionUsingNeighborCache tests -// that receiving a valid NDP NS message with the Source Link Layer Address -// option results in a new entry in the link address cache for the sender of -// the message. -func TestNeighborSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - optsBuf []byte - expectedLinkAddr tcpip.LinkAddress - }{ - { - name: "Valid", - optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7}, - expectedLinkAddr: "\x02\x03\x04\x05\x06\x07", - }, - { - name: "Too Small", - optsBuf: []byte{1, 1, 2, 3, 4, 5, 6}, - }, - { - name: "Invalid Length", - optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - UseNeighborCache: true, - }) - e := channel.New(0, 1280, linkAddr0) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) - } - - ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) - pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.MessageBody()) - ns.SetTargetAddress(lladdr0) - opts := ns.Options() - copy(opts, test.optsBuf) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - - invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid - - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - neighbors, err := s.Neighbors(nicID, ProtocolNumber) if err != nil { t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err) @@ -392,20 +291,6 @@ func TestNeighborSolicitationResponse(t *testing.T) { remoteLinkAddr0 := linkAddr1 remoteLinkAddr1 := linkAddr2 - stacks := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, - }, - } - tests := []struct { name string nsOpts header.NDPOptionsSerializer @@ -564,229 +449,44 @@ func TestNeighborSolicitationResponse(t *testing.T) { }, } - for _, stackTyp := range stacks { - t.Run(stackTyp.name, func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - UseNeighborCache: stackTyp.useNeighborCache, - }) - e := channel.New(1, 1280, nicLinkAddr) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) - } - - s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv6EmptySubnet, - NIC: 1, - }, - }) - - ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length() - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) - pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.MessageBody()) - ns.SetTargetAddress(nicAddr) - opts := ns.Options() - opts.Serialize(test.nsOpts) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: 255, - SrcAddr: test.nsSrc, - DstAddr: test.nsDst, - }) - - invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid - - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - - if test.nsInvalid { - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - - if p, got := e.Read(); got { - t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt) - } - - // If we expected the NS to be invalid, we have nothing else to check. - return - } - - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - - if test.performsLinkResolution { - p, got := e.ReadContext(context.Background()) - if !got { - t.Fatal("expected an NDP NS response") - } - - respNSDst := header.SolicitedNodeAddr(test.nsSrc) - var want stack.RouteInfo - want.NetProto = ProtocolNumber - want.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(respNSDst) - if diff := cmp.Diff(want, p.Route, cmp.AllowUnexported(want)); diff != "" { - t.Errorf("route info mismatch (-want +got):\n%s", diff) - } - - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(nicAddr), - checker.DstAddr(respNSDst), - checker.TTL(header.NDPHopLimit), - checker.NDPNS( - checker.NDPNSTargetAddress(test.nsSrc), - checker.NDPNSOptions([]header.NDPOption{ - header.NDPSourceLinkLayerAddressOption(nicLinkAddr), - }), - )) - - ser := header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - } - ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + ser.Length() - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) - pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.MessageBody()) - na.SetSolicitedFlag(true) - na.SetOverrideFlag(true) - na.SetTargetAddress(test.nsSrc) - na.Options().Serialize(ser) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, nicAddr, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: header.NDPHopLimit, - SrcAddr: test.nsSrc, - DstAddr: nicAddr, - }) - e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - - p, got := e.ReadContext(context.Background()) - if !got { - t.Fatal("expected an NDP NA response") - } - - if p.Route.LocalAddress != test.naSrc { - t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, test.naSrc) - } - if p.Route.LocalLinkAddress != nicLinkAddr { - t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr) - } - if p.Route.RemoteAddress != test.naDst { - t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst) - } - if p.Route.RemoteLinkAddress != test.naDstLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) - } - - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(test.naSrc), - checker.DstAddr(test.naDst), - checker.TTL(header.NDPHopLimit), - checker.NDPNA( - checker.NDPNASolicitedFlag(test.naSolicited), - checker.NDPNATargetAddress(nicAddr), - checker.NDPNAOptions([]header.NDPOption{ - header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]), - }), - )) - }) - } - }) - } -} - -// TestNeighborAdvertisementWithTargetLinkLayerOption tests that receiving a -// valid NDP NA message with the Target Link Layer Address option results in a -// new entry in the link address cache for the target of the message. -func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { - const nicID = 1 - - tests := []struct { - name string - optsBuf []byte - expectedLinkAddr tcpip.LinkAddress - }{ - { - name: "Valid", - optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7}, - expectedLinkAddr: "\x02\x03\x04\x05\x06\x07", - }, - { - name: "Too Small", - optsBuf: []byte{2, 1, 2, 3, 4, 5, 6}, - }, - { - name: "Invalid Length", - optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7}, - }, - { - name: "Multiple", - optsBuf: []byte{ - 2, 1, 2, 3, 4, 5, 6, 7, - 2, 1, 2, 3, 4, 5, 6, 8, - }, - }, - } - for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - UseLinkAddrCache: true, }) - e := channel.New(0, 1280, linkAddr0) + e := channel.New(1, 1280, nicLinkAddr) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) } - ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) - pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - ns := header.NDPNeighborAdvert(pkt.MessageBody()) - ns.SetTargetAddress(lladdr1) + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv6EmptySubnet, + NIC: 1, + }, + }) + + ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length() + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) + pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.MessageBody()) + ns.SetTargetAddress(nicAddr) opts := ns.Options() - copy(opts, test.optsBuf) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + opts.Serialize(test.nsOpts) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{})) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(payloadLength), TransportProtocol: header.ICMPv6ProtocolNumber, HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, + SrcAddr: test.nsSrc, + DstAddr: test.nsDst, }) invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid @@ -796,44 +496,116 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), })) - ch := make(chan stack.LinkResolutionResult, 1) - err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, func(r stack.LinkResolutionResult) { - ch <- r - }) + if test.nsInvalid { + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } - wantInvalid := uint64(0) - wantSucccess := true - if len(test.expectedLinkAddr) == 0 { - wantInvalid = 1 - wantSucccess = false - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, &tcpip.ErrWouldBlock{}) + if p, got := e.Read(); got { + t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt) } - } else { - if err != nil { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = nil", nicID, lladdr1, lladdr0, ProtocolNumber, err) + + // If we expected the NS to be invalid, we have nothing else to check. + return + } + + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + if test.performsLinkResolution { + p, got := e.ReadContext(context.Background()) + if !got { + t.Fatal("expected an NDP NS response") + } + + respNSDst := header.SolicitedNodeAddr(test.nsSrc) + var want stack.RouteInfo + want.NetProto = ProtocolNumber + want.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(respNSDst) + if diff := cmp.Diff(want, p.Route, cmp.AllowUnexported(want)); diff != "" { + t.Errorf("route info mismatch (-want +got):\n%s", diff) } + + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(nicAddr), + checker.DstAddr(respNSDst), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(test.nsSrc), + checker.NDPNSOptions([]header.NDPOption{ + header.NDPSourceLinkLayerAddressOption(nicLinkAddr), + }), + )) + + ser := header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + } + ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + ser.Length() + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) + pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(pkt.MessageBody()) + na.SetSolicitedFlag(true) + na.SetOverrideFlag(true) + na.SetTargetAddress(test.nsSrc) + na.Options().Serialize(ser) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, nicAddr, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: test.nsSrc, + DstAddr: nicAddr, + }) + e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) } - if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: wantSucccess}, <-ch); diff != "" { - t.Errorf("linkResolutionResult mismatch (-want +got):\n%s", diff) + p, got := e.ReadContext(context.Background()) + if !got { + t.Fatal("expected an NDP NA response") } - if got := invalid.Value(); got != wantInvalid { - t.Errorf("got invalid = %d, want = %d", got, wantInvalid) + + if p.Route.LocalAddress != test.naSrc { + t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, test.naSrc) + } + if p.Route.LocalLinkAddress != nicLinkAddr { + t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr) + } + if p.Route.RemoteAddress != test.naDst { + t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst) } + if p.Route.RemoteLinkAddress != test.naDstLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) + } + + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(test.naSrc), + checker.DstAddr(test.naDst), + checker.TTL(header.NDPHopLimit), + checker.NDPNA( + checker.NDPNASolicitedFlag(test.naSolicited), + checker.NDPNATargetAddress(nicAddr), + checker.NDPNAOptions([]header.NDPOption{ + header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]), + }), + )) }) } } -// TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache tests -// that receiving a valid NDP NA message with the Target Link Layer Address -// option does not result in a new entry in the neighbor cache for the target -// of the message. -func TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *testing.T) { +// TestNeighborAdvertisementWithTargetLinkLayerOption tests that receiving a +// valid NDP NA message with the Target Link Layer Address option does not +// result in a new entry in the neighbor cache for the target of the message. +func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { const nicID = 1 tests := []struct { @@ -867,7 +639,6 @@ func TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *tes t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - UseNeighborCache: true, }) e := channel.New(0, 1280, linkAddr0) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired @@ -944,235 +715,216 @@ func TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *tes } func TestNDPValidation(t *testing.T) { - stacks := []struct { - name string - useNeighborCache bool + setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) { + t.Helper() + + // Create a stack with the assigned link-local address lladdr0 + // and an endpoint to lladdr1. + s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1) + + return s, ep + } + + handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) { + var extHdrs header.IPv6ExtHdrSerializer + if atomicFragment { + extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{}) + } + extHdrsLen := extHdrs.Length() + + ip := buffer.NewView(header.IPv6MinimumSize + extHdrsLen) + header.IPv6(ip).Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(payload) + extHdrsLen), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: hopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + ExtensionHeaders: extHdrs, + }) + vv := ip.ToVectorisedView() + vv.AppendView(payload) + ep.HandlePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + })) + } + + var tllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) + + var sllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(sllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(linkAddr1), + }) + + types := []struct { + name string + typ header.ICMPv6Type + size int + extraData []byte + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + routerOnly bool }{ { - name: "linkAddrCache", - useNeighborCache: false, + name: "RouterSolicit", + typ: header.ICMPv6RouterSolicit, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterSolicit + }, + routerOnly: true, + }, + { + name: "RouterAdvert", + typ: header.ICMPv6RouterAdvert, + size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterAdvert + }, }, { - name: "neighborCache", - useNeighborCache: true, + name: "NeighborSolicit", + typ: header.ICMPv6NeighborSolicit, + size: header.ICMPv6NeighborSolicitMinimumSize, + extraData: sllData[:], + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborSolicit + }, + }, + { + name: "NeighborAdvert", + typ: header.ICMPv6NeighborAdvert, + size: header.ICMPv6NeighborAdvertMinimumSize, + extraData: tllData[:], + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborAdvert + }, + }, + { + name: "RedirectMsg", + typ: header.ICMPv6RedirectMsg, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RedirectMsg + }, }, } - for _, stackTyp := range stacks { - t.Run(stackTyp.name, func(t *testing.T) { - setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) { - t.Helper() - - // Create a stack with the assigned link-local address lladdr0 - // and an endpoint to lladdr1. - s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1, stackTyp.useNeighborCache) + subTests := []struct { + name string + atomicFragment bool + hopLimit uint8 + code header.ICMPv6Code + valid bool + }{ + { + name: "Valid", + atomicFragment: false, + hopLimit: header.NDPHopLimit, + code: 0, + valid: true, + }, + { + name: "Fragmented", + atomicFragment: true, + hopLimit: header.NDPHopLimit, + code: 0, + valid: false, + }, + { + name: "Invalid hop limit", + atomicFragment: false, + hopLimit: header.NDPHopLimit - 1, + code: 0, + valid: false, + }, + { + name: "Invalid ICMPv6 code", + atomicFragment: false, + hopLimit: header.NDPHopLimit, + code: 1, + valid: false, + }, + } - return s, ep + for _, typ := range types { + for _, isRouter := range []bool{false, true} { + name := typ.name + if isRouter { + name += " (Router)" } - handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) { - var extHdrs header.IPv6ExtHdrSerializer - if atomicFragment { - extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{}) - } - extHdrsLen := extHdrs.Length() + t.Run(name, func(t *testing.T) { + for _, test := range subTests { + t.Run(test.name, func(t *testing.T) { + s, ep := setup(t) - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.IPv6MinimumSize + extHdrsLen, - Data: payload.ToVectorisedView(), - }) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(payload) + extHdrsLen), - TransportProtocol: header.ICMPv6ProtocolNumber, - HopLimit: hopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - ExtensionHeaders: extHdrs, - }) - ep.HandlePacket(pkt) - } + if isRouter { + // Enabling forwarding makes the stack act as a router. + s.SetForwarding(ProtocolNumber, true) + } - var tllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) + stats := s.Stats().ICMP.V6.PacketsReceived + invalid := stats.Invalid + routerOnly := stats.RouterOnlyPacketsDroppedByHost + typStat := typ.statCounter(stats) - var sllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(sllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(linkAddr1), - }) + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) + icmp.SetCode(test.code) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) - types := []struct { - name string - typ header.ICMPv6Type - size int - extraData []byte - statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter - routerOnly bool - }{ - { - name: "RouterSolicit", - typ: header.ICMPv6RouterSolicit, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterSolicit - }, - routerOnly: true, - }, - { - name: "RouterAdvert", - typ: header.ICMPv6RouterAdvert, - size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterAdvert - }, - }, - { - name: "NeighborSolicit", - typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize, - extraData: sllData[:], - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborSolicit - }, - }, - { - name: "NeighborAdvert", - typ: header.ICMPv6NeighborAdvert, - size: header.ICMPv6NeighborAdvertMinimumSize, - extraData: tllData[:], - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborAdvert - }, - }, - { - name: "RedirectMsg", - typ: header.ICMPv6RedirectMsg, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RedirectMsg - }, - }, - } + // Rx count of the NDP message should initially be 0. + if got := typStat.Value(); got != 0 { + t.Errorf("got %s = %d, want = 0", typ.name, got) + } - subTests := []struct { - name string - atomicFragment bool - hopLimit uint8 - code header.ICMPv6Code - valid bool - }{ - { - name: "Valid", - atomicFragment: false, - hopLimit: header.NDPHopLimit, - code: 0, - valid: true, - }, - { - name: "Fragmented", - atomicFragment: true, - hopLimit: header.NDPHopLimit, - code: 0, - valid: false, - }, - { - name: "Invalid hop limit", - atomicFragment: false, - hopLimit: header.NDPHopLimit - 1, - code: 0, - valid: false, - }, - { - name: "Invalid ICMPv6 code", - atomicFragment: false, - hopLimit: header.NDPHopLimit, - code: 1, - valid: false, - }, - } + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Errorf("got invalid = %d, want = 0", got) + } - for _, typ := range types { - for _, isRouter := range []bool{false, true} { - name := typ.name - if isRouter { - name += " (Router)" - } + // RouterOnlyPacketsReceivedByHost count should initially be 0. + if got := routerOnly.Value(); got != 0 { + t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) + } + + if t.Failed() { + t.FailNow() + } + + handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep) + + // Rx count of the NDP packet should have increased. + if got := typStat.Value(); got != 1 { + t.Errorf("got %s = %d, want = 1", typ.name, got) + } + + want := uint64(0) + if !test.valid { + // Invalid count should have increased. + want = 1 + } + if got := invalid.Value(); got != want { + t.Errorf("got invalid = %d, want = %d", got, want) + } - t.Run(name, func(t *testing.T) { - for _, test := range subTests { - t.Run(test.name, func(t *testing.T) { - s, ep := setup(t) - - if isRouter { - // Enabling forwarding makes the stack act as a router. - s.SetForwarding(ProtocolNumber, true) - } - - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - routerOnly := stats.RouterOnlyPacketsDroppedByHost - typStat := typ.statCounter(stats) - - icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) - copy(icmp[typ.size:], typ.extraData) - icmp.SetType(typ.typ) - icmp.SetCode(test.code) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) - - // Rx count of the NDP message should initially be 0. - if got := typStat.Value(); got != 0 { - t.Errorf("got %s = %d, want = 0", typ.name, got) - } - - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) - } - - // RouterOnlyPacketsReceivedByHost count should initially be 0. - if got := routerOnly.Value(); got != 0 { - t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) - } - - if t.Failed() { - t.FailNow() - } - - handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep) - - // Rx count of the NDP packet should have increased. - if got := typStat.Value(); got != 1 { - t.Errorf("got %s = %d, want = 1", typ.name, got) - } - - want := uint64(0) - if !test.valid { - // Invalid count should have increased. - want = 1 - } - if got := invalid.Value(); got != want { - t.Errorf("got invalid = %d, want = %d", got, want) - } - - want = 0 - if test.valid && !isRouter && typ.routerOnly { - // RouterOnlyPacketsReceivedByHost count should have increased. - want = 1 - } - if got := routerOnly.Value(); got != want { - t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want) - } - - }) + want = 0 + if test.valid && !isRouter && typ.routerOnly { + // RouterOnlyPacketsReceivedByHost count should have increased. + want = 1 + } + if got := routerOnly.Value(); got != want { + t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want) } + }) } - } - }) + }) + } } - } // TestNeighborAdvertisementValidation tests that the NIC validates received @@ -1218,7 +970,6 @@ func TestNeighborAdvertisementValidation(t *testing.T) { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - UseNeighborCache: true, }) e := channel.New(0, header.IPv6MinimumMTU, linkAddr0) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired @@ -1291,20 +1042,6 @@ func TestNeighborAdvertisementValidation(t *testing.T) { // NDP Router Advertisement packets, it validates the Router Advertisement // properly before handling them. func TestRouterAdvertValidation(t *testing.T) { - stacks := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, - }, - } - tests := []struct { name string src tcpip.Address @@ -1426,68 +1163,170 @@ func TestRouterAdvertValidation(t *testing.T) { }, } - for _, stackTyp := range stacks { - t.Run(stackTyp.name, func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e := channel.New(10, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - UseNeighborCache: stackTyp.useNeighborCache, - }) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } - icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - pkt := header.ICMPv6(hdr.Prepend(icmpSize)) - pkt.SetType(header.ICMPv6RouterAdvert) - pkt.SetCode(test.code) - copy(pkt.MessageBody(), test.ndpPayload) - payloadLength := hdr.UsedLength() - pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - TransportProtocol: icmp.ProtocolNumber6, - HopLimit: test.hopLimit, - SrcAddr: test.src, - DstAddr: header.IPv6AllNodesMulticastAddress, - }) + icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(header.ICMPv6RouterAdvert) + pkt.SetCode(test.code) + copy(pkt.MessageBody(), test.ndpPayload) + payloadLength := hdr.UsedLength() + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: test.hopLimit, + SrcAddr: test.src, + DstAddr: header.IPv6AllNodesMulticastAddress, + }) - stats := s.Stats().ICMP.V6.PacketsReceived - invalid := stats.Invalid - rxRA := stats.RouterAdvert + stats := s.Stats().ICMP.V6.PacketsReceived + invalid := stats.Invalid + rxRA := stats.RouterAdvert - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := rxRA.Value(); got != 0 { - t.Fatalf("got rxRA = %d, want = 0", got) - } + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := rxRA.Value(); got != 0 { + t.Fatalf("got rxRA = %d, want = 0", got) + } - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) - if got := rxRA.Value(); got != 1 { - t.Fatalf("got rxRA = %d, want = 1", got) - } + if got := rxRA.Value(); got != 1 { + t.Fatalf("got rxRA = %d, want = 1", got) + } - if test.expectedSuccess { - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - } else { - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - } - }) + if test.expectedSuccess { + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + } else { + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } } }) } } + +// TestCheckDuplicateAddress checks that calls to CheckDuplicateAddress and DAD +// performed when adding new addresses do not interfere with each other. +func TestCheckDuplicateAddress(t *testing.T) { + const nicID = 1 + + clock := faketime.NewManualClock() + dadConfigs := stack.DADConfigurations{ + DupAddrDetectTransmits: 1, + RetransmitTimer: time.Second, + } + s := stack.New(stack.Options{ + Clock: clock, + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{ + DADConfigs: dadConfigs, + })}, + }) + // This test is expected to send at max 2 DAD messages. We allow an extra + // packet to be stored to catch unexpected packets. + e := channel.New(3, header.IPv6MinimumMTU, linkAddr0) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + dadPacketsSent := 1 + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + } + + // Start DAD for the address we just added. + // + // Even though the stack will perform DAD before the added address transitions + // from tentative to assigned, this DAD request should be independent of that. + ch := make(chan stack.DADResult, 3) + dadRequestsMade := 1 + dadPacketsSent++ + if res, err := s.CheckDuplicateAddress(nicID, ProtocolNumber, lladdr0, func(r stack.DADResult) { + ch <- r + }); err != nil { + t.Fatalf("s.CheckDuplicateAddress(%d, %d, %s, _): %s", nicID, ProtocolNumber, lladdr0, err) + } else if res != stack.DADStarting { + t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, ProtocolNumber, lladdr0, res, stack.DADStarting) + } + + // Remove the address and make sure our DAD request was not stopped. + if err := s.RemoveAddress(nicID, lladdr0); err != nil { + t.Fatalf("RemoveAddress(%d, %s): %s", nicID, lladdr0, err) + } + // Should not restart DAD since we already requested DAD above - the handler + // should be called when the original request compeletes so we should not send + // an extra DAD message here. + dadRequestsMade++ + if res, err := s.CheckDuplicateAddress(nicID, ProtocolNumber, lladdr0, func(r stack.DADResult) { + ch <- r + }); err != nil { + t.Fatalf("s.CheckDuplicateAddress(%d, %d, %s, _): %s", nicID, ProtocolNumber, lladdr0, err) + } else if res != stack.DADAlreadyRunning { + t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, ProtocolNumber, lladdr0, res, stack.DADAlreadyRunning) + } + + // Wait for DAD to resolve. + clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) + for i := 0; i < dadRequestsMade; i++ { + if diff := cmp.Diff(stack.DADResult{Resolved: true}, <-ch); diff != "" { + t.Errorf("(i=%d) DAD result mismatch (-want +got):\n%s", i, diff) + } + } + // Should have no more results. + select { + case r := <-ch: + t.Errorf("unexpectedly got an extra DAD result; r = %#v", r) + default: + } + + snmc := header.SolicitedNodeAddr(lladdr0) + remoteLinkAddr := header.EthernetAddressFromMulticastIPv6Address(snmc) + + for i := 0; i < dadPacketsSent; i++ { + p, ok := e.Read() + if !ok { + t.Fatalf("expected %d-th DAD message", i) + } + + if p.Proto != header.IPv6ProtocolNumber { + t.Errorf("(i=%d) got p.Proto = %d, want = %d", i, p.Proto, header.IPv6ProtocolNumber) + } + + if p.Route.RemoteLinkAddress != remoteLinkAddr { + t.Errorf("(i=%d) got p.Route.RemoteLinkAddress = %s, want = %s", i, p.Route.RemoteLinkAddress, remoteLinkAddr) + } + + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(snmc), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(lladdr0), + checker.NDPNSOptions(nil), + )) + } + + // Should have no more packets. + if p, ok := e.Read(); ok { + t.Errorf("got unexpected packet = %#v", p) + } +} diff --git a/pkg/tcpip/network/ipv6/stats.go b/pkg/tcpip/network/ipv6/stats.go index 0839be3cd..c2758352f 100644 --- a/pkg/tcpip/network/ipv6/stats.go +++ b/pkg/tcpip/network/ipv6/stats.go @@ -16,7 +16,7 @@ package ipv6 import ( "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/ip" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index 1e00144a5..dc37e61a4 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -54,11 +54,10 @@ type SocketOptionsHandler interface { // HasNIC is invoked to check if the NIC is valid for SO_BINDTODEVICE. HasNIC(v int32) bool - // GetSendBufferSize is invoked to get the SO_SNDBUFSIZE. - GetSendBufferSize() (int64, Error) - - // IsUnixSocket is invoked to check if the socket is of unix domain. - IsUnixSocket() bool + // OnSetSendBufferSize is invoked when the send buffer size for an endpoint is + // changed. The handler is invoked with the new value for the socket send + // buffer size. It also returns the newly set value. + OnSetSendBufferSize(v int64) (newSz int64) } // DefaultSocketOptionsHandler is an embeddable type that implements no-op @@ -95,14 +94,9 @@ func (*DefaultSocketOptionsHandler) HasNIC(int32) bool { return false } -// GetSendBufferSize implements SocketOptionsHandler.GetSendBufferSize. -func (*DefaultSocketOptionsHandler) GetSendBufferSize() (int64, Error) { - return 0, nil -} - -// IsUnixSocket implements SocketOptionsHandler.IsUnixSocket. -func (*DefaultSocketOptionsHandler) IsUnixSocket() bool { - return false +// OnSetSendBufferSize implements SocketOptionsHandler.OnSetSendBufferSize. +func (*DefaultSocketOptionsHandler) OnSetSendBufferSize(v int64) (newSz int64) { + return v } // StackHandler holds methods to access the stack options. These must be @@ -600,42 +594,41 @@ func (so *SocketOptions) SetBindToDevice(bindToDevice int32) Error { } // GetSendBufferSize gets value for SO_SNDBUF option. -func (so *SocketOptions) GetSendBufferSize() (int64, Error) { - if so.handler.IsUnixSocket() { - return so.handler.GetSendBufferSize() - } - return atomic.LoadInt64(&so.sendBufferSize), nil +func (so *SocketOptions) GetSendBufferSize() int64 { + return atomic.LoadInt64(&so.sendBufferSize) } // SetSendBufferSize sets value for SO_SNDBUF option. notify indicates if the // stack handler should be invoked to set the send buffer size. func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) { - if so.handler.IsUnixSocket() { + v := sendBufferSize + + if !notify { + atomic.StoreInt64(&so.sendBufferSize, v) return } - v := sendBufferSize - if notify { - // TODO(b/176170271): Notify waiters after size has grown. - // Make sure the send buffer size is within the min and max - // allowed. - ss := so.getSendBufferLimits(so.stackHandler) - min := int64(ss.Min) - max := int64(ss.Max) - // Validate the send buffer size with min and max values. - // Multiply it by factor of 2. - if v > max { - v = max - } + // Make sure the send buffer size is within the min and max + // allowed. + ss := so.getSendBufferLimits(so.stackHandler) + min := int64(ss.Min) + max := int64(ss.Max) + // Validate the send buffer size with min and max values. + // Multiply it by factor of 2. + if v > max { + v = max + } - if v < math.MaxInt32/PacketOverheadFactor { - v *= PacketOverheadFactor - if v < min { - v = min - } - } else { - v = math.MaxInt32 + if v < math.MaxInt32/PacketOverheadFactor { + v *= PacketOverheadFactor + if v < min { + v = min } + } else { + v = math.MaxInt32 } - atomic.StoreInt64(&so.sendBufferSize, v) + + // Notify endpoint about change in buffer size. + newSz := so.handler.OnSetSendBufferSize(v) + atomic.StoreInt64(&so.sendBufferSize, newSz) } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index ee23c9b98..49362333a 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -4,18 +4,6 @@ load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) go_template_instance( - name = "linkaddrentry_list", - out = "linkaddrentry_list.go", - package = "stack", - prefix = "linkAddrEntry", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*linkAddrEntry", - "Linker": "*linkAddrEntry", - }, -) - -go_template_instance( name = "neighbor_entry_list", out = "neighbor_entry_list.go", package = "stack", @@ -62,8 +50,6 @@ go_library( "iptables_state.go", "iptables_targets.go", "iptables_types.go", - "linkaddrcache.go", - "linkaddrentry_list.go", "neighbor_cache.go", "neighbor_entry.go", "neighbor_entry_list.go", @@ -141,7 +127,6 @@ go_test( size = "small", srcs = [ "forwarding_test.go", - "linkaddrcache_test.go", "neighbor_cache_test.go", "neighbor_entry_test.go", "nic_test.go", diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 54617f2e6..cdb435644 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -231,6 +231,12 @@ func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { return &conn } +func (ct *ConnTrack) init() { + ct.mu.Lock() + defer ct.mu.Unlock() + ct.buckets = make([]bucket, numBuckets) +} + // connFor gets the conn for pkt if it exists, or returns nil // if it does not. It returns an error when pkt does not contain a valid TCP // header. diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index c24f56ece..c987c1851 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -75,6 +75,10 @@ func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { } func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) { + if _, _, ok := f.proto.Parse(pkt); !ok { + return + } + netHdr := pkt.NetworkHeader().View() _, dst := f.proto.ParseAddresses(netHdr) @@ -161,9 +165,9 @@ var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) type fwdTestNetworkProtocol struct { stack *Stack - neighborTable neighborTable + neigh *neighborCache addrResolveDelay time.Duration - onLinkAddressResolved func(neighborTable, tcpip.Address, tcpip.LinkAddress) + onLinkAddressResolved func(*neighborCache, tcpip.Address, tcpip.LinkAddress) onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) mu struct { @@ -221,7 +225,7 @@ func (*fwdTestNetworkProtocol) Wait() {} func (f *fwdTestNetworkEndpoint) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { if fn := f.proto.onLinkAddressResolved; fn != nil { time.AfterFunc(f.proto.addrResolveDelay, func() { - fn(f.proto.neighborTable, addr, remoteLinkAddr) + fn(f.proto.neigh, addr, remoteLinkAddr) }) } return nil @@ -357,14 +361,13 @@ func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protoco panic("not implemented") } -func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) { +func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) { // Create a stack with the network protocol and two NICs. s := New(Options{ NetworkProtocols: []NetworkProtocolFactory{func(s *Stack) NetworkProtocol { proto.stack = s return proto }}, - UseNeighborCache: useNeighborCache, }) // Enable forwarding. @@ -402,7 +405,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC } if l, ok := nic.linkAddrResolvers[fwdTestNetNumber]; ok { - proto.neighborTable = l.neighborTable + proto.neigh = &l.neigh } // Route all packets to NIC 2. @@ -418,121 +421,85 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC } func TestForwardingWithStaticResolver(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, + // Create a network protocol with a static resolver. + proto := &fwdTestNetworkProtocol{ + onResolveStaticAddress: + // The network address 3 is resolved to the link address "c". + func(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if addr == "\x03" { + return "c", true + } + return "", false }, } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - // Create a network protocol with a static resolver. - proto := &fwdTestNetworkProtocol{ - onResolveStaticAddress: - // The network address 3 is resolved to the link address "c". - func(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if addr == "\x03" { - return "c", true - } - return "", false - }, - } - - ep1, ep2 := fwdTestNetFactory(t, proto, test.useNeighborCache) + ep1, ep2 := fwdTestNetFactory(t, proto) - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) - var p fwdTestPacketInfo + var p fwdTestPacketInfo - select { - case p = <-ep2.C: - default: - t.Fatal("packet not forwarded") - } + select { + case p = <-ep2.C: + default: + t.Fatal("packet not forwarded") + } - // Test that the static address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - }) + // Test that the static address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) } } func TestForwardingWithFakeResolver(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, + proto := fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { + t.Helper() + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) + } + // Any address will be resolved to the link address "c". + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) }, } + ep1, ep2 := fwdTestNetFactory(t, &proto) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Any address will be resolved to the link address "c". - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - } - ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) - - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - }) + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) } } @@ -542,7 +509,7 @@ func TestForwardingWithNoResolver(t *testing.T) { // Whether or not we use the neighbor cache here does not matter since // neither linkAddrCache nor neighborCache will be used. - ep1, ep2 := fwdTestNetFactory(t, proto, false /* useNeighborCache */) + ep1, ep2 := fwdTestNetFactory(t, proto) // inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. @@ -562,12 +529,12 @@ func TestForwardingWithNoResolver(t *testing.T) { func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { proto := &fwdTestNetworkProtocol{ addrResolveDelay: 50 * time.Millisecond, - onLinkAddressResolved: func(neighborTable, tcpip.Address, tcpip.LinkAddress) { + onLinkAddressResolved: func(*neighborCache, tcpip.Address, tcpip.LinkAddress) { // Don't resolve the link address. }, } - ep1, ep2 := fwdTestNetFactory(t, proto, true /* useNeighborCache */) + ep1, ep2 := fwdTestNetFactory(t, proto) const numPackets int = 5 // These packets will all be enqueued in the packet queue to wait for link @@ -592,300 +559,227 @@ func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { } func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, + proto := fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { + t.Helper() + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) + } + // Only packets to address 3 will be resolved to the + // link address "c". + if addr == "\x03" { + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + } }, } + ep1, ep2 := fwdTestNetFactory(t, &proto) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Only packets to address 3 will be resolved to the - // link address "c". - if addr == "\x03" { - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - } - }, - } - ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) - - // Inject an inbound packet to address 4 on NIC 1. This packet should - // not be forwarded. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 4 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf = buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } + // Inject an inbound packet to address 4 on NIC 1. This packet should + // not be forwarded. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 4 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) - if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) - } + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf = buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - }) + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) } } func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, + proto := fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { + t.Helper() + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) + } + // Any packets will be resolved to the link address "c". + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) }, } + ep1, ep2 := fwdTestNetFactory(t, &proto) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Any packets will be resolved to the link address "c". - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - } - ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) - - // Inject two inbound packets to address 3 on NIC 1. - for i := 0; i < 2; i++ { - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } + // Inject two inbound packets to address 3 on NIC 1. + for i := 0; i < 2; i++ { + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } - for i := 0; i < 2; i++ { - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - } - }) + for i := 0; i < 2; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } } } func TestForwardingWithFakeResolverManyPackets(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, + proto := fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { + t.Helper() + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) + } + // Any packets will be resolved to the link address "c". + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) }, } + ep1, ep2 := fwdTestNetFactory(t, &proto) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Any packets will be resolved to the link address "c". - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - } - ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) - - for i := 0; i < maxPendingPacketsPerResolution+5; i++ { - // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - // Set the packet sequence number. - binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } + for i := 0; i < maxPendingPacketsPerResolution+5; i++ { + // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + // Set the packet sequence number. + binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } - for i := 0; i < maxPendingPacketsPerResolution; i++ { - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - b := PayloadSince(p.Pkt.NetworkHeader()) - if b[dstAddrOffset] != 3 { - t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) - } - if len(b) < fwdTestNetHeaderLen+2 { - t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b) - } - seqNumBuf := b[fwdTestNetHeaderLen:] - - // The first 5 packets should not be forwarded so the sequence number should - // start with 5. - want := uint16(i + 5) - if n := binary.BigEndian.Uint16(seqNumBuf); n != want { - t.Fatalf("got the packet #%d, want = #%d", n, want) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - } - }) + for i := 0; i < maxPendingPacketsPerResolution; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + b := PayloadSince(p.Pkt.NetworkHeader()) + if b[dstAddrOffset] != 3 { + t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) + } + if len(b) < fwdTestNetHeaderLen+2 { + t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b) + } + seqNumBuf := b[fwdTestNetHeaderLen:] + + // The first 5 packets should not be forwarded so the sequence number should + // start with 5. + want := uint16(i + 5) + if n := binary.BigEndian.Uint16(seqNumBuf); n != want { + t.Fatalf("got the packet #%d, want = #%d", n, want) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } } } func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { - tests := []struct { - name string - useNeighborCache bool - proto *fwdTestNetworkProtocol - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, + proto := fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(neigh *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress) { + t.Helper() + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) + } + // Any packets will be resolved to the link address "c". + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) }, } + ep1, ep2 := fwdTestNetFactory(t, &proto) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - proto := fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - t.Helper() - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) - } - // Any packets will be resolved to the link address "c". - neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - }, - } - ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) - - for i := 0; i < maxPendingResolutions+5; i++ { - // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. - // Each packet has a different destination address (3 to - // maxPendingResolutions + 7). - buf := buffer.NewView(30) - buf[dstAddrOffset] = byte(3 + i) - ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } + for i := 0; i < maxPendingResolutions+5; i++ { + // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. + // Each packet has a different destination address (3 to + // maxPendingResolutions + 7). + buf := buffer.NewView(30) + buf[dstAddrOffset] = byte(3 + i) + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } - for i := 0; i < maxPendingResolutions; i++ { - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - // The first 5 packets (address 3 to 7) should not be forwarded - // because their address resolutions are interrupted. - if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset]) - } - - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } - } - }) + for i := 0; i < maxPendingResolutions; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + // The first 5 packets (address 3 to 7) should not be forwarded + // because their address resolutions are interrupted. + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } } } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 63832c200..52890f6eb 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -235,7 +235,7 @@ func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) tcpip.Error // If iptables is being enabled, initialize the conntrack table and // reaper. if !it.modified { - it.connections.buckets = make([]bucket, numBuckets) + it.connections.init() it.startReaper(reaperDelay) } it.modified = true diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go deleted file mode 100644 index 5b6b58b1d..000000000 --- a/pkg/tcpip/stack/linkaddrcache.go +++ /dev/null @@ -1,359 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "fmt" - "time" - - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" -) - -const linkAddrCacheSize = 512 // max cache entries - -// linkAddrCache is a fixed-sized cache mapping IP addresses to link addresses. -// -// The entries are stored in a ring buffer, oldest entry replaced first. -// -// This struct is safe for concurrent use. -type linkAddrCache struct { - nic *NIC - - linkRes LinkAddressResolver - - // ageLimit is how long a cache entry is valid for. - ageLimit time.Duration - - // resolutionTimeout is the amount of time to wait for a link request to - // resolve an address. - resolutionTimeout time.Duration - - // resolutionAttempts is the number of times an address is attempted to be - // resolved before failing. - resolutionAttempts int - - mu struct { - sync.Mutex - table map[tcpip.Address]*linkAddrEntry - lru linkAddrEntryList - } -} - -// entryState controls the state of a single entry in the cache. -type entryState int - -const ( - // incomplete means that there is an outstanding request to resolve the - // address. This is the initial state. - incomplete entryState = iota - // ready means that the address has been resolved and can be used. - ready -) - -// String implements Stringer. -func (s entryState) String() string { - switch s { - case incomplete: - return "incomplete" - case ready: - return "ready" - default: - return fmt.Sprintf("unknown(%d)", s) - } -} - -// A linkAddrEntry is an entry in the linkAddrCache. -// This struct is thread-compatible. -type linkAddrEntry struct { - // linkAddrEntryEntry access is synchronized by the linkAddrCache lock. - linkAddrEntryEntry - - cache *linkAddrCache - - mu struct { - sync.RWMutex - - addr tcpip.Address - linkAddr tcpip.LinkAddress - expiration time.Time - s entryState - - // done is closed when address resolution is complete. It is nil iff s is - // incomplete and resolution is not yet in progress. - done chan struct{} - - // onResolve is called with the result of address resolution. - onResolve []func(LinkResolutionResult) - } -} - -func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) { - res := LinkResolutionResult{LinkAddress: linkAddr, Success: len(linkAddr) != 0} - for _, callback := range e.mu.onResolve { - callback(res) - } - e.mu.onResolve = nil - if ch := e.mu.done; ch != nil { - close(ch) - e.mu.done = nil - // Dequeue the pending packets in a new goroutine to not hold up the current - // goroutine as writing packets may be a costly operation. - // - // At the time of writing, when writing packets, a neighbor's link address - // is resolved (which ends up obtaining the entry's lock) while holding the - // link resolution queue's lock. Dequeuing packets in a new goroutine avoids - // a lock ordering violation. - go e.cache.nic.linkResQueue.dequeue(ch, linkAddr, len(linkAddr) != 0) - } -} - -// changeStateLocked sets the entry's state to ns. -// -// The entry's expiration is bumped up to the greater of itself and the passed -// expiration; the zero value indicates immediate expiration, and is set -// unconditionally - this is an implementation detail that allows for entries -// to be reused. -// -// Precondition: e.mu must be locked -func (e *linkAddrEntry) changeStateLocked(ns entryState, expiration time.Time) { - if e.mu.s == incomplete && ns == ready { - e.notifyCompletionLocked(e.mu.linkAddr) - } - - if expiration.IsZero() || expiration.After(e.mu.expiration) { - e.mu.expiration = expiration - } - e.mu.s = ns -} - -// add adds a k -> v mapping to the cache. -func (c *linkAddrCache) add(k tcpip.Address, v tcpip.LinkAddress) { - // Calculate expiration time before acquiring the lock, since expiration is - // relative to the time when information was learned, rather than when it - // happened to be inserted into the cache. - expiration := time.Now().Add(c.ageLimit) - - c.mu.Lock() - entry := c.getOrCreateEntryLocked(k) - entry.mu.Lock() - defer entry.mu.Unlock() - c.mu.Unlock() - - entry.mu.linkAddr = v - entry.changeStateLocked(ready, expiration) -} - -// getOrCreateEntryLocked retrieves a cache entry associated with k. The -// returned entry is always refreshed in the cache (it is reachable via the -// map, and its place is bumped in LRU). -// -// If a matching entry exists in the cache, it is returned. If no matching -// entry exists and the cache is full, an existing entry is evicted via LRU, -// reset to state incomplete, and returned. If no matching entry exists and the -// cache is not full, a new entry with state incomplete is allocated and -// returned. -func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry { - if entry, ok := c.mu.table[k]; ok { - c.mu.lru.Remove(entry) - c.mu.lru.PushFront(entry) - return entry - } - var entry *linkAddrEntry - if len(c.mu.table) == linkAddrCacheSize { - entry = c.mu.lru.Back() - entry.mu.Lock() - - delete(c.mu.table, entry.mu.addr) - c.mu.lru.Remove(entry) - - // Wake waiters and mark the soon-to-be-reused entry as expired. - entry.notifyCompletionLocked("" /* linkAddr */) - entry.mu.Unlock() - } else { - entry = new(linkAddrEntry) - } - - *entry = linkAddrEntry{ - cache: c, - } - entry.mu.Lock() - entry.mu.addr = k - entry.mu.s = incomplete - entry.mu.Unlock() - c.mu.table[k] = entry - c.mu.lru.PushFront(entry) - return entry -} - -// get reports any known link address for addr. -func (c *linkAddrCache) get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { - c.mu.Lock() - entry := c.getOrCreateEntryLocked(addr) - entry.mu.Lock() - defer entry.mu.Unlock() - c.mu.Unlock() - - switch s := entry.mu.s; s { - case ready: - if !time.Now().After(entry.mu.expiration) { - // Not expired. - if onResolve != nil { - onResolve(LinkResolutionResult{LinkAddress: entry.mu.linkAddr, Success: true}) - } - return entry.mu.linkAddr, nil, nil - } - - entry.changeStateLocked(incomplete, time.Time{}) - fallthrough - case incomplete: - if onResolve != nil { - entry.mu.onResolve = append(entry.mu.onResolve, onResolve) - } - if entry.mu.done == nil { - entry.mu.done = make(chan struct{}) - go c.startAddressResolution(addr, localAddr, entry.mu.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. - } - return entry.mu.linkAddr, entry.mu.done, &tcpip.ErrWouldBlock{} - default: - panic(fmt.Sprintf("invalid cache entry state: %s", s)) - } -} - -func (c *linkAddrCache) startAddressResolution(k tcpip.Address, localAddr tcpip.Address, done <-chan struct{}) { - for i := 0; ; i++ { - // Send link request, then wait for the timeout limit and check - // whether the request succeeded. - c.linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */) - - select { - case now := <-time.After(c.resolutionTimeout): - if stop := c.checkLinkRequest(now, k, i); stop { - return - } - case <-done: - return - } - } -} - -// checkLinkRequest checks whether previous attempt to resolve address has -// succeeded and mark the entry accordingly. Returns true if request can stop, -// false if another request should be sent. -func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.Address, attempt int) bool { - c.mu.Lock() - defer c.mu.Unlock() - entry, ok := c.mu.table[k] - if !ok { - // Entry was evicted from the cache. - return true - } - entry.mu.Lock() - defer entry.mu.Unlock() - - switch s := entry.mu.s; s { - case ready: - // Entry was made ready by resolver. - case incomplete: - if attempt+1 < c.resolutionAttempts { - // No response yet, need to send another ARP request. - return false - } - // Max number of retries reached, delete entry. - entry.notifyCompletionLocked("" /* linkAddr */) - delete(c.mu.table, k) - default: - panic(fmt.Sprintf("invalid cache entry state: %s", s)) - } - return true -} - -func (c *linkAddrCache) init(nic *NIC, ageLimit, resolutionTimeout time.Duration, resolutionAttempts int, linkRes LinkAddressResolver) { - *c = linkAddrCache{ - nic: nic, - linkRes: linkRes, - ageLimit: ageLimit, - resolutionTimeout: resolutionTimeout, - resolutionAttempts: resolutionAttempts, - } - - c.mu.Lock() - c.mu.table = make(map[tcpip.Address]*linkAddrEntry, linkAddrCacheSize) - c.mu.Unlock() -} - -var _ neighborTable = (*linkAddrCache)(nil) - -func (*linkAddrCache) neighbors() ([]NeighborEntry, tcpip.Error) { - return nil, &tcpip.ErrNotSupported{} -} - -func (c *linkAddrCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAddress) { - c.add(addr, linkAddr) -} - -func (*linkAddrCache) remove(addr tcpip.Address) tcpip.Error { - return &tcpip.ErrNotSupported{} -} - -func (*linkAddrCache) removeAll() tcpip.Error { - return &tcpip.ErrNotSupported{} -} - -func (c *linkAddrCache) handleProbe(addr tcpip.Address, linkAddr tcpip.LinkAddress) { - if len(linkAddr) != 0 { - // NUD allows probes without a link address but linkAddrCache - // is a simple neighbor table which does not implement NUD. - // - // As per RFC 4861 section 4.3, - // - // Source link-layer address - // The link-layer address for the sender. MUST NOT be - // included when the source IP address is the - // unspecified address. Otherwise, on link layers - // that have addresses this option MUST be included in - // multicast solicitations and SHOULD be included in - // unicast solicitations. - c.add(addr, linkAddr) - } -} - -func (c *linkAddrCache) handleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { - if len(linkAddr) != 0 { - // NUD allows confirmations without a link address but linkAddrCache - // is a simple neighbor table which does not implement NUD. - // - // As per RFC 4861 section 4.4, - // - // Target link-layer address - // The link-layer address for the target, i.e., the - // sender of the advertisement. This option MUST be - // included on link layers that have addresses when - // responding to multicast solicitations. When - // responding to a unicast Neighbor Solicitation this - // option SHOULD be included. - c.add(addr, linkAddr) - } -} - -func (c *linkAddrCache) handleUpperLevelConfirmation(tcpip.Address) {} - -func (*linkAddrCache) nudConfig() (NUDConfigurations, tcpip.Error) { - return NUDConfigurations{}, &tcpip.ErrNotSupported{} -} - -func (*linkAddrCache) setNUDConfig(NUDConfigurations) tcpip.Error { - return &tcpip.ErrNotSupported{} -} diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go deleted file mode 100644 index 9e7f331c9..000000000 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ /dev/null @@ -1,291 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "fmt" - "math" - "sync/atomic" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" -) - -type testaddr struct { - addr tcpip.Address - linkAddr tcpip.LinkAddress -} - -var testAddrs = func() []testaddr { - var addrs []testaddr - for i := 0; i < 4*linkAddrCacheSize; i++ { - addr := fmt.Sprintf("Addr%06d", i) - addrs = append(addrs, testaddr{ - addr: tcpip.Address(addr), - linkAddr: tcpip.LinkAddress("Link" + addr), - }) - } - return addrs -}() - -type testLinkAddressResolver struct { - cache *linkAddrCache - delay time.Duration - onLinkAddressRequest func() -} - -func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress) tcpip.Error { - // TODO(gvisor.dev/issue/5141): Use a fake clock. - time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) }) - if f := r.onLinkAddressRequest; f != nil { - f() - } - return nil -} - -func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) { - for _, ta := range testAddrs { - if ta.addr == addr { - r.cache.add(ta.addr, ta.linkAddr) - break - } - } -} - -func (*testLinkAddressResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if addr == "broadcast" { - return "mac_broadcast", true - } - return "", false -} - -func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { - return 1 -} - -func getBlocking(c *linkAddrCache, addr tcpip.Address) (tcpip.LinkAddress, tcpip.Error) { - var attemptedResolution bool - for { - got, ch, err := c.get(addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { - if attemptedResolution { - return got, &tcpip.ErrTimeout{} - } - attemptedResolution = true - <-ch - continue - } - return got, err - } -} - -func newEmptyNIC() *NIC { - n := &NIC{} - n.linkResQueue.init(n) - return n -} - -func TestCacheOverflow(t *testing.T) { - var c linkAddrCache - c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, nil) - for i := len(testAddrs) - 1; i >= 0; i-- { - e := testAddrs[i] - c.add(e.addr, e.linkAddr) - got, _, err := c.get(e.addr, "", nil) - if err != nil { - t.Errorf("insert %d, c.get(%s, '', nil): %s", i, e.addr, err) - } - if got != e.linkAddr { - t.Errorf("insert %d, got c.get(%s, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr) - } - } - // Expect to find at least half of the most recent entries. - for i := 0; i < linkAddrCacheSize/2; i++ { - e := testAddrs[i] - got, _, err := c.get(e.addr, "", nil) - if err != nil { - t.Errorf("check %d, c.get(%s, '', nil): %s", i, e.addr, err) - } - if got != e.linkAddr { - t.Errorf("check %d, got c.get(%s, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr) - } - } - // The earliest entries should no longer be in the cache. - c.mu.Lock() - defer c.mu.Unlock() - for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- { - e := testAddrs[i] - if entry, ok := c.mu.table[e.addr]; ok { - t.Errorf("unexpected entry at c.mu.table[%s]: %#v", e.addr, entry) - } - } -} - -func TestCacheConcurrent(t *testing.T) { - var c linkAddrCache - linkRes := &testLinkAddressResolver{cache: &c} - c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, linkRes) - - var wg sync.WaitGroup - for r := 0; r < 16; r++ { - wg.Add(1) - go func() { - for _, e := range testAddrs { - c.add(e.addr, e.linkAddr) - } - wg.Done() - }() - } - wg.Wait() - - // All goroutines add in the same order and add more values than - // can fit in the cache, so our eviction strategy requires that - // the last entry be present and the first be missing. - e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, "", nil) - if err != nil { - t.Errorf("c.get(%s, '', nil): %s", e.addr, err) - } - if got != e.linkAddr { - t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) - } - - e = testAddrs[0] - c.mu.Lock() - defer c.mu.Unlock() - if entry, ok := c.mu.table[e.addr]; ok { - t.Errorf("unexpected entry at c.mu.table[%s]: %#v", e.addr, entry) - } -} - -func TestCacheAgeLimit(t *testing.T) { - var c linkAddrCache - linkRes := &testLinkAddressResolver{cache: &c} - c.init(newEmptyNIC(), 1*time.Millisecond, 1*time.Second, 3, linkRes) - - e := testAddrs[0] - c.add(e.addr, e.linkAddr) - time.Sleep(50 * time.Millisecond) - _, _, err := c.get(e.addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.get(%s, '', nil) = %s, want = ErrWouldBlock", e.addr, err) - } -} - -func TestCacheReplace(t *testing.T) { - var c linkAddrCache - c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, nil) - e := testAddrs[0] - l2 := e.linkAddr + "2" - c.add(e.addr, e.linkAddr) - got, _, err := c.get(e.addr, "", nil) - if err != nil { - t.Errorf("c.get(%s, '', nil): %s", e.addr, err) - } - if got != e.linkAddr { - t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) - } - - c.add(e.addr, l2) - got, _, err = c.get(e.addr, "", nil) - if err != nil { - t.Errorf("c.get(%s, '', nil): %s", e.addr, err) - } - if got != l2 { - t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, l2) - } -} - -func TestCacheResolution(t *testing.T) { - // There is a race condition causing this test to fail when the executor - // takes longer than the resolution timeout to call linkAddrCache.get. This - // is especially common when this test is run with gotsan. - // - // Using a large resolution timeout decreases the probability of experiencing - // this race condition and does not affect how long this test takes to run. - var c linkAddrCache - linkRes := &testLinkAddressResolver{cache: &c} - c.init(newEmptyNIC(), 1<<63-1, math.MaxInt64, 1, linkRes) - for i, ta := range testAddrs { - got, err := getBlocking(&c, ta.addr) - if err != nil { - t.Errorf("check %d, getBlocking(_, %s): %s", i, ta.addr, err) - } - if got != ta.linkAddr { - t.Errorf("check %d, got getBlocking(_, %s) = %s, want = %s", i, ta.addr, got, ta.linkAddr) - } - } - - // Check that after resolved, address stays in the cache and never returns WouldBlock. - for i := 0; i < 10; i++ { - e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, "", nil) - if err != nil { - t.Errorf("c.get(%s, '', nil): %s", e.addr, err) - } - if got != e.linkAddr { - t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) - } - } -} - -func TestCacheResolutionFailed(t *testing.T) { - var c linkAddrCache - linkRes := &testLinkAddressResolver{cache: &c} - c.init(newEmptyNIC(), 1<<63-1, 10*time.Millisecond, 5, linkRes) - - var requestCount uint32 - linkRes.onLinkAddressRequest = func() { - atomic.AddUint32(&requestCount, 1) - } - - // First, sanity check that resolution is working... - e := testAddrs[0] - got, err := getBlocking(&c, e.addr) - if err != nil { - t.Errorf("getBlocking(_, %s): %s", e.addr, err) - } - if got != e.linkAddr { - t.Errorf("got getBlocking(_, %s) = %s, want = %s", e.addr, got, e.linkAddr) - } - - before := atomic.LoadUint32(&requestCount) - - e.addr += "2" - a, err := getBlocking(&c, e.addr) - if _, ok := err.(*tcpip.ErrTimeout); !ok { - t.Errorf("got getBlocking(_, %s) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) - } - - if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want { - t.Errorf("got link address request count = %d, want = %d", got, want) - } -} - -func TestCacheResolutionTimeout(t *testing.T) { - resolverDelay := 500 * time.Millisecond - expiration := resolverDelay / 10 - var c linkAddrCache - linkRes := &testLinkAddressResolver{cache: &c, delay: resolverDelay} - c.init(newEmptyNIC(), expiration, 1*time.Millisecond, 3, linkRes) - - e := testAddrs[0] - a, err := getBlocking(&c, e.addr) - if _, ok := err.(*tcpip.ErrTimeout); !ok { - t.Errorf("got getBlocking(_, %s) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) - } -} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 0238605af..3b6ba9509 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -424,7 +424,7 @@ func TestDADResolve(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, - NDPConfigs: ipv6.NDPConfigurations{ + DADConfigs: stack.DADConfigurations{ RetransmitTimer: test.retransTimer, DupAddrDetectTransmits: test.dupAddrDetectTransmits, }, @@ -642,14 +642,14 @@ func TestDADFail(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent, 1), } - ndpConfigs := ipv6.DefaultNDPConfigurations() - ndpConfigs.RetransmitTimer = time.Second * 2 + dadConfigs := stack.DefaultDADConfigurations() + dadConfigs.RetransmitTimer = time.Second * 2 e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, - NDPConfigs: ndpConfigs, + DADConfigs: dadConfigs, })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -677,7 +677,7 @@ func TestDADFail(t *testing.T) { // Wait for DAD to fail and make sure the address did // not get resolved. select { - case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second): // If we don't get a failure event after the // expected resolution time + extra 1s buffer, // something is wrong. @@ -748,7 +748,7 @@ func TestDADStop(t *testing.T) { dadC: make(chan ndpDADEvent, 1), } - ndpConfigs := ipv6.NDPConfigurations{ + dadConfigs := stack.DADConfigurations{ RetransmitTimer: time.Second, DupAddrDetectTransmits: 2, } @@ -757,7 +757,7 @@ func TestDADStop(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, - NDPConfigs: ndpConfigs, + DADConfigs: dadConfigs, })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -777,12 +777,12 @@ func TestDADStop(t *testing.T) { // Wait for DAD to fail (since the address was removed during DAD). select { - case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second): // If we don't get a failure event after the expected resolution // time + extra 1s buffer, something is wrong. t.Fatal("timed out waiting for DAD failure") case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" { + if diff := checkDADEvent(e, nicID, addr1, false, &tcpip.ErrAborted{}); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } @@ -865,16 +865,15 @@ func TestSetNDPConfigurations(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) } - // Update the NDP configurations on NIC(1) to use DAD. - configs := ipv6.NDPConfigurations{ - DupAddrDetectTransmits: test.dupAddrDetectTransmits, - RetransmitTimer: test.retransmitTimer, - } + // Update the configurations on NIC(1) to use DAD. if ipv6Ep, err := s.GetNetworkEndpoint(nicID1, header.IPv6ProtocolNumber); err != nil { t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, header.IPv6ProtocolNumber, err) } else { - ndpEP := ipv6Ep.(ipv6.NDPEndpoint) - ndpEP.SetNDPConfigurations(configs) + dad := ipv6Ep.(stack.DuplicateAddressDetector) + dad.SetDADConfigurations(stack.DADConfigurations{ + DupAddrDetectTransmits: test.dupAddrDetectTransmits, + RetransmitTimer: test.retransmitTimer, + }) } // Created after updating NIC(1)'s NDP configurations @@ -1903,9 +1902,11 @@ func TestAutoGenTempAddr(t *testing.T) { e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + DADConfigs: stack.DADConfigurations{ + DupAddrDetectTransmits: test.dupAddrTransmits, + RetransmitTimer: test.retransmitTimer, + }, NDPConfigs: ipv6.NDPConfigurations{ - DupAddrDetectTransmits: test.dupAddrTransmits, - RetransmitTimer: test.retransmitTimer, HandleRAs: true, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, @@ -2202,9 +2203,11 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + DADConfigs: stack.DADConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + }, NDPConfigs: ipv6.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, HandleRAs: true, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, @@ -2635,16 +2638,15 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) - ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: test.tempAddrs, - AutoGenAddressConflictRetries: 1, - } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: test.tempAddrs, + AutoGenAddressConflictRetries: 1, + }, + NDPDisp: &ndpDisp, OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: test.nicNameFromID, }, @@ -2718,13 +2720,14 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { // Enable DAD. ndpDisp.dadC = make(chan ndpDADEvent, 2) - ndpConfigs.DupAddrDetectTransmits = dupAddrTransmits - ndpConfigs.RetransmitTimer = retransmitTimer if ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber); err != nil { t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) } else { - ndpEP := ipv6Ep.(ipv6.NDPEndpoint) - ndpEP.SetNDPConfigurations(ndpConfigs) + ndpEP := ipv6Ep.(stack.DuplicateAddressDetector) + ndpEP.SetDADConfigurations(stack.DADConfigurations{ + DupAddrDetectTransmits: dupAddrTransmits, + RetransmitTimer: retransmitTimer, + }) } // Do SLAAC for prefix. @@ -2769,7 +2772,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { // stack.Stack will have a default route through the router (llAddr3) installed // and a static link-address (linkAddr3) added to the link address cache for the // router. -func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useNeighborCache bool) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) { +func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) { t.Helper() ndpDisp := &ndpDispatcher{ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), @@ -2785,7 +2788,6 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN NDPDisp: ndpDisp, })}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - UseNeighborCache: useNeighborCache, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -2868,126 +2870,108 @@ func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullA // TestAutoGenAddrDeprecateFromPI tests deprecating a SLAAC address when // receiving a PI with 0 preferred lifetime. func TestAutoGenAddrDeprecateFromPI(t *testing.T) { - stacks := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, - }, - } - - for _, stackTyp := range stacks { - t.Run(stackTyp.name, func(t *testing.T) { - const nicID = 1 + const nicID = 1 - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache) + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("expected addr auto gen event") + } + } - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatal(err) - } + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatal(err) + } - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } - // Receive PI for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - expectPrimaryAddr(addr1) + // Receive PI for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + expectAutoGenAddrEvent(addr1, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + expectPrimaryAddr(addr1) - // Deprecate addr for prefix1 immedaitely. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, deprecatedAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - // addr should still be the primary endpoint as there are no other addresses. - expectPrimaryAddr(addr1) + // Deprecate addr for prefix1 immedaitely. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + // addr should still be the primary endpoint as there are no other addresses. + expectPrimaryAddr(addr1) - // Refresh lifetimes of addr generated from prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) + // Refresh lifetimes of addr generated from prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) - // Receive PI for prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) + // Receive PI for prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + expectAutoGenAddrEvent(addr2, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) - // Deprecate addr for prefix2 immedaitely. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - expectAutoGenAddrEvent(addr2, deprecatedAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr1 should be the primary endpoint now since addr2 is deprecated but - // addr1 is not. - expectPrimaryAddr(addr1) - // addr2 is deprecated but if explicitly requested, it should be used. - fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID} - if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) - } + // Deprecate addr for prefix2 immedaitely. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, deprecatedAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr1 should be the primary endpoint now since addr2 is deprecated but + // addr1 is not. + expectPrimaryAddr(addr1) + // addr2 is deprecated but if explicitly requested, it should be used. + fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID} + if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) + } - // Another PI w/ 0 preferred lifetime should not result in a deprecation - // event. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) - if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) - } + // Another PI w/ 0 preferred lifetime should not result in a deprecation + // event. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) + if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) + } - // Refresh lifetimes of addr generated from prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr2) - }) + // Refresh lifetimes of addr generated from prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: } + expectPrimaryAddr(addr2) } // TestAutoGenAddrJobDeprecation tests that an address is properly deprecated @@ -2997,232 +2981,214 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { const newMinVL = 2 newMinVLDuration := newMinVL * time.Second - stacks := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, - }, - } - - for _, stackTyp := range stacks { - t.Run(stackTyp.name, func(t *testing.T) { - saved := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = saved - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + saved := ipv6.MinPrefixInformationValidLifetimeForUpdate + defer func() { + ipv6.MinPrefixInformationValidLifetimeForUpdate = saved + }() + ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache) + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("expected addr auto gen event") + } + } - expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { - t.Helper() + expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { + t.Helper() - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(timeout): - t.Fatal("timed out waiting for addr auto gen event") - } + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } + case <-time.After(timeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatal(err) - } + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatal(err) + } - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } - // Receive PI for prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) + // Receive PI for prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + expectAutoGenAddrEvent(addr2, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) - // Receive a PI for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr1) + // Receive a PI for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90)) + expectAutoGenAddrEvent(addr1, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr1) - // Refresh lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) + // Refresh lifetime for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) - // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr2 should be the primary endpoint now since addr1 is deprecated but - // addr2 is not. - expectPrimaryAddr(addr2) - // addr1 is deprecated but if explicitly requested, it should be used. - fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } + // Wait for addr of prefix1 to be deprecated. + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr2 should be the primary endpoint now since addr1 is deprecated but + // addr2 is not. + expectPrimaryAddr(addr2) + // addr1 is deprecated but if explicitly requested, it should be used. + fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) + } - // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make - // sure we do not get a deprecation event again. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr2) - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } + // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make + // sure we do not get a deprecation event again. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr2) + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) + } - // Refresh lifetimes for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - // addr1 is the primary endpoint again since it is non-deprecated now. - expectPrimaryAddr(addr1) + // Refresh lifetimes for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + // addr1 is the primary endpoint again since it is non-deprecated now. + expectPrimaryAddr(addr1) - // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr2 should be the primary endpoint now since it is not deprecated. - expectPrimaryAddr(addr2) - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } + // Wait for addr of prefix1 to be deprecated. + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr2 should be the primary endpoint now since it is not deprecated. + expectPrimaryAddr(addr2) + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) + } - // Wait for addr of prefix1 to be invalidated. - expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout) - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) + // Wait for addr of prefix1 to be invalidated. + expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout) + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) - // Refresh both lifetimes for addr of prefix2 to the same value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } + // Refresh both lifetimes for addr of prefix2 to the same value. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } - // Wait for a deprecation then invalidation events, or just an invalidation - // event. We need to cover both cases but cannot deterministically hit both - // cases because the deprecation and invalidation handlers could be handled in - // either deprecation then invalidation, or invalidation then deprecation - // (which should be cancelled by the invalidation handler). + // Wait for a deprecation then invalidation events, or just an invalidation + // event. We need to cover both cases but cannot deterministically hit both + // cases because the deprecation and invalidation handlers could be handled in + // either deprecation then invalidation, or invalidation then deprecation + // (which should be cancelled by the invalidation handler). + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" { + // If we get a deprecation event first, we should get an invalidation + // event almost immediately after. select { case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" { - // If we get a deprecation event first, we should get an invalidation - // event almost immediately after. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { - // If we get an invalidation event first, we should not get a deprecation - // event after. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - case <-time.After(defaultAsyncNegativeEventTimeout): - } - } else { - t.Fatalf("got unexpected auto-generated event") + if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should not have %s in the list of addresses", addr2) - } - // Should not have any primary endpoints. - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) + } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { + // If we get an invalidation event first, we should not get a deprecation + // event after. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + case <-time.After(defaultAsyncNegativeEventTimeout): } - defer ep.Close() - ep.SocketOptions().SetV6Only(true) + } else { + t.Fatalf("got unexpected auto-generated event") + } + case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should not have %s in the list of addresses", addr2) + } + // Should not have any primary endpoints. + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { + t.Fatal(err) + } + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) + } + defer ep.Close() + ep.SocketOptions().SetV6Only(true) - { - err := ep.Connect(dstAddr) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, &tcpip.ErrNoRoute{}) - } - } - }) + { + err := ep.Connect(dstAddr) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, &tcpip.ErrNoRoute{}) + } } } @@ -3543,126 +3509,108 @@ func TestAutoGenAddrRemoval(t *testing.T) { func TestAutoGenAddrAfterRemoval(t *testing.T) { const nicID = 1 - stacks := []struct { - name string - useNeighborCache bool - }{ - { - name: "linkAddrCache", - useNeighborCache: false, - }, - { - name: "neighborCache", - useNeighborCache: true, - }, + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } } - for _, stackTyp := range stacks { - t.Run(stackTyp.name, func(t *testing.T) { - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache) + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatal(err) + } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() + // Receive a PI to auto-generate addr1 with a large valid and preferred + // lifetime. + const largeLifetimeSeconds = 999 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + expectAutoGenAddrEvent(addr1, newAddr) + expectPrimaryAddr(addr1) - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatal(err) - } + // Add addr2 as a static address. + protoAddr2 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr2, + } + if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { + t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) + } + // addr2 should be more preferred now since it is at the front of the primary + // list. + expectPrimaryAddr(addr2) - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } + // Get a route using addr2 to increment its reference count then remove it + // to leave it in the permanentExpired state. + r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false) + if err != nil { + t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err) + } + defer r.Release() + if err := s.RemoveAddress(nicID, addr2.Address); err != nil { + t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err) + } + // addr1 should be preferred again since addr2 is in the expired state. + expectPrimaryAddr(addr1) - // Receive a PI to auto-generate addr1 with a large valid and preferred - // lifetime. - const largeLifetimeSeconds = 999 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - expectAutoGenAddrEvent(addr1, newAddr) - expectPrimaryAddr(addr1) - - // Add addr2 as a static address. - protoAddr2 := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: addr2, - } - if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) - } - // addr2 should be more preferred now since it is at the front of the primary - // list. - expectPrimaryAddr(addr2) - - // Get a route using addr2 to increment its reference count then remove it - // to leave it in the permanentExpired state. - r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err) - } - defer r.Release() - if err := s.RemoveAddress(nicID, addr2.Address); err != nil { - t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err) - } - // addr1 should be preferred again since addr2 is in the expired state. - expectPrimaryAddr(addr1) - - // Receive a PI to auto-generate addr2 as valid and preferred. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - expectAutoGenAddrEvent(addr2, newAddr) - // addr2 should be more preferred now that it is closer to the front of the - // primary list and not deprecated. - expectPrimaryAddr(addr2) - - // Removing the address should result in an invalidation event immediately. - // It should still be in the permanentExpired state because r is still held. - // - // We remove addr2 here to make sure addr2 was marked as a SLAAC address - // (it was previously marked as a static address). - if err := s.RemoveAddress(1, addr2.Address); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) - } - expectAutoGenAddrEvent(addr2, invalidatedAddr) - // addr1 should be more preferred since addr2 is in the expired state. - expectPrimaryAddr(addr1) - - // Receive a PI to auto-generate addr2 as valid and deprecated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0)) - expectAutoGenAddrEvent(addr2, newAddr) - // addr1 should still be more preferred since addr2 is deprecated, even though - // it is closer to the front of the primary list. - expectPrimaryAddr(addr1) - - // Receive a PI to refresh addr2's preferred lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto gen addr event") - default: - } - // addr2 should be more preferred now that it is not deprecated. - expectPrimaryAddr(addr2) + // Receive a PI to auto-generate addr2 as valid and preferred. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + expectAutoGenAddrEvent(addr2, newAddr) + // addr2 should be more preferred now that it is closer to the front of the + // primary list and not deprecated. + expectPrimaryAddr(addr2) - if err := s.RemoveAddress(1, addr2.Address); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) - } - expectAutoGenAddrEvent(addr2, invalidatedAddr) - expectPrimaryAddr(addr1) - }) + // Removing the address should result in an invalidation event immediately. + // It should still be in the permanentExpired state because r is still held. + // + // We remove addr2 here to make sure addr2 was marked as a SLAAC address + // (it was previously marked as a static address). + if err := s.RemoveAddress(1, addr2.Address); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) + } + expectAutoGenAddrEvent(addr2, invalidatedAddr) + // addr1 should be more preferred since addr2 is in the expired state. + expectPrimaryAddr(addr1) + + // Receive a PI to auto-generate addr2 as valid and deprecated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0)) + expectAutoGenAddrEvent(addr2, newAddr) + // addr1 should still be more preferred since addr2 is deprecated, even though + // it is closer to the front of the primary list. + expectPrimaryAddr(addr1) + + // Receive a PI to refresh addr2's preferred lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto gen addr event") + default: + } + // addr2 should be more preferred now that it is not deprecated. + expectPrimaryAddr(addr2) + + if err := s.RemoveAddress(1, addr2.Address); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) } + expectAutoGenAddrEvent(addr2, invalidatedAddr) + expectPrimaryAddr(addr1) } // TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that @@ -3885,12 +3833,12 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } } - expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool) { + expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool, err tcpip.Error) { t.Helper() select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" { + if diff := checkDADEvent(e, nicID, addr, resolved, err); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } default: @@ -3923,8 +3871,6 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { { name: "Global address", ndpConfigs: ipv6.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, HandleRAs: true, AutoGenGlobalAddresses: true, }, @@ -3939,11 +3885,8 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { }, }, { - name: "LinkLocal address", - ndpConfigs: ipv6.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - }, + name: "LinkLocal address", + ndpConfigs: ipv6.NDPConfigurations{}, autoGenLinkLocal: true, prepareFn: func(*testing.T, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix { return nil @@ -3955,8 +3898,6 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { { name: "Temporary address", ndpConfigs: ipv6.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, HandleRAs: true, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, @@ -4008,8 +3949,12 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: addrType.autoGenLinkLocal, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, + DADConfigs: stack.DADConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + }, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: func(_ tcpip.NICID, nicName string) string { return nicName @@ -4039,7 +3984,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { // Simulate a DAD conflict. rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr) - expectDADEvent(t, &ndpDisp, addr.Address, false) + expectDADEvent(t, &ndpDisp, addr.Address, false, nil) // Attempting to add the address manually should not fail if the // address's state was cleaned up when DAD failed. @@ -4049,7 +3994,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { if err := s.RemoveAddress(nicID, addr.Address); err != nil { t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) } - expectDADEvent(t, &ndpDisp, addr.Address, false) + expectDADEvent(t, &ndpDisp, addr.Address, false, &tcpip.ErrAborted{}) } // Should not have any new addresses assigned to the NIC. @@ -4103,8 +4048,6 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { { name: "Global address", ndpConfigs: ipv6.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, HandleRAs: true, AutoGenGlobalAddresses: true, AutoGenAddressConflictRetries: maxRetries, @@ -4119,8 +4062,6 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { { name: "LinkLocal address", ndpConfigs: ipv6.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, AutoGenAddressConflictRetries: maxRetries, }, autoGenLinkLocal: true, @@ -4145,6 +4086,10 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { AutoGenLinkLocal: addrType.autoGenLinkLocal, NDPConfigs: addrType.ndpConfigs, NDPDisp: &ndpDisp, + DADConfigs: stack.DADConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + }, })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -4226,9 +4171,11 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + DADConfigs: stack.DADConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + }, NDPConfigs: ipv6.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, HandleRAs: true, AutoGenGlobalAddresses: true, AutoGenAddressConflictRetries: maxRetries, diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 7e3132058..533287c4c 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -25,9 +25,13 @@ const neighborCacheSize = 512 // max entries per interface // NeighborStats holds metrics for the neighbor table. type NeighborStats struct { - // FailedEntryLookups counts the number of lookups performed on an entry in - // Failed state. + // FailedEntryLookups is deprecated; UnreachableEntryLookups should be used + // instead. FailedEntryLookups *tcpip.StatCounter + + // UnreachableEntryLookups counts the number of lookups performed on an + // entry in Unreachable state. + UnreachableEntryLookups *tcpip.StatCounter } // neighborCache maps IP addresses to link addresses. It uses the Least @@ -43,21 +47,22 @@ type NeighborStats struct { // Their state is always Static. The amount of static entries stored in the // cache is unbounded. type neighborCache struct { - nic *NIC + nic *nic state *NUDState linkRes LinkAddressResolver - // mu protects the fields below. - mu sync.RWMutex + mu struct { + sync.RWMutex - cache map[tcpip.Address]*neighborEntry - dynamic struct { - lru neighborEntryList + cache map[tcpip.Address]*neighborEntry + dynamic struct { + lru neighborEntryList - // count tracks the amount of dynamic entries in the cache. This is - // needed since static entries do not count towards the LRU cache - // eviction strategy. - count uint16 + // count tracks the amount of dynamic entries in the cache. This is + // needed since static entries do not count towards the LRU cache + // eviction strategy. + count uint16 + } } } @@ -74,11 +79,11 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address) *neighborEntr n.mu.Lock() defer n.mu.Unlock() - if entry, ok := n.cache[remoteAddr]; ok { + if entry, ok := n.mu.cache[remoteAddr]; ok { entry.mu.RLock() - if entry.neigh.State != Static { - n.dynamic.lru.Remove(entry) - n.dynamic.lru.PushFront(entry) + if entry.mu.neigh.State != Static { + n.mu.dynamic.lru.Remove(entry) + n.mu.dynamic.lru.PushFront(entry) } entry.mu.RUnlock() return entry @@ -87,20 +92,20 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address) *neighborEntr // The entry that needs to be created must be dynamic since all static // entries are directly added to the cache via addStaticEntry. entry := newNeighborEntry(n, remoteAddr, n.state) - if n.dynamic.count == neighborCacheSize { - e := n.dynamic.lru.Back() + if n.mu.dynamic.count == neighborCacheSize { + e := n.mu.dynamic.lru.Back() e.mu.Lock() - delete(n.cache, e.neigh.Addr) - n.dynamic.lru.Remove(e) - n.dynamic.count-- + delete(n.mu.cache, e.mu.neigh.Addr) + n.mu.dynamic.lru.Remove(e) + n.mu.dynamic.count-- e.removeLocked() e.mu.Unlock() } - n.cache[remoteAddr] = entry - n.dynamic.lru.PushFront(entry) - n.dynamic.count++ + n.mu.cache[remoteAddr] = entry + n.mu.dynamic.lru.PushFront(entry) + n.mu.dynamic.count++ return entry } @@ -128,7 +133,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, onResolve fun entry.mu.Lock() defer entry.mu.Unlock() - switch s := entry.neigh.State; s { + switch s := entry.mu.neigh.State; s { case Stale: entry.handlePacketQueuedLocked(localAddr) fallthrough @@ -139,19 +144,19 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, onResolve fun // a node continues sending packets to that neighbor using the cached // link-layer address." if onResolve != nil { - onResolve(LinkResolutionResult{LinkAddress: entry.neigh.LinkAddr, Success: true}) + onResolve(LinkResolutionResult{LinkAddress: entry.mu.neigh.LinkAddr, Success: true}) } - return entry.neigh, nil, nil - case Unknown, Incomplete, Failed: + return entry.mu.neigh, nil, nil + case Unknown, Incomplete, Unreachable: if onResolve != nil { - entry.onResolve = append(entry.onResolve, onResolve) + entry.mu.onResolve = append(entry.mu.onResolve, onResolve) } - if entry.done == nil { + if entry.mu.done == nil { // Address resolution needs to be initiated. - entry.done = make(chan struct{}) + entry.mu.done = make(chan struct{}) } entry.handlePacketQueuedLocked(localAddr) - return entry.neigh, entry.done, &tcpip.ErrWouldBlock{} + return entry.mu.neigh, entry.mu.done, &tcpip.ErrWouldBlock{} default: panic(fmt.Sprintf("Invalid cache entry state: %s", s)) } @@ -162,10 +167,10 @@ func (n *neighborCache) entries() []NeighborEntry { n.mu.RLock() defer n.mu.RUnlock() - entries := make([]NeighborEntry, 0, len(n.cache)) - for _, entry := range n.cache { + entries := make([]NeighborEntry, 0, len(n.mu.cache)) + for _, entry := range n.mu.cache { entry.mu.RLock() - entries = append(entries, entry.neigh) + entries = append(entries, entry.mu.neigh) entry.mu.RUnlock() } return entries @@ -181,19 +186,19 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd n.mu.Lock() defer n.mu.Unlock() - if entry, ok := n.cache[addr]; ok { + if entry, ok := n.mu.cache[addr]; ok { entry.mu.Lock() - if entry.neigh.State != Static { + if entry.mu.neigh.State != Static { // Dynamic entry found with the same address. - n.dynamic.lru.Remove(entry) - n.dynamic.count-- - } else if entry.neigh.LinkAddr == linkAddr { + n.mu.dynamic.lru.Remove(entry) + n.mu.dynamic.count-- + } else if entry.mu.neigh.LinkAddr == linkAddr { // Static entry found with the same address and link address. entry.mu.Unlock() return } else { // Static entry found with the same address but different link address. - entry.neigh.LinkAddr = linkAddr + entry.mu.neigh.LinkAddr = linkAddr entry.dispatchChangeEventLocked() entry.mu.Unlock() return @@ -203,7 +208,12 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd entry.mu.Unlock() } - n.cache[addr] = newStaticNeighborEntry(n, addr, linkAddr, n.state) + entry := newStaticNeighborEntry(n, addr, linkAddr, n.state) + n.mu.cache[addr] = entry + + entry.mu.Lock() + defer entry.mu.Unlock() + entry.dispatchAddEventLocked() } // removeEntry removes a dynamic or static entry by address from the neighbor @@ -212,7 +222,7 @@ func (n *neighborCache) removeEntry(addr tcpip.Address) bool { n.mu.Lock() defer n.mu.Unlock() - entry, ok := n.cache[addr] + entry, ok := n.mu.cache[addr] if !ok { return false } @@ -220,13 +230,13 @@ func (n *neighborCache) removeEntry(addr tcpip.Address) bool { entry.mu.Lock() defer entry.mu.Unlock() - if entry.neigh.State != Static { - n.dynamic.lru.Remove(entry) - n.dynamic.count-- + if entry.mu.neigh.State != Static { + n.mu.dynamic.lru.Remove(entry) + n.mu.dynamic.count-- } entry.removeLocked() - delete(n.cache, entry.neigh.Addr) + delete(n.mu.cache, entry.mu.neigh.Addr) return true } @@ -235,15 +245,15 @@ func (n *neighborCache) clear() { n.mu.Lock() defer n.mu.Unlock() - for _, entry := range n.cache { + for _, entry := range n.mu.cache { entry.mu.Lock() entry.removeLocked() entry.mu.Unlock() } - n.dynamic.lru = neighborEntryList{} - n.cache = make(map[tcpip.Address]*neighborEntry) - n.dynamic.count = 0 + n.mu.dynamic.lru = neighborEntryList{} + n.mu.cache = make(map[tcpip.Address]*neighborEntry) + n.mu.dynamic.count = 0 } // config returns the NUD configuration. @@ -260,30 +270,6 @@ func (n *neighborCache) setConfig(config NUDConfigurations) { n.state.SetConfig(config) } -var _ neighborTable = (*neighborCache)(nil) - -func (n *neighborCache) neighbors() ([]NeighborEntry, tcpip.Error) { - return n.entries(), nil -} - -func (n *neighborCache) get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { - entry, ch, err := n.entry(addr, localAddr, onResolve) - return entry.LinkAddr, ch, err -} - -func (n *neighborCache) remove(addr tcpip.Address) tcpip.Error { - if !n.removeEntry(addr) { - return &tcpip.ErrBadAddress{} - } - - return nil -} - -func (n *neighborCache) removeAll() tcpip.Error { - n.clear() - return nil -} - // handleProbe handles a neighbor probe as defined by RFC 4861 section 7.2.3. // // Validation of the probe is expected to be handled by the caller. @@ -300,7 +286,7 @@ func (n *neighborCache) handleProbe(remoteAddr tcpip.Address, remoteLinkAddr tcp // Validation of the confirmation is expected to be handled by the caller. func (n *neighborCache) handleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { n.mu.RLock() - entry, ok := n.cache[addr] + entry, ok := n.mu.cache[addr] n.mu.RUnlock() if ok { entry.mu.Lock() @@ -316,7 +302,7 @@ func (n *neighborCache) handleConfirmation(addr tcpip.Address, linkAddr tcpip.Li // some protocol that operates at a layer above the IP/link layer. func (n *neighborCache) handleUpperLevelConfirmation(addr tcpip.Address) { n.mu.RLock() - entry, ok := n.cache[addr] + entry, ok := n.mu.cache[addr] n.mu.RUnlock() if ok { entry.mu.Lock() @@ -325,11 +311,13 @@ func (n *neighborCache) handleUpperLevelConfirmation(addr tcpip.Address) { } } -func (n *neighborCache) nudConfig() (NUDConfigurations, tcpip.Error) { - return n.config(), nil -} - -func (n *neighborCache) setNUDConfig(c NUDConfigurations) tcpip.Error { - n.setConfig(c) - return nil +func (n *neighborCache) init(nic *nic, r LinkAddressResolver) { + *n = neighborCache{ + nic: nic, + state: NewNUDState(nic.stack.nudConfigs, nic.stack.randomGenerator), + linkRes: r, + } + n.mu.Lock() + n.mu.cache = make(map[tcpip.Address]*neighborEntry, neighborCacheSize) + n.mu.Unlock() } diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index b489b5e08..909912662 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -84,19 +84,16 @@ func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, cl entries: newTestEntryStore(), delay: typicalLatency, } - linkRes.neigh = &neighborCache{ - nic: &NIC{ - stack: &Stack{ - clock: clock, - nudDisp: nudDisp, - }, - id: 1, - stats: makeNICStats(), + linkRes.neigh.init(&nic{ + stack: &Stack{ + clock: clock, + nudDisp: nudDisp, + nudConfigs: config, + randomGenerator: rng, }, - state: NewNUDState(config, rng), - linkRes: linkRes, - cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), - } + id: 1, + stats: makeNICStats(), + }, linkRes) return linkRes } @@ -190,7 +187,7 @@ func (s *testEntryStore) set(i int, linkAddr tcpip.LinkAddress) { // neighbor probe. type testNeighborResolver struct { clock tcpip.Clock - neigh *neighborCache + neigh neighborCache entries *testEntryStore delay time.Duration onLinkAddressRequest func() @@ -1613,12 +1610,11 @@ func TestNeighborCacheRetryResolution(t *testing.T) { } } - // Verify the entry is in Failed state. wantEntries := []NeighborEntry{ { Addr: entry.Addr, LinkAddr: "", - State: Failed, + State: Unreachable, }, } if diff := cmp.Diff(linkRes.neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" { diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index b05f96d4f..03fef52ee 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -38,7 +38,8 @@ type NeighborEntry struct { } // NeighborState defines the state of a NeighborEntry within the Neighbor -// Unreachability Detection state machine, as per RFC 4861 section 7.3.2. +// Unreachability Detection state machine, as per RFC 4861 section 7.3.2 and +// RFC 7048. type NeighborState uint8 const ( @@ -61,15 +62,33 @@ const ( Delay // Probe means a reachability confirmation is actively being sought by // periodically retransmitting reachability probes until a reachability - // confirmation is received, or until the max amount of probes has been sent. + // confirmation is received, or until the maximum number of probes has been + // sent. Probe // Static describes entries that have been explicitly added by the user. They // do not expire and are not deleted until explicitly removed. Static - // Failed means recent attempts of reachability have returned inconclusive. + // Failed is deprecated and should no longer be used. + // + // TODO(gvisor.dev/issue/4667): Remove this once all references to Failed + // are removed from Fuchsia. Failed + // Unreachable means reachability confirmation failed; the maximum number of + // reachability probes has been sent and no replies have been received. + // + // TODO(gvisor.dev/issue/5472): Add the following sentence when we implement + // RFC 7048: "Packets continue to be sent to the neighbor while + // re-attempting to resolve the address." + Unreachable ) +type timer struct { + // done indicates to the timer that the timer was stopped. + done *bool + + timer tcpip.Timer +} + // neighborEntry implements a neighbor entry's individual node behavior, as per // RFC 4861 section 7.3.3. Neighbor Unreachability Detection operates in // parallel with the sending of packets to a neighbor, necessitating the @@ -82,20 +101,22 @@ type neighborEntry struct { // nudState points to the Neighbor Unreachability Detection configuration. nudState *NUDState - // mu protects the fields below. - mu sync.RWMutex + mu struct { + sync.RWMutex + + neigh NeighborEntry - neigh NeighborEntry + // done is closed when address resolution is complete. It is nil iff s is + // incomplete and resolution is not yet in progress. + done chan struct{} - // done is closed when address resolution is complete. It is nil iff s is - // incomplete and resolution is not yet in progress. - done chan struct{} + // onResolve is called with the result of address resolution. + onResolve []func(LinkResolutionResult) - // onResolve is called with the result of address resolution. - onResolve []func(LinkResolutionResult) + isRouter bool - isRouter bool - job *tcpip.Job + timer timer + } } // newNeighborEntry creates a neighbor cache entry starting at the default @@ -103,14 +124,18 @@ type neighborEntry struct { // `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created // neighborEntry. func newNeighborEntry(cache *neighborCache, remoteAddr tcpip.Address, nudState *NUDState) *neighborEntry { - return &neighborEntry{ + n := &neighborEntry{ cache: cache, nudState: nudState, - neigh: NeighborEntry{ - Addr: remoteAddr, - State: Unknown, - }, } + n.mu.Lock() + n.mu.neigh = NeighborEntry{ + Addr: remoteAddr, + State: Unknown, + } + n.mu.Unlock() + return n + } // newStaticNeighborEntry creates a neighbor cache entry starting at the @@ -123,14 +148,14 @@ func newStaticNeighborEntry(cache *neighborCache, addr tcpip.Address, linkAddr t State: Static, UpdatedAtNanos: cache.nic.stack.clock.NowNanoseconds(), } - if nudDisp := cache.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborAdded(cache.nic.id, entry) - } - return &neighborEntry{ + n := &neighborEntry{ cache: cache, nudState: state, - neigh: entry, } + n.mu.Lock() + n.mu.neigh = entry + n.mu.Unlock() + return n } // notifyCompletionLocked notifies those waiting for address resolution, with @@ -138,14 +163,14 @@ func newStaticNeighborEntry(cache *neighborCache, addr tcpip.Address, linkAddr t // // Precondition: e.mu MUST be locked. func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { - res := LinkResolutionResult{LinkAddress: e.neigh.LinkAddr, Success: succeeded} - for _, callback := range e.onResolve { + res := LinkResolutionResult{LinkAddress: e.mu.neigh.LinkAddr, Success: succeeded} + for _, callback := range e.mu.onResolve { callback(res) } - e.onResolve = nil - if ch := e.done; ch != nil { + e.mu.onResolve = nil + if ch := e.mu.done; ch != nil { close(ch) - e.done = nil + e.mu.done = nil // Dequeue the pending packets in a new goroutine to not hold up the current // goroutine as writing packets may be a costly operation. // @@ -153,7 +178,7 @@ func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { // is resolved (which ends up obtaining the entry's lock) while holding the // link resolution queue's lock. Dequeuing packets in a new goroutine avoids // a lock ordering violation. - go e.cache.nic.linkResQueue.dequeue(ch, e.neigh.LinkAddr, succeeded) + go e.cache.nic.linkResQueue.dequeue(ch, e.mu.neigh.LinkAddr, succeeded) } } @@ -163,7 +188,7 @@ func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { // Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchAddEventLocked() { if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborAdded(e.cache.nic.id, e.neigh) + nudDisp.OnNeighborAdded(e.cache.nic.id, e.mu.neigh) } } @@ -173,7 +198,7 @@ func (e *neighborEntry) dispatchAddEventLocked() { // Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchChangeEventLocked() { if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborChanged(e.cache.nic.id, e.neigh) + nudDisp.OnNeighborChanged(e.cache.nic.id, e.mu.neigh) } } @@ -183,17 +208,20 @@ func (e *neighborEntry) dispatchChangeEventLocked() { // Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchRemoveEventLocked() { if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborRemoved(e.cache.nic.id, e.neigh) + nudDisp.OnNeighborRemoved(e.cache.nic.id, e.mu.neigh) } } -// cancelJobLocked cancels the currently scheduled action, if there is one. +// cancelTimerLocked cancels the currently scheduled action, if there is one. // Entries in Unknown, Stale, or Static state do not have a scheduled action. // // Precondition: e.mu MUST be locked. -func (e *neighborEntry) cancelJobLocked() { - if job := e.job; job != nil { - job.Cancel() +func (e *neighborEntry) cancelTimerLocked() { + if e.mu.timer.timer != nil { + e.mu.timer.timer.Stop() + *e.mu.timer.done = true + + e.mu.timer = timer{} } } @@ -201,9 +229,9 @@ func (e *neighborEntry) cancelJobLocked() { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) removeLocked() { - e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() + e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() e.dispatchRemoveEventLocked() - e.cancelJobLocked() + e.cancelTimerLocked() e.notifyCompletionLocked(false /* succeeded */) } @@ -213,61 +241,98 @@ func (e *neighborEntry) removeLocked() { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) setStateLocked(next NeighborState) { - e.cancelJobLocked() + e.cancelTimerLocked() - prev := e.neigh.State - e.neigh.State = next - e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() + prev := e.mu.neigh.State + e.mu.neigh.State = next + e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() config := e.nudState.Config() switch next { case Incomplete: - panic(fmt.Sprintf("should never transition to Incomplete with setStateLocked; neigh = %#v, prev state = %s", e.neigh, prev)) + panic(fmt.Sprintf("should never transition to Incomplete with setStateLocked; neigh = %#v, prev state = %s", e.mu.neigh, prev)) case Reachable: - e.job = e.cache.nic.stack.newJob(&e.mu, func() { - e.setStateLocked(Stale) - e.dispatchChangeEventLocked() - }) - e.job.Schedule(e.nudState.ReachableTime()) + // Protected by e.mu. + done := false + + e.mu.timer = timer{ + done: &done, + timer: e.cache.nic.stack.Clock().AfterFunc(e.nudState.ReachableTime(), func() { + e.mu.Lock() + defer e.mu.Unlock() + + if done { + // The timer was stopped because the entry changed state. + return + } + + e.setStateLocked(Stale) + e.dispatchChangeEventLocked() + }), + } case Delay: - e.job = e.cache.nic.stack.newJob(&e.mu, func() { - e.setStateLocked(Probe) - e.dispatchChangeEventLocked() - }) - e.job.Schedule(config.DelayFirstProbeTime) + // Protected by e.mu. + done := false + + e.mu.timer = timer{ + done: &done, + timer: e.cache.nic.stack.Clock().AfterFunc(config.DelayFirstProbeTime, func() { + e.mu.Lock() + defer e.mu.Unlock() + + if done { + // The timer was stopped because the entry changed state. + return + } - case Probe: - var retryCounter uint32 - var sendUnicastProbe func() - - sendUnicastProbe = func() { - if retryCounter == config.MaxUnicastProbes { - e.dispatchRemoveEventLocked() - e.setStateLocked(Failed) - return - } + e.setStateLocked(Probe) + e.dispatchChangeEventLocked() + }), + } - if err := e.cache.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr); err != nil { - e.dispatchRemoveEventLocked() - e.setStateLocked(Failed) - return - } + case Probe: + // Protected by e.mu. + done := false - retryCounter++ - e.job = e.cache.nic.stack.newJob(&e.mu, sendUnicastProbe) - e.job.Schedule(config.RetransmitTimer) - } + remaining := config.MaxUnicastProbes + addr := e.mu.neigh.Addr + linkAddr := e.mu.neigh.LinkAddr // Send a probe in another gorountine to free this thread of execution - // for finishing the state transition. This is necessary to avoid - // deadlock where sending and processing probes are done synchronously, - // such as loopback and integration tests. - e.job = e.cache.nic.stack.newJob(&e.mu, sendUnicastProbe) - e.job.Schedule(immediateDuration) + // for finishing the state transition. This is necessary to escape the + // currently held lock so we can send the probe message without holding + // a shared lock. + e.mu.timer = timer{ + done: &done, + timer: e.cache.nic.stack.Clock().AfterFunc(0, func() { + var err tcpip.Error + timedoutResolution := remaining == 0 + if !timedoutResolution { + err = e.cache.linkRes.LinkAddressRequest(addr, "" /* localAddr */, linkAddr) + } + + e.mu.Lock() + defer e.mu.Unlock() + + if done { + // The timer was stopped because the entry changed state. + return + } - case Failed: + if timedoutResolution || err != nil { + e.setStateLocked(Unreachable) + e.dispatchChangeEventLocked() + return + } + + remaining-- + e.mu.timer.timer.Reset(config.RetransmitTimer) + }), + } + + case Unreachable: e.notifyCompletionLocked(false /* succeeded */) case Unknown, Stale, Static: @@ -285,76 +350,66 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { - switch e.neigh.State { - case Failed: - e.cache.nic.stats.Neighbor.FailedEntryLookups.Increment() + switch e.mu.neigh.State { + case Unknown, Unreachable: + prev := e.mu.neigh.State + e.mu.neigh.State = Incomplete + e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() + + switch prev { + case Unknown: + e.dispatchAddEventLocked() + case Unreachable: + e.dispatchChangeEventLocked() + e.cache.nic.stats.Neighbor.UnreachableEntryLookups.Increment() + } - fallthrough - case Unknown: - e.neigh.State = Incomplete - e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() + config := e.nudState.Config() - e.dispatchAddEventLocked() + // Protected by e.mu. + done := false - config := e.nudState.Config() + remaining := config.MaxMulticastProbes + addr := e.mu.neigh.Addr - var retryCounter uint32 - var sendMulticastProbe func() - - sendMulticastProbe = func() { - if retryCounter == config.MaxMulticastProbes { - // "If no Neighbor Advertisement is received after - // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed. - // The sender MUST return ICMP destination unreachable indications with - // code 3 (Address Unreachable) for each packet queued awaiting address - // resolution." - RFC 4861 section 7.2.2 - // - // There is no need to send an ICMP destination unreachable indication - // since the failure to resolve the address is expected to only occur - // on this node. Thus, redirecting traffic is currently not supported. - // - // "If the error occurs on a node other than the node originating the - // packet, an ICMP error message is generated. If the error occurs on - // the originating node, an implementation is not required to actually - // create and send an ICMP error packet to the source, as long as the - // upper-layer sender is notified through an appropriate mechanism - // (e.g. return value from a procedure call). Note, however, that an - // implementation may find it convenient in some cases to return errors - // to the sender by taking the offending packet, generating an ICMP - // error message, and then delivering it (locally) through the generic - // error-handling routines." - RFC 4861 section 2.1 - e.dispatchRemoveEventLocked() - e.setStateLocked(Failed) - return - } + // Send a probe in another gorountine to free this thread of execution + // for finishing the state transition. This is necessary to escape the + // currently held lock so we can send the probe message without holding + // a shared lock. + e.mu.timer = timer{ + done: &done, + timer: e.cache.nic.stack.Clock().AfterFunc(0, func() { + var err tcpip.Error + timedoutResolution := remaining == 0 + if !timedoutResolution { + // As per RFC 4861 section 7.2.2: + // + // If the source address of the packet prompting the solicitation is + // the same as one of the addresses assigned to the outgoing interface, + // that address SHOULD be placed in the IP Source Address of the + // outgoing solicitation. + // + err = e.cache.linkRes.LinkAddressRequest(addr, localAddr, "" /* linkAddr */) + } - // As per RFC 4861 section 7.2.2: - // - // If the source address of the packet prompting the solicitation is the - // same as one of the addresses assigned to the outgoing interface, that - // address SHOULD be placed in the IP Source Address of the outgoing - // solicitation. - // - if err := e.cache.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, ""); err != nil { - // There is no need to log the error here; the NUD implementation may - // assume a working link. A valid link should be the responsibility of - // the NIC/stack.LinkEndpoint. - e.dispatchRemoveEventLocked() - e.setStateLocked(Failed) - return - } + e.mu.Lock() + defer e.mu.Unlock() - retryCounter++ - e.job = e.cache.nic.stack.newJob(&e.mu, sendMulticastProbe) - e.job.Schedule(config.RetransmitTimer) - } + if done { + // The timer was stopped because the entry changed state. + return + } - // Send a probe in another gorountine to free this thread of execution - // for finishing the state transition. This is necessary to avoid - // deadlock where sending and processing probes are done synchronously, - // such as loopback and integration tests. - e.job = e.cache.nic.stack.newJob(&e.mu, sendMulticastProbe) - e.job.Schedule(immediateDuration) + if timedoutResolution || err != nil { + e.setStateLocked(Unreachable) + e.dispatchChangeEventLocked() + return + } + + remaining-- + e.mu.timer.timer.Reset(config.RetransmitTimer) + }), + } case Stale: e.setStateLocked(Delay) @@ -363,7 +418,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { case Incomplete, Reachable, Delay, Probe, Static: // Do nothing default: - panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + panic(fmt.Sprintf("Invalid cache entry state: %s", e.mu.neigh.State)) } } @@ -378,9 +433,9 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { // not exist, or not bound to the NIC as per RFC 4861 section 7.2.3. These // checks MUST be done by the NetworkEndpoint. - switch e.neigh.State { - case Unknown, Failed: - e.neigh.LinkAddr = remoteLinkAddr + switch e.mu.neigh.State { + case Unknown: + e.mu.neigh.LinkAddr = remoteLinkAddr e.setStateLocked(Stale) e.dispatchAddEventLocked() @@ -390,29 +445,36 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { // cached address should be replaced by the received address, and the // entry's reachability state MUST be set to STALE." // - RFC 4861 section 7.2.3 - e.neigh.LinkAddr = remoteLinkAddr + e.mu.neigh.LinkAddr = remoteLinkAddr e.setStateLocked(Stale) e.notifyCompletionLocked(true /* succeeded */) e.dispatchChangeEventLocked() case Reachable, Delay, Probe: - if e.neigh.LinkAddr != remoteLinkAddr { - e.neigh.LinkAddr = remoteLinkAddr + if e.mu.neigh.LinkAddr != remoteLinkAddr { + e.mu.neigh.LinkAddr = remoteLinkAddr e.setStateLocked(Stale) e.dispatchChangeEventLocked() } case Stale: - if e.neigh.LinkAddr != remoteLinkAddr { - e.neigh.LinkAddr = remoteLinkAddr + if e.mu.neigh.LinkAddr != remoteLinkAddr { + e.mu.neigh.LinkAddr = remoteLinkAddr e.dispatchChangeEventLocked() } + case Unreachable: + // TODO(gvisor.dev/issue/5472): Do not change the entry if the link + // address is the same, as per RFC 7048. + e.mu.neigh.LinkAddr = remoteLinkAddr + e.setStateLocked(Stale) + e.dispatchChangeEventLocked() + case Static: // Do nothing default: - panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + panic(fmt.Sprintf("Invalid cache entry state: %s", e.mu.neigh.State)) } } @@ -430,7 +492,7 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { - switch e.neigh.State { + switch e.mu.neigh.State { case Incomplete: if len(linkAddr) == 0 { // "If the link layer has addresses and no Target Link-Layer Address @@ -439,35 +501,35 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla break } - e.neigh.LinkAddr = linkAddr + e.mu.neigh.LinkAddr = linkAddr if flags.Solicited { e.setStateLocked(Reachable) } else { e.setStateLocked(Stale) } e.dispatchChangeEventLocked() - e.isRouter = flags.IsRouter + e.mu.isRouter = flags.IsRouter e.notifyCompletionLocked(true /* succeeded */) // "Note that the Override flag is ignored if the entry is in the // INCOMPLETE state." - RFC 4861 section 7.2.5 case Reachable, Stale, Delay, Probe: - isLinkAddrDifferent := len(linkAddr) != 0 && e.neigh.LinkAddr != linkAddr + isLinkAddrDifferent := len(linkAddr) != 0 && e.mu.neigh.LinkAddr != linkAddr if isLinkAddrDifferent { if !flags.Override { - if e.neigh.State == Reachable { + if e.mu.neigh.State == Reachable { e.setStateLocked(Stale) e.dispatchChangeEventLocked() } break } - e.neigh.LinkAddr = linkAddr + e.mu.neigh.LinkAddr = linkAddr if !flags.Solicited { - if e.neigh.State != Stale { + if e.mu.neigh.State != Stale { e.setStateLocked(Stale) e.dispatchChangeEventLocked() } else { @@ -479,7 +541,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla } if flags.Solicited && (flags.Override || !isLinkAddrDifferent) { - wasReachable := e.neigh.State == Reachable + wasReachable := e.mu.neigh.State == Reachable // Set state to Reachable again to refresh timers. e.setStateLocked(Reachable) e.notifyCompletionLocked(true /* succeeded */) @@ -488,7 +550,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla } } - if e.isRouter && !flags.IsRouter && header.IsV6UnicastAddress(e.neigh.Addr) { + if e.mu.isRouter && !flags.IsRouter && header.IsV6UnicastAddress(e.mu.neigh.Addr) { // "In those cases where the IsRouter flag changes from TRUE to FALSE as // a result of this update, the node MUST remove that router from the // Default Router List and update the Destination Cache entries for all @@ -505,16 +567,16 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla } if ndpEP, ok := ep.(NDPEndpoint); ok { - ndpEP.InvalidateDefaultRouter(e.neigh.Addr) + ndpEP.InvalidateDefaultRouter(e.mu.neigh.Addr) } } - e.isRouter = flags.IsRouter + e.mu.isRouter = flags.IsRouter - case Unknown, Failed, Static: + case Unknown, Unreachable, Static: // Do nothing default: - panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + panic(fmt.Sprintf("Invalid cache entry state: %s", e.mu.neigh.State)) } } @@ -523,19 +585,19 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla // // Precondition: e.mu MUST be locked. func (e *neighborEntry) handleUpperLevelConfirmationLocked() { - switch e.neigh.State { + switch e.mu.neigh.State { case Reachable, Stale, Delay, Probe: - wasReachable := e.neigh.State == Reachable + wasReachable := e.mu.neigh.State == Reachable // Set state to Reachable again to refresh timers. e.setStateLocked(Reachable) if !wasReachable { e.dispatchChangeEventLocked() } - case Unknown, Incomplete, Failed, Static: + case Unknown, Incomplete, Unreachable, Static: // Do nothing default: - panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + panic(fmt.Sprintf("Invalid cache entry state: %s", e.mu.neigh.State)) } } diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index 57cfbdb8b..47a9e2448 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -70,38 +70,39 @@ func eventDiffOptsWithSort() []cmp.Option { } // The following unit tests exercise every state transition and verify its -// behavior with RFC 4681. +// behavior with RFC 4681 and RFC 7048. // -// | From | To | Cause | Update | Action | Event | -// | ========== | ========== | ========================================== | ======== | ===========| ======= | -// | Unknown | Unknown | Confirmation w/ unknown address | | | Added | -// | Unknown | Incomplete | Packet queued to unknown address | | Send probe | Added | -// | Unknown | Stale | Probe | | | Added | -// | Incomplete | Incomplete | Retransmit timer expired | | Send probe | Changed | -// | Incomplete | Reachable | Solicited confirmation | LinkAddr | Notify | Changed | -// | Incomplete | Stale | Unsolicited confirmation | LinkAddr | Notify | Changed | -// | Incomplete | Stale | Probe | LinkAddr | Notify | Changed | -// | Incomplete | Failed | Max probes sent without reply | | Notify | Removed | -// | Reachable | Reachable | Confirmation w/ different isRouter flag | IsRouter | | | -// | Reachable | Stale | Reachable timer expired | | | Changed | -// | Reachable | Stale | Probe or confirmation w/ different address | | | Changed | -// | Stale | Reachable | Solicited override confirmation | LinkAddr | | Changed | -// | Stale | Reachable | Solicited confirmation w/o address | | Notify | Changed | -// | Stale | Stale | Override confirmation | LinkAddr | | Changed | -// | Stale | Stale | Probe w/ different address | LinkAddr | | Changed | -// | Stale | Delay | Packet sent | | | Changed | -// | Delay | Reachable | Upper-layer confirmation | | | Changed | -// | Delay | Reachable | Solicited override confirmation | LinkAddr | | Changed | -// | Delay | Reachable | Solicited confirmation w/o address | | Notify | Changed | -// | Delay | Stale | Probe or confirmation w/ different address | | | Changed | -// | Delay | Probe | Delay timer expired | | Send probe | Changed | -// | Probe | Reachable | Solicited override confirmation | LinkAddr | | Changed | -// | Probe | Reachable | Solicited confirmation w/ same address | | Notify | Changed | -// | Probe | Reachable | Solicited confirmation w/o address | | Notify | Changed | -// | Probe | Stale | Probe or confirmation w/ different address | | | Changed | -// | Probe | Probe | Retransmit timer expired | | | Changed | -// | Probe | Failed | Max probes sent without reply | | Notify | Removed | -// | Failed | Incomplete | Packet queued | | Send probe | Added | +// | From | To | Cause | Update | Action | Event | +// | =========== | =========== | ========================================== | ======== | ===========| ======= | +// | Unknown | Unknown | Confirmation w/ unknown address | | | Added | +// | Unknown | Incomplete | Packet queued to unknown address | | Send probe | Added | +// | Unknown | Stale | Probe | | | Added | +// | Incomplete | Incomplete | Retransmit timer expired | | Send probe | Changed | +// | Incomplete | Reachable | Solicited confirmation | LinkAddr | Notify | Changed | +// | Incomplete | Stale | Unsolicited confirmation | LinkAddr | Notify | Changed | +// | Incomplete | Stale | Probe | LinkAddr | Notify | Changed | +// | Incomplete | Unreachable | Max probes sent without reply | | Notify | Changed | +// | Reachable | Reachable | Confirmation w/ different isRouter flag | IsRouter | | | +// | Reachable | Stale | Reachable timer expired | | | Changed | +// | Reachable | Stale | Probe or confirmation w/ different address | | | Changed | +// | Stale | Reachable | Solicited override confirmation | LinkAddr | | Changed | +// | Stale | Reachable | Solicited confirmation w/o address | | Notify | Changed | +// | Stale | Stale | Override confirmation | LinkAddr | | Changed | +// | Stale | Stale | Probe w/ different address | LinkAddr | | Changed | +// | Stale | Delay | Packet sent | | | Changed | +// | Delay | Reachable | Upper-layer confirmation | | | Changed | +// | Delay | Reachable | Solicited override confirmation | LinkAddr | | Changed | +// | Delay | Reachable | Solicited confirmation w/o address | | Notify | Changed | +// | Delay | Stale | Probe or confirmation w/ different address | | | Changed | +// | Delay | Probe | Delay timer expired | | Send probe | Changed | +// | Probe | Reachable | Solicited override confirmation | LinkAddr | | Changed | +// | Probe | Reachable | Solicited confirmation w/ same address | | Notify | Changed | +// | Probe | Reachable | Solicited confirmation w/o address | | Notify | Changed | +// | Probe | Stale | Probe or confirmation w/ different address | | | Changed | +// | Probe | Probe | Retransmit timer expired | | | Changed | +// | Probe | Unreachable | Max probes sent without reply | | Notify | Changed | +// | Unreachable | Incomplete | Packet queued | | Send probe | Changed | +// | Unreachable | Stale | Probe w/ different address | LinkAddr | | Changed | type testEntryEventType uint8 @@ -220,13 +221,15 @@ func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *faketime.ManualClock) { clock := faketime.NewManualClock() disp := testNUDDispatcher{} - nic := NIC{ + nic := nic{ LinkEndpoint: nil, // entryTestLinkResolver doesn't use a LinkEndpoint id: entryTestNICID, stack: &Stack{ - clock: clock, - nudDisp: &disp, + clock: clock, + nudDisp: &disp, + nudConfigs: c, + randomGenerator: rand.New(rand.NewSource(time.Now().UnixNano())), }, stats: makeNICStats(), } @@ -235,23 +238,18 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e header.IPv6ProtocolNumber: netEP, } - rng := rand.New(rand.NewSource(time.Now().UnixNano())) - nudState := NewNUDState(c, rng) var linkRes entryTestLinkResolver // Stub out the neighbor cache to verify deletion from the cache. - neigh := &neighborCache{ - nic: &nic, - state: nudState, - linkRes: &linkRes, - cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), - } - l := linkResolver{ - resolver: &linkRes, - neighborTable: neigh, - } - entry := newNeighborEntry(neigh, entryTestAddr1 /* remoteAddr */, nudState) - neigh.cache[entryTestAddr1] = entry - nic.linkAddrResolvers = map[tcpip.NetworkProtocolNumber]linkResolver{ + l := &linkResolver{ + resolver: &linkRes, + } + l.neigh.init(&nic, &linkRes) + + entry := newNeighborEntry(&l.neigh, entryTestAddr1 /* remoteAddr */, l.neigh.state) + l.neigh.mu.Lock() + l.neigh.mu.cache[entryTestAddr1] = entry + l.neigh.mu.Unlock() + nic.linkAddrResolvers = map[tcpip.NetworkProtocolNumber]*linkResolver{ header.IPv6ProtocolNumber: l, } @@ -265,8 +263,8 @@ func TestEntryInitiallyUnknown(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - if e.neigh.State != Unknown { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Unknown) + if e.mu.neigh.State != Unknown { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unknown) } e.mu.Unlock() @@ -298,8 +296,8 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { Override: false, IsRouter: false, }) - if e.neigh.State != Unknown { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Unknown) + if e.mu.neigh.State != Unknown { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unknown) } e.mu.Unlock() @@ -327,8 +325,8 @@ func TestEntryUnknownToIncomplete(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Incomplete { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) + if e.mu.neigh.State != Incomplete { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) } e.mu.Unlock() @@ -374,8 +372,8 @@ func TestEntryUnknownToStale(t *testing.T) { e.mu.Lock() e.handleProbeLocked(entryTestLinkAddr1) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.mu.Unlock() @@ -413,10 +411,10 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Incomplete { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) + if e.mu.neigh.State != Incomplete { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) } - updatedAtNanos := e.neigh.UpdatedAtNanos + updatedAtNanos := e.mu.neigh.UpdatedAtNanos e.mu.Unlock() clock.Advance(c.RetransmitTimer) @@ -443,15 +441,15 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.UpdatedAtNanos, updatedAtNanos; got != want { - t.Errorf("got e.neigh.UpdatedAt = %q, want = %q", got, want) + if got, want := e.mu.neigh.UpdatedAtNanos, updatedAtNanos; got != want { + t.Errorf("got e.mu.neigh.UpdatedAt = %q, want = %q", got, want) } e.mu.Unlock() clock.Advance(c.RetransmitTimer) // UpdatedAt should change after failing address resolution. Timing out after - // sending the last probe transitions the entry to Failed. + // sending the last probe transitions the entry to Unreachable. { wantProbes := []entryTestProbeInfo{ { @@ -481,12 +479,12 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { }, }, { - EventType: entryTestRemoved, + EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ Addr: entryTestAddr1, LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + State: Unreachable, }, }, } @@ -497,8 +495,8 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if got, notWant := e.neigh.UpdatedAtNanos, updatedAtNanos; got == notWant { - t.Errorf("expected e.neigh.UpdatedAt to change, got = %q", got) + if got, notWant := e.mu.neigh.UpdatedAtNanos, updatedAtNanos; got == notWant { + t.Errorf("expected e.mu.neigh.UpdatedAt to change, got = %q", got) } e.mu.Unlock() } @@ -509,8 +507,8 @@ func TestEntryIncompleteToReachable(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Incomplete { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) + if e.mu.neigh.State != Incomplete { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) } e.mu.Unlock() @@ -535,8 +533,8 @@ func TestEntryIncompleteToReachable(t *testing.T) { Override: false, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } e.mu.Unlock() @@ -573,8 +571,8 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Incomplete { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) + if e.mu.neigh.State != Incomplete { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) } e.mu.Unlock() @@ -599,11 +597,11 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { Override: false, IsRouter: true, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } - if !e.isRouter { - t.Errorf("got e.isRouter = %t, want = true", e.isRouter) + if !e.mu.isRouter { + t.Errorf("got e.mu.isRouter = %t, want = true", e.mu.isRouter) } e.mu.Unlock() @@ -640,8 +638,8 @@ func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Incomplete { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) + if e.mu.neigh.State != Incomplete { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) } e.mu.Unlock() @@ -666,8 +664,8 @@ func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) { Override: false, IsRouter: false, }) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.mu.Unlock() @@ -704,8 +702,8 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Incomplete { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) + if e.mu.neigh.State != Incomplete { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) } e.mu.Unlock() @@ -726,8 +724,8 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { e.mu.Lock() e.handleProbeLocked(entryTestLinkAddr1) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.mu.Unlock() @@ -758,15 +756,15 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { nudDisp.mu.Unlock() } -func TestEntryIncompleteToFailed(t *testing.T) { +func TestEntryIncompleteToUnreachable(t *testing.T) { c := DefaultNUDConfigurations() c.MaxMulticastProbes = 3 e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Incomplete { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) + if e.mu.neigh.State != Incomplete { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) } e.mu.Unlock() @@ -810,12 +808,12 @@ func TestEntryIncompleteToFailed(t *testing.T) { }, }, { - EventType: entryTestRemoved, + EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ Addr: entryTestAddr1, LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + State: Unreachable, }, }, } @@ -826,8 +824,8 @@ func TestEntryIncompleteToFailed(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if e.neigh.State != Failed { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed) + if e.mu.neigh.State != Unreachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unreachable) } e.mu.Unlock() } @@ -870,11 +868,11 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { Override: false, IsRouter: true, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } - if got, want := e.isRouter, true; got != want { - t.Errorf("got e.isRouter = %t, want = %t", got, want) + if got, want := e.mu.isRouter, true; got != want { + t.Errorf("got e.mu.isRouter = %t, want = %t", got, want) } e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ @@ -882,11 +880,11 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { Override: false, IsRouter: false, }) - if got, want := e.isRouter, false; got != want { - t.Errorf("got e.isRouter = %t, want = %t", got, want) + if got, want := e.mu.isRouter, false; got != want { + t.Errorf("got e.mu.isRouter = %t, want = %t", got, want) } - if ipv6EP.invalidatedRtr != e.neigh.Addr { - t.Errorf("got ipv6EP.invalidatedRtr = %s, want = %s", ipv6EP.invalidatedRtr, e.neigh.Addr) + if ipv6EP.invalidatedRtr != e.mu.neigh.Addr { + t.Errorf("got ipv6EP.invalidatedRtr = %s, want = %s", ipv6EP.invalidatedRtr, e.mu.neigh.Addr) } e.mu.Unlock() @@ -917,8 +915,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } e.mu.Unlock() } @@ -952,15 +950,15 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { Override: false, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } e.handleProbeLocked(entryTestLinkAddr1) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } - if e.neigh.LinkAddr != entryTestLinkAddr1 { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) } e.mu.Unlock() @@ -1025,8 +1023,8 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { Override: false, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } e.mu.Unlock() @@ -1068,8 +1066,8 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.mu.Unlock() } @@ -1103,12 +1101,12 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { Override: false, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } e.handleProbeLocked(entryTestLinkAddr2) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.mu.Unlock() @@ -1177,16 +1175,16 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) Override: false, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.mu.Unlock() @@ -1255,16 +1253,16 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t Override: false, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, }) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.mu.Unlock() @@ -1333,15 +1331,15 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { Override: false, IsRouter: false, }) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.handleProbeLocked(entryTestLinkAddr1) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } - if e.neigh.LinkAddr != entryTestLinkAddr1 { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) } e.mu.Unlock() @@ -1401,19 +1399,19 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { Override: false, IsRouter: false, }) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: true, Override: true, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } - if e.neigh.LinkAddr != entryTestLinkAddr2 { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + if e.mu.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2) } e.mu.Unlock() @@ -1482,19 +1480,19 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing Override: false, IsRouter: false, }) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } - if e.neigh.LinkAddr != entryTestLinkAddr1 { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) } e.mu.Unlock() @@ -1563,19 +1561,19 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { Override: false, IsRouter: false, }) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, }) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } - if e.neigh.LinkAddr != entryTestLinkAddr2 { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + if e.mu.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2) } e.mu.Unlock() @@ -1644,15 +1642,15 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { Override: false, IsRouter: false, }) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.handleProbeLocked(entryTestLinkAddr2) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } - if e.neigh.LinkAddr != entryTestLinkAddr2 { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + if e.mu.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2) } e.mu.Unlock() @@ -1721,12 +1719,12 @@ func TestEntryStaleToDelay(t *testing.T) { Override: false, IsRouter: false, }) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Delay { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Delay { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.mu.Unlock() @@ -1801,12 +1799,12 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { IsRouter: false, }) e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Delay { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + if e.mu.neigh.State != Delay { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) } e.handleUpperLevelConfirmationLocked() - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } e.mu.Unlock() @@ -1901,19 +1899,19 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { IsRouter: false, }) e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Delay { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + if e.mu.neigh.State != Delay { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: true, Override: true, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } - if e.neigh.LinkAddr != entryTestLinkAddr2 { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + if e.mu.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr2) } e.mu.Unlock() @@ -2008,19 +2006,19 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing IsRouter: false, }) e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Delay { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + if e.mu.neigh.State != Delay { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) } e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } - if e.neigh.LinkAddr != entryTestLinkAddr1 { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) } e.mu.Unlock() @@ -2109,19 +2107,19 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { IsRouter: false, }) e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Delay { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + if e.mu.neigh.State != Delay { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) } e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, }) - if e.neigh.State != Delay { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + if e.mu.neigh.State != Delay { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) } - if e.neigh.LinkAddr != entryTestLinkAddr1 { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + if e.mu.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", e.mu.neigh.LinkAddr, entryTestLinkAddr1) } e.mu.Unlock() @@ -2191,12 +2189,12 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { IsRouter: false, }) e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Delay { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + if e.mu.neigh.State != Delay { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) } e.handleProbeLocked(entryTestLinkAddr2) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.mu.Unlock() @@ -2275,16 +2273,16 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { IsRouter: false, }) e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Delay { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + if e.mu.neigh.State != Delay { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, }) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.mu.Unlock() @@ -2366,8 +2364,8 @@ func TestEntryDelayToProbe(t *testing.T) { IsRouter: false, }) e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Delay { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + if e.mu.neigh.State != Delay { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Delay) } e.mu.Unlock() @@ -2432,8 +2430,8 @@ func TestEntryDelayToProbe(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if e.neigh.State != Probe { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + if e.mu.neigh.State != Probe { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) } e.mu.Unlock() } @@ -2490,12 +2488,12 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { } e.mu.Lock() - if e.neigh.State != Probe { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + if e.mu.neigh.State != Probe { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) } e.handleProbeLocked(entryTestLinkAddr2) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.mu.Unlock() @@ -2605,16 +2603,16 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { } e.mu.Lock() - if e.neigh.State != Probe { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + if e.mu.neigh.State != Probe { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, }) - if e.neigh.State != Stale { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) } e.mu.Unlock() @@ -2725,19 +2723,19 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { } e.mu.Lock() - if e.neigh.State != Probe { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + if e.mu.neigh.State != Probe { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, }) - if e.neigh.State != Probe { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + if e.mu.neigh.State != Probe { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + if got, want := e.mu.neigh.LinkAddr, entryTestLinkAddr1; got != want { + t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", got, want) } e.mu.Unlock() @@ -2821,19 +2819,19 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { } e.mu.Lock() - if e.neigh.State != Probe { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + if e.mu.neigh.State != Probe { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: true, Override: true, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + if got, want := e.mu.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", got, want) } e.mu.Unlock() @@ -2949,19 +2947,19 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { } e.mu.Lock() - if e.neigh.State != Probe { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + if e.mu.neigh.State != Probe { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: true, Override: true, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + if got, want := e.mu.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.mu.neigh.LinkAddr = %q, want = %q", got, want) } e.mu.Unlock() @@ -3086,16 +3084,16 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin } e.mu.Lock() - if e.neigh.State != Probe { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + if e.mu.neigh.State != Probe { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } e.mu.Unlock() @@ -3220,16 +3218,16 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing } e.mu.Lock() - if e.neigh.State != Probe { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + if e.mu.neigh.State != Probe { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) } e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + if e.mu.neigh.State != Reachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Reachable) } e.mu.Unlock() @@ -3297,7 +3295,7 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing nudDisp.mu.Unlock() } -func TestEntryProbeToFailed(t *testing.T) { +func TestEntryProbeToUnreachable(t *testing.T) { c := DefaultNUDConfigurations() c.MaxMulticastProbes = 3 c.MaxUnicastProbes = 3 @@ -3352,17 +3350,17 @@ func TestEntryProbeToFailed(t *testing.T) { } e.mu.Lock() - if e.neigh.State != Probe { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + if e.mu.neigh.State != Probe { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Probe) } e.mu.Unlock() } - // Wait for the last probe to expire, causing a transition to Failed. + // Wait for the last probe to expire, causing a transition to Unreachable. clock.Advance(c.RetransmitTimer) e.mu.Lock() - if e.neigh.State != Failed { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed) + if e.mu.neigh.State != Unreachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unreachable) } e.mu.Unlock() @@ -3404,12 +3402,12 @@ func TestEntryProbeToFailed(t *testing.T) { }, }, { - EventType: entryTestRemoved, + EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ Addr: entryTestAddr1, LinkAddr: entryTestLinkAddr1, - State: Probe, + State: Unreachable, }, }, } @@ -3420,7 +3418,7 @@ func TestEntryProbeToFailed(t *testing.T) { nudDisp.mu.Unlock() } -func TestEntryFailedToIncomplete(t *testing.T) { +func TestEntryUnreachableToIncomplete(t *testing.T) { c := DefaultNUDConfigurations() c.MaxMulticastProbes = 3 e, nudDisp, linkRes, clock := entryTestSetup(c) @@ -3429,8 +3427,8 @@ func TestEntryFailedToIncomplete(t *testing.T) { // their expected state. e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Incomplete { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) + if e.mu.neigh.State != Incomplete { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) } e.mu.Unlock() @@ -3464,15 +3462,15 @@ func TestEntryFailedToIncomplete(t *testing.T) { } e.mu.Lock() - if e.neigh.State != Failed { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed) + if e.mu.neigh.State != Unreachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unreachable) } e.mu.Unlock() e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if e.neigh.State != Incomplete { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) + if e.mu.neigh.State != Incomplete { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) } e.mu.Unlock() @@ -3487,7 +3485,16 @@ func TestEntryFailedToIncomplete(t *testing.T) { }, }, { - EventType: entryTestRemoved, + EventType: entryTestChanged, + NICID: entryTestNICID, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Unreachable, + }, + }, + { + EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ Addr: entryTestAddr1, @@ -3495,6 +3502,72 @@ func TestEntryFailedToIncomplete(t *testing.T) { State: Incomplete, }, }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryUnreachableToStale(t *testing.T) { + wantProbes := []entryTestProbeInfo{ + // The Incomplete-to-Incomplete state transition is tested here by + // verifying that 3 reachability probes were sent. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = uint32(len(wantProbes)) + e, nudDisp, linkRes, clock := entryTestSetup(c) + + // TODO(gvisor.dev/issue/4872): Use helper functions to start entry tests in + // their expected state. + e.mu.Lock() + e.handlePacketQueuedLocked(entryTestAddr2) + if e.mu.neigh.State != Incomplete { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Incomplete) + } + e.mu.Unlock() + + waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes) + clock.Advance(waitFor) + + linkRes.mu.Lock() + diff := cmp.Diff(wantProbes, linkRes.probes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) + } + + e.mu.Lock() + if e.mu.neigh.State != Unreachable { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Unreachable) + } + e.mu.Unlock() + + e.mu.Lock() + e.handleProbeLocked(entryTestLinkAddr2) + if e.mu.neigh.State != Stale { + t.Errorf("got e.mu.neigh.State = %q, want = %q", e.mu.neigh.State, Stale) + } + e.mu.Unlock() + + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, @@ -3504,6 +3577,24 @@ func TestEntryFailedToIncomplete(t *testing.T) { State: Incomplete, }, }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Unreachable, + }, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + }, } nudDisp.mu.Lock() if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { diff --git a/pkg/tcpip/stack/neighborstate_string.go b/pkg/tcpip/stack/neighborstate_string.go index aa7311ec6..765df4d7a 100644 --- a/pkg/tcpip/stack/neighborstate_string.go +++ b/pkg/tcpip/stack/neighborstate_string.go @@ -1,4 +1,4 @@ -// Copyright 2020 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -30,11 +30,12 @@ func _() { _ = x[Probe-5] _ = x[Static-6] _ = x[Failed-7] + _ = x[Unreachable-8] } -const _NeighborState_name = "UnknownIncompleteReachableStaleDelayProbeStaticFailed" +const _NeighborState_name = "UnknownIncompleteReachableStaleDelayProbeStaticFailedUnreachable" -var _NeighborState_index = [...]uint8{0, 7, 17, 26, 31, 36, 41, 47, 53} +var _NeighborState_index = [...]uint8{0, 7, 17, 26, 31, 36, 41, 47, 53, 64} func (i NeighborState) String() string { if i >= NeighborState(len(_NeighborState_index)-1) { diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 41a489047..f66db16a7 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -24,40 +24,26 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) -type neighborTable interface { - neighbors() ([]NeighborEntry, tcpip.Error) - addStaticEntry(tcpip.Address, tcpip.LinkAddress) - get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) - remove(tcpip.Address) tcpip.Error - removeAll() tcpip.Error - - handleProbe(tcpip.Address, tcpip.LinkAddress) - handleConfirmation(tcpip.Address, tcpip.LinkAddress, ReachabilityConfirmationFlags) - handleUpperLevelConfirmation(tcpip.Address) - - nudConfig() (NUDConfigurations, tcpip.Error) - setNUDConfig(NUDConfigurations) tcpip.Error -} - -var _ NetworkInterface = (*NIC)(nil) - type linkResolver struct { resolver LinkAddressResolver - neighborTable neighborTable + neigh neighborCache } func (l *linkResolver) getNeighborLinkAddress(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { - return l.neighborTable.get(addr, localAddr, onResolve) + entry, ch, err := l.neigh.entry(addr, localAddr, onResolve) + return entry.LinkAddr, ch, err } func (l *linkResolver) confirmReachable(addr tcpip.Address) { - l.neighborTable.handleUpperLevelConfirmation(addr) + l.neigh.handleUpperLevelConfirmation(addr) } -// NIC represents a "network interface card" to which the networking stack is +var _ NetworkInterface = (*nic)(nil) + +// nic represents a "network interface card" to which the networking stack is // attached. -type NIC struct { +type nic struct { LinkEndpoint stack *Stack @@ -69,8 +55,9 @@ type NIC struct { // The network endpoints themselves may be modified by calling the interface's // methods, but the map reference and entries must be constant. - networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint - linkAddrResolvers map[tcpip.NetworkProtocolNumber]linkResolver + networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint + linkAddrResolvers map[tcpip.NetworkProtocolNumber]*linkResolver + duplicateAddressDetectors map[tcpip.NetworkProtocolNumber]DuplicateAddressDetector // enabled is set to 1 when the NIC is enabled and 0 when it is disabled. // @@ -147,7 +134,7 @@ func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) { } // newNIC returns a new NIC using the default NDP configurations from stack. -func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC { +func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *nic { // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For // example, make sure that the link address it provides is a valid // unicast ethernet address. @@ -156,16 +143,17 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC // observe an MTU of at least 1280 bytes. Ensure that this requirement // of IPv6 is supported on this endpoint's LinkEndpoint. - nic := &NIC{ + nic := &nic{ LinkEndpoint: ep, - stack: stack, - id: id, - name: name, - context: ctx, - stats: makeNICStats(), - networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), - linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]linkResolver), + stack: stack, + id: id, + name: name, + context: ctx, + stats: makeNICStats(), + networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), + linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]*linkResolver), + duplicateAddressDetectors: make(map[tcpip.NetworkProtocolNumber]DuplicateAddressDetector), } nic.linkResQueue.init(nic) nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) @@ -185,26 +173,15 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC if resolutionRequired { if r, ok := netEP.(LinkAddressResolver); ok { - l := linkResolver{ - resolver: r, - } - - if stack.useNeighborCache { - l.neighborTable = &neighborCache{ - nic: nic, - state: NewNUDState(stack.nudConfigs, stack.randomGenerator), - linkRes: r, - - cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), - } - } else { - cache := new(linkAddrCache) - cache.init(nic, ageLimit, resolutionTimeout, resolutionAttempts, r) - l.neighborTable = cache - } + l := &linkResolver{resolver: r} + l.neigh.init(nic, r) nic.linkAddrResolvers[r.LinkAddressProtocol()] = l } } + + if d, ok := netEP.(DuplicateAddressDetector); ok { + nic.duplicateAddressDetectors[d.DuplicateAddressProtocol()] = d + } } nic.LinkEndpoint.Attach(nic) @@ -212,19 +189,19 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC return nic } -func (n *NIC) getNetworkEndpoint(proto tcpip.NetworkProtocolNumber) NetworkEndpoint { +func (n *nic) getNetworkEndpoint(proto tcpip.NetworkProtocolNumber) NetworkEndpoint { return n.networkEndpoints[proto] } // Enabled implements NetworkInterface. -func (n *NIC) Enabled() bool { +func (n *nic) Enabled() bool { return atomic.LoadUint32(&n.enabled) == 1 } // setEnabled sets the enabled status for the NIC. // // Returns true if the enabled status was updated. -func (n *NIC) setEnabled(v bool) bool { +func (n *nic) setEnabled(v bool) bool { if v { return atomic.SwapUint32(&n.enabled, 1) == 0 } @@ -234,7 +211,7 @@ func (n *NIC) setEnabled(v bool) bool { // disable disables n. // // It undoes the work done by enable. -func (n *NIC) disable() { +func (n *nic) disable() { n.mu.Lock() n.disableLocked() n.mu.Unlock() @@ -245,7 +222,7 @@ func (n *NIC) disable() { // It undoes the work done by enable. // // n MUST be locked. -func (n *NIC) disableLocked() { +func (n *nic) disableLocked() { if !n.Enabled() { return } @@ -283,7 +260,7 @@ func (n *NIC) disableLocked() { // address (ff02::1), start DAD for permanent addresses, and start soliciting // routers if the stack is not operating as a router. If the stack is also // configured to auto-generate a link-local address, one will be generated. -func (n *NIC) enable() tcpip.Error { +func (n *nic) enable() tcpip.Error { n.mu.Lock() defer n.mu.Unlock() @@ -303,7 +280,7 @@ func (n *NIC) enable() tcpip.Error { // remove detaches NIC from the link endpoint and releases network endpoint // resources. This guarantees no packets between this NIC and the network // stack. -func (n *NIC) remove() tcpip.Error { +func (n *nic) remove() tcpip.Error { n.mu.Lock() defer n.mu.Unlock() @@ -319,14 +296,14 @@ func (n *NIC) remove() tcpip.Error { } // setPromiscuousMode enables or disables promiscuous mode. -func (n *NIC) setPromiscuousMode(enable bool) { +func (n *nic) setPromiscuousMode(enable bool) { n.mu.Lock() n.mu.promiscuous = enable n.mu.Unlock() } // Promiscuous implements NetworkInterface. -func (n *NIC) Promiscuous() bool { +func (n *nic) Promiscuous() bool { n.mu.RLock() rv := n.mu.promiscuous n.mu.RUnlock() @@ -334,17 +311,17 @@ func (n *NIC) Promiscuous() bool { } // IsLoopback implements NetworkInterface. -func (n *NIC) IsLoopback() bool { +func (n *nic) IsLoopback() bool { return n.LinkEndpoint.Capabilities()&CapabilityLoopback != 0 } // WritePacket implements NetworkLinkEndpoint. -func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { +func (n *nic) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { _, err := n.enqueuePacketBuffer(r, gso, protocol, pkt) return err } -func (n *NIC) writePacketBuffer(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { +func (n *nic) writePacketBuffer(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { switch pkt := pkt.(type) { case *PacketBuffer: if err := n.writePacket(r, gso, protocol, pkt); err != nil { @@ -358,7 +335,7 @@ func (n *NIC) writePacketBuffer(r RouteInfo, gso *GSO, protocol tcpip.NetworkPro } } -func (n *NIC) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { +func (n *nic) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { routeInfo, _, err := r.resolvedFields(nil) switch err.(type) { case nil: @@ -388,14 +365,14 @@ func (n *NIC) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProt } // WritePacketToRemote implements NetworkInterface. -func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { +func (n *nic) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { var r RouteInfo r.NetProto = protocol r.RemoteLinkAddress = remoteLinkAddr return n.writePacket(r, gso, protocol, pkt) } -func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { +func (n *nic) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { // WritePacket takes ownership of pkt, calculate numBytes first. numBytes := pkt.Size() @@ -412,11 +389,11 @@ func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolN } // WritePackets implements NetworkLinkEndpoint. -func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (n *nic) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { return n.enqueuePacketBuffer(r, gso, protocol, &pkts) } -func (n *NIC) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, tcpip.Error) { +func (n *nic) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { pkt.EgressRoute = r pkt.GSOOptions = gso @@ -435,15 +412,22 @@ func (n *NIC) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocol } // setSpoofing enables or disables address spoofing. -func (n *NIC) setSpoofing(enable bool) { +func (n *nic) setSpoofing(enable bool) { n.mu.Lock() n.mu.spoofing = enable n.mu.Unlock() } +// Spoofing implements NetworkInterface. +func (n *nic) Spoofing() bool { + n.mu.RLock() + defer n.mu.RUnlock() + return n.mu.spoofing +} + // primaryAddress returns an address that can be used to communicate with // remoteAddr. -func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint { +func (n *nic) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint { ep, ok := n.networkEndpoints[protocol] if !ok { return nil @@ -473,11 +457,11 @@ const ( promiscuous ) -func (n *NIC) getAddress(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) AssignableAddressEndpoint { +func (n *nic) getAddress(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) AssignableAddressEndpoint { return n.getAddressOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous) } -func (n *NIC) hasAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { +func (n *nic) hasAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { ep := n.getAddressOrCreateTempInner(protocol, addr, false, NeverPrimaryEndpoint) if ep != nil { ep.DecRef() @@ -488,7 +472,7 @@ func (n *NIC) hasAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Addres } // findEndpoint finds the endpoint, if any, with the given address. -func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { +func (n *nic) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { return n.getAddressOrCreateTemp(protocol, address, peb, spoofing) } @@ -501,7 +485,7 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A // // If the address is the IPv4 broadcast address for an endpoint's network, that // endpoint will be returned. -func (n *NIC) getAddressOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getAddressBehaviour) AssignableAddressEndpoint { +func (n *nic) getAddressOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getAddressBehaviour) AssignableAddressEndpoint { n.mu.RLock() var spoofingOrPromiscuous bool switch tempRef { @@ -516,7 +500,7 @@ func (n *NIC) getAddressOrCreateTemp(protocol tcpip.NetworkProtocolNumber, addre // getAddressOrCreateTempInner is like getAddressEpOrCreateTemp except a boolean // is passed to indicate whether or not we should generate temporary endpoints. -func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, createTemp bool, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { +func (n *nic) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, createTemp bool, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { ep, ok := n.networkEndpoints[protocol] if !ok { return nil @@ -532,7 +516,7 @@ func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, // addAddress adds a new address to n, so that it starts accepting packets // targeted at the given address (and network protocol). -func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error { +func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error { ep, ok := n.networkEndpoints[protocolAddress.Protocol] if !ok { return &tcpip.ErrUnknownProtocol{} @@ -553,7 +537,7 @@ func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo // allPermanentAddresses returns all permanent addresses associated with // this NIC. -func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress { +func (n *nic) allPermanentAddresses() []tcpip.ProtocolAddress { var addrs []tcpip.ProtocolAddress for p, ep := range n.networkEndpoints { addressableEndpoint, ok := ep.(AddressableEndpoint) @@ -569,7 +553,7 @@ func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress { } // primaryAddresses returns the primary addresses associated with this NIC. -func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress { +func (n *nic) primaryAddresses() []tcpip.ProtocolAddress { var addrs []tcpip.ProtocolAddress for p, ep := range n.networkEndpoints { addressableEndpoint, ok := ep.(AddressableEndpoint) @@ -589,7 +573,7 @@ func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress { // primaryAddress will return the first non-deprecated address if such an // address exists. If no non-deprecated address exists, the first deprecated // address will be returned. -func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix { +func (n *nic) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix { ep, ok := n.networkEndpoints[proto] if !ok { return tcpip.AddressWithPrefix{} @@ -604,7 +588,7 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit } // removeAddress removes an address from n. -func (n *NIC) removeAddress(addr tcpip.Address) tcpip.Error { +func (n *nic) removeAddress(addr tcpip.Address) tcpip.Error { for _, ep := range n.networkEndpoints { addressableEndpoint, ok := ep.(AddressableEndpoint) if !ok { @@ -622,7 +606,7 @@ func (n *NIC) removeAddress(addr tcpip.Address) tcpip.Error { return &tcpip.ErrBadLocalAddress{} } -func (n *NIC) getLinkAddress(addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(LinkResolutionResult)) tcpip.Error { +func (n *nic) getLinkAddress(addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(LinkResolutionResult)) tcpip.Error { linkRes, ok := n.linkAddrResolvers[protocol] if !ok { return &tcpip.ErrNotSupported{} @@ -637,34 +621,38 @@ func (n *NIC) getLinkAddress(addr, localAddr tcpip.Address, protocol tcpip.Netwo return err } -func (n *NIC) neighbors(protocol tcpip.NetworkProtocolNumber) ([]NeighborEntry, tcpip.Error) { +func (n *nic) neighbors(protocol tcpip.NetworkProtocolNumber) ([]NeighborEntry, tcpip.Error) { if linkRes, ok := n.linkAddrResolvers[protocol]; ok { - return linkRes.neighborTable.neighbors() + return linkRes.neigh.entries(), nil } return nil, &tcpip.ErrNotSupported{} } -func (n *NIC) addStaticNeighbor(addr tcpip.Address, protocol tcpip.NetworkProtocolNumber, linkAddress tcpip.LinkAddress) tcpip.Error { +func (n *nic) addStaticNeighbor(addr tcpip.Address, protocol tcpip.NetworkProtocolNumber, linkAddress tcpip.LinkAddress) tcpip.Error { if linkRes, ok := n.linkAddrResolvers[protocol]; ok { - linkRes.neighborTable.addStaticEntry(addr, linkAddress) + linkRes.neigh.addStaticEntry(addr, linkAddress) return nil } return &tcpip.ErrNotSupported{} } -func (n *NIC) removeNeighbor(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { +func (n *nic) removeNeighbor(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { if linkRes, ok := n.linkAddrResolvers[protocol]; ok { - return linkRes.neighborTable.remove(addr) + if !linkRes.neigh.removeEntry(addr) { + return &tcpip.ErrBadAddress{} + } + return nil } return &tcpip.ErrNotSupported{} } -func (n *NIC) clearNeighbors(protocol tcpip.NetworkProtocolNumber) tcpip.Error { +func (n *nic) clearNeighbors(protocol tcpip.NetworkProtocolNumber) tcpip.Error { if linkRes, ok := n.linkAddrResolvers[protocol]; ok { - return linkRes.neighborTable.removeAll() + linkRes.neigh.clear() + return nil } return &tcpip.ErrNotSupported{} @@ -672,7 +660,7 @@ func (n *NIC) clearNeighbors(protocol tcpip.NetworkProtocolNumber) tcpip.Error { // joinGroup adds a new endpoint for the given multicast address, if none // exists yet. Otherwise it just increments its count. -func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { +func (n *nic) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { // TODO(b/143102137): When implementing MLD, make sure MLD packets are // not sent unless a valid link-local address is available for use on n // as an MLD packet's source address must be a link-local address as @@ -693,7 +681,7 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address // leaveGroup decrements the count for the given multicast address, and when it // reaches zero removes the endpoint for this address. -func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { +func (n *nic) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { ep, ok := n.networkEndpoints[protocol] if !ok { return &tcpip.ErrNotSupported{} @@ -708,7 +696,7 @@ func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Addres } // isInGroup returns true if n has joined the multicast group addr. -func (n *NIC) isInGroup(addr tcpip.Address) bool { +func (n *nic) isInGroup(addr tcpip.Address) bool { for _, ep := range n.networkEndpoints { gep, ok := ep.(GroupAddressableEndpoint) if !ok { @@ -729,7 +717,7 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { // Note that the ownership of the slice backing vv is retained by the caller. // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. -func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { +func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { n.mu.RLock() enabled := n.Enabled() // If the NIC is not yet enabled, don't receive any packets. @@ -777,41 +765,11 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp anyEPs.forEach(deliverPacketEPs) } - // Parse headers. - netProto := n.stack.NetworkProtocolInstance(protocol) - transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) - if !ok { - // The packet is too small to contain a network header. - n.stack.stats.MalformedRcvdPackets.Increment() - return - } - if hasTransportHdr { - pkt.TransportProtocolNumber = transProtoNum - // Parse the transport header if present. - if state, ok := n.stack.transportProtocols[transProtoNum]; ok { - state.proto.Parse(pkt) - } - } - - if n.stack.handleLocal && !n.IsLoopback() { - src, _ := netProto.ParseAddresses(pkt.NetworkHeader().View()) - if r := n.getAddress(protocol, src); r != nil { - r.DecRef() - - // The source address is one of our own, so we never should have gotten a - // packet like this unless handleLocal is false. Loopback also calls this - // function even though the packets didn't come from the physical interface - // so don't drop those. - n.stack.stats.IP.InvalidSourceAddressesReceived.Increment() - return - } - } - networkEndpoint.HandlePacket(pkt) } // DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket. -func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { +func (n *nic) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { n.mu.RLock() // We do not deliver to protocol specific packet endpoints as on Linux // only ETH_P_ALL endpoints get outbound packets. @@ -831,7 +789,7 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { +func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -912,7 +870,7 @@ func (n *NIC) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt } // DeliverTransportError implements TransportDispatcher. -func (n *NIC) DeliverTransportError(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr TransportError, pkt *PacketBuffer) { +func (n *nic) DeliverTransportError(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr TransportError, pkt *PacketBuffer) { state, ok := n.stack.transportProtocols[trans] if !ok { return @@ -940,19 +898,19 @@ func (n *NIC) DeliverTransportError(local, remote tcpip.Address, net tcpip.Netwo } // ID implements NetworkInterface. -func (n *NIC) ID() tcpip.NICID { +func (n *nic) ID() tcpip.NICID { return n.id } // Name implements NetworkInterface. -func (n *NIC) Name() string { +func (n *nic) Name() string { return n.name } // nudConfigs gets the NUD configurations for n. -func (n *NIC) nudConfigs(protocol tcpip.NetworkProtocolNumber) (NUDConfigurations, tcpip.Error) { +func (n *nic) nudConfigs(protocol tcpip.NetworkProtocolNumber) (NUDConfigurations, tcpip.Error) { if linkRes, ok := n.linkAddrResolvers[protocol]; ok { - return linkRes.neighborTable.nudConfig() + return linkRes.neigh.config(), nil } return NUDConfigurations{}, &tcpip.ErrNotSupported{} @@ -962,16 +920,17 @@ func (n *NIC) nudConfigs(protocol tcpip.NetworkProtocolNumber) (NUDConfiguration // // Note, if c contains invalid NUD configuration values, it will be fixed to // use default values for the erroneous values. -func (n *NIC) setNUDConfigs(protocol tcpip.NetworkProtocolNumber, c NUDConfigurations) tcpip.Error { +func (n *nic) setNUDConfigs(protocol tcpip.NetworkProtocolNumber, c NUDConfigurations) tcpip.Error { if linkRes, ok := n.linkAddrResolvers[protocol]; ok { c.resetInvalidFields() - return linkRes.neighborTable.setNUDConfig(c) + linkRes.neigh.setConfig(c) + return nil } return &tcpip.ErrNotSupported{} } -func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error { +func (n *nic) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error { n.mu.Lock() defer n.mu.Unlock() @@ -984,7 +943,7 @@ func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa return nil } -func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) { +func (n *nic) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) { n.mu.Lock() defer n.mu.Unlock() @@ -998,7 +957,7 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep // isValidForOutgoing returns true if the endpoint can be used to send out a // packet. It requires the endpoint to not be marked expired (i.e., its address // has been removed) unless the NIC is in spoofing mode, or temporary. -func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool { +func (n *nic) isValidForOutgoing(ep AssignableAddressEndpoint) bool { n.mu.RLock() spoofing := n.mu.spoofing n.mu.RUnlock() @@ -1006,9 +965,9 @@ func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool { } // HandleNeighborProbe implements NetworkInterface. -func (n *NIC) HandleNeighborProbe(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { +func (n *nic) HandleNeighborProbe(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { if l, ok := n.linkAddrResolvers[protocol]; ok { - l.neighborTable.handleProbe(addr, linkAddr) + l.neigh.handleProbe(addr, linkAddr) return nil } @@ -1016,11 +975,34 @@ func (n *NIC) HandleNeighborProbe(protocol tcpip.NetworkProtocolNumber, addr tcp } // HandleNeighborConfirmation implements NetworkInterface. -func (n *NIC) HandleNeighborConfirmation(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) tcpip.Error { +func (n *nic) HandleNeighborConfirmation(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) tcpip.Error { if l, ok := n.linkAddrResolvers[protocol]; ok { - l.neighborTable.handleConfirmation(addr, linkAddr, flags) + l.neigh.handleConfirmation(addr, linkAddr, flags) return nil } return &tcpip.ErrNotSupported{} } + +// CheckLocalAddress implements NetworkInterface. +func (n *nic) CheckLocalAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + if n.Spoofing() { + return true + } + + if addressEndpoint := n.getAddressOrCreateTempInner(protocol, addr, false /* createTemp */, NeverPrimaryEndpoint); addressEndpoint != nil { + addressEndpoint.DecRef() + return true + } + + return false +} + +func (n *nic) checkDuplicateAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, h DADCompletionHandler) (DADCheckAddressDisposition, tcpip.Error) { + d, ok := n.duplicateAddressDetectors[protocol] + if !ok { + return 0, &tcpip.ErrNotSupported{} + } + + return d.CheckDuplicateAddress(addr, h), nil +} diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 9992d6eb4..c0f956e53 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -170,7 +170,7 @@ func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bo func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { // When the NIC is disabled, the only field that matters is the stats field. // This test is limited to stats counter checks. - nic := NIC{ + nic := nic{ stats: makeNICStats(), } diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index e9acef6a2..e1253f310 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -101,7 +101,6 @@ func TestNUDFunctions(t *testing.T) { clock := faketime.NewManualClock() s := stack.New(stack.Options{ NUDConfigs: stack.DefaultNUDConfigurations(), - UseNeighborCache: true, NetworkProtocols: test.netProtoFactory, Clock: clock, }) @@ -206,7 +205,6 @@ func TestDefaultNUDConfigurations(t *testing.T) { // address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: stack.DefaultNUDConfigurations(), - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -261,7 +259,6 @@ func TestNUDConfigurationsBaseReachableTime(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -318,7 +315,6 @@ func TestNUDConfigurationsMinRandomFactor(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -398,7 +394,6 @@ func TestNUDConfigurationsMaxRandomFactor(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -460,7 +455,6 @@ func TestNUDConfigurationsRetransmitTimer(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -512,7 +506,6 @@ func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -564,7 +557,6 @@ func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -616,7 +608,6 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, - UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index 1c651e216..dc139ebb2 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -55,7 +55,7 @@ type pendingPacket struct { // // Once link resolution completes successfully, the packets will be written. type packetsPendingLinkResolution struct { - nic *NIC + nic *nic mu struct { sync.Mutex @@ -82,7 +82,7 @@ func (f *packetsPendingLinkResolution) incrementOutgoingPacketErrors(proto tcpip } } -func (f *packetsPendingLinkResolution) init(nic *NIC) { +func (f *packetsPendingLinkResolution) init(nic *nic) { f.mu.Lock() defer f.mu.Unlock() f.nic = nic diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index d589f798d..43e9e4beb 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -16,6 +16,7 @@ package stack import ( "fmt" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -514,8 +515,19 @@ type NetworkInterface interface { Enabled() bool // Promiscuous returns true if the interface is in promiscuous mode. + // + // When in promiscuous mode, the interface should accept all packets. Promiscuous() bool + // Spoofing returns true if the interface is in spoofing mode. + // + // When in spoofing mode, the interface should consider all addresses as + // assigned to it. + Spoofing() bool + + // CheckLocalAddress returns true if the address exists on the interface. + CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool + // WritePacketToRemote writes the packet to the given remote link address. WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error @@ -840,7 +852,97 @@ type InjectableLinkEndpoint interface { InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error } -// A LinkAddressResolver handles link address resolution for a network protocol. +// DADResult is the result of a duplicate address detection process. +type DADResult struct { + // Resolved is true when DAD completed without detecting a duplicate address + // on the link. + // + // Ignored when Err is non-nil. + Resolved bool + + // Err is an error encountered while performing DAD. + Err tcpip.Error +} + +// DADCompletionHandler is a handler for DAD completion. +type DADCompletionHandler func(DADResult) + +// DADCheckAddressDisposition enumerates the possible return values from +// DAD.CheckDuplicateAddress. +type DADCheckAddressDisposition int + +const ( + _ DADCheckAddressDisposition = iota + + // DADDisabled indicates that DAD is disabled. + DADDisabled + + // DADStarting indicates that DAD is starting for an address. + DADStarting + + // DADAlreadyRunning indicates that DAD was already started for an address. + DADAlreadyRunning +) + +const ( + // defaultDupAddrDetectTransmits is the default number of NDP Neighbor + // Solicitation messages to send when doing Duplicate Address Detection + // for a tentative address. + // + // Default = 1 (from RFC 4862 section 5.1) + defaultDupAddrDetectTransmits = 1 +) + +// DADConfigurations holds configurations for duplicate address detection. +type DADConfigurations struct { + // The number of Neighbor Solicitation messages to send when doing + // Duplicate Address Detection for a tentative address. + // + // Note, a value of zero effectively disables DAD. + DupAddrDetectTransmits uint8 + + // The amount of time to wait between sending Neighbor Solicitation + // messages. + // + // Must be greater than or equal to 1ms. + RetransmitTimer time.Duration +} + +// DefaultDADConfigurations returns the default DAD configurations. +func DefaultDADConfigurations() DADConfigurations { + return DADConfigurations{ + DupAddrDetectTransmits: defaultDupAddrDetectTransmits, + RetransmitTimer: defaultRetransmitTimer, + } +} + +// Validate modifies the configuration with valid values. If invalid values are +// present in the configurations, the corresponding default values are used +// instead. +func (c *DADConfigurations) Validate() { + if c.RetransmitTimer < minimumRetransmitTimer { + c.RetransmitTimer = defaultRetransmitTimer + } +} + +// DuplicateAddressDetector handles checking if an address is already assigned +// to some neighboring node on the link. +type DuplicateAddressDetector interface { + // CheckDuplicateAddress checks if an address is assigned to a neighbor. + // + // If DAD is already being performed for the address, the handler will be + // called with the result of the original DAD request. + CheckDuplicateAddress(tcpip.Address, DADCompletionHandler) DADCheckAddressDisposition + + // SetDADConfiguations sets the configurations for DAD. + SetDADConfigurations(c DADConfigurations) + + // DuplicateAddressProtocol returns the network protocol the receiver can + // perform duplicate address detection for. + DuplicateAddressProtocol() tcpip.NetworkProtocolNumber +} + +// LinkAddressResolver handles link address resolution for a network protocol. type LinkAddressResolver interface { // LinkAddressRequest sends a request for the link address of the target // address. The request is broadcasted on the local network if a remote link diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index bab55ce49..e946f9fe3 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -35,7 +35,7 @@ type Route struct { // localAddressNIC is the interface the address is associated with. // TODO(gvisor.dev/issue/4548): Remove this field once we can query the // address's assigned status without the NIC. - localAddressNIC *NIC + localAddressNIC *nic mu struct { sync.RWMutex @@ -49,11 +49,11 @@ type Route struct { } // outgoingNIC is the interface this route uses to write packets. - outgoingNIC *NIC + outgoingNIC *nic // linkRes is set if link address resolution is enabled for this protocol on // the route's NIC. - linkRes linkResolver + linkRes *linkResolver } type routeInfo struct { @@ -108,7 +108,7 @@ func (r *Route) fieldsLocked() RouteInfo { // ownership of the provided local address. // // Returns an empty route if validation fails. -func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) *Route { +func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *nic, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) *Route { if len(localAddr) == 0 { localAddr = addressEndpoint.AddressWithPrefix().Address } @@ -140,7 +140,7 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp // makeRoute initializes a new route. It takes ownership of the provided // AssignableAddressEndpoint. -func makeRoute(netProto tcpip.NetworkProtocolNumber, gateway, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) *Route { +func makeRoute(netProto tcpip.NetworkProtocolNumber, gateway, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *nic, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) *Route { if localAddressNIC.stack != outgoingNIC.stack { panic(fmt.Sprintf("cannot create a route with NICs from different stacks")) } @@ -184,7 +184,7 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, gateway, localAddr, remoteA return r } - if r.linkRes.resolver == nil { + if r.linkRes == nil { return r } @@ -206,7 +206,7 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, gateway, localAddr, remoteA return r } -func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) *Route { +func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *nic, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) *Route { r := &Route{ routeInfo: routeInfo{ NetProto: netProto, @@ -230,7 +230,7 @@ func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr // provided AssignableAddressEndpoint. // // A local route is a route to a destination that is local to the stack. -func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) *Route { +func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *nic, localAddressEndpoint AssignableAddressEndpoint) *Route { loop := PacketLoop // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the // link endpoint level. We can remove this check once loopback interfaces @@ -400,7 +400,7 @@ func (r *Route) IsResolutionRequired() bool { } func (r *Route) isResolutionRequiredRLocked() bool { - return len(r.mu.remoteLinkAddress) == 0 && r.linkRes.resolver != nil && r.isValidForOutgoingRLocked() && !r.local() + return len(r.mu.remoteLinkAddress) == 0 && r.linkRes != nil && r.isValidForOutgoingRLocked() && !r.local() } func (r *Route) isValidForOutgoing() bool { @@ -528,7 +528,7 @@ func (r *Route) IsOutboundBroadcast() bool { // "Reachable" is defined as having full-duplex communication between the // local and remote ends of the route. func (r *Route) ConfirmReachable() { - if r.linkRes.resolver != nil { + if r.linkRes != nil { r.linkRes.confirmReachable(r.nextHop()) } } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index a51d758d0..674c9a1ff 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -395,7 +395,7 @@ type Stack struct { } mu sync.RWMutex - nics map[tcpip.NICID]*NIC + nics map[tcpip.NICID]*nic // cleanupEndpointsMu protects cleanupEndpoints. cleanupEndpointsMu sync.Mutex @@ -434,12 +434,6 @@ type Stack struct { // nudConfigs is the default NUD configurations used by interfaces. nudConfigs NUDConfigurations - // useNeighborCache indicates whether ARP and NDP packets should be handled - // by the NIC's neighborCache instead of linkAddrCache. - // - // TODO(gvisor.dev/issue/4658): Remove this field. - useNeighborCache bool - // nudDisp is the NUD event dispatcher that is used to send the netstack // integrator NUD related events. nudDisp NUDDispatcher @@ -516,17 +510,6 @@ type Options struct { // NUDConfigs is the default NUD configurations used by interfaces. NUDConfigs NUDConfigurations - // UseNeighborCache is unused. - // - // TODO(gvisor.dev/issue/4658): Remove this field. - UseNeighborCache bool - - // UseLinkAddrCache indicates that the legacy link address cache should be - // used for link resolution. - // - // TODO(gvisor.dev/issue/4658): Remove this field. - UseLinkAddrCache bool - // NUDDisp is the NUD event dispatcher that an integrator can provide to // receive NUD related events. NUDDisp NUDDispatcher @@ -656,7 +639,7 @@ func New(opts Options) *Stack { s := &Stack{ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), - nics: make(map[tcpip.NICID]*NIC), + nics: make(map[tcpip.NICID]*nic), cleanupEndpoints: make(map[TransportEndpoint]struct{}), PortManager: ports.NewPortManager(), clock: clock, @@ -666,7 +649,6 @@ func New(opts Options) *Stack { icmpRateLimiter: NewICMPRateLimiter(), seed: generateRandUint32(), nudConfigs: opts.NUDConfigs, - useNeighborCache: !opts.UseLinkAddrCache, uniqueIDGenerator: opts.UniqueID, nudDisp: opts.NUDDisp, randomGenerator: mathrand.New(randSrc), @@ -1233,7 +1215,7 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol return nic.primaryAddress(protocol), true } -func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) AssignableAddressEndpoint { +func (s *Stack) getAddressEP(nic *nic, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) AssignableAddressEndpoint { if len(localAddr) == 0 { return nic.primaryEndpoint(netProto, remoteAddr) } @@ -1244,13 +1226,13 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP // from the specified NIC. // // Precondition: s.mu must be read locked. -func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route { +func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *nic, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route { localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint) if localAddressEndpoint == nil { return nil } - var outgoingNIC *NIC + var outgoingNIC *nic // Prefer a local route to the same interface as the local address. if localAddressNIC.hasAddress(netProto, remoteAddr) { outgoingNIC = localAddressNIC @@ -1319,6 +1301,11 @@ func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, return nil } +// HandleLocal returns true if non-loopback interfaces are allowed to loop packets. +func (s *Stack) HandleLocal() bool { + return s.handleLocal +} + // FindRoute creates a route to the given destination address, leaving through // the given NIC and local address (if provided). // @@ -1479,6 +1466,17 @@ func (s *Stack) CheckNetworkProtocol(protocol tcpip.NetworkProtocolNumber) bool return ok } +// CheckDuplicateAddress performs duplicate address detection for the address on +// the specified interface. +func (s *Stack) CheckDuplicateAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, h DADCompletionHandler) (DADCheckAddressDisposition, tcpip.Error) { + nic, ok := s.nics[nicID] + if !ok { + return 0, &tcpip.ErrUnknownNICID{} + } + + return nic.checkDuplicateAddress(protocol, addr, h) +} + // CheckLocalAddress determines if the given local address exists, and if it // does, returns the id of the NIC it's bound to. Returns 0 if the address // does not exist. @@ -1493,20 +1491,16 @@ func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProto return 0 } - addressEndpoint := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint) - if addressEndpoint == nil { - return 0 + if nic.CheckLocalAddress(protocol, addr) { + return nic.id } - addressEndpoint.DecRef() - - return nic.id + return 0 } // Go through all the NICs. for _, nic := range s.nics { - if addressEndpoint := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint); addressEndpoint != nil { - addressEndpoint.DecRef() + if nic.CheckLocalAddress(protocol, addr) { return nic.id } } @@ -2062,22 +2056,6 @@ func generateRandInt64() int64 { return v } -// FindNetworkEndpoint returns the network endpoint for the given address. -func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, tcpip.Error) { - s.mu.RLock() - defer s.mu.RUnlock() - - for _, nic := range s.nics { - addressEndpoint := nic.getAddressOrCreateTempInner(netProto, address, false /* createTemp */, NeverPrimaryEndpoint) - if addressEndpoint == nil { - continue - } - addressEndpoint.DecRef() - return nic.getNetworkEndpoint(netProto), nil - } - return nil, &tcpip.ErrBadAddress{} -} - // FindNICNameFromID returns the name of the NIC for the given NICID. func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { s.mu.RLock() @@ -2103,13 +2081,6 @@ const ( // ParsedOK indicates that a packet was successfully parsed. ParsedOK ParseResult = iota - // UnknownNetworkProtocol indicates that the network protocol is unknown. - UnknownNetworkProtocol - - // NetworkLayerParseError indicates that the network packet was not - // successfully parsed. - NetworkLayerParseError - // UnknownTransportProtocol indicates that the transport protocol is unknown. UnknownTransportProtocol @@ -2118,31 +2089,19 @@ const ( TransportLayerParseError ) -// ParsePacketBuffer parses the provided packet buffer. -func (s *Stack) ParsePacketBuffer(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) ParseResult { - netProto, ok := s.networkProtocols[protocol] - if !ok { - return UnknownNetworkProtocol - } - - transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) - if !ok { - return NetworkLayerParseError - } - if !hasTransportHdr { - return ParsedOK - } - +// ParsePacketBufferTransport parses the provided packet buffer's transport +// header. +func (s *Stack) ParsePacketBufferTransport(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) ParseResult { // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a // full explanation. - if transProtoNum == header.ICMPv4ProtocolNumber || transProtoNum == header.ICMPv6ProtocolNumber { + if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { return ParsedOK } - pkt.TransportProtocolNumber = transProtoNum + pkt.TransportProtocolNumber = protocol // Parse the transport header if present. - state, ok := s.transportProtocols[transProtoNum] + state, ok := s.transportProtocols[protocol] if !ok { return UnknownTransportProtocol } @@ -2164,7 +2123,7 @@ func (s *Stack) networkProtocolNumbers() []tcpip.NetworkProtocolNumber { return protos } -func isSubnetBroadcastOnNIC(nic *NIC, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { +func isSubnetBroadcastOnNIC(nic *nic, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { addressEndpoint := nic.getAddressOrCreateTempInner(protocol, addr, false /* createTemp */, NeverPrimaryEndpoint) if addressEndpoint == nil { return false diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index b641a4aaa..92a0cb401 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -119,6 +119,10 @@ func (*fakeNetworkEndpoint) DefaultTTL() uint8 { } func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { + if _, _, ok := f.proto.Parse(pkt); !ok { + return + } + // Increment the received packet count in the protocol descriptor. netHdr := pkt.NetworkHeader().View() @@ -2569,16 +2573,16 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent), } - ndpConfigs := ipv6.DefaultNDPConfigurations() + dadConfigs := stack.DefaultDADConfigurations() opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ndpConfigs, AutoGenLinkLocal: true, NDPDisp: &ndpDisp, + DADConfigs: dadConfigs, })}, } - e := channel.New(int(ndpConfigs.DupAddrDetectTransmits), 1280, linkAddr1) + e := channel.New(int(dadConfigs.DupAddrDetectTransmits), 1280, linkAddr1) s := stack.New(opts) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -2594,7 +2598,7 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { // Wait for DAD to resolve. select { - case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second): // We should get a resolution event after 1s (default time to // resolve as per default NDP configurations). Waiting for that // resolution time + an extra 1s without a resolution event @@ -3231,7 +3235,7 @@ func TestDoDADWhenNICEnabled(t *testing.T) { } opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ + DADConfigs: stack.DADConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, }, @@ -4320,7 +4324,6 @@ func TestClearNeighborCacheOnNICDisable(t *testing.T) { clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - UseNeighborCache: true, Clock: clock, }) e := channel.New(0, 0, "") diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 7d8d0851e..292e51d20 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -183,7 +183,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet } // handleError delivers an error to the transport endpoint identified by id. -func (epsByNIC *endpointsByNIC) handleError(n *NIC, id TransportEndpointID, transErr TransportError, pkt *PacketBuffer) { +func (epsByNIC *endpointsByNIC) handleError(n *nic, id TransportEndpointID, transErr TransportError, pkt *PacketBuffer) { epsByNIC.mu.RLock() defer epsByNIC.mu.RUnlock() @@ -599,7 +599,7 @@ func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumb // endpoint. // // Returns true if the error was delivered. -func (d *transportDemuxer) deliverError(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr TransportError, pkt *PacketBuffer, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverError(n *nic, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr TransportError, pkt *PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index 21a8dd291..b56706357 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) type inputIfNameMatcher struct { @@ -334,3 +335,312 @@ func TestIPTablesStatsForInput(t *testing.T) { }) } } + +var _ stack.LinkEndpoint = (*channelEndpointWithoutWritePacket)(nil) + +// channelEndpointWithoutWritePacket is a channel endpoint that does not support +// stack.LinkEndpoint.WritePacket. +type channelEndpointWithoutWritePacket struct { + *channel.Endpoint + + t *testing.T +} + +func (c *channelEndpointWithoutWritePacket) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { + c.t.Error("unexpectedly called WritePacket; all writes should go through WritePackets") + return &tcpip.ErrNotSupported{} +} + +var _ stack.Matcher = (*udpSourcePortMatcher)(nil) + +type udpSourcePortMatcher struct { + port uint16 +} + +func (*udpSourcePortMatcher) Name() string { + return "udpSourcePortMatcher" +} + +func (m *udpSourcePortMatcher) Match(_ stack.Hook, pkt *stack.PacketBuffer, _, _ string) (matches, hotdrop bool) { + udp := header.UDP(pkt.TransportHeader().View()) + if len(udp) < header.UDPMinimumSize { + // Drop immediately as the packet is invalid. + return false, true + } + + return udp.SourcePort() == m.port, false +} + +func TestIPTableWritePackets(t *testing.T) { + const ( + nicID = 1 + + dropLocalPort = localPort - 1 + acceptPackets = 2 + dropPackets = 3 + ) + + udpHdr := func(hdr buffer.View, srcAddr, dstAddr tcpip.Address, srcPort, dstPort uint16) { + u := header.UDP(hdr) + u.Encode(&header.UDPFields{ + SrcPort: srcPort, + DstPort: dstPort, + Length: header.UDPMinimumSize, + }) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, srcAddr, dstAddr, header.UDPMinimumSize) + sum = header.Checksum(hdr, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + } + + tests := []struct { + name string + setupFilter func(*testing.T, *stack.Stack) + genPacket func(*stack.Route) stack.PacketBufferList + proto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + expectSent uint64 + expectOutputDropped uint64 + }{ + { + name: "IPv4 Accept", + setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, + genPacket: func(r *stack.Route) stack.PacketBufferList { + var pkts stack.PacketBufferList + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), + }) + hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) + udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort) + pkts.PushFront(pkt) + + return pkts + }, + proto: header.IPv4ProtocolNumber, + remoteAddr: dstAddrV4, + expectSent: 1, + expectOutputDropped: 0, + }, + { + name: "IPv4 Drop Other Port", + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + + table := stack.Table{ + Rules: []stack.Rule{ + { + Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}, + }, + { + Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}, + }, + { + Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}}, + Target: &stack.DropTarget{NetworkProtocol: header.IPv4ProtocolNumber}, + }, + { + Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}, + }, + { + Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}, + }, + }, + BuiltinChains: [stack.NumHooks]int{ + stack.Prerouting: stack.HookUnset, + stack.Input: 0, + stack.Forward: 1, + stack.Output: 2, + stack.Postrouting: stack.HookUnset, + }, + Underflows: [stack.NumHooks]int{ + stack.Prerouting: stack.HookUnset, + stack.Input: 0, + stack.Forward: 1, + stack.Output: 2, + stack.Postrouting: stack.HookUnset, + }, + } + + if err := s.IPTables().ReplaceTable(stack.FilterID, table, false /* ipv4 */); err != nil { + t.Fatalf("RelaceTable(%d, _, false): %s", stack.FilterID, err) + } + }, + genPacket: func(r *stack.Route) stack.PacketBufferList { + var pkts stack.PacketBufferList + + for i := 0; i < acceptPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), + }) + hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) + udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort) + pkts.PushFront(pkt) + } + for i := 0; i < dropPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), + }) + hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) + udpHdr(hdr, r.LocalAddress, r.RemoteAddress, dropLocalPort, remotePort) + pkts.PushFront(pkt) + } + + return pkts + }, + proto: header.IPv4ProtocolNumber, + remoteAddr: dstAddrV4, + expectSent: acceptPackets, + expectOutputDropped: dropPackets, + }, + { + name: "IPv6 Accept", + setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, + genPacket: func(r *stack.Route) stack.PacketBufferList { + var pkts stack.PacketBufferList + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), + }) + hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) + udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort) + pkts.PushFront(pkt) + + return pkts + }, + proto: header.IPv6ProtocolNumber, + remoteAddr: dstAddrV6, + expectSent: 1, + expectOutputDropped: 0, + }, + { + name: "IPv6 Drop Other Port", + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + + table := stack.Table{ + Rules: []stack.Rule{ + { + Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}, + }, + { + Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}, + }, + { + Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}}, + Target: &stack.DropTarget{NetworkProtocol: header.IPv6ProtocolNumber}, + }, + { + Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}, + }, + { + Target: &stack.ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}, + }, + }, + BuiltinChains: [stack.NumHooks]int{ + stack.Prerouting: stack.HookUnset, + stack.Input: 0, + stack.Forward: 1, + stack.Output: 2, + stack.Postrouting: stack.HookUnset, + }, + Underflows: [stack.NumHooks]int{ + stack.Prerouting: stack.HookUnset, + stack.Input: 0, + stack.Forward: 1, + stack.Output: 2, + stack.Postrouting: stack.HookUnset, + }, + } + + if err := s.IPTables().ReplaceTable(stack.FilterID, table, true /* ipv6 */); err != nil { + t.Fatalf("RelaceTable(%d, _, true): %s", stack.FilterID, err) + } + }, + genPacket: func(r *stack.Route) stack.PacketBufferList { + var pkts stack.PacketBufferList + + for i := 0; i < acceptPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), + }) + hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) + udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort) + pkts.PushFront(pkt) + } + for i := 0; i < dropPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), + }) + hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) + udpHdr(hdr, r.LocalAddress, r.RemoteAddress, dropLocalPort, remotePort) + pkts.PushFront(pkt) + } + + return pkts + }, + proto: header.IPv6ProtocolNumber, + remoteAddr: dstAddrV6, + expectSent: acceptPackets, + expectOutputDropped: dropPackets, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + e := channelEndpointWithoutWritePacket{ + Endpoint: channel.New(4, header.IPv6MinimumMTU, linkAddr), + t: t, + } + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, srcAddrV6); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, srcAddrV6, err) + } + if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, srcAddrV4); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, srcAddrV4, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + }) + + test.setupFilter(t, s) + + r, err := s.FindRoute(nicID, "", test.remoteAddr, test.proto, false) + if err != nil { + t.Fatalf("FindRoute(%d, '', %s, %d, false): %s", nicID, test.remoteAddr, test.proto, err) + } + defer r.Release() + + pkts := test.genPacket(r) + pktsLen := pkts.Len() + if n, err := r.WritePackets(nil /* gso */, pkts, stack.NetworkHeaderParams{ + Protocol: header.UDPProtocolNumber, + TTL: 64, + }); err != nil { + t.Fatalf("WritePackets(...): %s", err) + } else if n != pktsLen { + t.Fatalf("got WritePackets(...) = %d, want = %d", n, pktsLen) + } + + if got := s.Stats().IP.PacketsSent.Value(); got != test.expectSent { + t.Errorf("got PacketSent = %d, want = %d", got, test.expectSent) + } + if got := s.Stats().IP.IPTablesOutputDropped.Value(); got != test.expectOutputDropped { + t.Errorf("got IPTablesOutputDropped = %d, want = %d", got, test.expectOutputDropped) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index f2301a9e6..824f81a42 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -487,30 +487,25 @@ func TestGetLinkAddress(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - for _, useNeighborCache := range []bool{true, false} { - t.Run(fmt.Sprintf("UseNeighborCache=%t", useNeighborCache), func(t *testing.T) { - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - UseNeighborCache: useNeighborCache, - } + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + } - host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) + host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) - ch := make(chan stack.LinkResolutionResult, 1) - err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(r stack.LinkResolutionResult) { - ch <- r - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, &tcpip.ErrWouldBlock{}) - } - wantRes := stack.LinkResolutionResult{Success: test.expectedOk} - if test.expectedOk { - wantRes.LinkAddress = linkAddr2 - } - if diff := cmp.Diff(wantRes, <-ch); diff != "" { - t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) - } - }) + ch := make(chan stack.LinkResolutionResult, 1) + err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(r stack.LinkResolutionResult) { + ch <- r + }) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, &tcpip.ErrWouldBlock{}) + } + wantRes := stack.LinkResolutionResult{Success: test.expectedOk} + if test.expectedOk { + wantRes.LinkAddress = linkAddr2 + } + if diff := cmp.Diff(wantRes, <-ch); diff != "" { + t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) } }) } @@ -587,66 +582,61 @@ func TestRouteResolvedFields(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - for _, useNeighborCache := range []bool{true, false} { - t.Run(fmt.Sprintf("UseNeighborCache=%t", useNeighborCache), func(t *testing.T) { - stackOpts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, - UseNeighborCache: useNeighborCache, - } + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + } - host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) - r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */) - if err != nil { - t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err) - } - defer r.Release() - - var wantRouteInfo stack.RouteInfo - wantRouteInfo.LocalLinkAddress = linkAddr1 - wantRouteInfo.LocalAddress = test.localAddr - wantRouteInfo.RemoteAddress = test.remoteAddr - wantRouteInfo.NetProto = test.netProto - wantRouteInfo.Loop = stack.PacketOut - wantRouteInfo.RemoteLinkAddress = test.expectedLinkAddr - - ch := make(chan stack.ResolvedFieldsResult, 1) - - if !test.immediatelyResolvable { - wantUnresolvedRouteInfo := wantRouteInfo - wantUnresolvedRouteInfo.RemoteLinkAddress = "" - - err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) { - ch <- r - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{}) - } - if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: test.expectedSuccess}, <-ch, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { - t.Errorf("route resolve result mismatch (-want +got):\n%s", diff) - } + host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) + r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */) + if err != nil { + t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err) + } + defer r.Release() - if !test.expectedSuccess { - return - } + var wantRouteInfo stack.RouteInfo + wantRouteInfo.LocalLinkAddress = linkAddr1 + wantRouteInfo.LocalAddress = test.localAddr + wantRouteInfo.RemoteAddress = test.remoteAddr + wantRouteInfo.NetProto = test.netProto + wantRouteInfo.Loop = stack.PacketOut + wantRouteInfo.RemoteLinkAddress = test.expectedLinkAddr - // At this point the neighbor table should be populated so the route - // should be immediately resolvable. - } + ch := make(chan stack.ResolvedFieldsResult, 1) - if err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) { - ch <- r - }); err != nil { - t.Errorf("r.ResolvedFields(_): %s", err) - } - select { - case routeResolveRes := <-ch: - if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: true}, routeResolveRes, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { - t.Errorf("route resolve result from resolved route mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected route to be immediately resolvable") - } + if !test.immediatelyResolvable { + wantUnresolvedRouteInfo := wantRouteInfo + wantUnresolvedRouteInfo.RemoteLinkAddress = "" + + err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) { + ch <- r }) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{}) + } + if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: test.expectedSuccess}, <-ch, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { + t.Errorf("route resolve result mismatch (-want +got):\n%s", diff) + } + + if !test.expectedSuccess { + return + } + + // At this point the neighbor table should be populated so the route + // should be immediately resolvable. + } + + if err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) { + ch <- r + }); err != nil { + t.Errorf("r.ResolvedFields(_): %s", err) + } + select { + case routeResolveRes := <-ch: + if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: true}, routeResolveRes, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { + t.Errorf("route resolve result from resolved route mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected route to be immediately resolvable") } }) } @@ -1065,7 +1055,6 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, Clock: clock, - UseNeighborCache: true, } host1StackOpts := stackOpts host1StackOpts.NUDDisp = &nudDisp @@ -1210,3 +1199,148 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { }) } } + +func TestDAD(t *testing.T) { + const ( + host1NICID = 1 + host2NICID = 4 + ) + + dadConfigs := stack.DADConfigurations{ + DupAddrDetectTransmits: 1, + RetransmitTimer: time.Second, + } + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + dadNetProto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + expectedResolved bool + }{ + { + name: "IPv4 own address", + netProto: ipv4.ProtocolNumber, + dadNetProto: arp.ProtocolNumber, + remoteAddr: ipv4Addr1.AddressWithPrefix.Address, + expectedResolved: true, + }, + { + name: "IPv6 own address", + netProto: ipv6.ProtocolNumber, + dadNetProto: ipv6.ProtocolNumber, + remoteAddr: ipv6Addr1.AddressWithPrefix.Address, + expectedResolved: true, + }, + { + name: "IPv4 duplicate address", + netProto: ipv4.ProtocolNumber, + dadNetProto: arp.ProtocolNumber, + remoteAddr: ipv4Addr2.AddressWithPrefix.Address, + expectedResolved: false, + }, + { + name: "IPv6 duplicate address", + netProto: ipv6.ProtocolNumber, + dadNetProto: ipv6.ProtocolNumber, + remoteAddr: ipv6Addr2.AddressWithPrefix.Address, + expectedResolved: false, + }, + { + name: "IPv4 no duplicate address", + netProto: ipv4.ProtocolNumber, + dadNetProto: arp.ProtocolNumber, + remoteAddr: ipv4Addr3.AddressWithPrefix.Address, + expectedResolved: true, + }, + { + name: "IPv6 no duplicate address", + netProto: ipv6.ProtocolNumber, + dadNetProto: ipv6.ProtocolNumber, + remoteAddr: ipv6Addr3.AddressWithPrefix.Address, + expectedResolved: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + stackOpts := stack.Options{ + Clock: clock, + NetworkProtocols: []stack.NetworkProtocolFactory{ + arp.NewProtocol, + ipv4.NewProtocol, + ipv6.NewProtocol, + }, + } + + host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) + + // DAD should be disabled by default. + if res, err := host1Stack.CheckDuplicateAddress(host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) { + t.Errorf("unexpectedly called DAD completion handler when DAD was supposed to be disabled") + }); err != nil { + t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", host1NICID, test.netProto, test.remoteAddr, err) + } else if res != stack.DADDisabled { + t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", host1NICID, test.netProto, test.remoteAddr, res, stack.DADDisabled) + } + + // Enable DAD then attempt to check if an address is duplicated. + netEP, err := host1Stack.GetNetworkEndpoint(host1NICID, test.dadNetProto) + if err != nil { + t.Fatalf("host1Stack.GetNetworkEndpoint(%d, %d): %s", host1NICID, test.dadNetProto, err) + } + dad, ok := netEP.(stack.DuplicateAddressDetector) + if !ok { + t.Fatalf("expected %T to implement stack.DuplicateAddressDetector", netEP) + } + dad.SetDADConfigurations(dadConfigs) + ch := make(chan stack.DADResult, 3) + if res, err := host1Stack.CheckDuplicateAddress(host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) { + ch <- r + }); err != nil { + t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", host1NICID, test.netProto, test.remoteAddr, err) + } else if res != stack.DADStarting { + t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", host1NICID, test.netProto, test.remoteAddr, res, stack.DADStarting) + } + + expectResults := 1 + if test.expectedResolved { + const delta = time.Nanosecond + clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer - delta) + select { + case r := <-ch: + t.Fatalf("unexpectedly got DAD result before the DAD timeout; r = %#v", r) + default: + } + + // If we expect the resolve to succeed try requesting DAD again on the + // same address. The handler for the new request should be called once + // the original DAD request completes. + expectResults = 2 + if res, err := host1Stack.CheckDuplicateAddress(host1NICID, test.netProto, test.remoteAddr, func(r stack.DADResult) { + ch <- r + }); err != nil { + t.Fatalf("host1Stack.CheckDuplicateAddress(%d, %d, %s, _): %s", host1NICID, test.netProto, test.remoteAddr, err) + } else if res != stack.DADAlreadyRunning { + t.Errorf("got host1Stack.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", host1NICID, test.netProto, test.remoteAddr, res, stack.DADAlreadyRunning) + } + + clock.Advance(delta) + } + + for i := 0; i < expectResults; i++ { + if diff := cmp.Diff(stack.DADResult{Resolved: test.expectedResolved}, <-ch); diff != "" { + t.Errorf("(i=%d) DAD result mismatch (-want +got):\n%s", i, diff) + } + } + + // Should have no more results. + select { + case r := <-ch: + t.Errorf("unexpectedly got an extra DAD result; r = %#v", r) + default: + } + }) + } +} diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 34a631b53..461b1a9d7 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -1301,7 +1301,8 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // e.mu is expected to be hold upon entering this section. if e.snd != nil { e.snd.resendTimer.cleanup() - e.snd.rc.probeTimer.cleanup() + e.snd.probeTimer.cleanup() + e.snd.reorderTimer.cleanup() } if closeTimer != nil { @@ -1396,7 +1397,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ }, }, { - w: &e.snd.rc.probeWaker, + w: &e.snd.probeWaker, f: e.snd.probeTimerExpired, }, { @@ -1475,6 +1476,10 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ return nil }, }, + { + w: &e.snd.reorderWaker, + f: e.snd.rc.reorderTimerExpired, + }, } // Initialize the sleeper based on the wakers in funcs. diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 4e5a6089f..8c5be0586 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1698,11 +1698,7 @@ func (e *endpoint) OnCorkOptionSet(v bool) { } func (e *endpoint) getSendBufferSize() int { - sndBufSize, err := e.ops.GetSendBufferSize() - if err != nil { - panic(fmt.Sprintf("e.ops.GetSendBufferSize() = %s", err)) - } - return int(sndBufSize) + return int(e.ops.GetSendBufferSize()) } // SetSockOptInt sets a socket option. diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go index e862f159e..9959b60b8 100644 --- a/pkg/tcpip/transport/tcp/rack.go +++ b/pkg/tcpip/transport/tcp/rack.go @@ -17,7 +17,6 @@ package tcp import ( "time" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/seqnum" ) @@ -50,7 +49,8 @@ type rackControl struct { // dsackSeen indicates if the connection has seen a DSACK. dsackSeen bool - // endSequence is the ending TCP sequence number of rackControl.seg. + // endSequence is the ending TCP sequence number of the most recent + // acknowledged segment. endSequence seqnum.Value // exitedRecovery indicates if the connection is exiting loss recovery. @@ -90,13 +90,10 @@ type rackControl struct { // rttSeq is the SND.NXT when rtt is updated. rttSeq seqnum.Value - // xmitTime is the latest transmission timestamp of rackControl.seg. + // xmitTime is the latest transmission timestamp of the most recent + // acknowledged segment. xmitTime time.Time `state:".(unixTime)"` - // probeTimer and probeWaker are used to schedule PTO for RACK TLP algorithm. - probeTimer timer `state:"nosave"` - probeWaker sleep.Waker `state:"nosave"` - // tlpRxtOut indicates whether there is an unacknowledged // TLP retransmission. tlpRxtOut bool @@ -114,7 +111,6 @@ func (rc *rackControl) init(snd *sender, iss seqnum.Value) { rc.fack = iss rc.reoWndIncr = 1 rc.snd = snd - rc.probeTimer.init(&rc.probeWaker) } // update will update the RACK related fields when an ACK has been received. @@ -223,13 +219,13 @@ func (s *sender) schedulePTO() { s.resendTimer.disable() } - s.rc.probeTimer.enable(pto) + s.probeTimer.enable(pto) } // probeTimerExpired is the same as TLP_send_probe() as defined in // https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5.2. func (s *sender) probeTimerExpired() tcpip.Error { - if !s.rc.probeTimer.checkExpiration() { + if !s.probeTimer.checkExpiration() { return nil } @@ -386,3 +382,102 @@ func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { func (rc *rackControl) exitRecovery() { rc.exitedRecovery = true } + +// detectLoss marks the segment as lost if the reordering window has elapsed +// and the ACK is not received. It will also arm the reorder timer. +// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 Step 5. +func (rc *rackControl) detectLoss(rcvTime time.Time) int { + var timeout time.Duration + numLost := 0 + for seg := rc.snd.writeList.Front(); seg != nil && seg.xmitCount != 0; seg = seg.Next() { + if rc.snd.ep.scoreboard.IsSACKED(seg.sackBlock()) { + continue + } + + if seg.lost && seg.xmitCount == 1 { + numLost++ + continue + } + + endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) + if seg.xmitTime.Before(rc.xmitTime) || (seg.xmitTime.Equal(rc.xmitTime) && rc.endSequence.LessThan(endSeq)) { + timeRemaining := seg.xmitTime.Sub(rcvTime) + rc.rtt + rc.reoWnd + if timeRemaining <= 0 { + seg.lost = true + numLost++ + } else if timeRemaining > timeout { + timeout = timeRemaining + } + } + } + + if timeout != 0 && !rc.snd.reorderTimer.enabled() { + rc.snd.reorderTimer.enable(timeout) + } + return numLost +} + +// reorderTimerExpired will retransmit the segments which have not been acked +// before the reorder timer expired. +func (rc *rackControl) reorderTimerExpired() tcpip.Error { + // Check if the timer actually expired or if it's a spurious wake due + // to a previously orphaned runtime timer. + if !rc.snd.reorderTimer.checkExpiration() { + return nil + } + + numLost := rc.detectLoss(time.Now()) + if numLost == 0 { + return nil + } + + fastRetransmit := false + if !rc.snd.fr.active { + rc.snd.cc.HandleLossDetected() + rc.snd.enterRecovery() + fastRetransmit = true + } + + rc.DoRecovery(nil, fastRetransmit) + return nil +} + +// DoRecovery implements lossRecovery.DoRecovery. +func (rc *rackControl) DoRecovery(_ *segment, fastRetransmit bool) { + snd := rc.snd + if fastRetransmit { + snd.resendSegment() + } + + var dataSent bool + // Iterate the writeList and retransmit the segments which are marked + // as lost by RACK. + for seg := snd.writeList.Front(); seg != nil && seg.xmitCount > 0; seg = seg.Next() { + if seg == snd.writeNext { + break + } + + if !seg.lost { + continue + } + + // Reset seg.lost as it is already SACKed. + if snd.ep.scoreboard.IsSACKED(seg.sackBlock()) { + seg.lost = false + continue + } + + // Check the congestion window after entering recovery. + if snd.outstanding >= snd.sndCwnd { + break + } + + snd.outstanding++ + dataSent = true + snd.sendSegment(seg) + } + + // Rearm the RTO. + snd.resendTimer.enable(snd.rto) + snd.postXmit(dataSent) +} diff --git a/pkg/tcpip/transport/tcp/rack_state.go b/pkg/tcpip/transport/tcp/rack_state.go index 76cad0831..c9dc7e773 100644 --- a/pkg/tcpip/transport/tcp/rack_state.go +++ b/pkg/tcpip/transport/tcp/rack_state.go @@ -27,8 +27,3 @@ func (rc *rackControl) saveXmitTime() unixTime { func (rc *rackControl) loadXmitTime(unix unixTime) { rc.xmitTime = time.Unix(unix.second, unix.nano) } - -// afterLoad is invoked by stateify. -func (rc *rackControl) afterLoad() { - rc.probeTimer.init(&rc.probeWaker) -} diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 7cca4def5..f27eef6a9 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -83,6 +83,9 @@ type segment struct { // dataMemSize is the memory used by data initially. dataMemSize int + + // lost indicates if the segment is marked as lost by RACK. + lost bool } func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 463a259b7..d6365b93d 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -191,6 +191,15 @@ type sender struct { // rc has the fields needed for implementing RACK loss detection // algorithm. rc rackControl + + // reorderTimer is the timer used to retransmit the segments after RACK + // detects them as lost. + reorderTimer timer `state:"nosave"` + reorderWaker sleep.Waker `state:"nosave"` + + // probeTimer and probeWaker are used to schedule PTO for RACK TLP algorithm. + probeTimer timer `state:"nosave"` + probeWaker sleep.Waker `state:"nosave"` } // rtt is a synchronization wrapper used to appease stateify. See the comment @@ -267,7 +276,6 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint } s.cc = s.initCongestionControl(ep.cc) - s.lr = s.initLossRecovery() s.rc.init(s, iss) @@ -278,6 +286,8 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint } s.resendTimer.init(&s.resendWaker) + s.reorderTimer.init(&s.reorderWaker) + s.probeTimer.init(&s.probeWaker) s.updateMaxPayloadSize(int(ep.route.MTU()), 0) @@ -1126,6 +1136,15 @@ func (s *sender) SetPipe() { func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) { // We're not in fast recovery yet. + // If RACK is enabled and there is no reordering we should honor the + // three duplicate ACK rule to enter recovery. + // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-4 + if s.ep.sackPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { + if s.rc.reorderSeen { + return false + } + } + if !s.isDupAck(seg) { s.dupAckCount = 0 return false @@ -1320,7 +1339,9 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // unacknowledged and also never retransmitted sequence below // RACK.fack, then the corresponding packet has been // reordered and RACK.reord is set to TRUE. - s.walkSACK(rcvdSeg) + if s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { + s.walkSACK(rcvdSeg) + } s.SetPipe() } @@ -1339,7 +1360,9 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } // See if TLP based recovery was successful. - s.detectTLPRecovery(ack, rcvdSeg) + if s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { + s.detectTLPRecovery(ack, rcvdSeg) + } // Stash away the current window size. s.sndWnd = rcvdSeg.window @@ -1421,7 +1444,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } // Update the RACK fields if SACK is enabled. - if s.ep.sackPermitted && !seg.acked { + if s.ep.sackPermitted && !seg.acked && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { s.rc.update(seg, rcvdSeg) s.rc.detectReorder(seg) } @@ -1455,7 +1478,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // Update RACK when we are exiting fast or RTO // recovery as described in the RFC // draft-ietf-tcpm-rack-08 Section-7.2 Step 4. - s.rc.exitRecovery() + if s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { + s.rc.exitRecovery() + } + s.reorderTimer.disable() } } @@ -1475,19 +1501,36 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // Reset firstRetransmittedSegXmitTime to the zero value. s.firstRetransmittedSegXmitTime = time.Time{} s.resendTimer.disable() - s.rc.probeTimer.disable() + s.probeTimer.disable() } } - // Update RACK reorder window. - // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 - // * Upon receiving an ACK: - // * Step 4: Update RACK reordering window - s.rc.updateRACKReorderWindow(rcvdSeg) + if s.ep.sackPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { + // Update RACK reorder window. + // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 + // * Upon receiving an ACK: + // * Step 4: Update RACK reordering window + s.rc.updateRACKReorderWindow(rcvdSeg) + + // After the reorder window is calculated, detect any loss by checking + // if the time elapsed after the segments are sent is greater than the + // reorder window. + if numLost := s.rc.detectLoss(rcvdSeg.rcvdTime); numLost > 0 && !s.fr.active { + // If any segment is marked as lost by + // RACK, enter recovery and retransmit + // the lost segments. + s.cc.HandleLossDetected() + s.enterRecovery() + } + + if s.fr.active { + s.rc.DoRecovery(nil, true) + } + } // Now that we've popped all acknowledged data from the retransmit // queue, retransmit if needed. - if s.fr.active { + if s.fr.active && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 { s.lr.DoRecovery(rcvdSeg, fastRetransmit) // When SACK is enabled data sending is governed by steps in // RFC 6675 Section 5 recovery steps A-C. @@ -1515,6 +1558,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error { } seg.xmitTime = time.Now() seg.xmitCount++ + seg.lost = false err := s.sendSegmentFromView(seg.data, seg.flags, seg.sequenceNumber) // Every time a packet containing data is sent (including a diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go index 8b20c3455..ba41cff6d 100644 --- a/pkg/tcpip/transport/tcp/snd_state.go +++ b/pkg/tcpip/transport/tcp/snd_state.go @@ -47,6 +47,8 @@ func (s *sender) loadRttMeasureTime(unix unixTime) { // afterLoad is invoked by stateify. func (s *sender) afterLoad() { s.resendTimer.init(&s.resendWaker) + s.reorderTimer.init(&s.reorderWaker) + s.probeTimer.init(&s.probeWaker) } // saveFirstRetransmittedSegXmitTime is invoked by stateify. diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index a6a26b705..6da981d80 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" + "gvisor.dev/gvisor/pkg/test/testutil" ) const ( @@ -34,6 +35,14 @@ const ( mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload ) +func setStackRACKPermitted(t *testing.T, c *context.Context) { + t.Helper() + opt := tcpip.TCPRecovery(tcpip.TCPRACKLossDetection) + if err := c.Stack().SetTransportProtocolOption(header.TCPProtocolNumber, &opt); err != nil { + t.Fatalf("c.s.SetTransportProtocolOption(%d, &%v(%v)): %s", header.TCPProtocolNumber, opt, opt, err) + } +} + // TestRACKUpdate tests the RACK related fields are updated when an ACK is // received on a SACK enabled connection. func TestRACKUpdate(t *testing.T) { @@ -60,6 +69,7 @@ func TestRACKUpdate(t *testing.T) { close(probeDone) }) setStackSACKPermitted(t, c, true) + setStackRACKPermitted(t, c) createConnectedWithSACKAndTS(c) data := make([]byte, maxPayload) @@ -90,6 +100,8 @@ func TestRACKDetectReorder(t *testing.T) { c := context.New(t, uint32(mtu)) defer c.Cleanup() + t.Skipf("Skipping this test as reorder detection does not consider DSACK.") + var n int const ackNumToVerify = 2 probeDone := make(chan struct{}) @@ -116,6 +128,7 @@ func TestRACKDetectReorder(t *testing.T) { close(probeDone) }) setStackSACKPermitted(t, c, true) + setStackRACKPermitted(t, c) createConnectedWithSACKAndTS(c) data := make([]byte, ackNumToVerify*maxPayload) for i := range data { @@ -148,6 +161,7 @@ func TestRACKDetectReorder(t *testing.T) { func sendAndReceive(t *testing.T, c *context.Context, numPackets int) []byte { setStackSACKPermitted(t, c, true) + setStackRACKPermitted(t, c) createConnectedWithSACKAndTS(c) data := make([]byte, numPackets*maxPayload) @@ -580,7 +594,6 @@ func TestRACKCheckReorderWindow(t *testing.T) { c.SendAck(seq, bytesRead) // Missing [2-6] packets and SACK #7 packet. - seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) start := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) end := start.Add(seqnum.Size(maxPayload)) c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) @@ -596,3 +609,46 @@ func TestRACKCheckReorderWindow(t *testing.T) { t.Fatalf("unexpected values for RACK variables: %v", err) } } + +func TestRACKWithDuplicateACK(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + const numPackets = 4 + data := sendAndReceive(t, c, numPackets) + + // Send three duplicate ACKs to trigger fast recovery. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + start := c.IRS.Add(1 + seqnum.Size(maxPayload)) + end := start.Add(seqnum.Size(maxPayload)) + for i := 0; i < 3; i++ { + c.SendAckWithSACK(seq, maxPayload, []header.SACKBlock{{start, end}}) + end = end.Add(seqnum.Size(maxPayload)) + } + + // Receive the retransmitted packet. + c.ReceiveAndCheckPacketWithOptions(data, maxPayload, maxPayload, tsOptionSize) + + metricPollFn := func() error { + tcpStats := c.Stack().Stats().TCP + stats := []struct { + stat *tcpip.StatCounter + name string + want uint64 + }{ + {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, + {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, + {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0}, + } + for _, s := range stats { + if got, want := s.stat.Value(), s.want; got != want { + return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + } + } + return nil + } + + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index cd3c4a027..0128c1f7e 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -4393,12 +4393,7 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { t.Helper() - s, err := ep.SocketOptions().GetSendBufferSize() - if err != nil { - t.Fatalf("GetSendBufferSize failed: %s", err) - } - - if int(s) != v { + if s := ep.SocketOptions().GetSendBufferSize(); int(s) != v { t.Fatalf("got send buffer size = %d, want = %d", s, v) } } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index afd8f4d39..807df2bb5 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -938,11 +938,6 @@ func (e *endpoint) Disconnect() tcpip.Error { // Connect connects the endpoint to its peer. Specifying a NIC is optional. func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { - if addr.Port == 0 { - // We don't support connecting to port zero. - return &tcpip.ErrInvalidEndpointState{} - } - e.mu.Lock() defer e.mu.Unlock() @@ -1188,7 +1183,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - if e.EndpointState() != StateConnected { + if e.EndpointState() != StateConnected || e.dstPort == 0 { return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go index eacd73531..2a8c916d5 100644 --- a/runsc/boot/filter/config.go +++ b/runsc/boot/filter/config.go @@ -100,6 +100,15 @@ var allowedSyscalls = seccomp.SyscallRules{ seccomp.MatchAny{}, }, }, + // getcpu is used by some versions of the Go runtime and by the hostcpu + // package on arm64. + unix.SYS_GETCPU: []seccomp.Rule{ + { + seccomp.MatchAny{}, + seccomp.EqualTo(0), + seccomp.EqualTo(0), + }, + }, syscall.SYS_GETPID: {}, unix.SYS_GETRANDOM: {}, syscall.SYS_GETSOCKOPT: []seccomp.Rule{ diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index d50bbcd9f..129478505 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -777,6 +777,28 @@ func TestExec(t *testing.T) { } }) } + + // Test for exec failure with an non-existent file. + t.Run("nonexist", func(t *testing.T) { + // b/179114837 found by Syzkaller that causes nil pointer panic when + // trying to dec-ref an unix socket FD. + fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) + if err != nil { + t.Fatal(err) + } + defer syscall.Close(fds[0]) + + _, err = cont.executeSync(&control.ExecArgs{ + Argv: []string{"/nonexist"}, + FilePayload: urpc.FilePayload{ + Files: []*os.File{os.NewFile(uintptr(fds[1]), "sock")}, + }, + }) + want := "failed to load /nonexist" + if err == nil || !strings.Contains(err.Error(), want) { + t.Errorf("executeSync: want err containing %q; got err = %q", want, err) + } + }) }) } } diff --git a/runsc/fsgofer/filter/config.go b/runsc/fsgofer/filter/config.go index 39b8a0b1e..f92e2f80e 100644 --- a/runsc/fsgofer/filter/config.go +++ b/runsc/fsgofer/filter/config.go @@ -107,6 +107,15 @@ var allowedSyscalls = seccomp.SyscallRules{ seccomp.MatchAny{}, }, }, + // getcpu is used by some versions of the Go runtime and by the hostcpu + // package on arm64. + unix.SYS_GETCPU: []seccomp.Rule{ + { + seccomp.MatchAny{}, + seccomp.EqualTo(0), + seccomp.EqualTo(0), + }, + }, syscall.SYS_GETDENTS64: {}, syscall.SYS_GETPID: {}, unix.SYS_GETRANDOM: {}, diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go index aaffabfd0..49cd74887 100644 --- a/test/e2e/integration_test.go +++ b/test/e2e/integration_test.go @@ -430,11 +430,44 @@ func TestTmpMount(t *testing.T) { } } +// TestSyntheticDirs checks that submounts can be created inside a readonly +// mount even if the target path does not exist. +func TestSyntheticDirs(t *testing.T) { + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) + + opts := dockerutil.RunOpts{ + Image: "basic/alpine", + // Make the root read-only to force use of synthetic dirs + // inside the root gofer mount. + ReadOnly: true, + Mounts: []mount.Mount{ + // Mount inside read-only gofer-backed root. + { + Type: mount.TypeTmpfs, + Target: "/foo/bar/baz", + }, + // Mount inside sysfs, which always uses synthetic dirs + // for submounts. + { + Type: mount.TypeTmpfs, + Target: "/sys/foo/bar/baz", + }, + }, + } + // Make sure the directories exist. + if _, err := d.Run(ctx, opts, "ls", "/foo/bar/baz", "/sys/foo/bar/baz"); err != nil { + t.Fatalf("docker run failed: %v", err) + } + +} + // TestHostOverlayfsCopyUp tests that the --overlayfs-stale-read option causes // runsc to hide the incoherence of FDs opened before and after overlayfs // copy-up on the host. func TestHostOverlayfsCopyUp(t *testing.T) { - runIntegrationTest(t, nil, "sh", "-c", "gcc -O2 -o test_copy_up test_copy_up.c && ./test_copy_up") + runIntegrationTest(t, nil, "./test_copy_up") } // TestHostOverlayfsRewindDir tests that rewinddir() "causes the directory @@ -449,14 +482,14 @@ func TestHostOverlayfsCopyUp(t *testing.T) { // automated tests yield newly-added files from readdir() even if the fsgofer // does not explicitly rewinddir(), but overlayfs does not. func TestHostOverlayfsRewindDir(t *testing.T) { - runIntegrationTest(t, nil, "sh", "-c", "gcc -O2 -o test_rewinddir test_rewinddir.c && ./test_rewinddir") + runIntegrationTest(t, nil, "./test_rewinddir") } // Basic test for linkat(2). Syscall tests requires CAP_DAC_READ_SEARCH and it // cannot use tricks like userns as root. For this reason, run a basic link test // to ensure some coverage. func TestLink(t *testing.T) { - runIntegrationTest(t, nil, "sh", "-c", "gcc -O2 -o link_test link_test.c && ./link_test") + runIntegrationTest(t, nil, "./link_test") } // This test ensures we can run ping without errors. @@ -487,6 +520,20 @@ func TestPing6Loopback(t *testing.T) { runIntegrationTest(t, []string{"NET_ADMIN"}, "./ping6.sh") } +// This test checks that the owner of the sticky directory can delete files +// inside it belonging to other users. It also checks that the owner of a file +// can always delete its file when the file is inside a sticky directory owned +// by another user. +func TestStickyDir(t *testing.T) { + if vfs2Used, err := dockerutil.UsingVFS2(); err != nil { + t.Fatalf("failed to read config for runtime %s: %v", dockerutil.Runtime(), err) + } else if !vfs2Used { + t.Skip("sticky bit test fails on VFS1.") + } + + runIntegrationTest(t, nil, "./test_sticky") +} + func runIntegrationTest(t *testing.T, capAdd []string, args ...string) { ctx := context.Background() d := dockerutil.MakeContainer(ctx, t) diff --git a/test/packetimpact/tests/tcp_rack_test.go b/test/packetimpact/tests/tcp_rack_test.go index fb2a4cc90..ef902c54d 100644 --- a/test/packetimpact/tests/tcp_rack_test.go +++ b/test/packetimpact/tests/tcp_rack_test.go @@ -168,9 +168,10 @@ func TestRACKTLPLost(t *testing.T) { closeSACKConnection(t, dut, conn, acceptFd, listenFd) } -// TestRACKTLPWithSACK tests TLP by acknowledging out of order packets. +// TestRACKWithSACK tests that RACK marks the packets as lost after receiving +// the ACK for retransmitted packets. // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-8.1 -func TestRACKTLPWithSACK(t *testing.T) { +func TestRACKWithSACK(t *testing.T) { dut, conn, acceptFd, listenFd := createSACKConnection(t) seqNum1 := *conn.RemoteSeqNum(t) @@ -180,8 +181,9 @@ func TestRACKTLPWithSACK(t *testing.T) { // We are not sending ACK for these packets. const numPkts = 3 - lastSent := sendAndReceive(t, dut, conn, numPkts, acceptFd, false /* sendACK */) + sendAndReceive(t, dut, conn, numPkts, acceptFd, false /* sendACK */) + time.Sleep(simulatedRTT) // SACK for #2 packet. sackBlock := make([]byte, 40) start := seqNum1.Add(seqnum.Size(payloadSize)) @@ -194,32 +196,25 @@ func TestRACKTLPWithSACK(t *testing.T) { }}, sackBlock[sbOff:]) conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]}) - // RACK marks #1 packet as lost and retransmits it. - if _, err := conn.Expect(t, testbench.TCP{SeqNum: testbench.Uint32(uint32(seqNum1))}, time.Second); err != nil { + rtt, _ := getRTTAndRTO(t, dut, acceptFd) + timeout := 2 * rtt + // RACK marks #1 packet as lost after RTT+reorderWindow(RTT/4) and + // retransmits it. + if _, err := conn.Expect(t, testbench.TCP{SeqNum: testbench.Uint32(uint32(seqNum1))}, timeout); err != nil { t.Fatalf("expected payload was not received: %s", err) } + time.Sleep(simulatedRTT) // ACK for #1 packet. conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(end))}) - // Probe Timeout (PTO) should be two times RTT. TLP will trigger for #3 - // packet. RACK adds an additional timeout of 200ms if the number of - // outstanding packets is equal to 1. - rtt, rto := getRTTAndRTO(t, dut, acceptFd) - pto := rtt*2 + (200 * time.Millisecond) - if rto < pto { - pto = rto - } - // We expect the 3rd packet (the last unacknowledged packet) to be - // retransmitted. - tlpProbe := testbench.Uint32(uint32(seqNum1) + uint32((numPkts-1)*payloadSize)) - if _, err := conn.Expect(t, testbench.TCP{SeqNum: tlpProbe}, time.Second); err != nil { + // RACK considers transmission times of the packets to mark them lost. + // As the 3rd packet was sent before the retransmitted 1st packet, RACK + // marks it as lost and retransmits it.. + expectedSeqNum := testbench.Uint32(uint32(seqNum1) + uint32((numPkts-1)*payloadSize)) + if _, err := conn.Expect(t, testbench.TCP{SeqNum: expectedSeqNum}, timeout); err != nil { t.Fatalf("expected payload was not received: %s", err) } - diff := time.Now().Sub(lastSent) - if diff < pto { - t.Fatalf("expected payload was received before the probe timeout, got: %v, want: %v", diff, pto) - } closeSACKConnection(t, dut, conn, acceptFd, listenFd) } diff --git a/test/packetimpact/tests/udp_send_recv_dgram_test.go b/test/packetimpact/tests/udp_send_recv_dgram_test.go index 6e45cb143..894d156cf 100644 --- a/test/packetimpact/tests/udp_send_recv_dgram_test.go +++ b/test/packetimpact/tests/udp_send_recv_dgram_test.go @@ -32,6 +32,7 @@ import ( func init() { testbench.Initialize(flag.CommandLine) + testbench.RPCTimeout = 500 * time.Millisecond } type udpConn interface { diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index e43f30ba3..d6658898d 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -993,3 +993,7 @@ syscall_test( syscall_test( test = "//test/syscalls/linux:proc_net_udp_test", ) + +syscall_test( + test = "//test/syscalls/linux:processes_test", +) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 80e2837f8..42fc363a2 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -2346,6 +2346,7 @@ cc_library( deps = [ ":socket_test_util", ":unix_domain_socket_test_util", + "@com_google_absl//absl/time", gtest, "//test/util:test_util", ], @@ -2360,6 +2361,7 @@ cc_library( deps = [ ":socket_test_util", ":unix_domain_socket_test_util", + "@com_google_absl//absl/time", gtest, "//test/util:test_util", ], @@ -2678,6 +2680,7 @@ cc_binary( deps = [ ":socket_test_util", ":unix_domain_socket_test_util", + "@com_google_absl//absl/time", gtest, "//test/util:test_main", "//test/util:test_util", @@ -4160,6 +4163,18 @@ cc_binary( ) cc_binary( + name = "processes_test", + testonly = 1, + srcs = ["processes.cc"], + linkstatic = 1, + deps = [ + "//test/util:capability_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( name = "xattr_test", testonly = 1, srcs = [ diff --git a/test/syscalls/linux/exec_binary.cc b/test/syscalls/linux/exec_binary.cc index 3797fd4c8..b0fb120c6 100644 --- a/test/syscalls/linux/exec_binary.cc +++ b/test/syscalls/linux/exec_binary.cc @@ -951,6 +951,34 @@ TEST(ElfTest, PIEOutOfOrderSegments) { EXPECT_EQ(execve_errno, ENOEXEC); } +TEST(ElfTest, PIEOverflow) { + ElfBinary<64> elf = StandardElf(); + + elf.header.e_type = ET_DYN; + + // Choose vaddr of the first segment so that the end address overflows if the + // segment is mapped with a non-zero offset. + elf.phdrs[1].p_vaddr = 0xfffffffffffff000UL - elf.phdrs[1].p_memsz; + + elf.UpdateOffsets(); + + TempPath file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(elf)); + + pid_t child; + int execve_errno; + auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( + ForkAndExec(file.path(), {file.path()}, {}, &child, &execve_errno)); + if (IsRunningOnGvisor()) { + ASSERT_EQ(execve_errno, EINVAL); + } else { + ASSERT_EQ(execve_errno, 0); + int status; + ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0), + SyscallSucceedsWithValue(child)); + EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGSEGV) << status; + } +} + // Standard dynamically linked binary with an ELF interpreter. TEST(ElfTest, ELFInterpreter) { ElfBinary<64> interpreter = StandardElf(); diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc index f8fbea79e..46f41de50 100644 --- a/test/syscalls/linux/open_create.cc +++ b/test/syscalls/linux/open_create.cc @@ -46,8 +46,10 @@ TEST(CreateTest, ExistingFile) { TEST(CreateTest, CreateAtFile) { auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); auto dirfd = ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_DIRECTORY, 0666)); - EXPECT_THAT(openat(dirfd.get(), "CreateAtFile", O_RDWR | O_CREAT, 0666), + int fd; + EXPECT_THAT(fd = openat(dirfd.get(), "CreateAtFile", O_RDWR | O_CREAT, 0666), SyscallSucceeds()); + EXPECT_THAT(close(fd), SyscallSucceeds()); } TEST(CreateTest, HonorsUmask_NoRandomSave) { diff --git a/test/syscalls/linux/processes.cc b/test/syscalls/linux/processes.cc new file mode 100644 index 000000000..412582515 --- /dev/null +++ b/test/syscalls/linux/processes.cc @@ -0,0 +1,90 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stdint.h> +#include <sys/syscall.h> +#include <unistd.h> + +#include "test/util/capability_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +int testSetPGIDOfZombie(void* arg) { + int p[2]; + + TEST_PCHECK(pipe(p) == 0); + + pid_t pid = fork(); + if (pid == 0) { + pid = fork(); + // Create a second child to repeat one of syzkaller reproducers. + if (pid == 0) { + pid = getpid(); + TEST_PCHECK(setpgid(pid, 0) == 0); + TEST_PCHECK(write(p[1], &pid, sizeof(pid)) == sizeof(pid)); + _exit(0); + } + TEST_PCHECK(pid > 0); + _exit(0); + } + close(p[1]); + TEST_PCHECK(pid > 0); + + // Get PID of the second child. + pid_t cpid; + TEST_PCHECK(read(p[0], &cpid, sizeof(cpid)) == sizeof(cpid)); + + // Wait when both child processes will die. + int c; + TEST_PCHECK(read(p[0], &c, sizeof(c)) == 0); + + // Wait the second child process to collect its zombie. + int status; + TEST_PCHECK(RetryEINTR(waitpid)(cpid, &status, 0) == cpid); + + // Set the child's group. + TEST_PCHECK(setpgid(pid, pid) == 0); + + TEST_PCHECK(RetryEINTR(waitpid)(-pid, &status, 0) == pid); + + TEST_PCHECK(status == 0); + _exit(0); +} + +TEST(Processes, SetPGIDOfZombie) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + + // Fork a test process in a new PID namespace, because it needs to manipulate + // with reparanted processes. + struct clone_arg { + // Reserve some space for clone() to locate arguments and retcode in this + // place. + char stack[128] __attribute__((aligned(16))); + char stack_ptr[0]; + } ca; + pid_t pid; + ASSERT_THAT(pid = clone(testSetPGIDOfZombie, ca.stack_ptr, + CLONE_NEWPID | SIGCHLD, &ca), + SyscallSucceeds()); + + int status; + ASSERT_THAT(RetryEINTR(waitpid)(pid, &status, 0), + SyscallSucceedsWithValue(pid)); + EXPECT_EQ(status, 0); +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/rename.cc b/test/syscalls/linux/rename.cc index 5458f54ad..22c8c19cf 100644 --- a/test/syscalls/linux/rename.cc +++ b/test/syscalls/linux/rename.cc @@ -391,6 +391,39 @@ TEST(RenameTest, FileWithOpenFd) { EXPECT_EQ(absl::string_view(buf, sizeof(buf) - 1), kContents); } +// Tests that calling rename with file path ending with . or .. causes EBUSY. +TEST(RenameTest, PathEndingWithDots) { + TempPath root_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + TempPath dir1 = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root_dir.path())); + TempPath dir2 = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root_dir.path())); + + // Try to move dir1 into dir2 but mess up the paths. + auto dir1Dot = JoinPath(dir1.path(), "."); + auto dir2Dot = JoinPath(dir2.path(), "."); + auto dir1DotDot = JoinPath(dir1.path(), ".."); + auto dir2DotDot = JoinPath(dir2.path(), ".."); + ASSERT_THAT(rename(dir1.path().c_str(), dir2Dot.c_str()), + SyscallFailsWithErrno(EBUSY)); + ASSERT_THAT(rename(dir1.path().c_str(), dir2DotDot.c_str()), + SyscallFailsWithErrno(EBUSY)); + ASSERT_THAT(rename(dir1Dot.c_str(), dir2.path().c_str()), + SyscallFailsWithErrno(EBUSY)); + ASSERT_THAT(rename(dir1DotDot.c_str(), dir2.path().c_str()), + SyscallFailsWithErrno(EBUSY)); +} + +// Calling rename with file path ending with . or .. causes EBUSY in sysfs. +TEST(RenameTest, SysfsPathEndingWithDots) { + // If a non-root user tries to rename inside /sys then we get EPERM. + SKIP_IF(geteuid() != 0); + ASSERT_THAT(rename("/sys/devices/system/cpu/online", "/sys/."), + SyscallFailsWithErrno(EBUSY)); + ASSERT_THAT(rename("/sys/devices/system/cpu/online", "/sys/.."), + SyscallFailsWithErrno(EBUSY)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc index 3924e0001..93b3a94f1 100644 --- a/test/syscalls/linux/sendfile.cc +++ b/test/syscalls/linux/sendfile.cc @@ -551,6 +551,29 @@ TEST(SendFileTest, SendPipeEOF) { SyscallSucceedsWithValue(0)); } +TEST(SendFileTest, SendToFullPipeReturnsEAGAIN) { + // Create and open an empty input file. + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor in_fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); + + // Set up the output pipe. + int fds[2]; + ASSERT_THAT(pipe2(fds, O_NONBLOCK), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + int pipe_size = -1; + ASSERT_THAT(pipe_size = fcntl(wfd.get(), F_GETPIPE_SZ), SyscallSucceeds()); + int data_size = pipe_size * 8; + ASSERT_THAT(ftruncate(in_fd.get(), data_size), SyscallSucceeds()); + + ASSERT_THAT(sendfile(wfd.get(), in_fd.get(), 0, data_size), + SyscallSucceeds()); + EXPECT_THAT(sendfile(wfd.get(), in_fd.get(), 0, data_size), + SyscallFailsWithErrno(EAGAIN)); +} + TEST(SendFileTest, SendPipeBlocks) { // Create temp file. constexpr char kData[] = diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc index de0b8bb11..f70047a09 100644 --- a/test/syscalls/linux/socket_generic.cc +++ b/test/syscalls/linux/socket_generic.cc @@ -379,7 +379,7 @@ TEST_P(AllSocketPairTest, RcvBufSucceeds) { EXPECT_GT(size, 0); } -TEST_P(AllSocketPairTest, SndBufSucceeds) { +TEST_P(AllSocketPairTest, GetSndBufSucceeds) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); int size = 0; socklen_t size_size = sizeof(size); diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index a11147085..344a5a22c 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -526,7 +526,7 @@ void TestListenWhileConnect(const TestParam& param, stopListen(listen_fd); for (auto& client : clients) { - const int kTimeout = 10000; + constexpr int kTimeout = 10000; struct pollfd pfd = { .fd = client.get(), .events = POLLIN, @@ -942,7 +942,7 @@ void setupTimeWaitClose(const TestAddress* listener, // shutdown to trigger TIME_WAIT. ASSERT_THAT(shutdown(active_closefd.get(), SHUT_WR), SyscallSucceeds()); { - const int kTimeout = 10000; + constexpr int kTimeout = 10000; struct pollfd pfd = { .fd = passive_closefd.get(), .events = POLLIN, @@ -1186,7 +1186,7 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) { ASSERT_EQ(addrlen, listener.addr_len); // Wait for accept_fd to process the RST. - const int kTimeout = 10000; + constexpr int kTimeout = 10000; struct pollfd pfd = { .fd = accept_fd.get(), .events = POLLIN, diff --git a/test/syscalls/linux/socket_unix_dgram.cc b/test/syscalls/linux/socket_unix_dgram.cc index af0df4fb4..5b0844493 100644 --- a/test/syscalls/linux/socket_unix_dgram.cc +++ b/test/syscalls/linux/socket_unix_dgram.cc @@ -18,6 +18,8 @@ #include <sys/un.h> #include "gtest/gtest.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" @@ -39,6 +41,39 @@ TEST_P(DgramUnixSocketPairTest, WriteOneSideClosed) { SyscallFailsWithErrno(ECONNREFUSED)); } +TEST_P(DgramUnixSocketPairTest, IncreasedSocketSendBufUnblocksWrites) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + int sock = sockets->first_fd(); + int buf_size = 0; + socklen_t buf_size_len = sizeof(buf_size); + ASSERT_THAT(getsockopt(sock, SOL_SOCKET, SO_SNDBUF, &buf_size, &buf_size_len), + SyscallSucceeds()); + int opts; + ASSERT_THAT(opts = fcntl(sock, F_GETFL), SyscallSucceeds()); + opts |= O_NONBLOCK; + ASSERT_THAT(fcntl(sock, F_SETFL, opts), SyscallSucceeds()); + + std::vector<char> buf(buf_size / 4); + // Write till the socket buffer is full. + while (RetryEINTR(send)(sock, buf.data(), buf.size(), 0) != -1) { + // Sleep to give linux a chance to move data from the send buffer to the + // receive buffer. + absl::SleepFor(absl::Milliseconds(10)); // 10ms. + } + // The last error should have been EWOULDBLOCK. + ASSERT_EQ(errno, EWOULDBLOCK); + + // Now increase the socket send buffer. + buf_size = buf_size * 2; + ASSERT_THAT( + setsockopt(sock, SOL_SOCKET, SO_SNDBUF, &buf_size, sizeof(buf_size)), + SyscallSucceeds()); + + // The send should succeed again. + ASSERT_THAT(RetryEINTR(send)(sock, buf.data(), buf.size(), 0), + SyscallSucceeds()); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_unix_seqpacket.cc b/test/syscalls/linux/socket_unix_seqpacket.cc index 6d03df4d9..eb373373d 100644 --- a/test/syscalls/linux/socket_unix_seqpacket.cc +++ b/test/syscalls/linux/socket_unix_seqpacket.cc @@ -18,6 +18,8 @@ #include <sys/un.h> #include "gtest/gtest.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" @@ -61,6 +63,39 @@ TEST_P(SeqpacketUnixSocketPairTest, Sendto) { SyscallSucceedsWithValue(3)); } +TEST_P(SeqpacketUnixSocketPairTest, IncreasedSocketSendBufUnblocksWrites) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + int sock = sockets->first_fd(); + int buf_size = 0; + socklen_t buf_size_len = sizeof(buf_size); + ASSERT_THAT(getsockopt(sock, SOL_SOCKET, SO_SNDBUF, &buf_size, &buf_size_len), + SyscallSucceeds()); + int opts; + ASSERT_THAT(opts = fcntl(sock, F_GETFL), SyscallSucceeds()); + opts |= O_NONBLOCK; + ASSERT_THAT(fcntl(sock, F_SETFL, opts), SyscallSucceeds()); + + std::vector<char> buf(buf_size / 4); + // Write till the socket buffer is full. + while (RetryEINTR(send)(sock, buf.data(), buf.size(), 0) != -1) { + // Sleep to give linux a chance to move data from the send buffer to the + // receive buffer. + absl::SleepFor(absl::Milliseconds(10)); // 10ms. + } + // The last error should have been EWOULDBLOCK. + ASSERT_EQ(errno, EWOULDBLOCK); + + // Now increase the socket send buffer. + buf_size = buf_size * 2; + ASSERT_THAT( + setsockopt(sock, SOL_SOCKET, SO_SNDBUF, &buf_size, sizeof(buf_size)), + SyscallSucceeds()); + + // The send should succeed again. + ASSERT_THAT(RetryEINTR(send)(sock, buf.data(), buf.size(), 0), + SyscallSucceeds()); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_unix_stream.cc b/test/syscalls/linux/socket_unix_stream.cc index ad9c4bf37..3ff810914 100644 --- a/test/syscalls/linux/socket_unix_stream.cc +++ b/test/syscalls/linux/socket_unix_stream.cc @@ -17,6 +17,8 @@ #include <sys/un.h> #include "gtest/gtest.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" @@ -134,6 +136,84 @@ TEST_P(StreamUnixSocketPairTest, GetSocketAcceptConn) { EXPECT_EQ(got, 0); } +TEST_P(StreamUnixSocketPairTest, SetSocketSendBuf) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + auto s = sockets->first_fd(); + int max = 0; + int min = 0; + { + // Discover maxmimum buffer size by setting to a really large value. + constexpr int kRcvBufSz = INT_MAX; + ASSERT_THAT( + setsockopt(s, SOL_SOCKET, SO_SNDBUF, &kRcvBufSz, sizeof(kRcvBufSz)), + SyscallSucceeds()); + + max = 0; + socklen_t max_len = sizeof(max); + ASSERT_THAT(getsockopt(s, SOL_SOCKET, SO_SNDBUF, &max, &max_len), + SyscallSucceeds()); + } + + { + // Discover minimum buffer size by setting it to zero. + constexpr int kRcvBufSz = 0; + ASSERT_THAT( + setsockopt(s, SOL_SOCKET, SO_SNDBUF, &kRcvBufSz, sizeof(kRcvBufSz)), + SyscallSucceeds()); + + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s, SOL_SOCKET, SO_SNDBUF, &min, &min_len), + SyscallSucceeds()); + } + + int quarter_sz = min + (max - min) / 4; + ASSERT_THAT( + setsockopt(s, SOL_SOCKET, SO_SNDBUF, &quarter_sz, sizeof(quarter_sz)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s, SOL_SOCKET, SO_SNDBUF, &val, &val_len), + SyscallSucceeds()); + + // Linux doubles the value set by SO_SNDBUF/SO_SNDBUF. + quarter_sz *= 2; + ASSERT_EQ(quarter_sz, val); +} + +TEST_P(StreamUnixSocketPairTest, IncreasedSocketSendBufUnblocksWrites) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + int sock = sockets->first_fd(); + int buf_size = 0; + socklen_t buf_size_len = sizeof(buf_size); + ASSERT_THAT(getsockopt(sock, SOL_SOCKET, SO_SNDBUF, &buf_size, &buf_size_len), + SyscallSucceeds()); + int opts; + ASSERT_THAT(opts = fcntl(sock, F_GETFL), SyscallSucceeds()); + opts |= O_NONBLOCK; + ASSERT_THAT(fcntl(sock, F_SETFL, opts), SyscallSucceeds()); + + std::vector<char> buf(buf_size / 4); + // Write till the socket buffer is full. + while (RetryEINTR(send)(sock, buf.data(), buf.size(), 0) != -1) { + // Sleep to give linux a chance to move data from the send buffer to the + // receive buffer. + absl::SleepFor(absl::Milliseconds(10)); // 10ms. + } + // The last error should have been EWOULDBLOCK. + ASSERT_EQ(errno, EWOULDBLOCK); + + // Now increase the socket send buffer. + buf_size = buf_size * 2; + ASSERT_THAT( + setsockopt(sock, SOL_SOCKET, SO_SNDBUF, &buf_size, sizeof(buf_size)), + SyscallSucceeds()); + + // The send should succeed again. + ASSERT_THAT(RetryEINTR(send)(sock, buf.data(), buf.size(), 0), + SyscallSucceeds()); +} + INSTANTIATE_TEST_SUITE_P( AllUnixDomainSockets, StreamUnixSocketPairTest, ::testing::ValuesIn(IncludeReversals(VecCat<SocketPairKind>( diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc index 650f12350..50f589708 100644 --- a/test/syscalls/linux/udp_socket.cc +++ b/test/syscalls/linux/udp_socket.cc @@ -2061,11 +2061,92 @@ TEST_P(UdpSocketTest, SendToZeroPort) { SyscallSucceedsWithValue(sizeof(buf))); } +TEST_P(UdpSocketTest, ConnectToZeroPortUnbound) { + struct sockaddr_storage addr = InetLoopbackAddr(); + SetPort(&addr, 0); + ASSERT_THAT( + connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen_), + SyscallSucceeds()); +} + +TEST_P(UdpSocketTest, ConnectToZeroPortBound) { + struct sockaddr_storage addr = InetLoopbackAddr(); + ASSERT_NO_ERRNO( + BindSocket(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr))); + + SetPort(&addr, 0); + ASSERT_THAT( + connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen_), + SyscallSucceeds()); + socklen_t len = sizeof(sockaddr_storage); + ASSERT_THAT( + getsockname(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr), &len), + SyscallSucceeds()); + ASSERT_EQ(len, addrlen_); +} + +TEST_P(UdpSocketTest, ConnectToZeroPortConnected) { + struct sockaddr_storage addr = InetLoopbackAddr(); + ASSERT_NO_ERRNO( + BindSocket(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr))); + + // Connect to an address with non-zero port should succeed. + ASSERT_THAT( + connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen_), + SyscallSucceeds()); + sockaddr_storage peername; + socklen_t peerlen = sizeof(peername); + ASSERT_THAT( + getpeername(sock_.get(), reinterpret_cast<struct sockaddr*>(&peername), + &peerlen), + SyscallSucceeds()); + ASSERT_EQ(peerlen, addrlen_); + ASSERT_EQ(memcmp(&peername, &addr, addrlen_), 0); + + // However connect() to an address with port 0 will make the following + // getpeername() fail. + SetPort(&addr, 0); + ASSERT_THAT( + connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen_), + SyscallSucceeds()); + ASSERT_THAT( + getpeername(sock_.get(), reinterpret_cast<struct sockaddr*>(&peername), + &peerlen), + SyscallFailsWithErrno(ENOTCONN)); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, UdpSocketTest, ::testing::Values(AddressFamily::kIpv4, AddressFamily::kIpv6, AddressFamily::kDualStack)); +TEST(UdpInet6SocketTest, ConnectInet4Sockaddr) { + // glibc getaddrinfo expects the invariant expressed by this test to be held. + const sockaddr_in connect_sockaddr = { + .sin_family = AF_INET, .sin_addr = {.s_addr = htonl(INADDR_LOOPBACK)}}; + auto sock_ = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP)); + ASSERT_THAT( + connect(sock_.get(), + reinterpret_cast<const struct sockaddr*>(&connect_sockaddr), + sizeof(sockaddr_in)), + SyscallSucceeds()); + socklen_t len; + sockaddr_storage sockname; + ASSERT_THAT(getsockname(sock_.get(), + reinterpret_cast<struct sockaddr*>(&sockname), &len), + SyscallSucceeds()); + ASSERT_EQ(sockname.ss_family, AF_INET6); + ASSERT_EQ(len, sizeof(sockaddr_in6)); + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&sockname); + char addr_buf[INET6_ADDRSTRLEN]; + const char* addr; + ASSERT_NE(addr = inet_ntop(sockname.ss_family, &sockname, addr_buf, + sizeof(addr_buf)), + nullptr); + ASSERT_TRUE(IN6_IS_ADDR_V4MAPPED(sin6->sin6_addr.s6_addr)) << addr; +} + } // namespace } // namespace testing diff --git a/tools/bazel_gazelle_noise.patch b/tools/bazel_gazelle_noise.patch deleted file mode 100644 index e35f38933..000000000 --- a/tools/bazel_gazelle_noise.patch +++ /dev/null @@ -1,24 +0,0 @@ -diff -r -u2 a/language/go/resolve.go b/language/go/resolve.go ---- a/language/go/resolve.go 2020-10-02 14:22:18.000000000 -0700 -+++ b/language/go/resolve.go 2020-11-17 19:40:59.770648029 -0800 -@@ -20,5 +20,4 @@ - "fmt" - "go/build" -- "log" - "path" - "regexp" -@@ -80,5 +79,5 @@ - resolve = ResolveGo - } -- deps, errs := imports.Map(func(imp string) (string, error) { -+ deps, _ := imports.Map(func(imp string) (string, error) { - l, err := resolve(c, ix, rc, imp, from) - if err == skipImportError { -@@ -95,7 +94,4 @@ - return l.String(), nil - }) -- for _, err := range errs { -- log.Print(err) -- } - if !deps.IsEmpty() { - if r.Kind() == "go_proto_library" { diff --git a/tools/go_marshal/README.md b/tools/go_marshal/README.md index d8045c295..eddba0c21 100644 --- a/tools/go_marshal/README.md +++ b/tools/go_marshal/README.md @@ -98,6 +98,18 @@ for embedded structs that are not aligned. Because of this, it's generally best to avoid using `marshal:"unaligned"` and insert explicit padding fields instead. +## Working with dynamically sized structs + +While `go_marshal` seamlessly supports statically sized structs (which most ABI +structs are), it can also used for other uses cases where marshalling is +required. There is some provision to partially support dynamically sized structs +that may not be ABI structs. A user can define a dynamic struct and define +`SizeBytes()`, `MarshalBytes(dst)` and `UnmarshalBytes(src)` for it. Then user +can then add a comment above the struct like `// +marshal dynamic` while will +make `go_marshal` autogenerate the remaining methods required to complete the +`Marshallable` interface. This feature is currently only available for structs +and can not be used alongside the Slice API. + ## Modifying the `go_marshal` Tool The following are some guidelines for modifying the `go_marshal` tool: diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index fa642c88a..abd6f69ea 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -211,9 +211,10 @@ type sliceAPI struct { // marshallableType carries information about a type marked with the '+marshal' // directive. type marshallableType struct { - spec *ast.TypeSpec - slice *sliceAPI - recv string + spec *ast.TypeSpec + slice *sliceAPI + recv string + dynamic bool } func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.TypeSpec) *marshallableType { @@ -248,6 +249,9 @@ func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.Ty } continue + } else if tag == "dynamic" { + mt.dynamic = true + continue } unhandledTags = append(unhandledTags, tag) @@ -379,23 +383,38 @@ func (g *Generator) generateOne(t *marshallableType, fset *token.FileSet) *inter i := newInterfaceGenerator(t.spec, t.recv, fset) switch ty := t.spec.Type.(type) { case *ast.StructType: + if t.dynamic { + // Don't validate because this type is dynamically sized and probably + // contains some funky slices which the validation does not allow. + i.emitMarshallableForStruct(ty, t.dynamic) + if t.slice != nil { + abortAt(fset.Position(t.slice.comment.Slash), "Slice API is not supported for dynamic types because it assumes that each slice element is statically sized.") + } + break + } i.validateStruct(t.spec, ty) - i.emitMarshallableForStruct(ty) + i.emitMarshallableForStruct(ty, t.dynamic) if t.slice != nil { i.emitMarshallableSliceForStruct(ty, t.slice) } case *ast.Ident: i.validatePrimitiveNewtype(ty) + if t.dynamic { + abortAt(fset.Position(t.slice.comment.Slash), "Primitive type marked as '+marshal dynamic', but primitive types can not be dynamic.") + } i.emitMarshallableForPrimitiveNewtype(ty) if t.slice != nil { i.emitMarshallableSliceForPrimitiveNewtype(ty, t.slice) } case *ast.ArrayType: i.validateArrayNewtype(t.spec.Name, ty) + if t.dynamic { + abortAt(fset.Position(t.slice.comment.Slash), "Marking array types as `dynamic` is currently not supported.") + } // After validate, we can safely call arrayLen. i.emitMarshallableForArrayNewtype(t.spec.Name, ty, ty.Elt.(*ast.Ident)) if t.slice != nil { - abortAt(fset.Position(t.slice.comment.Slash), fmt.Sprintf("Array type marked as '+marshal slice:...', but this is not supported. Perhaps fold one of the dimensions?")) + abortAt(fset.Position(t.slice.comment.Slash), "Array type marked as '+marshal slice:...', but this is not supported. Perhaps fold one of the dimensions?") } default: // This should've been filtered out by collectMarshallabeTypes. @@ -408,7 +427,7 @@ func (g *Generator) generateOne(t *marshallableType, fset *token.FileSet) *inter // implementations type t. func (g *Generator) generateOneTestSuite(t *marshallableType) *testGenerator { i := newTestGenerator(t.spec, t.recv) - i.emitTests(t.slice) + i.emitTests(t.slice, t.dynamic) return i } diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go index 5f6306b8f..f98e41ed7 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_struct.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go @@ -69,7 +69,11 @@ func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType }) } -func (g *interfaceGenerator) isStructPacked(st *ast.StructType) bool { +func (g *interfaceGenerator) isStructPacked(st *ast.StructType, isDynamic bool) bool { + if isDynamic { + // Dynamic types are not packed because a slice header might be present. + return false + } packed := true forEachStructField(st, func(f *ast.Field) { if f.Tag != nil { @@ -85,165 +89,17 @@ func (g *interfaceGenerator) isStructPacked(st *ast.StructType) bool { return packed } -func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { - thisPacked := g.isStructPacked(st) - - g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") - g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) - g.inIndent(func() { - primitiveSize := 0 - var dynamicSizeTerms []string - - forEachStructField(st, fieldDispatcher{ - primitive: func(_, t *ast.Ident) { - if size, dynamic := g.scalarSize(t); !dynamic { - primitiveSize += size - } else { - g.recordUsedMarshallable(t.Name) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name)) - } - }, - selector: func(_, tX, tSel *ast.Ident) { - tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name) - g.recordUsedImport(tX.Name) - g.recordUsedMarshallable(tName) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName)) - }, - array: func(_ *ast.Ident, a *ast.ArrayType, t *ast.Ident) { - lenExpr := g.arrayLenExpr(a) - if size, dynamic := g.scalarSize(t); !dynamic { - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%d*%s", size, lenExpr)) - } else { - g.recordUsedMarshallable(t.Name) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%s", t.Name, lenExpr)) - } - }, - }.dispatch) - g.emit("return %d", primitiveSize) - if len(dynamicSizeTerms) > 0 { - g.incIndent() - } - { - for _, d := range dynamicSizeTerms { - g.emitNoIndent(" +\n") - g.emit(d) - } - } - if len(dynamicSizeTerms) > 0 { - g.decIndent() - } - }) - g.emit("\n}\n\n") +func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType, isDynamic bool) { + thisPacked := g.isStructPacked(st, isDynamic) - g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") - g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - forEachStructField(st, fieldDispatcher{ - primitive: func(n, t *ast.Ident) { - if n.Name == "_" { - g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("dst", len) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can reference here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name) - } - return - } - g.marshalScalar(g.fieldAccessor(n), t.Name, "dst") - }, - selector: func(n, tX, tSel *ast.Ident) { - if n.Name == "_" { - g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", tX.Name, tSel.Name) - g.emit("dst = dst[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name) - return - } - g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") - }, - array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) { - lenExpr := g.arrayLenExpr(a) - if n.Name == "_" { - g.emit("// Padding: dst[:sizeof(%s)*%s] ~= [%s]%s{0}\n", t.Name, lenExpr, lenExpr, t.Name) - if size, dynamic := g.scalarSize(t); !dynamic { - g.emit("dst = dst[%d*(%s):]\n", size, lenExpr) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can reference here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("dst = dst[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr) - } - return - } - - g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) - g.inIndent(func() { - g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst") - }) - g.emit("}\n") - }, - }.dispatch) - }) - g.emit("}\n\n") - - g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") - g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - forEachStructField(st, fieldDispatcher{ - primitive: func(n, t *ast.Ident) { - if n.Name == "_" { - g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("src", len) - } else { - // We don't have an instance of the dynamic type we can - // reference here (since the version in this struct is - // anonymous). Use a typed nil pointer to call - // SizeBytes() instead. - g.shiftDynamic("src", fmt.Sprintf("(*%s)(nil)", t.Name)) - g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name)) - } - return - } - g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src") - }, - selector: func(n, tX, tSel *ast.Ident) { - if n.Name == "_" { - g.emit("// Padding: %s ~= src[:sizeof(%s.%s)]\n", g.fieldAccessor(n), tX.Name, tSel.Name) - g.emit("src = src[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name) - g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s.%s)(nil)", tX.Name, tSel.Name)) - return - } - g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") - }, - array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) { - lenExpr := g.arrayLenExpr(a) - if n.Name == "_" { - g.emit("// Padding: ~ copy([%s]%s(%s), src[:sizeof(%s)*%s])\n", lenExpr, t.Name, g.fieldAccessor(n), t.Name, lenExpr) - if size, dynamic := g.scalarSize(t); !dynamic { - g.emit("src = src[%d*(%s):]\n", size, lenExpr) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can referece here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("src = src[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr) - } - return - } - - g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) - g.inIndent(func() { - g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src") - }) - g.emit("}\n") - }, - }.dispatch) - }) - g.emit("}\n\n") + // Dynamic types are supposed to manually implement SizeBytes, MarshalBytes + // and UnmarshalBytes. The rest of the methos are autogenerated and depend on + // the implementation of these three. + if !isDynamic { + g.emitSizeBytesForStruct(st) + g.emitMarshalBytesForStruct(st) + g.emitUnmarshalBytesForStruct(st) + } g.emit("// Packed implements marshal.Marshallable.Packed.\n") g.emit("//go:nosplit\n") @@ -428,8 +284,171 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("}\n\n") } +func (g *interfaceGenerator) emitSizeBytesForStruct(st *ast.StructType) { + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) + g.inIndent(func() { + primitiveSize := 0 + var dynamicSizeTerms []string + + forEachStructField(st, fieldDispatcher{ + primitive: func(_, t *ast.Ident) { + if size, dynamic := g.scalarSize(t); !dynamic { + primitiveSize += size + } else { + g.recordUsedMarshallable(t.Name) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name)) + } + }, + selector: func(_, tX, tSel *ast.Ident) { + tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name) + g.recordUsedImport(tX.Name) + g.recordUsedMarshallable(tName) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName)) + }, + array: func(_ *ast.Ident, a *ast.ArrayType, t *ast.Ident) { + lenExpr := g.arrayLenExpr(a) + if size, dynamic := g.scalarSize(t); !dynamic { + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%d*%s", size, lenExpr)) + } else { + g.recordUsedMarshallable(t.Name) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%s", t.Name, lenExpr)) + } + }, + }.dispatch) + g.emit("return %d", primitiveSize) + if len(dynamicSizeTerms) > 0 { + g.incIndent() + } + { + for _, d := range dynamicSizeTerms { + g.emitNoIndent(" +\n") + g.emit(d) + } + } + if len(dynamicSizeTerms) > 0 { + g.decIndent() + } + }) + g.emit("\n}\n\n") +} + +func (g *interfaceGenerator) emitMarshalBytesForStruct(st *ast.StructType) { + g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") + g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + forEachStructField(st, fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("dst", len) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can reference here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name) + } + return + } + g.marshalScalar(g.fieldAccessor(n), t.Name, "dst") + }, + selector: func(n, tX, tSel *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", tX.Name, tSel.Name) + g.emit("dst = dst[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name) + return + } + g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") + }, + array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) { + lenExpr := g.arrayLenExpr(a) + if n.Name == "_" { + g.emit("// Padding: dst[:sizeof(%s)*%s] ~= [%s]%s{0}\n", t.Name, lenExpr, lenExpr, t.Name) + if size, dynamic := g.scalarSize(t); !dynamic { + g.emit("dst = dst[%d*(%s):]\n", size, lenExpr) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can reference here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("dst = dst[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr) + } + return + } + + g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) + g.inIndent(func() { + g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst") + }) + g.emit("}\n") + }, + }.dispatch) + }) + g.emit("}\n\n") +} + +func (g *interfaceGenerator) emitUnmarshalBytesForStruct(st *ast.StructType) { + g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") + g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + forEachStructField(st, fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("src", len) + } else { + // We don't have an instance of the dynamic type we can + // reference here (since the version in this struct is + // anonymous). Use a typed nil pointer to call + // SizeBytes() instead. + g.shiftDynamic("src", fmt.Sprintf("(*%s)(nil)", t.Name)) + g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name)) + } + return + } + g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src") + }, + selector: func(n, tX, tSel *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: %s ~= src[:sizeof(%s.%s)]\n", g.fieldAccessor(n), tX.Name, tSel.Name) + g.emit("src = src[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name) + g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s.%s)(nil)", tX.Name, tSel.Name)) + return + } + g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") + }, + array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) { + lenExpr := g.arrayLenExpr(a) + if n.Name == "_" { + g.emit("// Padding: ~ copy([%s]%s(%s), src[:sizeof(%s)*%s])\n", lenExpr, t.Name, g.fieldAccessor(n), t.Name, lenExpr) + if size, dynamic := g.scalarSize(t); !dynamic { + g.emit("src = src[%d*(%s):]\n", size, lenExpr) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can referece here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("src = src[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr) + } + return + } + + g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) + g.inIndent(func() { + g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src") + }) + g.emit("}\n") + }, + }.dispatch) + }) + g.emit("}\n\n") +} + func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, slice *sliceAPI) { - thisPacked := g.isStructPacked(st) + thisPacked := g.isStructPacked(st, false /* isDynamic */) if slice.inner { abortAt(g.f.Position(slice.comment.Slash), fmt.Sprintf("The ':inner' argument to '+marshal slice:%s:inner' is only applicable to newtypes on primitives. Remove it from this struct declaration.", slice.ident)) diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go index 6cf00843f..ca3e15c16 100644 --- a/tools/go_marshal/gomarshal/generator_tests.go +++ b/tools/go_marshal/gomarshal/generator_tests.go @@ -216,12 +216,16 @@ func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() { }) } -func (g *testGenerator) emitTests(slice *sliceAPI) { +func (g *testGenerator) emitTests(slice *sliceAPI, isDynamic bool) { g.emitTestNonZeroSize() g.emitTestSuspectAlignment() - g.emitTestMarshalUnmarshalPreservesData() - g.emitTestWriteToUnmarshalPreservesData() - g.emitTestSizeBytesOnTypedNilPtr() + if !isDynamic { + // Do not test these for dynamic structs because they violate some + // assumptions that these tests make. + g.emitTestMarshalUnmarshalPreservesData() + g.emitTestWriteToUnmarshalPreservesData() + g.emitTestSizeBytesOnTypedNilPtr() + } if slice != nil { g.emitTestMarshalUnmarshalSlicePreservesData(slice) diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD index 4b27773c2..cb2d4e6e3 100644 --- a/tools/go_marshal/test/BUILD +++ b/tools/go_marshal/test/BUILD @@ -26,7 +26,10 @@ go_library( srcs = ["test.go"], marshal = True, visibility = ["//tools/go_marshal/test:__subpackages__"], - deps = ["//tools/go_marshal/test/external"], + deps = [ + "//pkg/marshal/primitive", + "//tools/go_marshal/test/external", + ], ) go_test( @@ -36,6 +39,7 @@ go_test( deps = [ ":test", "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/syserror", "//pkg/usermem", "//tools/go_marshal/analysis", diff --git a/tools/go_marshal/test/marshal_test.go b/tools/go_marshal/test/marshal_test.go index a00f9a684..b0091dc64 100644 --- a/tools/go_marshal/test/marshal_test.go +++ b/tools/go_marshal/test/marshal_test.go @@ -28,6 +28,7 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/tools/go_marshal/analysis" @@ -513,3 +514,21 @@ func TestLimitedSliceMarshalling(t *testing.T) { }) } } + +func TestDynamicType(t *testing.T) { + t12 := test.Type12Dynamic{ + X: 32, + Y: []primitive.Int64{5, 6, 7}, + } + + var m marshal.Marshallable + m = &t12 // Ensure that all methods were generated. + b := make([]byte, m.SizeBytes()) + m.MarshalBytes(b) + + var res test.Type12Dynamic + res.UnmarshalBytes(b) + if !reflect.DeepEqual(t12, res) { + t.Errorf("dynamic type is not same after marshalling and unmarshalling: before = %+v, after = %+v", t12, res) + } +} diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go index e7e3ed74a..b8eb989d9 100644 --- a/tools/go_marshal/test/test.go +++ b/tools/go_marshal/test/test.go @@ -16,6 +16,8 @@ package test import ( + "gvisor.dev/gvisor/pkg/marshal/primitive" + // We're intentionally using a package name alias here even though it's not // necessary to test the code generator's ability to handle package aliases. ex "gvisor.dev/gvisor/tools/go_marshal/test/external" @@ -198,3 +200,36 @@ type Type11 struct { ex.External y int64 } + +// Type12Dynamic is a dynamically sized struct which depends on the autogenerator +// to generate some Marshallable methods for it. +// +// +marshal dynamic +type Type12Dynamic struct { + X primitive.Int64 + Y []primitive.Int64 +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (t *Type12Dynamic) SizeBytes() int { + return (len(t.Y) * 8) + t.X.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (t *Type12Dynamic) MarshalBytes(dst []byte) { + t.X.MarshalBytes(dst) + dst = dst[t.X.SizeBytes():] + for i, x := range t.Y { + x.MarshalBytes(dst[i*8 : (i+1)*8]) + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (t *Type12Dynamic) UnmarshalBytes(src []byte) { + t.X.UnmarshalBytes(src) + for i := t.X.SizeBytes(); i < len(src); i += 8 { + var x primitive.Int64 + x.UnmarshalBytes(src[i:]) + t.Y = append(t.Y, x) + } +} |