diff options
author | Ian Lewis <ianmlewis@gmail.com> | 2020-08-17 21:44:31 -0400 |
---|---|---|
committer | Ian Lewis <ianmlewis@gmail.com> | 2020-08-17 21:44:31 -0400 |
commit | ac324f646ee3cb7955b0b45a7453aeb9671cbdf1 (patch) | |
tree | 0cbc5018e8807421d701d190dc20525726c7ca76 /test | |
parent | 352ae1022ce19de28fc72e034cc469872ad79d06 (diff) | |
parent | 6d0c5803d557d453f15ac6f683697eeb46dab680 (diff) |
Merge branch 'master' into ip-forwarding
- Merges aleksej-paschenko's with HEAD
- Adds vfs2 support for ip_forward
Diffstat (limited to 'test')
444 files changed, 43025 insertions, 6382 deletions
diff --git a/test/BUILD b/test/BUILD index 01fa01f2e..34b950644 100644 --- a/test/BUILD +++ b/test/BUILD @@ -1,44 +1 @@ -package(licenses = ["notice"]) # Apache 2.0 - -# We need to define a bazel platform and toolchain to specify dockerPrivileged -# and dockerRunAsRoot options, they are required to run tests on the RBE -# cluster in Kokoro. -alias( - name = "rbe_ubuntu1604", - actual = ":rbe_ubuntu1604_r346485", -) - -platform( - name = "rbe_ubuntu1604_r346485", - constraint_values = [ - "@bazel_tools//platforms:x86_64", - "@bazel_tools//platforms:linux", - "@bazel_tools//tools/cpp:clang", - "@bazel_toolchains//constraints:xenial", - "@bazel_toolchains//constraints/sanitizers:support_msan", - ], - remote_execution_properties = """ - properties: { - name: "container-image" - value:"docker://gcr.io/cloud-marketplace/google/rbe-ubuntu16-04@sha256:69c9f1652941d64a46f6f7358a44c1718f25caa5cb1ced4a58ccc5281cd183b5" - } - properties: { - name: "dockerAddCapabilities" - value: "SYS_ADMIN" - } - properties: { - name: "dockerPrivileged" - value: "true" - } - """, -) - -toolchain( - name = "cc-toolchain-clang-x86_64-default", - exec_compatible_with = [ - ], - target_compatible_with = [ - ], - toolchain = "@bazel_toolchains//configs/ubuntu16_04_clang/9.0.0/bazel_0.28.0/cc:cc-compiler-k8", - toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", -) +package(licenses = ["notice"]) diff --git a/test/README.md b/test/README.md index 97fe7ea04..02bbf42ff 100644 --- a/test/README.md +++ b/test/README.md @@ -24,11 +24,11 @@ also used to run these tests in `kokoro`. To run image and integration tests, run: -`./scripts/docker_test.sh` +`./scripts/docker_tests.sh` To run root tests, run: -`./scripts/root_test.sh` +`./scripts/root_tests.sh` There are a few other interesting variations for image and integration tests: diff --git a/test/benchmarks/README.md b/test/benchmarks/README.md new file mode 100644 index 000000000..d1bbabf6f --- /dev/null +++ b/test/benchmarks/README.md @@ -0,0 +1,157 @@ +# Benchmark tools + +This package and subpackages are for running macro benchmarks on `runsc`. They +are meant to replace the previous //benchmarks benchmark-tools written in +python. + +Benchmarks are meant to look like regular golang benchmarks using the testing.B +library. + +## Setup + +To run benchmarks you will need: + +* Docker installed (17.09.0 or greater). + +The easiest way to setup runsc for running benchmarks is to use the make file. +From the root directory: + +* Download images: `make load-all-images` +* Install runsc suitable for benchmarking, which should probably not have + strace or debug logs enabled. For example:`make configure RUNTIME=myrunsc + ARGS=--platform=kvm`. +* Restart docker: `sudo service docker restart` + +You should now have a runtime with the following options configured in +`/etc/docker/daemon.json` + +``` +"myrunsc": { + "path": "/tmp/myrunsc/runsc", + "runtimeArgs": [ + "--debug-log", + "/tmp/bench/logs/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND%", + "--platform=kvm" + ] + }, + +``` + +This runtime has been configured with a debugging off and strace logs off and is +using kvm for demonstration. + +## Running benchmarks + +Given the runtime above runtime `myrunsc`, run benchmarks with the following: + +``` +make sudo TARGETS=//path/to:target ARGS="--runtime=myrunsc -test.v \ + -test.bench=." OPTIONS="-c opt +``` + +For example, to run only the Iperf tests: + +``` +make sudo TARGETS=//test/benchmarks/network:network_test \ + ARGS="--runtime=myrunsc -test.v -test.bench=Iperf" OPTIONS="-c opt" +``` + +Benchmarks are run with root as some benchmarks require root privileges to do +things like drop caches. + +## Writing benchmarks + +Benchmarks consist of docker images as Dockerfiles and golang testing.B +benchmarks. + +### Dockerfiles: + +* Are stored at //images. +* New Dockerfiles go in an appropriately named directory at + `//images/benchmarks/my-cool-dockerfile`. +* Dockerfiles for benchmarks should: + * Use explicitly versioned packages. + * Not use ENV and CMD statements...it is easy to add these in the API. +* Note: A common pattern for getting access to a tmpfs mount is to copy files + there after container start. See: //test/benchmarks/build/bazel_test.go. You + can also make your own with `RunOpts.Mounts`. + +### testing.B packages + +In general, benchmarks should look like this: + +```golang + +var h harness.Harness + +func BenchmarkMyCoolOne(b *testing.B) { + machine, err := h.GetMachine() + // check err + defer machine.CleanUp() + + ctx := context.Background() + container := machine.GetContainer(ctx, b) + defer container.CleanUp(ctx) + + b.ResetTimer() + + //Respect b.N. + for i := 0; i < b.N; i++ { + out, err := container.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/my-cool-image", + Env: []string{"MY_VAR=awesome"}, + other options...see dockerutil + }, "sh", "-c", "echo MY_VAR") + //check err + b.StopTimer() + + // Do parsing and reporting outside of the timer. + number := parseMyMetric(out) + b.ReportMetric(number, "my-cool-custom-metric") + + b.StartTimer() + } +} + +func TestMain(m *testing.M) { + h.Init() + os.Exit(m.Run()) +} +``` + +Some notes on the above: + +* The harness is initiated in the TestMain method and made global to test + module. The harness will handle any presetup that needs to happen with + flags, remote virtual machines (eventually), and other services. +* Respect `b.N` in that users of the benchmark may want to "run for an hour" + or something of the sort. +* Use the `b.ReportMetric()` method to report custom metrics. +* Set the timer if time is useful for reporting. There isn't a way to turn off + default metrics in testing.B (B/op, allocs/op, ns/op). +* Take a look at dockerutil at //pkg/test/dockerutil to see all methods + available from containers. The API is based on the "official" + [docker API for golang](https://pkg.go.dev/mod/github.com/docker/docker). +* `harness.GetMachine()` marks how many machines this tests needs. If you have + a client and server and to mark them as multiple machines, call + `harness.GetMachine()` twice. + +## Profiling + +For profiling, the runtime is required to have the `--profile` flag enabled. +This flag loosens seccomp filters so that the runtime can write profile data to +disk. This configuration is not recommended for production. + +* Install runsc with the `--profile` flag: `make configure RUNTIME=myrunsc + ARGS="--profile --platform=kvm --vfs2"`. The kvm and vfs2 flags are not + required, but are included for demonstration. +* Restart docker: `sudo service docker restart` + +To run and generate CPU profiles fs_test test run: + +``` +make sudo TARGETS=//test/benchmarks/fs:fs_test \ + ARGS="--runtime=myrunsc -test.v -test.bench=. --pprof-cpu" OPTIONS="-c opt" +``` + +Profiles would be at: `/tmp/profile/myrunsc/CONTAINERNAME/cpu.pprof` diff --git a/test/benchmarks/base/BUILD b/test/benchmarks/base/BUILD new file mode 100644 index 000000000..32c139204 --- /dev/null +++ b/test/benchmarks/base/BUILD @@ -0,0 +1,34 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "base", + testonly = 1, + srcs = [ + "base.go", + ], + deps = ["//test/benchmarks/harness"], +) + +go_test( + name = "base_test", + size = "large", + srcs = [ + "size_test.go", + "startup_test.go", + "sysbench_test.go", + ], + library = ":base", + tags = [ + # Requires docker and runsc to be configured before test runs. + "manual", + "local", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/test/dockerutil", + "//test/benchmarks/harness", + "//test/benchmarks/tools", + ], +) diff --git a/test/root/testdata/sandbox.go b/test/benchmarks/base/base.go index 0db210370..7bac52ff1 100644 --- a/test/root/testdata/sandbox.go +++ b/test/benchmarks/base/base.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,19 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -package testdata +// Package base holds base performance benchmarks. +package base -// Sandbox is a default JSON config for a sandbox. -const Sandbox = ` -{ - "metadata": { - "name": "default-sandbox", - "namespace": "default", - "attempt": 1, - "uid": "hdishd83djaidwnduwk28bcsb" - }, - "linux": { - }, - "log_directory": "/tmp" +import ( + "os" + "testing" + + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +var testHarness harness.Harness + +// TestMain is the main method for package network. +func TestMain(m *testing.M) { + testHarness.Init() + os.Exit(m.Run()) } -` diff --git a/test/benchmarks/base/size_test.go b/test/benchmarks/base/size_test.go new file mode 100644 index 000000000..3c1364faf --- /dev/null +++ b/test/benchmarks/base/size_test.go @@ -0,0 +1,220 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package base + +import ( + "context" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" + "gvisor.dev/gvisor/test/benchmarks/tools" +) + +// BenchmarkSizeEmpty creates N empty containers and reads memory usage from +// /proc/meminfo. +func BenchmarkSizeEmpty(b *testing.B) { + machine, err := testHarness.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer machine.CleanUp() + meminfo := tools.Meminfo{} + ctx := context.Background() + containers := make([]*dockerutil.Container, 0, b.N) + + // DropCaches before the test. + harness.DropCaches(machine) + + // Check available memory on 'machine'. + cmd, args := meminfo.MakeCmd() + before, err := machine.RunCommand(cmd, args...) + if err != nil { + b.Fatalf("failed to get meminfo: %v", err) + } + + // Make N containers. + for i := 0; i < b.N; i++ { + container := machine.GetContainer(ctx, b) + containers = append(containers, container) + if err := container.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/alpine", + }, "sh", "-c", "echo Hello && sleep 1000"); err != nil { + cleanUpContainers(ctx, containers) + b.Fatalf("failed to run container: %v", err) + } + if _, err := container.WaitForOutputSubmatch(ctx, "Hello", 5*time.Second); err != nil { + cleanUpContainers(ctx, containers) + b.Fatalf("failed to read container output: %v", err) + } + } + + // Drop caches again before second measurement. + harness.DropCaches(machine) + + // Check available memory after containers are up. + after, err := machine.RunCommand(cmd, args...) + cleanUpContainers(ctx, containers) + if err != nil { + b.Fatalf("failed to get meminfo: %v", err) + } + meminfo.Report(b, before, after) +} + +// BenchmarkSizeNginx starts N containers running Nginx, checks that they're +// serving, and checks memory used based on /proc/meminfo. +func BenchmarkSizeNginx(b *testing.B) { + machine, err := testHarness.GetMachine() + if err != nil { + b.Fatalf("failed to get machine with: %v", err) + } + defer machine.CleanUp() + + // DropCaches for the first measurement. + harness.DropCaches(machine) + + // Measure MemAvailable before creating containers. + meminfo := tools.Meminfo{} + cmd, args := meminfo.MakeCmd() + before, err := machine.RunCommand(cmd, args...) + if err != nil { + b.Fatalf("failed to run meminfo command: %v", err) + } + + // Make N Nginx containers. + ctx := context.Background() + runOpts := dockerutil.RunOpts{ + Image: "benchmarks/nginx", + } + const port = 80 + servers := startServers(ctx, b, + serverArgs{ + machine: machine, + port: port, + runOpts: runOpts, + }) + defer cleanUpContainers(ctx, servers) + + // DropCaches after servers are created. + harness.DropCaches(machine) + // Take after measurement. + after, err := machine.RunCommand(cmd, args...) + if err != nil { + b.Fatalf("failed to run meminfo command: %v", err) + } + meminfo.Report(b, before, after) +} + +// BenchmarkSizeNode starts N containers running a Node app, checks that +// they're serving, and checks memory used based on /proc/meminfo. +func BenchmarkSizeNode(b *testing.B) { + machine, err := testHarness.GetMachine() + if err != nil { + b.Fatalf("failed to get machine with: %v", err) + } + defer machine.CleanUp() + + // Make a redis instance for Node to connect. + ctx := context.Background() + redis, redisIP := redisInstance(ctx, b, machine) + defer redis.CleanUp(ctx) + + // DropCaches after redis is created. + harness.DropCaches(machine) + + // Take before measurement. + meminfo := tools.Meminfo{} + cmd, args := meminfo.MakeCmd() + before, err := machine.RunCommand(cmd, args...) + if err != nil { + b.Fatalf("failed to run meminfo commend: %v", err) + } + + // Create N Node servers. + runOpts := dockerutil.RunOpts{ + Image: "benchmarks/node", + WorkDir: "/usr/src/app", + Links: []string{redis.MakeLink("redis")}, + } + nodeCmd := []string{"node", "index.js", redisIP.String()} + const port = 8080 + servers := startServers(ctx, b, + serverArgs{ + machine: machine, + port: port, + runOpts: runOpts, + cmd: nodeCmd, + }) + defer cleanUpContainers(ctx, servers) + + // DropCaches after servers are created. + harness.DropCaches(machine) + // Take after measurement. + cmd, args = meminfo.MakeCmd() + after, err := machine.RunCommand(cmd, args...) + if err != nil { + b.Fatalf("failed to run meminfo command: %v", err) + } + meminfo.Report(b, before, after) +} + +// serverArgs wraps args for startServers and runServerWorkload. +type serverArgs struct { + machine harness.Machine + port int + runOpts dockerutil.RunOpts + cmd []string +} + +// startServers starts b.N containers defined by 'runOpts' and 'cmd' and uses +// 'machine' to check that each is up. +func startServers(ctx context.Context, b *testing.B, args serverArgs) []*dockerutil.Container { + b.Helper() + servers := make([]*dockerutil.Container, 0, b.N) + + // Create N servers and wait until each of them is serving. + for i := 0; i < b.N; i++ { + server := args.machine.GetContainer(ctx, b) + servers = append(servers, server) + if err := server.Spawn(ctx, args.runOpts, args.cmd...); err != nil { + cleanUpContainers(ctx, servers) + b.Fatalf("failed to spawn node instance: %v", err) + } + + // Get the container IP. + servingIP, err := server.FindIP(ctx, false) + if err != nil { + cleanUpContainers(ctx, servers) + b.Fatalf("failed to get ip from server: %v", err) + } + + // Wait until the server is up. + if err := harness.WaitUntilServing(ctx, args.machine, servingIP, args.port); err != nil { + cleanUpContainers(ctx, servers) + b.Fatalf("failed to wait for serving") + } + } + return servers +} + +// cleanUpContainers cleans up a slice of containers. +func cleanUpContainers(ctx context.Context, containers []*dockerutil.Container) { + for _, c := range containers { + if c != nil { + c.CleanUp(ctx) + } + } +} diff --git a/test/benchmarks/base/startup_test.go b/test/benchmarks/base/startup_test.go new file mode 100644 index 000000000..4628a0a41 --- /dev/null +++ b/test/benchmarks/base/startup_test.go @@ -0,0 +1,156 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package base + +import ( + "context" + "fmt" + "net" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +// BenchmarkStartEmpty times startup time for an empty container. +func BenchmarkStartupEmpty(b *testing.B) { + machine, err := testHarness.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer machine.CleanUp() + + ctx := context.Background() + for i := 0; i < b.N; i++ { + container := machine.GetContainer(ctx, b) + defer container.CleanUp(ctx) + if _, err := container.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/alpine", + }, "true"); err != nil { + b.Fatalf("failed to run container: %v", err) + } + } +} + +// BenchmarkStartupNginx times startup for a Nginx instance. +// Time is measured from start until the first request is served. +func BenchmarkStartupNginx(b *testing.B) { + // The machine to hold Nginx and the Node Server. + machine, err := testHarness.GetMachine() + if err != nil { + b.Fatalf("failed to get machine with: %v", err) + } + defer machine.CleanUp() + + ctx := context.Background() + runOpts := dockerutil.RunOpts{ + Image: "benchmarks/nginx", + } + runServerWorkload(ctx, b, + serverArgs{ + machine: machine, + runOpts: runOpts, + port: 80, + }) +} + +// BenchmarkStartupNode times startup for a Node application instance. +// Time is measured from start until the first request is served. +// Note that the Node app connects to a Redis instance before serving. +func BenchmarkStartupNode(b *testing.B) { + machine, err := testHarness.GetMachine() + if err != nil { + b.Fatalf("failed to get machine with: %v", err) + } + defer machine.CleanUp() + + ctx := context.Background() + redis, redisIP := redisInstance(ctx, b, machine) + defer redis.CleanUp(ctx) + runOpts := dockerutil.RunOpts{ + Image: "benchmarks/node", + WorkDir: "/usr/src/app", + Links: []string{redis.MakeLink("redis")}, + } + + cmd := []string{"node", "index.js", redisIP.String()} + runServerWorkload(ctx, b, + serverArgs{ + machine: machine, + port: 8080, + runOpts: runOpts, + cmd: cmd, + }) +} + +// redisInstance returns a Redis container and its reachable IP. +func redisInstance(ctx context.Context, b *testing.B, machine harness.Machine) (*dockerutil.Container, net.IP) { + b.Helper() + // Spawn a redis instance for the app to use. + redis := machine.GetNativeContainer(ctx, b) + if err := redis.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/redis", + }); err != nil { + redis.CleanUp(ctx) + b.Fatalf("failed to spwan redis instance: %v", err) + } + + if out, err := redis.WaitForOutput(ctx, "Ready to accept connections", 3*time.Second); err != nil { + redis.CleanUp(ctx) + b.Fatalf("failed to start redis server: %v %s", err, out) + } + redisIP, err := redis.FindIP(ctx, false) + if err != nil { + redis.CleanUp(ctx) + b.Fatalf("failed to get IP from redis instance: %v", err) + } + return redis, redisIP +} + +// runServerWorkload runs a server workload defined by 'runOpts' and 'cmd'. +// 'clientMachine' is used to connect to the server on 'serverMachine'. +func runServerWorkload(ctx context.Context, b *testing.B, args serverArgs) { + b.Helper() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := func() error { + server := args.machine.GetContainer(ctx, b) + defer func() { + b.StopTimer() + // Cleanup servers as we run so that we can go indefinitely. + server.CleanUp(ctx) + b.StartTimer() + }() + if err := server.Spawn(ctx, args.runOpts, args.cmd...); err != nil { + return fmt.Errorf("failed to spawn node instance: %v", err) + } + + servingIP, err := server.FindIP(ctx, false) + if err != nil { + return fmt.Errorf("failed to get ip from server: %v", err) + } + + // Wait until the Client sees the server as up. + if err := harness.WaitUntilServing(ctx, args.machine, servingIP, args.port); err != nil { + return fmt.Errorf("failed to wait for serving: %v", err) + } + return nil + }(); err != nil { + b.Fatal(err) + } + } +} diff --git a/test/benchmarks/base/sysbench_test.go b/test/benchmarks/base/sysbench_test.go new file mode 100644 index 000000000..6fb813640 --- /dev/null +++ b/test/benchmarks/base/sysbench_test.go @@ -0,0 +1,89 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package base + +import ( + "context" + "testing" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/tools" +) + +type testCase struct { + name string + test tools.Sysbench +} + +// BenchmarSysbench runs sysbench on the runtime. +func BenchmarkSysbench(b *testing.B) { + + testCases := []testCase{ + testCase{ + name: "CPU", + test: &tools.SysbenchCPU{ + Base: tools.SysbenchBase{ + Threads: 1, + Time: 5, + }, + MaxPrime: 50000, + }, + }, + testCase{ + name: "Memory", + test: &tools.SysbenchMemory{ + Base: tools.SysbenchBase{ + Threads: 1, + }, + BlockSize: "1M", + TotalSize: "500G", + }, + }, + testCase{ + name: "Mutex", + test: &tools.SysbenchMutex{ + Base: tools.SysbenchBase{ + Threads: 8, + }, + Loops: 1, + Locks: 10000000, + Num: 4, + }, + }, + } + + machine, err := testHarness.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer machine.CleanUp() + + for _, tc := range testCases { + b.Run(tc.name, func(b *testing.B) { + + ctx := context.Background() + sysbench := machine.GetContainer(ctx, b) + defer sysbench.CleanUp(ctx) + + out, err := sysbench.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/sysbench", + }, tc.test.MakeCmd()...) + if err != nil { + b.Fatalf("failed to run sysbench: %v: logs:%s", err, out) + } + tc.test.Report(b, out) + }) + } +} diff --git a/test/benchmarks/database/BUILD b/test/benchmarks/database/BUILD new file mode 100644 index 000000000..93b380e8a --- /dev/null +++ b/test/benchmarks/database/BUILD @@ -0,0 +1,28 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "database", + testonly = 1, + srcs = ["database.go"], + deps = ["//test/benchmarks/harness"], +) + +go_test( + name = "database_test", + size = "enormous", + srcs = ["redis_test.go"], + library = ":database", + tags = [ + # Requires docker and runsc to be configured before test runs. + "manual", + "local", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/test/dockerutil", + "//test/benchmarks/harness", + "//test/benchmarks/tools", + ], +) diff --git a/test/root/testdata/busybox.go b/test/benchmarks/database/database.go index e4dbd2843..9eeb59f9a 100644 --- a/test/root/testdata/busybox.go +++ b/test/benchmarks/database/database.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,21 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -package testdata +// Package database holds benchmarks around database applications. +package database -// MountOverSymlink is a JSON config for a container that /etc/resolv.conf is a -// symlink to /tmp/resolv.conf. -var MountOverSymlink = ` -{ - "metadata": { - "name": "busybox" - }, - "image": { - "image": "k8s.gcr.io/busybox" - }, - "command": [ - "sleep", - "1000" - ] +import ( + "os" + "testing" + + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +var h harness.Harness + +// TestMain is the main method for package database. +func TestMain(m *testing.M) { + h.Init() + os.Exit(m.Run()) } -` diff --git a/test/benchmarks/database/redis_test.go b/test/benchmarks/database/redis_test.go new file mode 100644 index 000000000..394fce820 --- /dev/null +++ b/test/benchmarks/database/redis_test.go @@ -0,0 +1,123 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package database + +import ( + "context" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" + "gvisor.dev/gvisor/test/benchmarks/tools" +) + +// All possible operations from redis. Note: "ping" will +// run both PING_INLINE and PING_BUILD. +var operations []string = []string{ + "PING_INLINE", + "PING_BULK", + "SET", + "GET", + "INCR", + "LPUSH", + "RPUSH", + "LPOP", + "RPOP", + "SADD", + "HSET", + "SPOP", + "LRANGE_100", + "LRANGE_300", + "LRANGE_500", + "LRANGE_600", + "MSET", +} + +// BenchmarkRedis runs redis-benchmark against a redis instance and reports +// data in queries per second. Each is reported by named operation (e.g. LPUSH). +func BenchmarkRedis(b *testing.B) { + clientMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer clientMachine.CleanUp() + + serverMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer serverMachine.CleanUp() + + // Redis runs on port 6379 by default. + port := 6379 + ctx := context.Background() + + for _, operation := range operations { + b.Run(operation, func(b *testing.B) { + server := serverMachine.GetContainer(ctx, b) + defer server.CleanUp(ctx) + + // The redis docker container takes no arguments to run a redis server. + if err := server.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/redis", + Ports: []int{port}, + }); err != nil { + b.Fatalf("failed to start redis server with: %v", err) + } + + if out, err := server.WaitForOutput(ctx, "Ready to accept connections", 3*time.Second); err != nil { + b.Fatalf("failed to start redis server: %v %s", err, out) + } + + ip, err := serverMachine.IPAddress() + if err != nil { + b.Fatal("failed to get IP from server: %v", err) + } + + serverPort, err := server.FindPort(ctx, port) + if err != nil { + b.Fatal("failed to get IP from server: %v", err) + } + + if err = harness.WaitUntilServing(ctx, clientMachine, ip, serverPort); err != nil { + b.Fatalf("failed to start redis with: %v", err) + } + + redis := tools.Redis{ + Operation: operation, + } + + // Reset profiles and timer to begin the measurement. + server.RestartProfiles() + b.ResetTimer() + for i := 0; i < b.N; i++ { + client := clientMachine.GetNativeContainer(ctx, b) + defer client.CleanUp(ctx) + out, err := client.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/redis", + }, redis.MakeCmd(ip, serverPort)...) + if err != nil { + b.Fatalf("redis-benchmark failed with: %v", err) + } + + // Stop time while we parse results. + b.StopTimer() + redis.Report(b, out) + b.StartTimer() + } + }) + } +} diff --git a/test/benchmarks/fs/BUILD b/test/benchmarks/fs/BUILD new file mode 100644 index 000000000..45f11372b --- /dev/null +++ b/test/benchmarks/fs/BUILD @@ -0,0 +1,32 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "fs", + testonly = 1, + srcs = ["fs.go"], + deps = ["//test/benchmarks/harness"], +) + +go_test( + name = "fs_test", + size = "large", + srcs = [ + "bazel_test.go", + "fio_test.go", + ], + library = ":fs", + tags = [ + # Requires docker and runsc to be configured before test runs. + "local", + "manual", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/test/dockerutil", + "//test/benchmarks/harness", + "//test/benchmarks/tools", + "@com_github_docker_docker//api/types/mount:go_default_library", + ], +) diff --git a/test/benchmarks/fs/bazel_test.go b/test/benchmarks/fs/bazel_test.go new file mode 100644 index 000000000..f4236ba37 --- /dev/null +++ b/test/benchmarks/fs/bazel_test.go @@ -0,0 +1,119 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package fs + +import ( + "context" + "fmt" + "strings" + "testing" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +// Note: CleanCache versions of this test require running with root permissions. +func BenchmarkBuildABSL(b *testing.B) { + runBuildBenchmark(b, "benchmarks/absl", "/abseil-cpp", "absl/base/...") +} + +// Note: CleanCache versions of this test require running with root permissions. +// Note: This test takes on the order of 10m per permutation for runsc on kvm. +func BenchmarkBuildRunsc(b *testing.B) { + runBuildBenchmark(b, "benchmarks/runsc", "/gvisor", "runsc:runsc") +} + +func runBuildBenchmark(b *testing.B, image, workdir, target string) { + b.Helper() + // Get a machine from the Harness on which to run. + machine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer machine.CleanUp() + + // Dimensions here are clean/dirty cache (do or don't drop caches) + // and if the mount on which we are compiling is a tmpfs/bind mount. + benchmarks := []struct { + name string + clearCache bool // clearCache drops caches before running. + tmpfs bool // tmpfs will run compilation on a tmpfs. + }{ + {name: "CleanCache", clearCache: true, tmpfs: false}, + {name: "DirtyCache", clearCache: false, tmpfs: false}, + {name: "CleanCacheTmpfs", clearCache: true, tmpfs: true}, + {name: "DirtyCacheTmpfs", clearCache: false, tmpfs: true}, + } + for _, bm := range benchmarks { + b.Run(bm.name, func(b *testing.B) { + // Grab a container. + ctx := context.Background() + container := machine.GetContainer(ctx, b) + defer container.CleanUp(ctx) + + // Start a container and sleep by an order of b.N. + if err := container.Spawn(ctx, dockerutil.RunOpts{ + Image: image, + }, "sleep", fmt.Sprintf("%d", 1000000)); err != nil { + b.Fatalf("run failed with: %v", err) + } + + // If we are running on a tmpfs, copy to /tmp which is a tmpfs. + if bm.tmpfs { + if out, err := container.Exec(ctx, dockerutil.ExecOpts{}, + "cp", "-r", workdir, "/tmp/."); err != nil { + b.Fatal("failed to copy directory: %v %s", err, out) + } + workdir = "/tmp" + workdir + } + + // Restart profiles after the copy. + container.RestartProfiles() + b.ResetTimer() + // Drop Caches and bazel clean should happen inside the loop as we may use + // time options with b.N. (e.g. Run for an hour.) + for i := 0; i < b.N; i++ { + b.StopTimer() + // Drop Caches for clear cache runs. + if bm.clearCache { + if err := harness.DropCaches(machine); err != nil { + b.Skipf("failed to drop caches: %v. You probably need root.", err) + } + } + b.StartTimer() + + got, err := container.Exec(ctx, dockerutil.ExecOpts{ + WorkDir: workdir, + }, "bazel", "build", "-c", "opt", target) + if err != nil { + b.Fatalf("build failed with: %v", err) + } + b.StopTimer() + + want := "Build completed successfully" + if !strings.Contains(got, want) { + b.Fatalf("string %s not in: %s", want, got) + } + // Clean bazel in case we use b.N. + _, err = container.Exec(ctx, dockerutil.ExecOpts{ + WorkDir: workdir, + }, "bazel", "clean") + if err != nil { + b.Fatalf("build failed with: %v", err) + } + b.StartTimer() + } + }) + } +} diff --git a/test/benchmarks/fs/fio_test.go b/test/benchmarks/fs/fio_test.go new file mode 100644 index 000000000..65874ed8b --- /dev/null +++ b/test/benchmarks/fs/fio_test.go @@ -0,0 +1,170 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package fs + +import ( + "context" + "fmt" + "path/filepath" + "strings" + "testing" + + "github.com/docker/docker/api/types/mount" + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" + "gvisor.dev/gvisor/test/benchmarks/tools" +) + +// BenchmarkFio runs fio on the runtime under test. There are 4 basic test +// cases each run on a tmpfs mount and a bind mount. Fio requires root so that +// caches can be dropped. +func BenchmarkFio(b *testing.B) { + testCases := []tools.Fio{ + tools.Fio{ + Test: "write", + Size: "5G", + Blocksize: "1M", + Iodepth: 4, + }, + tools.Fio{ + Test: "read", + Size: "5G", + Blocksize: "1M", + Iodepth: 4, + }, + tools.Fio{ + Test: "randwrite", + Size: "5G", + Blocksize: "4K", + Iodepth: 4, + Time: 30, + }, + tools.Fio{ + Test: "randread", + Size: "5G", + Blocksize: "4K", + Iodepth: 4, + Time: 30, + }, + } + + machine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine with: %v", err) + } + defer machine.CleanUp() + + for _, fsType := range []mount.Type{mount.TypeBind, mount.TypeTmpfs} { + for _, tc := range testCases { + testName := strings.Title(tc.Test) + strings.Title(string(fsType)) + b.Run(testName, func(b *testing.B) { + ctx := context.Background() + container := machine.GetContainer(ctx, b) + defer container.CleanUp(ctx) + + // Directory and filename inside container where fio will read/write. + outdir := "/data" + outfile := filepath.Join(outdir, "test.txt") + + // Make the required mount and grab a cleanup for bind mounts + // as they are backed by a temp directory (mktemp). + mnt, mountCleanup, err := makeMount(machine, fsType, outdir) + if err != nil { + b.Fatalf("failed to make mount: %v", err) + } + defer mountCleanup() + + // Start the container with the mount. + if err := container.Spawn( + ctx, + dockerutil.RunOpts{ + Image: "benchmarks/fio", + Mounts: []mount.Mount{ + mnt, + }, + }, + // Sleep on the order of b.N. + "sleep", fmt.Sprintf("%d", 1000*b.N), + ); err != nil { + b.Fatalf("failed to start fio container with: %v", err) + } + + // For reads, we need a file to read so make one inside the container. + if strings.Contains(tc.Test, "read") { + fallocateCmd := fmt.Sprintf("fallocate -l %s %s", tc.Size, outfile) + if out, err := container.Exec(ctx, dockerutil.ExecOpts{}, + strings.Split(fallocateCmd, " ")...); err != nil { + b.Fatalf("failed to create readable file on mount: %v, %s", err, out) + } + } + + // Drop caches just before running. + if err := harness.DropCaches(machine); err != nil { + b.Skipf("failed to drop caches with %v. You probably need root.", err) + } + cmd := tc.MakeCmd(outfile) + container.RestartProfiles() + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Run fio. + data, err := container.Exec(ctx, dockerutil.ExecOpts{}, cmd...) + if err != nil { + b.Fatalf("failed to run cmd %v: %v", cmd, err) + } + b.StopTimer() + tc.Report(b, data) + // If b.N is used (i.e. we run for an hour), we should drop caches + // after each run. + if err := harness.DropCaches(machine); err != nil { + b.Fatalf("failed to drop caches: %v", err) + } + b.StartTimer() + } + }) + } + } +} + +// makeMount makes a mount and cleanup based on the requested type. Bind +// and volume mounts are backed by a temp directory made with mktemp. +// tmpfs mounts require no such backing and are just made. +// It is up to the caller to call the returned cleanup. +func makeMount(machine harness.Machine, mountType mount.Type, target string) (mount.Mount, func(), error) { + switch mountType { + case mount.TypeVolume, mount.TypeBind: + dir, err := machine.RunCommand("mktemp", "-d") + if err != nil { + return mount.Mount{}, func() {}, fmt.Errorf("failed to create tempdir: %v", err) + } + dir = strings.TrimSuffix(dir, "\n") + + out, err := machine.RunCommand("chmod", "777", dir) + if err != nil { + machine.RunCommand("rm", "-rf", dir) + return mount.Mount{}, func() {}, fmt.Errorf("failed modify directory: %v %s", err, out) + } + return mount.Mount{ + Target: target, + Source: dir, + Type: mount.TypeBind, + }, func() { machine.RunCommand("rm", "-rf", dir) }, nil + case mount.TypeTmpfs: + return mount.Mount{ + Target: target, + Type: mount.TypeTmpfs, + }, func() {}, nil + default: + return mount.Mount{}, func() {}, fmt.Errorf("illegal mount time not supported: %v", mountType) + } +} diff --git a/test/benchmarks/fs/fs.go b/test/benchmarks/fs/fs.go new file mode 100644 index 000000000..e5ca28c3b --- /dev/null +++ b/test/benchmarks/fs/fs.go @@ -0,0 +1,31 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package fs holds benchmarks around filesystem performance. +package fs + +import ( + "os" + "testing" + + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +var h harness.Harness + +// TestMain is the main method for package fs. +func TestMain(m *testing.M) { + h.Init() + os.Exit(m.Run()) +} diff --git a/test/benchmarks/harness/BUILD b/test/benchmarks/harness/BUILD new file mode 100644 index 000000000..c2e316709 --- /dev/null +++ b/test/benchmarks/harness/BUILD @@ -0,0 +1,18 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "harness", + testonly = 1, + srcs = [ + "harness.go", + "machine.go", + "util.go", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/test/dockerutil", + "//pkg/test/testutil", + ], +) diff --git a/test/root/testdata/simple.go b/test/benchmarks/harness/harness.go index 1cca53f0c..68bd7b4cf 100644 --- a/test/root/testdata/simple.go +++ b/test/benchmarks/harness/harness.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,30 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -package testdata +// Package harness holds utility code for running benchmarks on Docker. +package harness import ( - "encoding/json" - "fmt" + "flag" + + "gvisor.dev/gvisor/pkg/test/dockerutil" ) -// SimpleSpec returns a JSON config for a simple container that runs the -// specified command in the specified image. -func SimpleSpec(name, image string, cmd []string) string { - cmds, err := json.Marshal(cmd) - if err != nil { - // This shouldn't happen. - panic(err) - } - return fmt.Sprintf(` -{ - "metadata": { - "name": %q - }, - "image": { - "image": %q - }, - "command": %s - } -`, name, image, cmds) +// Harness is a handle for managing state in benchmark runs. +type Harness struct { +} + +// Init performs any harness initilialization before runs. +func (h *Harness) Init() error { + flag.Parse() + dockerutil.EnsureSupportedDockerVersion() + return nil +} + +// GetMachine returns this run's implementation of machine. +func (h *Harness) GetMachine() (Machine, error) { + return &localMachine{}, nil } diff --git a/test/benchmarks/harness/machine.go b/test/benchmarks/harness/machine.go new file mode 100644 index 000000000..88e5e841b --- /dev/null +++ b/test/benchmarks/harness/machine.go @@ -0,0 +1,81 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package harness + +import ( + "context" + "net" + "os/exec" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/pkg/test/testutil" +) + +// Machine describes a real machine for use in benchmarks. +type Machine interface { + // GetContainer gets a container from the machine. The container uses the + // runtime under test and is profiled if requested by flags. + GetContainer(ctx context.Context, log testutil.Logger) *dockerutil.Container + + // GetNativeContainer gets a native container from the machine. Native containers + // use runc by default and are not profiled. + GetNativeContainer(ctx context.Context, log testutil.Logger) *dockerutil.Container + + // RunCommand runs cmd on this machine. + RunCommand(cmd string, args ...string) (string, error) + + // Returns IP Address for the machine. + IPAddress() (net.IP, error) + + // CleanUp cleans up this machine. + CleanUp() +} + +// localMachine describes this machine. +type localMachine struct { +} + +// GetContainer implements Machine.GetContainer for localMachine. +func (l *localMachine) GetContainer(ctx context.Context, logger testutil.Logger) *dockerutil.Container { + return dockerutil.MakeContainer(ctx, logger) +} + +// GetContainer implements Machine.GetContainer for localMachine. +func (l *localMachine) GetNativeContainer(ctx context.Context, logger testutil.Logger) *dockerutil.Container { + return dockerutil.MakeNativeContainer(ctx, logger) +} + +// RunCommand implements Machine.RunCommand for localMachine. +func (l *localMachine) RunCommand(cmd string, args ...string) (string, error) { + c := exec.Command(cmd, args...) + out, err := c.CombinedOutput() + return string(out), err +} + +// IPAddress implements Machine.IPAddress. +func (l *localMachine) IPAddress() (net.IP, error) { + conn, err := net.Dial("udp", "8.8.8.8:80") + if err != nil { + return nil, err + } + defer conn.Close() + + addr := conn.LocalAddr().(*net.UDPAddr) + return addr.IP, nil +} + +// CleanUp implements Machine.CleanUp and does nothing for localMachine. +func (*localMachine) CleanUp() { +} diff --git a/test/benchmarks/harness/util.go b/test/benchmarks/harness/util.go new file mode 100644 index 000000000..86b863f78 --- /dev/null +++ b/test/benchmarks/harness/util.go @@ -0,0 +1,48 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package harness + +import ( + "context" + "fmt" + "net" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/pkg/test/testutil" +) + +//TODO(gvisor.dev/issue/3535): move to own package or move methods to harness struct. + +// WaitUntilServing grabs a container from `machine` and waits for a server at +// IP:port. +func WaitUntilServing(ctx context.Context, machine Machine, server net.IP, port int) error { + var logger testutil.DefaultLogger = "util" + netcat := machine.GetNativeContainer(ctx, logger) + defer netcat.CleanUp(ctx) + + cmd := fmt.Sprintf("while ! wget -q --spider http://%s:%d; do true; done", server, port) + _, err := netcat.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/util", + }, "sh", "-c", cmd) + return err +} + +// DropCaches drops caches on the provided machine. Requires root. +func DropCaches(machine Machine) error { + if out, err := machine.RunCommand("/bin/sh", "-c", "sync && sysctl vm.drop_caches=3"); err != nil { + return fmt.Errorf("failed to drop caches: %v logs: %s", err, out) + } + return nil +} diff --git a/test/benchmarks/media/BUILD b/test/benchmarks/media/BUILD new file mode 100644 index 000000000..bb242d385 --- /dev/null +++ b/test/benchmarks/media/BUILD @@ -0,0 +1,22 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "media", + testonly = 1, + srcs = ["media.go"], + deps = ["//test/benchmarks/harness"], +) + +go_test( + name = "media_test", + size = "large", + srcs = ["ffmpeg_test.go"], + library = ":media", + visibility = ["//:sandbox"], + deps = [ + "//pkg/test/dockerutil", + "//test/benchmarks/harness", + ], +) diff --git a/test/benchmarks/media/ffmpeg_test.go b/test/benchmarks/media/ffmpeg_test.go new file mode 100644 index 000000000..7822dfad7 --- /dev/null +++ b/test/benchmarks/media/ffmpeg_test.go @@ -0,0 +1,53 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package media + +import ( + "context" + "strings" + "testing" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +// BenchmarkFfmpeg runs ffmpeg in a container and records runtime. +// BenchmarkFfmpeg should run as root to drop caches. +func BenchmarkFfmpeg(b *testing.B) { + machine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer machine.CleanUp() + + ctx := context.Background() + container := machine.GetContainer(ctx, b) + defer container.CleanUp(ctx) + cmd := strings.Split("ffmpeg -i video.mp4 -c:v libx264 -preset veryslow output.mp4", " ") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + if err := harness.DropCaches(machine); err != nil { + b.Skipf("failed to drop caches: %v. You probably need root.", err) + } + b.StartTimer() + + if _, err := container.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/ffmpeg", + }, cmd...); err != nil { + b.Fatalf("failed to run container: %v", err) + } + } +} diff --git a/test/benchmarks/media/media.go b/test/benchmarks/media/media.go new file mode 100644 index 000000000..c7b35b758 --- /dev/null +++ b/test/benchmarks/media/media.go @@ -0,0 +1,31 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package media holds benchmarks around media processing applications. +package media + +import ( + "os" + "testing" + + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +var h harness.Harness + +// TestMain is the main method for package media. +func TestMain(m *testing.M) { + h.Init() + os.Exit(m.Run()) +} diff --git a/test/benchmarks/ml/BUILD b/test/benchmarks/ml/BUILD new file mode 100644 index 000000000..970f52706 --- /dev/null +++ b/test/benchmarks/ml/BUILD @@ -0,0 +1,22 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "ml", + testonly = 1, + srcs = ["ml.go"], + deps = ["//test/benchmarks/harness"], +) + +go_test( + name = "ml_test", + size = "large", + srcs = ["tensorflow_test.go"], + library = ":ml", + visibility = ["//:sandbox"], + deps = [ + "//pkg/test/dockerutil", + "//test/benchmarks/harness", + ], +) diff --git a/test/benchmarks/ml/ml.go b/test/benchmarks/ml/ml.go new file mode 100644 index 000000000..13282d7bb --- /dev/null +++ b/test/benchmarks/ml/ml.go @@ -0,0 +1,31 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ml holds benchmarks around machine learning performance. +package ml + +import ( + "os" + "testing" + + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +var h harness.Harness + +// TestMain is the main method for package ml. +func TestMain(m *testing.M) { + h.Init() + os.Exit(m.Run()) +} diff --git a/test/benchmarks/ml/tensorflow_test.go b/test/benchmarks/ml/tensorflow_test.go new file mode 100644 index 000000000..f7746897d --- /dev/null +++ b/test/benchmarks/ml/tensorflow_test.go @@ -0,0 +1,69 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package ml + +import ( + "context" + "testing" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +// BenchmarkTensorflow runs workloads from a TensorFlow tutorial. +// See: https://github.com/aymericdamien/TensorFlow-Examples +func BenchmarkTensorflow(b *testing.B) { + workloads := map[string]string{ + "GradientDecisionTree": "2_BasicModels/gradient_boosted_decision_tree.py", + "Kmeans": "2_BasicModels/kmeans.py", + "LogisticRegression": "2_BasicModels/logistic_regression.py", + "NearestNeighbor": "2_BasicModels/nearest_neighbor.py", + "RandomForest": "2_BasicModels/random_forest.py", + "ConvolutionalNetwork": "3_NeuralNetworks/convolutional_network.py", + "MultilayerPerceptron": "3_NeuralNetworks/multilayer_perceptron.py", + "NeuralNetwork": "3_NeuralNetworks/neural_network.py", + } + + machine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer machine.CleanUp() + + for name, workload := range workloads { + b.Run(name, func(b *testing.B) { + ctx := context.Background() + container := machine.GetContainer(ctx, b) + defer container.CleanUp(ctx) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + if err := harness.DropCaches(machine); err != nil { + b.Skipf("failed to drop caches: %v. You probably need root.", err) + } + b.StartTimer() + + if out, err := container.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/tensorflow", + Env: []string{"PYTHONPATH=$PYTHONPATH:/TensorFlow-Examples/examples"}, + WorkDir: "/TensorFlow-Examples/examples", + }, "python", workload); err != nil { + b.Fatalf("failed to run container: %v logs: %s", err, out) + } + } + }) + } + +} diff --git a/test/benchmarks/network/BUILD b/test/benchmarks/network/BUILD new file mode 100644 index 000000000..bd3f6245c --- /dev/null +++ b/test/benchmarks/network/BUILD @@ -0,0 +1,35 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "network", + testonly = 1, + srcs = ["network.go"], + deps = ["//test/benchmarks/harness"], +) + +go_test( + name = "network_test", + size = "large", + srcs = [ + "httpd_test.go", + "iperf_test.go", + "nginx_test.go", + "node_test.go", + "ruby_test.go", + ], + library = ":network", + tags = [ + # Requires docker and runsc to be configured before test runs. + "manual", + "local", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/test/dockerutil", + "//pkg/test/testutil", + "//test/benchmarks/harness", + "//test/benchmarks/tools", + ], +) diff --git a/test/benchmarks/network/httpd_test.go b/test/benchmarks/network/httpd_test.go new file mode 100644 index 000000000..336e04c91 --- /dev/null +++ b/test/benchmarks/network/httpd_test.go @@ -0,0 +1,181 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package network + +import ( + "context" + "fmt" + "testing" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" + "gvisor.dev/gvisor/test/benchmarks/tools" +) + +// see Dockerfile '//images/benchmarks/httpd'. +var docs = map[string]string{ + "notfound": "notfound", + "1Kb": "latin1k.txt", + "10Kb": "latin10k.txt", + "100Kb": "latin100k.txt", + "1000Kb": "latin1000k.txt", + "1Mb": "latin1024k.txt", + "10Mb": "latin10240k.txt", +} + +// BenchmarkHttpdConcurrency iterates the concurrency argument and tests +// how well the runtime under test handles requests in parallel. +func BenchmarkHttpdConcurrency(b *testing.B) { + // Grab a machine for the client and server. + clientMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get client: %v", err) + } + defer clientMachine.CleanUp() + + serverMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get server: %v", err) + } + defer serverMachine.CleanUp() + + // The test iterates over client concurrency, so set other parameters. + concurrency := []int{1, 25, 50, 100, 1000} + + for _, c := range concurrency { + b.Run(fmt.Sprintf("%d", c), func(b *testing.B) { + hey := &tools.Hey{ + Requests: 10000, + Concurrency: c, + Doc: docs["10Kb"], + } + runHttpd(b, clientMachine, serverMachine, hey, false /* reverse */) + }) + } +} + +// BenchmarkHttpdDocSize iterates over different sized payloads, testing how +// well the runtime handles sending different payload sizes. +func BenchmarkHttpdDocSize(b *testing.B) { + benchmarkHttpdDocSize(b, false /* reverse */) +} + +// BenchmarkReverseHttpdDocSize iterates over different sized payloads, testing +// how well the runtime handles receiving different payload sizes. +func BenchmarkReverseHttpdDocSize(b *testing.B) { + benchmarkHttpdDocSize(b, true /* reverse */) +} + +func benchmarkHttpdDocSize(b *testing.B, reverse bool) { + b.Helper() + + clientMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer clientMachine.CleanUp() + + serverMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer serverMachine.CleanUp() + + for name, filename := range docs { + concurrency := []int{1, 25, 50, 100, 1000} + for _, c := range concurrency { + b.Run(fmt.Sprintf("%s_%d", name, c), func(b *testing.B) { + hey := &tools.Hey{ + Requests: 10000, + Concurrency: c, + Doc: filename, + } + runHttpd(b, clientMachine, serverMachine, hey, reverse) + }) + } + } +} + +// runHttpd runs a single test run. +func runHttpd(b *testing.B, clientMachine, serverMachine harness.Machine, hey *tools.Hey, reverse bool) { + b.Helper() + + // Grab a container from the server. + ctx := context.Background() + var server *dockerutil.Container + if reverse { + server = serverMachine.GetNativeContainer(ctx, b) + } else { + server = serverMachine.GetContainer(ctx, b) + } + + defer server.CleanUp(ctx) + + // Copy the docs to /tmp and serve from there. + cmd := "mkdir -p /tmp/html; cp -r /local/* /tmp/html/.; apache2 -X" + port := 80 + + // Start the server. + if err := server.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/httpd", + Ports: []int{port}, + Env: []string{ + // Standard environmental variables for httpd. + "APACHE_RUN_DIR=/tmp", + "APACHE_RUN_USER=nobody", + "APACHE_RUN_GROUP=nogroup", + "APACHE_LOG_DIR=/tmp", + "APACHE_PID_FILE=/tmp/apache.pid", + }, + }, "sh", "-c", cmd); err != nil { + b.Fatalf("failed to start server: %v", err) + } + + ip, err := serverMachine.IPAddress() + if err != nil { + b.Fatalf("failed to find server ip: %v", err) + } + + servingPort, err := server.FindPort(ctx, port) + if err != nil { + b.Fatalf("failed to find server port %d: %v", port, err) + } + + // Check the server is serving. + harness.WaitUntilServing(ctx, clientMachine, ip, servingPort) + + var client *dockerutil.Container + // Grab a client. + if reverse { + client = clientMachine.GetContainer(ctx, b) + } else { + client = clientMachine.GetNativeContainer(ctx, b) + } + defer client.CleanUp(ctx) + + b.ResetTimer() + server.RestartProfiles() + for i := 0; i < b.N; i++ { + out, err := client.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/hey", + }, hey.MakeCmd(ip, servingPort)...) + if err != nil { + b.Fatalf("run failed with: %v", err) + } + + b.StopTimer() + hey.Report(b, out) + b.StartTimer() + } +} diff --git a/test/benchmarks/network/iperf_test.go b/test/benchmarks/network/iperf_test.go new file mode 100644 index 000000000..b8ab7dfb8 --- /dev/null +++ b/test/benchmarks/network/iperf_test.go @@ -0,0 +1,113 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package network + +import ( + "context" + "testing" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/pkg/test/testutil" + "gvisor.dev/gvisor/test/benchmarks/harness" + "gvisor.dev/gvisor/test/benchmarks/tools" +) + +func BenchmarkIperf(b *testing.B) { + iperf := tools.Iperf{ + Time: 10, // time in seconds to run client. + } + + clientMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer clientMachine.CleanUp() + + serverMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer serverMachine.CleanUp() + ctx := context.Background() + for _, bm := range []struct { + name string + clientFunc func(context.Context, testutil.Logger) *dockerutil.Container + serverFunc func(context.Context, testutil.Logger) *dockerutil.Container + }{ + // We are either measuring the server or the client. The other should be + // runc. e.g. Upload sees how fast the runtime under test uploads to a native + // server. + { + name: "Upload", + clientFunc: clientMachine.GetContainer, + serverFunc: serverMachine.GetNativeContainer, + }, + { + name: "Download", + clientFunc: clientMachine.GetNativeContainer, + serverFunc: serverMachine.GetContainer, + }, + } { + b.Run(bm.name, func(b *testing.B) { + // Set up the containers. + server := bm.serverFunc(ctx, b) + defer server.CleanUp(ctx) + client := bm.clientFunc(ctx, b) + defer client.CleanUp(ctx) + + // iperf serves on port 5001 by default. + port := 5001 + + // Start the server. + if err := server.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/iperf", + Ports: []int{port}, + }, "iperf", "-s"); err != nil { + b.Fatalf("failed to start server with: %v", err) + } + + ip, err := serverMachine.IPAddress() + if err != nil { + b.Fatalf("failed to find server ip: %v", err) + } + + servingPort, err := server.FindPort(ctx, port) + if err != nil { + b.Fatalf("failed to find port %d: %v", port, err) + } + + // Make sure the server is up and serving before we run. + if err := harness.WaitUntilServing(ctx, clientMachine, ip, servingPort); err != nil { + b.Fatalf("failed to wait for server: %v", err) + } + // Run the client. + b.ResetTimer() + + // Restart the server profiles. If the server isn't being profiled + // this does nothing. + server.RestartProfiles() + for i := 0; i < b.N; i++ { + out, err := client.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/iperf", + }, iperf.MakeCmd(ip, servingPort)...) + if err != nil { + b.Fatalf("failed to run client: %v", err) + } + b.StopTimer() + iperf.Report(b, out) + b.StartTimer() + } + }) + } +} diff --git a/test/benchmarks/network/network.go b/test/benchmarks/network/network.go new file mode 100644 index 000000000..ce17ddb94 --- /dev/null +++ b/test/benchmarks/network/network.go @@ -0,0 +1,31 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package network holds benchmarks around raw network performance. +package network + +import ( + "os" + "testing" + + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +var h harness.Harness + +// TestMain is the main method for package network. +func TestMain(m *testing.M) { + h.Init() + os.Exit(m.Run()) +} diff --git a/test/benchmarks/network/nginx_test.go b/test/benchmarks/network/nginx_test.go new file mode 100644 index 000000000..2bf1a3624 --- /dev/null +++ b/test/benchmarks/network/nginx_test.go @@ -0,0 +1,104 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package network + +import ( + "context" + "fmt" + "testing" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" + "gvisor.dev/gvisor/test/benchmarks/tools" +) + +// BenchmarkNginxConcurrency iterates the concurrency argument and tests +// how well the runtime under test handles requests in parallel. +// TODO(gvisor.dev/issue/3536): Update with different doc sizes like Httpd. +func BenchmarkNginxConcurrency(b *testing.B) { + // Grab a machine for the client and server. + clientMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get client: %v", err) + } + defer clientMachine.CleanUp() + + serverMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get server: %v", err) + } + defer serverMachine.CleanUp() + + concurrency := []int{1, 5, 10, 25} + for _, c := range concurrency { + b.Run(fmt.Sprintf("%d", c), func(b *testing.B) { + hey := &tools.Hey{ + Requests: 10000, + Concurrency: c, + } + runNginx(b, clientMachine, serverMachine, hey) + }) + } +} + +// runHttpd runs a single test run. +func runNginx(b *testing.B, clientMachine, serverMachine harness.Machine, hey *tools.Hey) { + b.Helper() + + // Grab a container from the server. + ctx := context.Background() + server := serverMachine.GetContainer(ctx, b) + defer server.CleanUp(ctx) + + port := 80 + // Start the server. + if err := server.Spawn(ctx, + dockerutil.RunOpts{ + Image: "benchmarks/nginx", + Ports: []int{port}, + }); err != nil { + b.Fatalf("server failed to start: %v", err) + } + + ip, err := serverMachine.IPAddress() + if err != nil { + b.Fatalf("failed to find server ip: %v", err) + } + + servingPort, err := server.FindPort(ctx, port) + if err != nil { + b.Fatalf("failed to find server port %d: %v", port, err) + } + + // Check the server is serving. + harness.WaitUntilServing(ctx, clientMachine, ip, servingPort) + + // Grab a client. + client := clientMachine.GetNativeContainer(ctx, b) + defer client.CleanUp(ctx) + + b.ResetTimer() + server.RestartProfiles() + for i := 0; i < b.N; i++ { + out, err := client.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/hey", + }, hey.MakeCmd(ip, servingPort)...) + if err != nil { + b.Fatalf("run failed with: %v", err) + } + b.StopTimer() + hey.Report(b, out) + b.StartTimer() + } +} diff --git a/test/benchmarks/network/node_test.go b/test/benchmarks/network/node_test.go new file mode 100644 index 000000000..52eb794c4 --- /dev/null +++ b/test/benchmarks/network/node_test.go @@ -0,0 +1,127 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package network + +import ( + "context" + "fmt" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" + "gvisor.dev/gvisor/test/benchmarks/tools" +) + +// BenchmarkNode runs requests using 'hey' against a Node server run on +// 'runtime'. The server responds to requests by grabbing some data in a +// redis instance and returns the data in its reponse. The test loops through +// increasing amounts of concurency for requests. +func BenchmarkNode(b *testing.B) { + concurrency := []int{1, 5, 10, 25} + for _, c := range concurrency { + b.Run(fmt.Sprintf("Concurrency%d", c), func(b *testing.B) { + hey := &tools.Hey{ + Requests: b.N * c, // Requests b.N requests per thread. + Concurrency: c, + } + runNode(b, hey) + }) + } +} + +// runNode runs the test for a given # of requests and concurrency. +func runNode(b *testing.B, hey *tools.Hey) { + b.Helper() + + // The machine to hold Redis and the Node Server. + serverMachine, err := h.GetMachine() + if err != nil { + b.Fatal("failed to get machine with: %v", err) + } + defer serverMachine.CleanUp() + + // The machine to run 'hey'. + clientMachine, err := h.GetMachine() + if err != nil { + b.Fatal("failed to get machine with: %v", err) + } + defer clientMachine.CleanUp() + + ctx := context.Background() + + // Spawn a redis instance for the app to use. + redis := serverMachine.GetNativeContainer(ctx, b) + if err := redis.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/redis", + }); err != nil { + b.Fatalf("failed to spwan redis instance: %v", err) + } + defer redis.CleanUp(ctx) + + if out, err := redis.WaitForOutput(ctx, "Ready to accept connections", 3*time.Second); err != nil { + b.Fatalf("failed to start redis server: %v %s", err, out) + } + redisIP, err := redis.FindIP(ctx, false) + if err != nil { + b.Fatalf("failed to get IP from redis instance: %v", err) + } + + // Node runs on port 8080. + port := 8080 + + // Start-up the Node server. + nodeApp := serverMachine.GetContainer(ctx, b) + if err := nodeApp.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/node", + WorkDir: "/usr/src/app", + Links: []string{redis.MakeLink("redis")}, + Ports: []int{port}, + }, "node", "index.js", redisIP.String()); err != nil { + b.Fatalf("failed to spawn node instance: %v", err) + } + defer nodeApp.CleanUp(ctx) + + servingIP, err := serverMachine.IPAddress() + if err != nil { + b.Fatalf("failed to get ip from server: %v", err) + } + + servingPort, err := nodeApp.FindPort(ctx, port) + if err != nil { + b.Fatalf("failed to port from node instance: %v", err) + } + + // Wait until the Client sees the server as up. + harness.WaitUntilServing(ctx, clientMachine, servingIP, servingPort) + + heyCmd := hey.MakeCmd(servingIP, servingPort) + + nodeApp.RestartProfiles() + b.ResetTimer() + + // the client should run on Native. + client := clientMachine.GetNativeContainer(ctx, b) + out, err := client.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/hey", + }, heyCmd...) + if err != nil { + b.Fatalf("hey container failed: %v logs: %s", err, out) + } + + // Stop the timer to parse the data and report stats. + b.StopTimer() + hey.Report(b, out) + b.StartTimer() +} diff --git a/test/benchmarks/network/ruby_test.go b/test/benchmarks/network/ruby_test.go new file mode 100644 index 000000000..5e0b2b724 --- /dev/null +++ b/test/benchmarks/network/ruby_test.go @@ -0,0 +1,134 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package network + +import ( + "context" + "fmt" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" + "gvisor.dev/gvisor/test/benchmarks/tools" +) + +// BenchmarkRuby runs requests using 'hey' against a ruby application server. +// On start, ruby app generates some random data and pushes it to a redis +// instance. On a request, the app grabs for random entries from the redis +// server, publishes it to a document, and returns the doc to the request. +func BenchmarkRuby(b *testing.B) { + concurrency := []int{1, 5, 10, 25} + for _, c := range concurrency { + b.Run(fmt.Sprintf("Concurrency%d", c), func(b *testing.B) { + hey := &tools.Hey{ + Requests: b.N * c, // b.N requests per thread. + Concurrency: c, + } + runRuby(b, hey) + }) + } +} + +// runRuby runs the test for a given # of requests and concurrency. +func runRuby(b *testing.B, hey *tools.Hey) { + b.Helper() + // The machine to hold Redis and the Ruby Server. + serverMachine, err := h.GetMachine() + if err != nil { + b.Fatal("failed to get machine with: %v", err) + } + defer serverMachine.CleanUp() + + // The machine to run 'hey'. + clientMachine, err := h.GetMachine() + if err != nil { + b.Fatal("failed to get machine with: %v", err) + } + defer clientMachine.CleanUp() + ctx := context.Background() + + // Spawn a redis instance for the app to use. + redis := serverMachine.GetNativeContainer(ctx, b) + if err := redis.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/redis", + }); err != nil { + b.Fatalf("failed to spwan redis instance: %v", err) + } + defer redis.CleanUp(ctx) + + if out, err := redis.WaitForOutput(ctx, "Ready to accept connections", 3*time.Second); err != nil { + b.Fatalf("failed to start redis server: %v %s", err, out) + } + redisIP, err := redis.FindIP(ctx, false) + if err != nil { + b.Fatalf("failed to get IP from redis instance: %v", err) + } + + // Ruby runs on port 9292. + const port = 9292 + + // Start-up the Ruby server. + rubyApp := serverMachine.GetContainer(ctx, b) + if err := rubyApp.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/ruby", + WorkDir: "/app", + Links: []string{redis.MakeLink("redis")}, + Ports: []int{port}, + Env: []string{ + fmt.Sprintf("PORT=%d", port), + "WEB_CONCURRENCY=20", + "WEB_MAX_THREADS=20", + "RACK_ENV=production", + fmt.Sprintf("HOST=%s", redisIP), + }, + User: "nobody", + }, "sh", "-c", "/usr/bin/puma"); err != nil { + b.Fatalf("failed to spawn node instance: %v", err) + } + defer rubyApp.CleanUp(ctx) + + servingIP, err := serverMachine.IPAddress() + if err != nil { + b.Fatalf("failed to get ip from server: %v", err) + } + + servingPort, err := rubyApp.FindPort(ctx, port) + if err != nil { + b.Fatalf("failed to port from node instance: %v", err) + } + + // Wait until the Client sees the server as up. + if err := harness.WaitUntilServing(ctx, clientMachine, servingIP, servingPort); err != nil { + b.Fatalf("failed to wait until serving: %v", err) + } + heyCmd := hey.MakeCmd(servingIP, servingPort) + rubyApp.RestartProfiles() + b.ResetTimer() + + // the client should run on Native. + client := clientMachine.GetNativeContainer(ctx, b) + defer client.CleanUp(ctx) + out, err := client.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/hey", + }, heyCmd...) + if err != nil { + b.Fatalf("hey container failed: %v logs: %s", err, out) + } + + // Stop the timer to parse the data and report stats. + b.StopTimer() + hey.Report(b, out) + b.StartTimer() +} diff --git a/test/benchmarks/tcp/BUILD b/test/benchmarks/tcp/BUILD new file mode 100644 index 000000000..6dde7d9e6 --- /dev/null +++ b/test/benchmarks/tcp/BUILD @@ -0,0 +1,41 @@ +load("//tools:defs.bzl", "cc_binary", "go_binary") + +package(licenses = ["notice"]) + +go_binary( + name = "tcp_proxy", + srcs = ["tcp_proxy.go"], + visibility = ["//:sandbox"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/adapters/gonet", + "//pkg/tcpip/link/fdbased", + "//pkg/tcpip/link/qdisc/fifo", + "//pkg/tcpip/network/arp", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/tcp", + "//pkg/tcpip/transport/udp", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +# nsjoin is a trivial replacement for nsenter. This is used because nsenter is +# not available on all systems where this benchmark is run (and we aim to +# minimize external dependencies.) + +cc_binary( + name = "nsjoin", + srcs = ["nsjoin.c"], + visibility = ["//:sandbox"], +) + +sh_binary( + name = "tcp_benchmark", + srcs = ["tcp_benchmark.sh"], + data = [ + ":nsjoin", + ":tcp_proxy", + ], + visibility = ["//:sandbox"], +) diff --git a/test/benchmarks/tcp/README.md b/test/benchmarks/tcp/README.md new file mode 100644 index 000000000..38e6e69f0 --- /dev/null +++ b/test/benchmarks/tcp/README.md @@ -0,0 +1,87 @@ +# TCP Benchmarks + +This directory contains a standardized TCP benchmark. This helps to evaluate the +performance of netstack and native networking stacks under various conditions. + +## `tcp_benchmark` + +This benchmark allows TCP throughput testing under various conditions. The setup +consists of an iperf client, a client proxy, a server proxy and an iperf server. +The client proxy and server proxy abstract the network mechanism used to +communicate between the iperf client and server. + +The setup looks like the following: + +``` + +--------------+ (native) +--------------+ + | iperf client |[lo @ 10.0.0.1]------>| client proxy | + +--------------+ +--------------+ + [client.0 @ 10.0.0.2] + (netstack) | | (native) + +------+-----+ + | + [br0] + | + Network emulation applied ---> [wan.0:wan.1] + | + [br1] + | + +------+-----+ + (netstack) | | (native) + [server.0 @ 10.0.0.3] + +--------------+ +--------------+ + | iperf server |<------[lo @ 10.0.0.4]| server proxy | + +--------------+ (native) +--------------+ +``` + +Different configurations can be run using different arguments. For example: + +* Native test under normal internet conditions: `tcp_benchmark` +* Native test under ideal conditions: `tcp_benchmark --ideal` +* Netstack client under ideal conditions: `tcp_benchmark --client --ideal` +* Netstack client with 5% packet loss: `tcp_benchmark --client --ideal --loss + 5` + +Use `tcp_benchmark --help` for full arguments. + +This tool may be used to easily generate data for graphing. For example, to +generate a CSV for various latencies, you might do: + +``` +rm -f /tmp/netstack_latency.csv /tmp/native_latency.csv +latencies=$(seq 0 5 50; + seq 60 10 100; + seq 125 25 250; + seq 300 50 500) +for latency in $latencies; do + read throughput client_cpu server_cpu <<< \ + $(./tcp_benchmark --duration 30 --client --ideal --latency $latency) + echo $latency,$throughput,$client_cpu >> /tmp/netstack_latency.csv +done +for latency in $latencies; do + read throughput client_cpu server_cpu <<< \ + $(./tcp_benchmark --duration 30 --ideal --latency $latency) + echo $latency,$throughput,$client_cpu >> /tmp/native_latency.csv +done +``` + +Similarly, to generate a CSV for various levels of packet loss, the following +would be appropriate: + +``` +rm -f /tmp/netstack_loss.csv /tmp/native_loss.csv +losses=$(seq 0 0.1 1.0; + seq 1.2 0.2 2.0; + seq 2.5 0.5 5.0; + seq 6.0 1.0 10.0) +for loss in $losses; do + read throughput client_cpu server_cpu <<< \ + $(./tcp_benchmark --duration 30 --client --ideal --latency 10 --loss $loss) + echo $loss,$throughput,$client_cpu >> /tmp/netstack_loss.csv +done +for loss in $losses; do + read throughput client_cpu server_cpu <<< \ + $(./tcp_benchmark --duration 30 --ideal --latency 10 --loss $loss) + echo $loss,$throughput,$client_cpu >> /tmp/native_loss.csv +done +``` diff --git a/test/benchmarks/tcp/nsjoin.c b/test/benchmarks/tcp/nsjoin.c new file mode 100644 index 000000000..524b4d549 --- /dev/null +++ b/test/benchmarks/tcp/nsjoin.c @@ -0,0 +1,47 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#include <errno.h> +#include <fcntl.h> +#include <sched.h> +#include <stdio.h> +#include <string.h> +#include <sys/stat.h> +#include <sys/types.h> +#include <unistd.h> + +int main(int argc, char** argv) { + if (argc <= 2) { + fprintf(stderr, "error: must provide a namespace file.\n"); + fprintf(stderr, "usage: %s <file> [arguments...]\n", argv[0]); + return 1; + } + + int fd = open(argv[1], O_RDONLY); + if (fd < 0) { + fprintf(stderr, "error opening %s: %s\n", argv[1], strerror(errno)); + return 1; + } + if (setns(fd, 0) < 0) { + fprintf(stderr, "error joining %s: %s\n", argv[1], strerror(errno)); + return 1; + } + + execvp(argv[2], &argv[2]); + return 1; +} diff --git a/test/benchmarks/tcp/tcp_benchmark.sh b/test/benchmarks/tcp/tcp_benchmark.sh new file mode 100755 index 000000000..ef04b4ace --- /dev/null +++ b/test/benchmarks/tcp/tcp_benchmark.sh @@ -0,0 +1,392 @@ +#!/bin/bash + +# Copyright 2018 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TCP benchmark; see README.md for documentation. + +# Fixed parameters. +iperf_port=45201 # Not likely to be privileged. +proxy_port=44000 # Ditto. +client_addr=10.0.0.1 +client_proxy_addr=10.0.0.2 +server_proxy_addr=10.0.0.3 +server_addr=10.0.0.4 +mask=8 + +# Defaults; this provides a reasonable approximation of a decent internet link. +# Parameters can be varied independently from this set to see response to +# various changes in the kind of link available. +client=false +server=false +verbose=false +gso=0 +swgso=false +mtu=1280 # 1280 is a reasonable lowest-common-denominator. +latency=10 # 10ms approximates a fast, dedicated connection. +latency_variation=1 # +/- 1ms is a relatively low amount of jitter. +loss=0.1 # 0.1% loss is non-zero, but not extremely high. +duplicate=0.1 # 0.1% means duplicates are 1/10x as frequent as losses. +duration=30 # 30s is enough time to consistent results (experimentally). +helper_dir=$(dirname $0) +netstack_opts= +disable_linux_gso= +num_client_threads=1 + +# Check for netem support. +lsmod_output=$(lsmod | grep sch_netem) +if [ "$?" != "0" ]; then + echo "warning: sch_netem may not be installed." >&2 +fi + +while [ $# -gt 0 ]; do + case "$1" in + --client) + client=true + ;; + --client_tcp_probe_file) + shift + netstack_opts="${netstack_opts} -client_tcp_probe_file=$1" + ;; + --server) + server=true + ;; + --verbose) + verbose=true + ;; + --gso) + shift + gso=$1 + ;; + --swgso) + swgso=true + ;; + --server_tcp_probe_file) + shift + netstack_opts="${netstack_opts} -server_tcp_probe_file=$1" + ;; + --ideal) + mtu=1500 # Standard ethernet. + latency=0 # No latency. + latency_variation=0 # No jitter. + loss=0 # No loss. + duplicate=0 # No duplicates. + ;; + --mtu) + shift + [ "$#" -le 0 ] && echo "no mtu provided" && exit 1 + mtu=$1 + ;; + --sack) + netstack_opts="${netstack_opts} -sack" + ;; + --cubic) + netstack_opts="${netstack_opts} -cubic" + ;; + --moderate-recv-buf) + netstack_opts="${netstack_opts} -moderate_recv_buf" + ;; + --duration) + shift + [ "$#" -le 0 ] && echo "no duration provided" && exit 1 + duration=$1 + ;; + --latency) + shift + [ "$#" -le 0 ] && echo "no latency provided" && exit 1 + latency=$1 + ;; + --latency-variation) + shift + [ "$#" -le 0 ] && echo "no latency variation provided" && exit 1 + latency_variation=$1 + ;; + --loss) + shift + [ "$#" -le 0 ] && echo "no loss probability provided" && exit 1 + loss=$1 + ;; + --duplicate) + shift + [ "$#" -le 0 ] && echo "no duplicate provided" && exit 1 + duplicate=$1 + ;; + --cpuprofile) + shift + netstack_opts="${netstack_opts} -cpuprofile=$1" + ;; + --memprofile) + shift + netstack_opts="${netstack_opts} -memprofile=$1" + ;; + --disable-linux-gso) + disable_linux_gso=1 + ;; + --num-client-threads) + shift + num_client_threads=$1 + ;; + --helpers) + shift + [ "$#" -le 0 ] && echo "no helper dir provided" && exit 1 + helper_dir=$1 + ;; + *) + echo "usage: $0 [options]" + echo "options:" + echo " --help show this message" + echo " --verbose verbose output" + echo " --client use netstack as the client" + echo " --ideal reset all network emulation" + echo " --server use netstack as the server" + echo " --mtu set the mtu (bytes)" + echo " --sack enable SACK support" + echo " --moderate-recv-buf enable TCP receive buffer auto-tuning" + echo " --cubic enable CUBIC congestion control for Netstack" + echo " --duration set the test duration (s)" + echo " --latency set the latency (ms)" + echo " --latency-variation set the latency variation" + echo " --loss set the loss probability (%)" + echo " --duplicate set the duplicate probability (%)" + echo " --helpers set the helper directory" + echo " --num-client-threads number of parallel client threads to run" + echo " --disable-linux-gso disable segmentation offload in the Linux network stack" + echo "" + echo "The output will of the script will be:" + echo " <throughput> <client-cpu-usage> <server-cpu-usage>" + exit 1 + esac + shift +done + +if [ ${verbose} == "true" ]; then + set -x +fi + +# Latency needs to be halved, since it's applied on both ways. +half_latency=$(echo ${latency}/2 | bc -l | awk '{printf "%1.2f", $0}') +half_loss=$(echo ${loss}/2 | bc -l | awk '{printf "%1.6f", $0}') +half_duplicate=$(echo ${duplicate}/2 | bc -l | awk '{printf "%1.6f", $0}') +helper_dir=${helper_dir#$(pwd)/} # Use relative paths. +proxy_binary=${helper_dir}/tcp_proxy +nsjoin_binary=${helper_dir}/nsjoin + +if [ ! -e ${proxy_binary} ]; then + echo "Could not locate ${proxy_binary}, please make sure you've built the binary" + exit 1 +fi + +if [ ! -e ${nsjoin_binary} ]; then + echo "Could not locate ${nsjoin_binary}, please make sure you've built the binary" + exit 1 +fi + +if [ $(echo ${latency_variation} | awk '{printf "%1.2f", $0}') != "0.00" ]; then + # As long as there's some jitter, then we use the paretonormal distribution. + # This will preserve the minimum RTT, but add a realistic amount of jitter to + # the connection and cause re-ordering, etc. The regular pareto distribution + # appears to an unreasonable level of delay (we want only small spikes.) + distribution="distribution paretonormal" +else + distribution="" +fi + +# Client proxy that will listen on the client's iperf target forward traffic +# using the host networking stack. +client_args="${proxy_binary} -port ${proxy_port} -forward ${server_proxy_addr}:${proxy_port}" +if ${client}; then + # Client proxy that will listen on the client's iperf target + # and forward traffic using netstack. + client_args="${proxy_binary} ${netstack_opts} -port ${proxy_port} -client \\ + -mtu ${mtu} -iface client.0 -addr ${client_proxy_addr} -mask ${mask} \\ + -forward ${server_proxy_addr}:${proxy_port} -gso=${gso} -swgso=${swgso}" +fi + +# Server proxy that will listen on the proxy port and forward to the server's +# iperf server using the host networking stack. +server_args="${proxy_binary} -port ${proxy_port} -forward ${server_addr}:${iperf_port}" +if ${server}; then + # Server proxy that will listen on the proxy port and forward to the servers' + # iperf server using netstack. + server_args="${proxy_binary} ${netstack_opts} -port ${proxy_port} -server \\ + -mtu ${mtu} -iface server.0 -addr ${server_proxy_addr} -mask ${mask} \\ + -forward ${server_addr}:${iperf_port} -gso=${gso} -swgso=${swgso}" +fi + +# Specify loss and duplicate parameters only if they are non-zero +loss_opt="" +if [ "$(echo $half_loss | bc -q)" != "0" ]; then + loss_opt="loss random ${half_loss}%" +fi +duplicate_opt="" +if [ "$(echo $half_duplicate | bc -q)" != "0" ]; then + duplicate_opt="duplicate ${half_duplicate}%" +fi + +exec unshare -U -m -n -r -f -p --mount-proc /bin/bash << EOF +set -e -m + +if [ ${verbose} == "true" ]; then + set -x +fi + +mount -t tmpfs netstack-bench /tmp + +# We may have reset the path in the unshare if the shell loaded some public +# profiles. Ensure that tools are discoverable via the parent's PATH. +export PATH=${PATH} + +# Add client, server interfaces. +ip link add client.0 type veth peer name client.1 +ip link add server.0 type veth peer name server.1 + +# Add network emulation devices. +ip link add wan.0 type veth peer name wan.1 +ip link set wan.0 up +ip link set wan.1 up + +# Enroll on the bridge. +ip link add name br0 type bridge +ip link add name br1 type bridge +ip link set client.1 master br0 +ip link set server.1 master br1 +ip link set wan.0 master br0 +ip link set wan.1 master br1 +ip link set br0 up +ip link set br1 up + +# Set the MTU appropriately. +ip link set client.0 mtu ${mtu} +ip link set server.0 mtu ${mtu} +ip link set wan.0 mtu ${mtu} +ip link set wan.1 mtu ${mtu} + +# Add appropriate latency, loss and duplication. +# +# This is added in at the point of bridge connection. +for device in wan.0 wan.1; do + # NOTE: We don't support a loss correlation as testing has shown that it + # actually doesn't work. The man page actually has a small comment about this + # "It is also possible to add a correlation, but this option is now deprecated + # due to the noticed bad behavior." For more information see netem(8). + tc qdisc add dev \$device root netem \\ + delay ${half_latency}ms ${latency_variation}ms ${distribution} \\ + ${loss_opt} ${duplicate_opt} +done + +# Start a client proxy. +touch /tmp/client.netns +unshare -n mount --bind /proc/self/ns/net /tmp/client.netns + +# Move the endpoint into the namespace. +while ip link | grep client.0 > /dev/null; do + ip link set dev client.0 netns /tmp/client.netns +done + +if ! ${client}; then + # Only add the address to NIC if netstack is not in use. Otherwise the host + # will also process the inbound SYN and send a RST back. + ${nsjoin_binary} /tmp/client.netns ip addr add ${client_proxy_addr}/${mask} dev client.0 +fi + +# Start a server proxy. +touch /tmp/server.netns +unshare -n mount --bind /proc/self/ns/net /tmp/server.netns +# Move the endpoint into the namespace. +while ip link | grep server.0 > /dev/null; do + ip link set dev server.0 netns /tmp/server.netns +done +if ! ${server}; then + # Only add the address to NIC if netstack is not in use. Otherwise the host + # will also process the inbound SYN and send a RST back. + ${nsjoin_binary} /tmp/server.netns ip addr add ${server_proxy_addr}/${mask} dev server.0 +fi + +# Add client and server addresses, and bring everything up. +${nsjoin_binary} /tmp/client.netns ip addr add ${client_addr}/${mask} dev client.0 +${nsjoin_binary} /tmp/server.netns ip addr add ${server_addr}/${mask} dev server.0 +if [ "${disable_linux_gso}" == "1" ]; then + ${nsjoin_binary} /tmp/client.netns ethtool -K client.0 tso off + ${nsjoin_binary} /tmp/client.netns ethtool -K client.0 gro off + ${nsjoin_binary} /tmp/client.netns ethtool -K client.0 gso off + ${nsjoin_binary} /tmp/server.netns ethtool -K server.0 tso off + ${nsjoin_binary} /tmp/server.netns ethtool -K server.0 gso off + ${nsjoin_binary} /tmp/server.netns ethtool -K server.0 gro off +fi +${nsjoin_binary} /tmp/client.netns ip link set client.0 up +${nsjoin_binary} /tmp/client.netns ip link set lo up +${nsjoin_binary} /tmp/server.netns ip link set server.0 up +${nsjoin_binary} /tmp/server.netns ip link set lo up +ip link set dev client.1 up +ip link set dev server.1 up + +${nsjoin_binary} /tmp/client.netns ${client_args} & +client_pid=\$! +${nsjoin_binary} /tmp/server.netns ${server_args} & +server_pid=\$! + +# Start the iperf server. +${nsjoin_binary} /tmp/server.netns iperf -p ${iperf_port} -s >&2 & +iperf_pid=\$! + +# Show traffic information. +if ! ${client} && ! ${server}; then + ${nsjoin_binary} /tmp/client.netns ping -c 100 -i 0.001 -W 1 ${server_addr} >&2 || true +fi + +results_file=\$(mktemp) +function cleanup { + rm -f \$results_file + kill -TERM \$client_pid + kill -TERM \$server_pid + wait \$client_pid + wait \$server_pid + kill -9 \$iperf_pid 2>/dev/null +} + +# Allow failure from this point. +set +e +trap cleanup EXIT + +# Run the benchmark, recording the results file. +while ${nsjoin_binary} /tmp/client.netns iperf \\ + -p ${proxy_port} -c ${client_addr} -t ${duration} -f m -P ${num_client_threads} 2>&1 \\ + | tee \$results_file \\ + | grep "connect failed" >/dev/null; do + sleep 0.1 # Wait for all services. +done + +# Unlink all relevant devices from the bridge. This is because when the bridge +# is deleted, the kernel may hang. It appears that this problem is fixed in +# upstream commit 1ce5cce895309862d2c35d922816adebe094fe4a. +ip link set client.1 nomaster +ip link set server.1 nomaster +ip link set wan.0 nomaster +ip link set wan.1 nomaster + +# Emit raw results. +cat \$results_file >&2 + +# Emit a useful result (final throughput). +mbits=\$(grep Mbits/sec \$results_file \\ + | sed -n -e 's/^.*[[:space:]]\\([[:digit:]]\\+\\(\\.[[:digit:]]\\+\\)\\?\\)[[:space:]]*Mbits\\/sec.*/\\1/p') +client_cpu_ticks=\$(cat /proc/\$client_pid/stat \\ + | awk '{print (\$14+\$15);}') +server_cpu_ticks=\$(cat /proc/\$server_pid/stat \\ + | awk '{print (\$14+\$15);}') +ticks_per_sec=\$(getconf CLK_TCK) +client_cpu_load=\$(bc -l <<< \$client_cpu_ticks/\$ticks_per_sec/${duration}) +server_cpu_load=\$(bc -l <<< \$server_cpu_ticks/\$ticks_per_sec/${duration}) +echo \$mbits \$client_cpu_load \$server_cpu_load +EOF diff --git a/test/benchmarks/tcp/tcp_proxy.go b/test/benchmarks/tcp/tcp_proxy.go new file mode 100644 index 000000000..4b7ca7a14 --- /dev/null +++ b/test/benchmarks/tcp/tcp_proxy.go @@ -0,0 +1,451 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary tcp_proxy is a simple TCP proxy. +package main + +import ( + "encoding/gob" + "flag" + "fmt" + "io" + "log" + "math/rand" + "net" + "os" + "os/signal" + "regexp" + "runtime" + "runtime/pprof" + "strconv" + "syscall" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" + "gvisor.dev/gvisor/pkg/tcpip/link/qdisc/fifo" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" +) + +var ( + port = flag.Int("port", 0, "bind port (all addresses)") + forward = flag.String("forward", "", "forwarding target") + client = flag.Bool("client", false, "use netstack for listen") + server = flag.Bool("server", false, "use netstack for dial") + + // Netstack-specific options. + mtu = flag.Int("mtu", 1280, "mtu for network stack") + addr = flag.String("addr", "", "address for tap-based netstack") + mask = flag.Int("mask", 8, "mask size for address") + iface = flag.String("iface", "", "network interface name to bind for netstack") + sack = flag.Bool("sack", false, "enable SACK support for netstack") + moderateRecvBuf = flag.Bool("moderate_recv_buf", false, "enable TCP Receive Buffer Auto-tuning") + cubic = flag.Bool("cubic", false, "enable use of CUBIC congestion control for netstack") + gso = flag.Int("gso", 0, "GSO maximum size") + swgso = flag.Bool("swgso", false, "software-level GSO") + clientTCPProbeFile = flag.String("client_tcp_probe_file", "", "if specified, installs a tcp probe to dump endpoint state to the specified file.") + serverTCPProbeFile = flag.String("server_tcp_probe_file", "", "if specified, installs a tcp probe to dump endpoint state to the specified file.") + cpuprofile = flag.String("cpuprofile", "", "write cpu profile to the specified file.") + memprofile = flag.String("memprofile", "", "write memory profile to the specified file.") +) + +type impl interface { + dial(address string) (net.Conn, error) + listen(port int) (net.Listener, error) + printStats() +} + +type netImpl struct{} + +func (netImpl) dial(address string) (net.Conn, error) { + return net.Dial("tcp", address) +} + +func (netImpl) listen(port int) (net.Listener, error) { + return net.Listen("tcp", fmt.Sprintf(":%d", port)) +} + +func (netImpl) printStats() { +} + +const ( + nicID = 1 // Fixed. + bufSize = 4 << 20 // 4MB. +) + +type netstackImpl struct { + s *stack.Stack + addr tcpip.Address + mode string +} + +func setupNetwork(ifaceName string, numChannels int) (fds []int, err error) { + // Get all interfaces in the namespace. + ifaces, err := net.Interfaces() + if err != nil { + return nil, fmt.Errorf("querying interfaces: %v", err) + } + + for _, iface := range ifaces { + if iface.Name != ifaceName { + continue + } + // Create the socket. + const protocol = 0x0300 // htons(ETH_P_ALL) + fds := make([]int, numChannels) + for i := range fds { + fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_RAW, protocol) + if err != nil { + return nil, fmt.Errorf("unable to create raw socket: %v", err) + } + + // Bind to the appropriate device. + ll := syscall.SockaddrLinklayer{ + Protocol: protocol, + Ifindex: iface.Index, + Pkttype: syscall.PACKET_HOST, + } + if err := syscall.Bind(fd, &ll); err != nil { + return nil, fmt.Errorf("unable to bind to %q: %v", iface.Name, err) + } + + // RAW Sockets by default have a very small SO_RCVBUF of 256KB, + // up it to at least 4MB to reduce packet drops. + if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bufSize); err != nil { + return nil, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", bufSize, err) + } + + if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bufSize); err != nil { + return nil, fmt.Errorf("setsockopt(..., SO_SNDBUF, %v,..) = %v", bufSize, err) + } + + if !*swgso && *gso != 0 { + if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil { + return nil, fmt.Errorf("unable to enable the PACKET_VNET_HDR option: %v", err) + } + } + fds[i] = fd + } + return fds, nil + } + return nil, fmt.Errorf("failed to find interface: %v", ifaceName) +} + +func newNetstackImpl(mode string) (impl, error) { + fds, err := setupNetwork(*iface, runtime.GOMAXPROCS(-1)) + if err != nil { + return nil, err + } + + // Parse details. + parsedAddr := tcpip.Address(net.ParseIP(*addr).To4()) + parsedDest := tcpip.Address("") // Filled in below. + parsedMask := tcpip.AddressMask("") // Filled in below. + switch *mask { + case 8: + parsedDest = tcpip.Address([]byte{parsedAddr[0], 0, 0, 0}) + parsedMask = tcpip.AddressMask([]byte{0xff, 0, 0, 0}) + case 16: + parsedDest = tcpip.Address([]byte{parsedAddr[0], parsedAddr[1], 0, 0}) + parsedMask = tcpip.AddressMask([]byte{0xff, 0xff, 0, 0}) + case 24: + parsedDest = tcpip.Address([]byte{parsedAddr[0], parsedAddr[1], parsedAddr[2], 0}) + parsedMask = tcpip.AddressMask([]byte{0xff, 0xff, 0xff, 0}) + default: + // This is just laziness; we don't expect a different mask. + return nil, fmt.Errorf("mask %d not supported", mask) + } + + // Create a new network stack. + netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()} + transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()} + s := stack.New(stack.Options{ + NetworkProtocols: netProtos, + TransportProtocols: transProtos, + }) + + // Generate a new mac for the eth device. + mac := make(net.HardwareAddr, 6) + rand.Read(mac) // Fill with random data. + mac[0] &^= 0x1 // Clear multicast bit. + mac[0] |= 0x2 // Set local assignment bit (IEEE802). + ep, err := fdbased.New(&fdbased.Options{ + FDs: fds, + MTU: uint32(*mtu), + EthernetHeader: true, + Address: tcpip.LinkAddress(mac), + // Enable checksum generation as we need to generate valid + // checksums for the veth device to deliver our packets to the + // peer. But we do want to disable checksum verification as veth + // devices do perform GRO and the linux host kernel may not + // regenerate valid checksums after GRO. + TXChecksumOffload: false, + RXChecksumOffload: true, + PacketDispatchMode: fdbased.RecvMMsg, + GSOMaxSize: uint32(*gso), + SoftwareGSOEnabled: *swgso, + }) + if err != nil { + return nil, fmt.Errorf("failed to create FD endpoint: %v", err) + } + if err := s.CreateNIC(nicID, fifo.New(ep, runtime.GOMAXPROCS(0), 1000)); err != nil { + return nil, fmt.Errorf("error creating NIC %q: %v", *iface, err) + } + if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + return nil, fmt.Errorf("error adding ARP address to %q: %v", *iface, err) + } + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, parsedAddr); err != nil { + return nil, fmt.Errorf("error adding IP address to %q: %v", *iface, err) + } + + subnet, err := tcpip.NewSubnet(parsedDest, parsedMask) + if err != nil { + return nil, fmt.Errorf("tcpip.Subnet(%s, %s): %s", parsedDest, parsedMask, err) + } + // Add default route; we only support + s.SetRouteTable([]tcpip.Route{ + { + Destination: subnet, + NIC: nicID, + }, + }) + + // Set protocol options. + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(*sack)); err != nil { + return nil, fmt.Errorf("SetTransportProtocolOption for SACKEnabled failed: %s", err) + } + + // Enable Receive Buffer Auto-Tuning. + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(*moderateRecvBuf)); err != nil { + return nil, fmt.Errorf("SetTransportProtocolOption failed: %s", err) + } + + // Set Congestion Control to cubic if requested. + if *cubic { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.CongestionControlOption("cubic")); err != nil { + return nil, fmt.Errorf("SetTransportProtocolOption for CongestionControlOption(cubic) failed: %s", err) + } + } + + return netstackImpl{ + s: s, + addr: parsedAddr, + mode: mode, + }, nil +} + +func (n netstackImpl) dial(address string) (net.Conn, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + if host == "" { + // A host must be provided for the dial. + return nil, fmt.Errorf("no host provided") + } + portNumber, err := strconv.Atoi(port) + if err != nil { + return nil, err + } + addr := tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.Address(net.ParseIP(host).To4()), + Port: uint16(portNumber), + } + conn, err := gonet.DialTCP(n.s, addr, ipv4.ProtocolNumber) + if err != nil { + return nil, err + } + return conn, nil +} + +func (n netstackImpl) listen(port int) (net.Listener, error) { + addr := tcpip.FullAddress{ + NIC: nicID, + Port: uint16(port), + } + listener, err := gonet.ListenTCP(n.s, addr, ipv4.ProtocolNumber) + if err != nil { + return nil, err + } + return listener, nil +} + +var zeroFieldsRegexp = regexp.MustCompile(`\s*[a-zA-Z0-9]*:0`) + +func (n netstackImpl) printStats() { + // Don't show zero fields. + stats := zeroFieldsRegexp.ReplaceAllString(fmt.Sprintf("%+v", n.s.Stats()), "") + log.Printf("netstack %s Stats: %+v\n", n.mode, stats) +} + +// installProbe installs a TCP Probe function that will dump endpoint +// state to the specified file. It also returns a close func() that +// can be used to close the probeFile. +func (n netstackImpl) installProbe(probeFileName string) (close func()) { + // Install Probe to dump out end point state. + probeFile, err := os.Create(probeFileName) + if err != nil { + log.Fatalf("failed to create tcp_probe file %s: %v", probeFileName, err) + } + probeEncoder := gob.NewEncoder(probeFile) + // Install a TCP Probe. + n.s.AddTCPProbe(func(state stack.TCPEndpointState) { + probeEncoder.Encode(state) + }) + return func() { probeFile.Close() } +} + +func main() { + flag.Parse() + if *port == 0 { + log.Fatalf("no port provided") + } + if *forward == "" { + log.Fatalf("no forward provided") + } + // Seed the random number generator to ensure that we are given MAC addresses that don't + // for the case of the client and server stack. + rand.Seed(time.Now().UTC().UnixNano()) + + if *cpuprofile != "" { + f, err := os.Create(*cpuprofile) + if err != nil { + log.Fatal("could not create CPU profile: ", err) + } + defer func() { + if err := f.Close(); err != nil { + log.Print("error closing CPU profile: ", err) + } + }() + if err := pprof.StartCPUProfile(f); err != nil { + log.Fatal("could not start CPU profile: ", err) + } + defer pprof.StopCPUProfile() + } + + var ( + in impl + out impl + err error + ) + if *server { + in, err = newNetstackImpl("server") + if *serverTCPProbeFile != "" { + defer in.(netstackImpl).installProbe(*serverTCPProbeFile)() + } + + } else { + in = netImpl{} + } + if err != nil { + log.Fatalf("netstack error: %v", err) + } + if *client { + out, err = newNetstackImpl("client") + if *clientTCPProbeFile != "" { + defer out.(netstackImpl).installProbe(*clientTCPProbeFile)() + } + } else { + out = netImpl{} + } + if err != nil { + log.Fatalf("netstack error: %v", err) + } + + // Dial forward before binding. + var next net.Conn + for { + next, err = out.dial(*forward) + if err == nil { + break + } + time.Sleep(50 * time.Millisecond) + log.Printf("connect failed retrying: %v", err) + } + + // Bind once to the server socket. + listener, err := in.listen(*port) + if err != nil { + // Should not happen, everything must be bound by this time + // this proxy is started. + log.Fatalf("unable to listen: %v", err) + } + log.Printf("client=%v, server=%v, ready.", *client, *server) + + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGTERM) + go func() { + <-sigs + if *cpuprofile != "" { + pprof.StopCPUProfile() + } + if *memprofile != "" { + f, err := os.Create(*memprofile) + if err != nil { + log.Fatal("could not create memory profile: ", err) + } + defer func() { + if err := f.Close(); err != nil { + log.Print("error closing memory profile: ", err) + } + }() + runtime.GC() // get up-to-date statistics + if err := pprof.WriteHeapProfile(f); err != nil { + log.Fatalf("Unable to write heap profile: %v", err) + } + } + os.Exit(0) + }() + + for { + // Forward all connections. + inConn, err := listener.Accept() + if err != nil { + // This should not happen; we are listening + // successfully. Exhausted all available FDs? + log.Fatalf("accept error: %v", err) + } + log.Printf("incoming connection established.") + + // Copy both ways. + go io.Copy(inConn, next) + go io.Copy(next, inConn) + + // Print stats every second. + go func() { + t := time.NewTicker(time.Second) + defer t.Stop() + for { + <-t.C + in.printStats() + out.printStats() + } + }() + + for { + // Dial again. + next, err = out.dial(*forward) + if err == nil { + break + } + } + } +} diff --git a/test/benchmarks/tools/BUILD b/test/benchmarks/tools/BUILD new file mode 100644 index 000000000..e5734d85c --- /dev/null +++ b/test/benchmarks/tools/BUILD @@ -0,0 +1,33 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "tools", + srcs = [ + "ab.go", + "fio.go", + "hey.go", + "iperf.go", + "meminfo.go", + "redis.go", + "sysbench.go", + "tools.go", + ], + visibility = ["//:sandbox"], +) + +go_test( + name = "tools_test", + size = "small", + srcs = [ + "ab_test.go", + "fio_test.go", + "hey_test.go", + "iperf_test.go", + "meminfo_test.go", + "redis_test.go", + "sysbench_test.go", + ], + library = ":tools", +) diff --git a/test/benchmarks/tools/ab.go b/test/benchmarks/tools/ab.go new file mode 100644 index 000000000..4cc9c3bce --- /dev/null +++ b/test/benchmarks/tools/ab.go @@ -0,0 +1,94 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "fmt" + "net" + "regexp" + "strconv" + "testing" +) + +// ApacheBench is for the client application ApacheBench. +type ApacheBench struct { + Requests int + Concurrency int + Doc string + // TODO(zkoopmans): support KeepAlive and pass option to enable. +} + +// MakeCmd makes an ApacheBench command. +func (a *ApacheBench) MakeCmd(ip net.IP, port int) []string { + path := fmt.Sprintf("http://%s:%d/%s", ip, port, a.Doc) + // See apachebench (ab) for flags. + cmd := fmt.Sprintf("ab -n %d -c %d %s", a.Requests, a.Concurrency, path) + return []string{"sh", "-c", cmd} +} + +// Report parses and reports metrics from ApacheBench output. +func (a *ApacheBench) Report(b *testing.B, output string) { + // Parse and report custom metrics. + transferRate, err := a.parseTransferRate(output) + if err != nil { + b.Logf("failed to parse transferrate: %v", err) + } + b.ReportMetric(transferRate*1024, "transfer_rate_b/s") // Convert from Kb/s to b/s. + + latency, err := a.parseLatency(output) + if err != nil { + b.Logf("failed to parse latency: %v", err) + } + b.ReportMetric(latency/1000, "mean_latency_secs") // Convert from ms to s. + + reqPerSecond, err := a.parseRequestsPerSecond(output) + if err != nil { + b.Logf("failed to parse requests per second: %v", err) + } + b.ReportMetric(reqPerSecond, "requests_per_second") +} + +var transferRateRE = regexp.MustCompile(`Transfer rate:\s+(\d+\.?\d+?)\s+\[Kbytes/sec\]\s+received`) + +// parseTransferRate parses transfer rate from ApacheBench output. +func (a *ApacheBench) parseTransferRate(data string) (float64, error) { + match := transferRateRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0, fmt.Errorf("failed get bandwidth: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} + +var latencyRE = regexp.MustCompile(`Total:\s+\d+\s+(\d+)\s+(\d+\.?\d+?)\s+\d+\s+\d+\s`) + +// parseLatency parses latency from ApacheBench output. +func (a *ApacheBench) parseLatency(data string) (float64, error) { + match := latencyRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0, fmt.Errorf("failed get bandwidth: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} + +var requestsPerSecondRE = regexp.MustCompile(`Requests per second:\s+(\d+\.?\d+?)\s+`) + +// parseRequestsPerSecond parses requests per second from ApacheBench output. +func (a *ApacheBench) parseRequestsPerSecond(data string) (float64, error) { + match := requestsPerSecondRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0, fmt.Errorf("failed get bandwidth: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} diff --git a/test/benchmarks/tools/ab_test.go b/test/benchmarks/tools/ab_test.go new file mode 100644 index 000000000..28ee66ec1 --- /dev/null +++ b/test/benchmarks/tools/ab_test.go @@ -0,0 +1,90 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import "testing" + +// TestApacheBench checks the ApacheBench parsers on sample output. +func TestApacheBench(t *testing.T) { + // Sample output from apachebench. + sampleData := `This is ApacheBench, Version 2.3 <$Revision: 1826891 $> +Copyright 1996 Adam Twiss, Zeus Technology Ltd, http://www.zeustech.net/ +Licensed to The Apache Software Foundation, http://www.apache.org/ + +Benchmarking 10.10.10.10 (be patient).....done + + +Server Software: Apache/2.4.38 +Server Hostname: 10.10.10.10 +Server Port: 80 + +Document Path: /latin10k.txt +Document Length: 210 bytes + +Concurrency Level: 1 +Time taken for tests: 0.180 seconds +Complete requests: 100 +Failed requests: 0 +Non-2xx responses: 100 +Total transferred: 38800 bytes +HTML transferred: 21000 bytes +Requests per second: 556.44 [#/sec] (mean) +Time per request: 1.797 [ms] (mean) +Time per request: 1.797 [ms] (mean, across all concurrent requests) +Transfer rate: 210.84 [Kbytes/sec] received + +Connection Times (ms) + min mean[+/-sd] median max +Connect: 0 0 0.2 0 2 +Processing: 1 2 1.0 1 8 +Waiting: 1 1 1.0 1 7 +Total: 1 2 1.2 1 10 + +Percentage of the requests served within a certain time (ms) + 50% 1 + 66% 2 + 75% 2 + 80% 2 + 90% 2 + 95% 3 + 98% 7 + 99% 10 + 100% 10 (longest request)` + + ab := ApacheBench{} + want := 210.84 + got, err := ab.parseTransferRate(sampleData) + if err != nil { + t.Fatalf("failed to parse transfer rate with error: %v", err) + } else if got != want { + t.Fatalf("parseTransferRate got: %f, want: %f", got, want) + } + + want = 2.0 + got, err = ab.parseLatency(sampleData) + if err != nil { + t.Fatalf("failed to parse transfer rate with error: %v", err) + } else if got != want { + t.Fatalf("parseLatency got: %f, want: %f", got, want) + } + + want = 556.44 + got, err = ab.parseRequestsPerSecond(sampleData) + if err != nil { + t.Fatalf("failed to parse transfer rate with error: %v", err) + } else if got != want { + t.Fatalf("parseRequestsPerSecond got: %f, want: %f", got, want) + } +} diff --git a/test/benchmarks/tools/fio.go b/test/benchmarks/tools/fio.go new file mode 100644 index 000000000..20000db16 --- /dev/null +++ b/test/benchmarks/tools/fio.go @@ -0,0 +1,124 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "encoding/json" + "fmt" + "strconv" + "strings" + "testing" +) + +// Fio makes 'fio' commands and parses their output. +type Fio struct { + Test string // test to run: read, write, randread, randwrite. + Size string // total size to be read/written of format N[GMK] (e.g. 5G). + Blocksize string // blocksize to be read/write of format N[GMK] (e.g. 4K). + Iodepth int // iodepth for reads/writes. + Time int // time to run the test in seconds, usually for rand(read/write). +} + +// MakeCmd makes a 'fio' command. +func (f *Fio) MakeCmd(filename string) []string { + cmd := []string{"fio", "--output-format=json", "--ioengine=sync"} + cmd = append(cmd, fmt.Sprintf("--name=%s", f.Test)) + cmd = append(cmd, fmt.Sprintf("--size=%s", f.Size)) + cmd = append(cmd, fmt.Sprintf("--blocksize=%s", f.Blocksize)) + cmd = append(cmd, fmt.Sprintf("--filename=%s", filename)) + cmd = append(cmd, fmt.Sprintf("--iodepth=%d", f.Iodepth)) + cmd = append(cmd, fmt.Sprintf("--rw=%s", f.Test)) + if f.Time != 0 { + cmd = append(cmd, "--time_based") + cmd = append(cmd, fmt.Sprintf("--runtime=%d", f.Time)) + } + return cmd +} + +// Report reports metrics based on output from an 'fio' command. +func (f *Fio) Report(b *testing.B, output string) { + b.Helper() + // Parse the output and report the metrics. + isRead := strings.Contains(f.Test, "read") + bw, err := f.parseBandwidth(output, isRead) + if err != nil { + b.Fatalf("failed to parse bandwidth from %s with: %v", output, err) + } + b.ReportMetric(bw, "bandwidth_b/s") // in b/s. + + iops, err := f.parseIOps(output, isRead) + if err != nil { + b.Fatalf("failed to parse iops from %s with: %v", output, err) + } + b.ReportMetric(iops, "iops") +} + +// parseBandwidth reports the bandwidth in b/s. +func (f *Fio) parseBandwidth(data string, isRead bool) (float64, error) { + if isRead { + result, err := f.parseFioJSON(data, "read", "bw") + if err != nil { + return 0, err + } + return 1024 * result, nil + } + result, err := f.parseFioJSON(data, "write", "bw") + if err != nil { + return 0, err + } + return 1024 * result, nil +} + +// parseIOps reports the write IO per second metric. +func (f *Fio) parseIOps(data string, isRead bool) (float64, error) { + if isRead { + return f.parseFioJSON(data, "read", "iops") + } + return f.parseFioJSON(data, "write", "iops") +} + +// fioResult is for parsing FioJSON. +type fioResult struct { + Jobs []fioJob +} + +// fioJob is for parsing FioJSON. +type fioJob map[string]json.RawMessage + +// fioMetrics is for parsing FioJSON. +type fioMetrics map[string]json.RawMessage + +// parseFioJSON parses data and grabs "op" (read or write) and "metric" +// (bw or iops) from the JSON. +func (f *Fio) parseFioJSON(data, op, metric string) (float64, error) { + var result fioResult + if err := json.Unmarshal([]byte(data), &result); err != nil { + return 0, fmt.Errorf("could not unmarshal data: %v", err) + } + + if len(result.Jobs) < 1 { + return 0, fmt.Errorf("no jobs present to parse") + } + + var metrics fioMetrics + if err := json.Unmarshal(result.Jobs[0][op], &metrics); err != nil { + return 0, fmt.Errorf("could not unmarshal jobs: %v", err) + } + + if _, ok := metrics[metric]; !ok { + return 0, fmt.Errorf("no metric found for op: %s", op) + } + return strconv.ParseFloat(string(metrics[metric]), 64) +} diff --git a/test/benchmarks/tools/fio_test.go b/test/benchmarks/tools/fio_test.go new file mode 100644 index 000000000..a98277150 --- /dev/null +++ b/test/benchmarks/tools/fio_test.go @@ -0,0 +1,122 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import "testing" + +// TestFio checks the Fio parsers on sample output. +func TestFio(t *testing.T) { + sampleData := ` +{ + "fio version" : "fio-3.1", + "timestamp" : 1554837456, + "timestamp_ms" : 1554837456621, + "time" : "Tue Apr 9 19:17:36 2019", + "jobs" : [ + { + "jobname" : "test", + "groupid" : 0, + "error" : 0, + "eta" : 2147483647, + "elapsed" : 1, + "job options" : { + "name" : "test", + "ioengine" : "sync", + "size" : "1073741824", + "filename" : "/disk/file.dat", + "iodepth" : "4", + "bs" : "4096", + "rw" : "write" + }, + "read" : { + "io_bytes" : 0, + "io_kbytes" : 0, + "bw" : 123456, + "iops" : 1234.5678, + "runtime" : 0, + "total_ios" : 0, + "short_ios" : 0, + "bw_min" : 0, + "bw_max" : 0, + "bw_agg" : 0.000000, + "bw_mean" : 0.000000, + "bw_dev" : 0.000000, + "bw_samples" : 0, + "iops_min" : 0, + "iops_max" : 0, + "iops_mean" : 0.000000, + "iops_stddev" : 0.000000, + "iops_samples" : 0 + }, + "write" : { + "io_bytes" : 1073741824, + "io_kbytes" : 1048576, + "bw" : 1753471, + "iops" : 438367.892977, + "runtime" : 598, + "total_ios" : 262144, + "bw_min" : 1731120, + "bw_max" : 1731120, + "bw_agg" : 98.725328, + "bw_mean" : 1731120.000000, + "bw_dev" : 0.000000, + "bw_samples" : 1, + "iops_min" : 432780, + "iops_max" : 432780, + "iops_mean" : 432780.000000, + "iops_stddev" : 0.000000, + "iops_samples" : 1 + } + } + ] +} +` + fio := Fio{} + // WriteBandwidth. + got, err := fio.parseBandwidth(sampleData, false) + var want float64 = 1753471.0 * 1024 + if err != nil { + t.Fatalf("parse failed with err: %v", err) + } else if got != want { + t.Fatalf("got: %f, want: %f", got, want) + } + + // ReadBandwidth. + got, err = fio.parseBandwidth(sampleData, true) + want = 123456 * 1024 + if err != nil { + t.Fatalf("parse failed with err: %v", err) + } else if got != want { + t.Fatalf("got: %f, want: %f", got, want) + } + + // WriteIOps. + got, err = fio.parseIOps(sampleData, false) + want = 438367.892977 + if err != nil { + t.Fatalf("parse failed with err: %v", err) + } else if got != want { + t.Fatalf("got: %f, want: %f", got, want) + } + + // ReadIOps. + got, err = fio.parseIOps(sampleData, true) + want = 1234.5678 + if err != nil { + t.Fatalf("parse failed with err: %v", err) + } else if got != want { + t.Fatalf("got: %f, want: %f", got, want) + } +} diff --git a/test/benchmarks/tools/hey.go b/test/benchmarks/tools/hey.go new file mode 100644 index 000000000..b1e20e356 --- /dev/null +++ b/test/benchmarks/tools/hey.go @@ -0,0 +1,75 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "fmt" + "net" + "regexp" + "strconv" + "strings" + "testing" +) + +// Hey is for the client application 'hey'. +type Hey struct { + Requests int // Note: requests cannot be less than concurrency. + Concurrency int + Doc string +} + +// MakeCmd returns a 'hey' command. +func (h *Hey) MakeCmd(ip net.IP, port int) []string { + return strings.Split(fmt.Sprintf("hey -n %d -c %d http://%s:%d/%s", + h.Requests, h.Concurrency, ip, port, h.Doc), " ") +} + +// Report parses output from 'hey' and reports metrics. +func (h *Hey) Report(b *testing.B, output string) { + b.Helper() + requests, err := h.parseRequestsPerSecond(output) + if err != nil { + b.Fatalf("failed to parse requests per second: %v", err) + } + b.ReportMetric(requests, "requests_per_second") + + ave, err := h.parseAverageLatency(output) + if err != nil { + b.Fatalf("failed to parse average latency: %v", err) + } + b.ReportMetric(ave, "average_latency_secs") +} + +var heyReqPerSecondRE = regexp.MustCompile(`Requests/sec:\s*(\d+\.?\d+?)\s+`) + +// parseRequestsPerSecond finds requests per second from 'hey' output. +func (h *Hey) parseRequestsPerSecond(data string) (float64, error) { + match := heyReqPerSecondRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0, fmt.Errorf("failed get bandwidth: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} + +var heyAverageLatencyRE = regexp.MustCompile(`Average:\s*(\d+\.?\d+?)\s+secs`) + +// parseHeyAverageLatency finds Average Latency in seconds form 'hey' output. +func (h *Hey) parseAverageLatency(data string) (float64, error) { + match := heyAverageLatencyRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0, fmt.Errorf("failed get average latency match%d : %s", len(match), data) + } + return strconv.ParseFloat(match[1], 64) +} diff --git a/test/benchmarks/tools/hey_test.go b/test/benchmarks/tools/hey_test.go new file mode 100644 index 000000000..e0cab1f52 --- /dev/null +++ b/test/benchmarks/tools/hey_test.go @@ -0,0 +1,81 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import "testing" + +// TestHey checks the Hey parsers on sample output. +func TestHey(t *testing.T) { + sampleData := ` + Summary: + Total: 2.2391 secs + Slowest: 1.6292 secs + Fastest: 0.0066 secs + Average: 0.5351 secs + Requests/sec: 89.3202 + + Total data: 841200 bytes + Size/request: 4206 bytes + + Response time histogram: + 0.007 [1] | + 0.169 [0] | + 0.331 [149] |â– â– â– â– â– â– â– â– â– â– â– â– â– â– â– â– â– â– â– â– â– â– â– â– â– â– â– â– â– â– â– â– â– â– â– â– â– â– â– â– + 0.493 [0] | + 0.656 [0] | + 0.818 [0] | + 0.980 [0] | + 1.142 [0] | + 1.305 [0] | + 1.467 [49] |â– â– â– â– â– â– â– â– â– â– â– â– â– + 1.629 [1] | + + + Latency distribution: + 10% in 0.2149 secs + 25% in 0.2449 secs + 50% in 0.2703 secs + 75% in 1.3315 secs + 90% in 1.4045 secs + 95% in 1.4232 secs + 99% in 1.4362 secs + + Details (average, fastest, slowest): + DNS+dialup: 0.0002 secs, 0.0066 secs, 1.6292 secs + DNS-lookup: 0.0000 secs, 0.0000 secs, 0.0000 secs + req write: 0.0000 secs, 0.0000 secs, 0.0012 secs + resp wait: 0.5225 secs, 0.0064 secs, 1.4346 secs + resp read: 0.0122 secs, 0.0001 secs, 0.2006 secs + + Status code distribution: + [200] 200 responses + ` + hey := Hey{} + want := 89.3202 + got, err := hey.parseRequestsPerSecond(sampleData) + if err != nil { + t.Fatalf("failed to parse request per second with: %v", err) + } else if got != want { + t.Fatalf("got: %f, want: %f", got, want) + } + + want = 0.5351 + got, err = hey.parseAverageLatency(sampleData) + if err != nil { + t.Fatalf("failed to parse average latency with: %v", err) + } else if got != want { + t.Fatalf("got: %f, want: %f", got, want) + } +} diff --git a/test/benchmarks/tools/iperf.go b/test/benchmarks/tools/iperf.go new file mode 100644 index 000000000..df3d9349b --- /dev/null +++ b/test/benchmarks/tools/iperf.go @@ -0,0 +1,56 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "fmt" + "net" + "regexp" + "strconv" + "strings" + "testing" +) + +// Iperf is for the client side of `iperf`. +type Iperf struct { + Time int +} + +// MakeCmd returns a iperf client command. +func (i *Iperf) MakeCmd(ip net.IP, port int) []string { + // iperf report in Kb realtime + return strings.Split(fmt.Sprintf("iperf -f K --realtime --time %d -c %s -p %d", i.Time, ip, port), " ") +} + +// Report parses output from iperf client and reports metrics. +func (i *Iperf) Report(b *testing.B, output string) { + b.Helper() + // Parse bandwidth and report it. + bW, err := i.bandwidth(output) + if err != nil { + b.Fatalf("failed to parse bandwitdth from %s: %v", output, err) + } + b.ReportMetric(bW*1024, "bandwidth_b/s") // Convert from Kb/s to b/s. +} + +// bandwidth parses the Bandwidth number from an iperf report. A sample is below. +func (i *Iperf) bandwidth(data string) (float64, error) { + re := regexp.MustCompile(`\[\s*\d+\][^\n]+\s+(\d+\.?\d*)\s+KBytes/sec`) + match := re.FindStringSubmatch(data) + if len(match) < 1 { + return 0, fmt.Errorf("failed get bandwidth: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} diff --git a/test/benchmarks/tools/iperf_test.go b/test/benchmarks/tools/iperf_test.go new file mode 100644 index 000000000..03bb30d05 --- /dev/null +++ b/test/benchmarks/tools/iperf_test.go @@ -0,0 +1,34 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package tools + +import "testing" + +// TestIperf checks the Iperf parsers on sample output. +func TestIperf(t *testing.T) { + sampleData := ` +------------------------------------------------------------ +Client connecting to 10.138.15.215, TCP port 32779 +TCP window size: 45.0 KByte (default) +------------------------------------------------------------ +[ 3] local 10.138.15.216 port 32866 connected with 10.138.15.215 port 32779 +[ ID] Interval Transfer Bandwidth +[ 3] 0.0-10.0 sec 459520 KBytes 45900 KBytes/sec +` + i := Iperf{} + bandwidth, err := i.bandwidth(sampleData) + if err != nil || bandwidth != 45900 { + t.Fatalf("failed with: %v and %f", err, bandwidth) + } +} diff --git a/test/benchmarks/tools/meminfo.go b/test/benchmarks/tools/meminfo.go new file mode 100644 index 000000000..2414a96a7 --- /dev/null +++ b/test/benchmarks/tools/meminfo.go @@ -0,0 +1,60 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "fmt" + "regexp" + "strconv" + "testing" +) + +// Meminfo wraps measurements of MemAvailable using /proc/meminfo. +type Meminfo struct { +} + +// MakeCmd returns a command for checking meminfo. +func (*Meminfo) MakeCmd() (string, []string) { + return "cat", []string{"/proc/meminfo"} +} + +// Report takes two reads of meminfo, parses them, and reports the difference +// divided by b.N. +func (*Meminfo) Report(b *testing.B, before, after string) { + b.Helper() + + beforeVal, err := parseMemAvailable(before) + if err != nil { + b.Fatalf("could not parse before value %s: %v", before, err) + } + + afterVal, err := parseMemAvailable(after) + if err != nil { + b.Fatalf("could not parse before value %s: %v", before, err) + } + val := 1024 * ((beforeVal - afterVal) / float64(b.N)) + b.ReportMetric(val, "average_container_size_bytes") +} + +var memInfoRE = regexp.MustCompile(`MemAvailable:\s*(\d+)\skB\n`) + +// parseMemAvailable grabs the MemAvailable number from /proc/meminfo. +func parseMemAvailable(data string) (float64, error) { + match := memInfoRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0, fmt.Errorf("couldn't find MemAvailable in %s", data) + } + return strconv.ParseFloat(match[1], 64) +} diff --git a/test/benchmarks/tools/meminfo_test.go b/test/benchmarks/tools/meminfo_test.go new file mode 100644 index 000000000..ba803540f --- /dev/null +++ b/test/benchmarks/tools/meminfo_test.go @@ -0,0 +1,84 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "testing" +) + +// TestMeminfo checks the Meminfo parser on sample output. +func TestMeminfo(t *testing.T) { + sampleData := ` +MemTotal: 16337408 kB +MemFree: 3742696 kB +MemAvailable: 9319948 kB +Buffers: 1433884 kB +Cached: 4607036 kB +SwapCached: 45284 kB +Active: 8288376 kB +Inactive: 2685928 kB +Active(anon): 4724912 kB +Inactive(anon): 1047940 kB +Active(file): 3563464 kB +Inactive(file): 1637988 kB +Unevictable: 326940 kB +Mlocked: 48 kB +SwapTotal: 33292284 kB +SwapFree: 32865736 kB +Dirty: 708 kB +Writeback: 0 kB +AnonPages: 4304204 kB +Mapped: 975424 kB +Shmem: 910292 kB +KReclaimable: 744532 kB +Slab: 1058448 kB +SReclaimable: 744532 kB +SUnreclaim: 313916 kB +KernelStack: 25188 kB +PageTables: 65300 kB +NFS_Unstable: 0 kB +Bounce: 0 kB +WritebackTmp: 0 kB +CommitLimit: 41460988 kB +Committed_AS: 22859492 kB +VmallocTotal: 34359738367 kB +VmallocUsed: 63088 kB +VmallocChunk: 0 kB +Percpu: 9248 kB +HardwareCorrupted: 0 kB +AnonHugePages: 786432 kB +ShmemHugePages: 0 kB +ShmemPmdMapped: 0 kB +FileHugePages: 0 kB +FilePmdMapped: 0 kB +HugePages_Total: 0 +HugePages_Free: 0 +HugePages_Rsvd: 0 +HugePages_Surp: 0 +Hugepagesize: 2048 kB +Hugetlb: 0 kB +DirectMap4k: 5408532 kB +DirectMap2M: 11241472 kB +DirectMap1G: 1048576 kB +` + want := 9319948.0 + got, err := parseMemAvailable(sampleData) + if err != nil { + t.Fatalf("parseMemAvailable failed: %v", err) + } + if got != want { + t.Fatalf("parseMemAvailable got %f, want %f", got, want) + } +} diff --git a/test/benchmarks/tools/redis.go b/test/benchmarks/tools/redis.go new file mode 100644 index 000000000..c899ae0d4 --- /dev/null +++ b/test/benchmarks/tools/redis.go @@ -0,0 +1,63 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "fmt" + "net" + "regexp" + "strconv" + "strings" + "testing" +) + +// Redis is for the client 'redis-benchmark'. +type Redis struct { + Operation string +} + +// MakeCmd returns a redis-benchmark client command. +func (r *Redis) MakeCmd(ip net.IP, port int) []string { + // There is no -t PING_BULK for redis-benchmark, so adjust the command in that case. + // Note that "ping" will run both PING_INLINE and PING_BULK. + if r.Operation == "PING_BULK" { + return strings.Split( + fmt.Sprintf("redis-benchmark --csv -t ping -h %s -p %d", ip, port), " ") + } + + // runs redis-benchmark -t operation for 100K requests against server. + return strings.Split( + fmt.Sprintf("redis-benchmark --csv -t %s -h %s -p %d", r.Operation, ip, port), " ") +} + +// Report parses output from redis-benchmark client and reports metrics. +func (r *Redis) Report(b *testing.B, output string) { + b.Helper() + result, err := r.parseOperation(output) + if err != nil { + b.Fatalf("parsing result %s failed with err: %v", output, err) + } + b.ReportMetric(result, r.Operation) // operations per second +} + +// parseOperation grabs the metric operations per second from redis-benchmark output. +func (r *Redis) parseOperation(data string) (float64, error) { + re := regexp.MustCompile(fmt.Sprintf(`"%s( .*)?","(\d*\.\d*)"`, r.Operation)) + match := re.FindStringSubmatch(data) + if len(match) < 3 { + return 0.0, fmt.Errorf("could not find %s in %s", r.Operation, data) + } + return strconv.ParseFloat(match[2], 64) +} diff --git a/test/benchmarks/tools/redis_test.go b/test/benchmarks/tools/redis_test.go new file mode 100644 index 000000000..4bafda66f --- /dev/null +++ b/test/benchmarks/tools/redis_test.go @@ -0,0 +1,87 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "testing" +) + +// TestRedis checks the Redis parsers on sample output. +func TestRedis(t *testing.T) { + sampleData := ` + "PING_INLINE","48661.80" + "PING_BULK","50301.81" + "SET","48923.68" + "GET","49382.71" + "INCR","49975.02" + "LPUSH","49875.31" + "RPUSH","50276.52" + "LPOP","50327.12" + "RPOP","50556.12" + "SADD","49504.95" + "HSET","49504.95" + "SPOP","50025.02" + "LPUSH (needed to benchmark LRANGE)","48875.86" + "LRANGE_100 (first 100 elements)","33955.86" + "LRANGE_300 (first 300 elements)","16550.81"// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + + "LRANGE_500 (first 450 elements)","13653.74" + "LRANGE_600 (first 600 elements)","11219.57" + "MSET (10 keys)","44682.75" + ` + wants := map[string]float64{ + "PING_INLINE": 48661.80, + "PING_BULK": 50301.81, + "SET": 48923.68, + "GET": 49382.71, + "INCR": 49975.02, + "LPUSH": 49875.31, + "RPUSH": 50276.52, + "LPOP": 50327.12, + "RPOP": 50556.12, + "SADD": 49504.95, + "HSET": 49504.95, + "SPOP": 50025.02, + "LRANGE_100": 33955.86, + "LRANGE_300": 16550.81, + "LRANGE_500": 13653.74, + "LRANGE_600": 11219.57, + "MSET": 44682.75, + } + for op, want := range wants { + redis := Redis{ + Operation: op, + } + if got, err := redis.parseOperation(sampleData); err != nil { + t.Fatalf("failed to parse %s: %v", op, err) + } else if want != got { + t.Fatalf("wanted %f for op %s, got %f", want, op, got) + } + } +} diff --git a/test/benchmarks/tools/sysbench.go b/test/benchmarks/tools/sysbench.go new file mode 100644 index 000000000..6b2f75ca2 --- /dev/null +++ b/test/benchmarks/tools/sysbench.go @@ -0,0 +1,245 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "fmt" + "regexp" + "strconv" + "strings" + "testing" +) + +var warmup = "sysbench --threads=8 --memory-total-size=5G memory run > /dev/null &&" + +// Sysbench represents a 'sysbench' command. +type Sysbench interface { + MakeCmd() []string // Makes a sysbench command. + flags() []string + Report(*testing.B, string) // Reports results contained in string. +} + +// SysbenchBase is the top level struct for sysbench and holds top-level arguments +// for sysbench. See: 'sysbench --help' +type SysbenchBase struct { + Threads int // number of Threads for the test. + Time int // time limit for test in seconds. +} + +// baseFlags returns top level flags. +func (s *SysbenchBase) baseFlags() []string { + var ret []string + if s.Threads > 0 { + ret = append(ret, fmt.Sprintf("--threads=%d", s.Threads)) + } + if s.Time > 0 { + ret = append(ret, fmt.Sprintf("--time=%d", s.Time)) + } + return ret +} + +// SysbenchCPU is for 'sysbench [flags] cpu run' and holds CPU specific arguments. +type SysbenchCPU struct { + Base SysbenchBase + MaxPrime int // upper limit for primes generator [10000]. +} + +// MakeCmd makes commands for SysbenchCPU. +func (s *SysbenchCPU) MakeCmd() []string { + cmd := []string{warmup, "sysbench"} + cmd = append(cmd, s.flags()...) + cmd = append(cmd, "cpu run") + return []string{"sh", "-c", strings.Join(cmd, " ")} +} + +// flags makes flags for SysbenchCPU cmds. +func (s *SysbenchCPU) flags() []string { + cmd := s.Base.baseFlags() + if s.MaxPrime > 0 { + return append(cmd, fmt.Sprintf("--cpu-max-prime=%d", s.MaxPrime)) + } + return cmd +} + +// Report reports the relevant metrics for SysbenchCPU. +func (s *SysbenchCPU) Report(b *testing.B, output string) { + b.Helper() + result, err := s.parseEvents(output) + if err != nil { + b.Fatalf("parsing CPU events from %s failed: %v", output, err) + } + b.ReportMetric(result, "cpu_events_per_second") +} + +var cpuEventsPerSecondRE = regexp.MustCompile(`events per second:\s*(\d*.?\d*)\n`) + +// parseEvents parses cpu events per second. +func (s *SysbenchCPU) parseEvents(data string) (float64, error) { + match := cpuEventsPerSecondRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0.0, fmt.Errorf("could not find events per second: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} + +// SysbenchMemory is for 'sysbench [FLAGS] memory run' and holds Memory specific arguments. +type SysbenchMemory struct { + Base SysbenchBase + BlockSize string // size of test memory block [1K]. + TotalSize string // size of data to transfer [100G]. + Scope string // memory access scope {global, local} [global]. + HugeTLB bool // allocate memory from HugeTLB [off]. + OperationType string // type of memory ops {read, write, none} [write]. + AccessMode string // access mode {seq, rnd} [seq]. +} + +// MakeCmd makes commands for SysbenchMemory. +func (s *SysbenchMemory) MakeCmd() []string { + cmd := []string{warmup, "sysbench"} + cmd = append(cmd, s.flags()...) + cmd = append(cmd, "memory run") + return []string{"sh", "-c", strings.Join(cmd, " ")} +} + +// flags makes flags for SysbenchMemory cmds. +func (s *SysbenchMemory) flags() []string { + cmd := s.Base.baseFlags() + if s.BlockSize != "" { + cmd = append(cmd, fmt.Sprintf("--memory-block-size=%s", s.BlockSize)) + } + if s.TotalSize != "" { + cmd = append(cmd, fmt.Sprintf("--memory-total-size=%s", s.TotalSize)) + } + if s.Scope != "" { + cmd = append(cmd, fmt.Sprintf("--memory-scope=%s", s.Scope)) + } + if s.HugeTLB { + cmd = append(cmd, "--memory-hugetlb=on") + } + if s.OperationType != "" { + cmd = append(cmd, fmt.Sprintf("--memory-oper=%s", s.OperationType)) + } + if s.AccessMode != "" { + cmd = append(cmd, fmt.Sprintf("--memory-access-mode=%s", s.AccessMode)) + } + return cmd +} + +// Report reports the relevant metrics for SysbenchMemory. +func (s *SysbenchMemory) Report(b *testing.B, output string) { + b.Helper() + result, err := s.parseOperations(output) + if err != nil { + b.Fatalf("parsing result %s failed with err: %v", output, err) + } + b.ReportMetric(result, "operations_per_second") +} + +var memoryOperationsRE = regexp.MustCompile(`Total\soperations:\s+\d*\s*\((\d*\.\d*)\sper\ssecond\)`) + +// parseOperations parses memory operations per second form sysbench memory ouput. +func (s *SysbenchMemory) parseOperations(data string) (float64, error) { + match := memoryOperationsRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0.0, fmt.Errorf("couldn't find memory operations per second: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} + +// SysbenchMutex is for 'sysbench [FLAGS] mutex run' and holds Mutex specific arguments. +type SysbenchMutex struct { + Base SysbenchBase + Num int // total size of mutex array [4096]. + Locks int // number of mutex locks per thread [50K]. + Loops int // number of loops to do outside mutex lock [10K]. +} + +// MakeCmd makes commands for SysbenchMutex. +func (s *SysbenchMutex) MakeCmd() []string { + cmd := []string{warmup, "sysbench"} + cmd = append(cmd, s.flags()...) + cmd = append(cmd, "mutex run") + return []string{"sh", "-c", strings.Join(cmd, " ")} +} + +// flags makes flags for SysbenchMutex commands. +func (s *SysbenchMutex) flags() []string { + var cmd []string + cmd = append(cmd, s.Base.baseFlags()...) + if s.Num > 0 { + cmd = append(cmd, fmt.Sprintf("--mutex-num=%d", s.Num)) + } + if s.Locks > 0 { + cmd = append(cmd, fmt.Sprintf("--mutex-locks=%d", s.Locks)) + } + if s.Loops > 0 { + cmd = append(cmd, fmt.Sprintf("--mutex-loops=%d", s.Loops)) + } + return cmd +} + +// Report parses and reports relevant sysbench mutex metrics. +func (s *SysbenchMutex) Report(b *testing.B, output string) { + b.Helper() + + result, err := s.parseExecutionTime(output) + if err != nil { + b.Fatalf("parsing result %s failed with err: %v", output, err) + } + b.ReportMetric(result, "average_execution_time_secs") + + result, err = s.parseDeviation(output) + if err != nil { + b.Fatalf("parsing result %s failed with err: %v", output, err) + } + b.ReportMetric(result, "stdev_execution_time_secs") + + result, err = s.parseLatency(output) + if err != nil { + b.Fatalf("parsing result %s failed with err: %v", output, err) + } + b.ReportMetric(result/1000, "average_latency_secs") +} + +var executionTimeRE = regexp.MustCompile(`execution time \(avg/stddev\):\s*(\d*.?\d*)/(\d*.?\d*)`) + +// parseExecutionTime parses threads fairness average execution time from sysbench output. +func (s *SysbenchMutex) parseExecutionTime(data string) (float64, error) { + match := executionTimeRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0.0, fmt.Errorf("could not find execution time average: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} + +// parseDeviation parses threads fairness stddev time from sysbench output. +func (s *SysbenchMutex) parseDeviation(data string) (float64, error) { + match := executionTimeRE.FindStringSubmatch(data) + if len(match) < 3 { + return 0.0, fmt.Errorf("could not find execution time deviation: %s", data) + } + return strconv.ParseFloat(match[2], 64) +} + +var averageLatencyRE = regexp.MustCompile(`avg:[^\n^\d]*(\d*\.?\d*)`) + +// parseLatency parses latency from sysbench output. +func (s *SysbenchMutex) parseLatency(data string) (float64, error) { + match := averageLatencyRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0.0, fmt.Errorf("could not find average latency: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} diff --git a/test/benchmarks/tools/sysbench_test.go b/test/benchmarks/tools/sysbench_test.go new file mode 100644 index 000000000..850d1939e --- /dev/null +++ b/test/benchmarks/tools/sysbench_test.go @@ -0,0 +1,169 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "testing" +) + +// TestSysbenchCpu tests parses on sample 'sysbench cpu' output. +func TestSysbenchCpu(t *testing.T) { + sampleData := ` +sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3) + +Running the test with following options: +Number of threads: 8 +Initializing random number generator from current time + + +Prime numbers limit: 10000 + +Initializing worker threads... + +Threads started! + +CPU speed: + events per second: 9093.38 + +General statistics: + total time: 10.0007s + total number of events: 90949 + +Latency (ms): + min: 0.64 + avg: 0.88 + max: 24.65 + 95th percentile: 1.55 + sum: 79936.91 + +Threads fairness: + events (avg/stddev): 11368.6250/831.38 + execution time (avg/stddev): 9.9921/0.01 +` + sysbench := SysbenchCPU{} + want := 9093.38 + if got, err := sysbench.parseEvents(sampleData); err != nil { + t.Fatalf("parse cpu events failed: %v", err) + } else if want != got { + t.Fatalf("got: %f want: %f", got, want) + } +} + +// TestSysbenchMemory tests parsers on sample 'sysbench memory' output. +func TestSysbenchMemory(t *testing.T) { + sampleData := ` +sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3) + +Running the test with following options: +Number of threads: 8 +Initializing random number generator from current time + + +Running memory speed test with the following options: + block size: 1KiB + total size: 102400MiB + operation: write + scope: global + +Initializing worker threads... + +Threads started! + +Total operations: 47999046 (9597428.64 per second) + +46874.07 MiB transferred (9372.49 MiB/sec) + + +General statistics: + total time: 5.0001s + total number of events: 47999046 + +Latency (ms): + min: 0.00 + avg: 0.00 + max: 0.21 + 95th percentile: 0.00 + sum: 33165.91 + +Threads fairness: + events (avg/stddev): 5999880.7500/111242.52 + execution time (avg/stddev): 4.1457/0.09 +` + sysbench := SysbenchMemory{} + want := 9597428.64 + if got, err := sysbench.parseOperations(sampleData); err != nil { + t.Fatalf("parse memory ops failed: %v", err) + } else if want != got { + t.Fatalf("got: %f want: %f", got, want) + } +} + +// TestSysbenchMutex tests parsers on sample 'sysbench mutex' output. +func TestSysbenchMutex(t *testing.T) { + sampleData := ` +sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3) + +The 'mutex' test requires a command argument. See 'sysbench mutex help' +root@ec078132e294:/# sysbench mutex --threads=8 run +sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3) + +Running the test with following options: +Number of threads: 8 +Initializing random number generator from current time + + +Initializing worker threads... + +Threads started! + + +General statistics: + total time: 0.2320s + total number of events: 8 + +Latency (ms): + min: 152.35 + avg: 192.48 + max: 231.41 + 95th percentile: 231.53 + sum: 1539.83 + +Threads fairness: + events (avg/stddev): 1.0000/0.00 + execution time (avg/stddev): 0.1925/0.04 +` + + sysbench := SysbenchMutex{} + want := .1925 + if got, err := sysbench.parseExecutionTime(sampleData); err != nil { + t.Fatalf("parse mutex time failed: %v", err) + } else if want != got { + t.Fatalf("got: %f want: %f", got, want) + } + + want = 0.04 + if got, err := sysbench.parseDeviation(sampleData); err != nil { + t.Fatalf("parse mutex deviation failed: %v", err) + } else if want != got { + t.Fatalf("got: %f want: %f", got, want) + } + + want = 192.48 + if got, err := sysbench.parseLatency(sampleData); err != nil { + t.Fatalf("parse mutex time failed: %v", err) + } else if want != got { + t.Fatalf("got: %f want: %f", got, want) + } +} diff --git a/test/root/testdata/httpd.go b/test/benchmarks/tools/tools.go index 45d5e33d4..eb61c0136 100644 --- a/test/root/testdata/httpd.go +++ b/test/benchmarks/tools/tools.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,21 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -package testdata - -// Httpd is a JSON config for an httpd container. -const Httpd = ` -{ - "metadata": { - "name": "httpd" - }, - "image":{ - "image": "httpd" - }, - "mounts": [ - ], - "linux": { - }, - "log_path": "httpd.log" -} -` +// Package tools holds tooling to couple command formatting and output parsers +// together. +package tools diff --git a/test/cmd/test_app/BUILD b/test/cmd/test_app/BUILD new file mode 100644 index 000000000..98ba5a3d9 --- /dev/null +++ b/test/cmd/test_app/BUILD @@ -0,0 +1,21 @@ +load("//tools:defs.bzl", "go_binary") + +package(licenses = ["notice"]) + +go_binary( + name = "test_app", + testonly = 1, + srcs = [ + "fds.go", + "test_app.go", + ], + pure = True, + visibility = ["//runsc/container:__pkg__"], + deps = [ + "//pkg/test/testutil", + "//pkg/unet", + "//runsc/flag", + "@com_github_google_subcommands//:go_default_library", + "@com_github_kr_pty//:go_default_library", + ], +) diff --git a/test/cmd/test_app/fds.go b/test/cmd/test_app/fds.go new file mode 100644 index 000000000..a7658eefd --- /dev/null +++ b/test/cmd/test_app/fds.go @@ -0,0 +1,185 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "io/ioutil" + "log" + "os" + "time" + + "github.com/google/subcommands" + "gvisor.dev/gvisor/pkg/test/testutil" + "gvisor.dev/gvisor/pkg/unet" + "gvisor.dev/gvisor/runsc/flag" +) + +const fileContents = "foobarbaz" + +// fdSender will open a file and send the FD over a unix domain socket. +type fdSender struct { + socketPath string +} + +// Name implements subcommands.Command.Name. +func (*fdSender) Name() string { + return "fd_sender" +} + +// Synopsis implements subcommands.Command.Synopsys. +func (*fdSender) Synopsis() string { + return "creates a file and sends the FD over the socket" +} + +// Usage implements subcommands.Command.Usage. +func (*fdSender) Usage() string { + return "fd_sender <flags>" +} + +// SetFlags implements subcommands.Command.SetFlags. +func (fds *fdSender) SetFlags(f *flag.FlagSet) { + f.StringVar(&fds.socketPath, "socket", "", "path to socket") +} + +// Execute implements subcommands.Command.Execute. +func (fds *fdSender) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + if fds.socketPath == "" { + log.Fatalf("socket flag must be set") + } + + dir, err := ioutil.TempDir("", "") + if err != nil { + log.Fatalf("TempDir failed: %v", err) + } + + fileToSend, err := ioutil.TempFile(dir, "") + if err != nil { + log.Fatalf("TempFile failed: %v", err) + } + defer fileToSend.Close() + + if _, err := fileToSend.WriteString(fileContents); err != nil { + log.Fatalf("Write(%q) failed: %v", fileContents, err) + } + + // Receiver may not be started yet, so try connecting in a poll loop. + var s *unet.Socket + if err := testutil.Poll(func() error { + var err error + s, err = unet.Connect(fds.socketPath, true /* SEQPACKET, so we can send empty message with FD */) + return err + }, 10*time.Second); err != nil { + log.Fatalf("Error connecting to socket %q: %v", fds.socketPath, err) + } + defer s.Close() + + w := s.Writer(true) + w.ControlMessage.PackFDs(int(fileToSend.Fd())) + if _, err := w.WriteVec([][]byte{[]byte{'a'}}); err != nil { + log.Fatalf("Error sending FD %q over socket %q: %v", fileToSend.Fd(), fds.socketPath, err) + } + + log.Print("FD SENDER exiting successfully") + return subcommands.ExitSuccess +} + +// fdReceiver receives an FD from a unix domain socket and does things to it. +type fdReceiver struct { + socketPath string +} + +// Name implements subcommands.Command.Name. +func (*fdReceiver) Name() string { + return "fd_receiver" +} + +// Synopsis implements subcommands.Command.Synopsys. +func (*fdReceiver) Synopsis() string { + return "reads an FD from a unix socket, and then does things to it" +} + +// Usage implements subcommands.Command.Usage. +func (*fdReceiver) Usage() string { + return "fd_receiver <flags>" +} + +// SetFlags implements subcommands.Command.SetFlags. +func (fdr *fdReceiver) SetFlags(f *flag.FlagSet) { + f.StringVar(&fdr.socketPath, "socket", "", "path to socket") +} + +// Execute implements subcommands.Command.Execute. +func (fdr *fdReceiver) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + if fdr.socketPath == "" { + log.Fatalf("Flags cannot be empty, given: socket: %q", fdr.socketPath) + } + + ss, err := unet.BindAndListen(fdr.socketPath, true /* packet */) + if err != nil { + log.Fatalf("BindAndListen(%q) failed: %v", fdr.socketPath, err) + } + defer ss.Close() + + var s *unet.Socket + c := make(chan error, 1) + go func() { + var err error + s, err = ss.Accept() + c <- err + }() + + select { + case err := <-c: + if err != nil { + log.Fatalf("Accept() failed: %v", err) + } + case <-time.After(10 * time.Second): + log.Fatalf("Timeout waiting for accept") + } + + r := s.Reader(true) + r.EnableFDs(1) + b := [][]byte{{'a'}} + if n, err := r.ReadVec(b); n != 1 || err != nil { + log.Fatalf("ReadVec got n=%d err %v (wanted 0, nil)", n, err) + } + + fds, err := r.ExtractFDs() + if err != nil { + log.Fatalf("ExtractFD() got err %v", err) + } + if len(fds) != 1 { + log.Fatalf("ExtractFD() got %d FDs, wanted 1", len(fds)) + } + fd := fds[0] + + file := os.NewFile(uintptr(fd), "received file") + defer file.Close() + if _, err := file.Seek(0, os.SEEK_SET); err != nil { + log.Fatalf("Seek(0, 0) failed: %v", err) + } + + got, err := ioutil.ReadAll(file) + if err != nil { + log.Fatalf("ReadAll failed: %v", err) + } + if string(got) != fileContents { + log.Fatalf("ReadAll got %q want %q", string(got), fileContents) + } + + log.Print("FD RECEIVER exiting successfully") + return subcommands.ExitSuccess +} diff --git a/test/cmd/test_app/test_app.go b/test/cmd/test_app/test_app.go new file mode 100644 index 000000000..3ba4f38f8 --- /dev/null +++ b/test/cmd/test_app/test_app.go @@ -0,0 +1,394 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary test_app is like a swiss knife for tests that need to run anything +// inside the sandbox. New functionality can be added with new commands. +package main + +import ( + "context" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "os" + "os/exec" + "regexp" + "strconv" + sys "syscall" + "time" + + "github.com/google/subcommands" + "github.com/kr/pty" + "gvisor.dev/gvisor/pkg/test/testutil" + "gvisor.dev/gvisor/runsc/flag" +) + +func main() { + subcommands.Register(subcommands.HelpCommand(), "") + subcommands.Register(subcommands.FlagsCommand(), "") + subcommands.Register(new(capability), "") + subcommands.Register(new(fdReceiver), "") + subcommands.Register(new(fdSender), "") + subcommands.Register(new(forkBomb), "") + subcommands.Register(new(ptyRunner), "") + subcommands.Register(new(reaper), "") + subcommands.Register(new(syscall), "") + subcommands.Register(new(taskTree), "") + subcommands.Register(new(uds), "") + + flag.Parse() + + exitCode := subcommands.Execute(context.Background()) + os.Exit(int(exitCode)) +} + +type uds struct { + fileName string + socketPath string +} + +// Name implements subcommands.Command.Name. +func (*uds) Name() string { + return "uds" +} + +// Synopsis implements subcommands.Command.Synopsys. +func (*uds) Synopsis() string { + return "creates unix domain socket client and server. Client sends a contant flow of sequential numbers. Server prints them to --file" +} + +// Usage implements subcommands.Command.Usage. +func (*uds) Usage() string { + return "uds <flags>" +} + +// SetFlags implements subcommands.Command.SetFlags. +func (c *uds) SetFlags(f *flag.FlagSet) { + f.StringVar(&c.fileName, "file", "", "name of output file") + f.StringVar(&c.socketPath, "socket", "", "path to socket") +} + +// Execute implements subcommands.Command.Execute. +func (c *uds) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + if c.fileName == "" || c.socketPath == "" { + log.Fatalf("Flags cannot be empty, given: fileName: %q, socketPath: %q", c.fileName, c.socketPath) + return subcommands.ExitFailure + } + outputFile, err := os.OpenFile(c.fileName, os.O_WRONLY|os.O_CREATE, 0666) + if err != nil { + log.Fatal("error opening output file:", err) + } + + defer os.Remove(c.socketPath) + + listener, err := net.Listen("unix", c.socketPath) + if err != nil { + log.Fatalf("error listening on socket %q: %v", c.socketPath, err) + } + + go server(listener, outputFile) + for i := 0; ; i++ { + conn, err := net.Dial("unix", c.socketPath) + if err != nil { + log.Fatal("error dialing:", err) + } + if _, err := conn.Write([]byte(strconv.Itoa(i))); err != nil { + log.Fatal("error writing:", err) + } + conn.Close() + time.Sleep(100 * time.Millisecond) + } +} + +func server(listener net.Listener, out *os.File) { + buf := make([]byte, 16) + + for { + c, err := listener.Accept() + if err != nil { + log.Fatal("error accepting connection:", err) + } + nr, err := c.Read(buf) + if err != nil { + log.Fatal("error reading from buf:", err) + } + data := buf[0:nr] + fmt.Fprint(out, string(data)+"\n") + } +} + +type taskTree struct { + depth int + width int + pause bool +} + +// Name implements subcommands.Command. +func (*taskTree) Name() string { + return "task-tree" +} + +// Synopsis implements subcommands.Command. +func (*taskTree) Synopsis() string { + return "creates a tree of tasks" +} + +// Usage implements subcommands.Command. +func (*taskTree) Usage() string { + return "task-tree <flags>" +} + +// SetFlags implements subcommands.Command. +func (c *taskTree) SetFlags(f *flag.FlagSet) { + f.IntVar(&c.depth, "depth", 1, "number of levels to create") + f.IntVar(&c.width, "width", 1, "number of tasks at each level") + f.BoolVar(&c.pause, "pause", false, "whether the tasks should pause perpetually") +} + +// Execute implements subcommands.Command. +func (c *taskTree) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + stop := testutil.StartReaper() + defer stop() + + if c.depth == 0 { + log.Printf("Child sleeping, PID: %d\n", os.Getpid()) + select {} + } + log.Printf("Parent %d sleeping, PID: %d\n", c.depth, os.Getpid()) + + var cmds []*exec.Cmd + for i := 0; i < c.width; i++ { + cmd := exec.Command( + "/proc/self/exe", c.Name(), + "--depth", strconv.Itoa(c.depth-1), + "--width", strconv.Itoa(c.width), + "--pause", strconv.FormatBool(c.pause)) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Start(); err != nil { + log.Fatal("failed to call self:", err) + } + cmds = append(cmds, cmd) + } + + for _, c := range cmds { + c.Wait() + } + + if c.pause { + select {} + } + + return subcommands.ExitSuccess +} + +type forkBomb struct { + delay time.Duration +} + +// Name implements subcommands.Command. +func (*forkBomb) Name() string { + return "fork-bomb" +} + +// Synopsis implements subcommands.Command. +func (*forkBomb) Synopsis() string { + return "creates child process until the end of times" +} + +// Usage implements subcommands.Command. +func (*forkBomb) Usage() string { + return "fork-bomb <flags>" +} + +// SetFlags implements subcommands.Command. +func (c *forkBomb) SetFlags(f *flag.FlagSet) { + f.DurationVar(&c.delay, "delay", 100*time.Millisecond, "amount of time to delay creation of child") +} + +// Execute implements subcommands.Command. +func (c *forkBomb) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + time.Sleep(c.delay) + + cmd := exec.Command("/proc/self/exe", c.Name()) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + log.Fatal("failed to call self:", err) + } + return subcommands.ExitSuccess +} + +type reaper struct{} + +// Name implements subcommands.Command. +func (*reaper) Name() string { + return "reaper" +} + +// Synopsis implements subcommands.Command. +func (*reaper) Synopsis() string { + return "reaps all children in a loop" +} + +// Usage implements subcommands.Command. +func (*reaper) Usage() string { + return "reaper <flags>" +} + +// SetFlags implements subcommands.Command. +func (*reaper) SetFlags(*flag.FlagSet) {} + +// Execute implements subcommands.Command. +func (c *reaper) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + stop := testutil.StartReaper() + defer stop() + select {} +} + +type syscall struct { + sysno uint64 +} + +// Name implements subcommands.Command. +func (*syscall) Name() string { + return "syscall" +} + +// Synopsis implements subcommands.Command. +func (*syscall) Synopsis() string { + return "syscall makes a syscall" +} + +// Usage implements subcommands.Command. +func (*syscall) Usage() string { + return "syscall <flags>" +} + +// SetFlags implements subcommands.Command. +func (s *syscall) SetFlags(f *flag.FlagSet) { + f.Uint64Var(&s.sysno, "syscall", 0, "syscall to call") +} + +// Execute implements subcommands.Command. +func (s *syscall) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + if _, _, errno := sys.Syscall(uintptr(s.sysno), 0, 0, 0); errno != 0 { + fmt.Printf("syscall(%d, 0, 0...) failed: %v\n", s.sysno, errno) + } else { + fmt.Printf("syscall(%d, 0, 0...) success\n", s.sysno) + } + return subcommands.ExitSuccess +} + +type capability struct { + enabled uint64 + disabled uint64 +} + +// Name implements subcommands.Command. +func (*capability) Name() string { + return "capability" +} + +// Synopsis implements subcommands.Command. +func (*capability) Synopsis() string { + return "checks if effective capabilities are set/unset" +} + +// Usage implements subcommands.Command. +func (*capability) Usage() string { + return "capability [--enabled=number] [--disabled=number]" +} + +// SetFlags implements subcommands.Command. +func (c *capability) SetFlags(f *flag.FlagSet) { + f.Uint64Var(&c.enabled, "enabled", 0, "") + f.Uint64Var(&c.disabled, "disabled", 0, "") +} + +// Execute implements subcommands.Command. +func (c *capability) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + if c.enabled == 0 && c.disabled == 0 { + fmt.Println("One of the flags must be set") + return subcommands.ExitUsageError + } + + status, err := ioutil.ReadFile("/proc/self/status") + if err != nil { + fmt.Printf("Error reading %q: %v\n", "proc/self/status", err) + return subcommands.ExitFailure + } + re := regexp.MustCompile("CapEff:\t([0-9a-f]+)\n") + matches := re.FindStringSubmatch(string(status)) + if matches == nil || len(matches) != 2 { + fmt.Printf("Effective capabilities not found in\n%s\n", status) + return subcommands.ExitFailure + } + caps, err := strconv.ParseUint(matches[1], 16, 64) + if err != nil { + fmt.Printf("failed to convert capabilities %q: %v\n", matches[1], err) + return subcommands.ExitFailure + } + + if c.enabled != 0 && (caps&c.enabled) != c.enabled { + fmt.Printf("Missing capabilities, want: %#x: got: %#x\n", c.enabled, caps) + return subcommands.ExitFailure + } + if c.disabled != 0 && (caps&c.disabled) != 0 { + fmt.Printf("Extra capabilities found, dont_want: %#x: got: %#x\n", c.disabled, caps) + return subcommands.ExitFailure + } + + return subcommands.ExitSuccess +} + +type ptyRunner struct{} + +// Name implements subcommands.Command. +func (*ptyRunner) Name() string { + return "pty-runner" +} + +// Synopsis implements subcommands.Command. +func (*ptyRunner) Synopsis() string { + return "runs the given command with an open pty terminal" +} + +// Usage implements subcommands.Command. +func (*ptyRunner) Usage() string { + return "pty-runner [command]" +} + +// SetFlags implements subcommands.Command.SetFlags. +func (*ptyRunner) SetFlags(f *flag.FlagSet) {} + +// Execute implements subcommands.Command. +func (*ptyRunner) Execute(_ context.Context, fs *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { + c := exec.Command(fs.Args()[0], fs.Args()[1:]...) + f, err := pty.Start(c) + if err != nil { + fmt.Printf("pty.Start failed: %v", err) + return subcommands.ExitFailure + } + defer f.Close() + + // Copy stdout from the command to keep this process alive until the + // subprocess exits. + io.Copy(os.Stdout, f) + + return subcommands.ExitSuccess +} diff --git a/test/e2e/BUILD b/test/e2e/BUILD index 4fe03a220..29a84f184 100644 --- a/test/e2e/BUILD +++ b/test/e2e/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,7 +10,7 @@ go_test( "integration_test.go", "regression_test.go", ], - embed = [":integration"], + library = ":integration", tags = [ # Requires docker and runsc to be configured before the test runs. "manual", @@ -20,14 +20,14 @@ go_test( deps = [ "//pkg/abi/linux", "//pkg/bits", - "//runsc/dockerutil", + "//pkg/test/dockerutil", + "//pkg/test/testutil", "//runsc/specutils", - "//runsc/testutil", + "@com_github_docker_docker//api/types/mount:go_default_library", ], ) go_library( name = "integration", srcs = ["integration.go"], - importpath = "gvisor.dev/gvisor/test/integration", ) diff --git a/test/e2e/exec_test.go b/test/e2e/exec_test.go index 4074d2285..b47df447c 100644 --- a/test/e2e/exec_test.go +++ b/test/e2e/exec_test.go @@ -22,33 +22,34 @@ package integration import ( + "context" "fmt" "strconv" "strings" - "syscall" "testing" "time" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bits" - "gvisor.dev/gvisor/runsc/dockerutil" + "gvisor.dev/gvisor/pkg/test/dockerutil" "gvisor.dev/gvisor/runsc/specutils" ) // Test that exec uses the exact same capability set as the container. func TestExecCapabilities(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("exec-capabilities-test") + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container. - if err := d.Run("alpine", "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil { + if err := d.Spawn(ctx, dockerutil.RunOpts{ + Image: "basic/alpine", + }, "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() - matches, err := d.WaitForOutputSubmatch("CapEff:\t([0-9a-f]+)\n", 5*time.Second) + // Check that capability. + matches, err := d.WaitForOutputSubmatch(ctx, "CapEff:\t([0-9a-f]+)\n", 5*time.Second) if err != nil { t.Fatalf("WaitForOutputSubmatch() timeout: %v", err) } @@ -59,7 +60,7 @@ func TestExecCapabilities(t *testing.T) { t.Log("Root capabilities:", want) // Now check that exec'd process capabilities match the root. - got, err := d.Exec("grep", "CapEff:", "/proc/self/status") + got, err := d.Exec(ctx, dockerutil.ExecOpts{}, "grep", "CapEff:", "/proc/self/status") if err != nil { t.Fatalf("docker exec failed: %v", err) } @@ -72,19 +73,20 @@ func TestExecCapabilities(t *testing.T) { // Test that 'exec --privileged' adds all capabilities, except for CAP_NET_RAW // which is removed from the container when --net-raw=false. func TestExecPrivileged(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("exec-privileged-test") + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container with all capabilities dropped. - if err := d.Run("--cap-drop=all", "alpine", "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil { + if err := d.Spawn(ctx, dockerutil.RunOpts{ + Image: "basic/alpine", + CapDrop: []string{"all"}, + }, "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() // Check that all capabilities where dropped from container. - matches, err := d.WaitForOutputSubmatch("CapEff:\t([0-9a-f]+)\n", 5*time.Second) + matches, err := d.WaitForOutputSubmatch(ctx, "CapEff:\t([0-9a-f]+)\n", 5*time.Second) if err != nil { t.Fatalf("WaitForOutputSubmatch() timeout: %v", err) } @@ -100,9 +102,11 @@ func TestExecPrivileged(t *testing.T) { t.Fatalf("Container should have no capabilities: %x", containerCaps) } - // Check that 'exec --privileged' adds all capabilities, except - // for CAP_NET_RAW. - got, err := d.ExecWithFlags([]string{"--privileged"}, "grep", "CapEff:", "/proc/self/status") + // Check that 'exec --privileged' adds all capabilities, except for + // CAP_NET_RAW. + got, err := d.Exec(ctx, dockerutil.ExecOpts{ + Privileged: true, + }, "grep", "CapEff:", "/proc/self/status") if err != nil { t.Fatalf("docker exec failed: %v", err) } @@ -114,97 +118,83 @@ func TestExecPrivileged(t *testing.T) { } func TestExecJobControl(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("exec-job-control-test") + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container. - if err := d.Run("alpine", "sleep", "1000"); err != nil { + if err := d.Spawn(ctx, dockerutil.RunOpts{ + Image: "basic/alpine", + }, "sleep", "1000"); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() - // Exec 'sh' with an attached pty. - cmd, ptmx, err := d.ExecWithTerminal("sh") + p, err := d.ExecProcess(ctx, dockerutil.ExecOpts{UseTTY: true}, "/bin/sh") if err != nil { t.Fatalf("docker exec failed: %v", err) } - defer ptmx.Close() - // Call "sleep 100 | cat" in the shell. We pipe to cat so that there - // will be two processes in the foreground process group. - if _, err := ptmx.Write([]byte("sleep 100 | cat\n")); err != nil { - t.Fatalf("error writing to pty: %v", err) + if _, err = p.Write(time.Second, []byte("sleep 100 | cat\n")); err != nil { + t.Fatalf("error exit: %v", err) } + time.Sleep(time.Second) - // Give shell a few seconds to start executing the sleep. - time.Sleep(2 * time.Second) - - // Send a ^C to the pty, which should kill sleep and cat, but not the - // shell. \x03 is ASCII "end of text", which is the same as ^C. - if _, err := ptmx.Write([]byte{'\x03'}); err != nil { - t.Fatalf("error writing to pty: %v", err) + if _, err = p.Write(time.Second, []byte{0x03}); err != nil { + t.Fatalf("error exit: %v", err) } - // The shell should still be alive at this point. Sleep should have - // exited with code 2+128=130. We'll exit with 10 plus that number, so - // that we can be sure that the shell did not get signalled. - if _, err := ptmx.Write([]byte("exit $(expr $? + 10)\n")); err != nil { - t.Fatalf("error writing to pty: %v", err) + if _, err = p.Write(time.Second, []byte("exit $(expr $? + 10)\n")); err != nil { + t.Fatalf("error exit: %v", err) } - // Exec process should exit with code 10+130=140. - ps, err := cmd.Process.Wait() + want := 140 + got, err := p.WaitExitStatus(ctx) if err != nil { - t.Fatalf("error waiting for exec process: %v", err) - } - ws := ps.Sys().(syscall.WaitStatus) - if !ws.Exited() { - t.Errorf("ws.Exited got false, want true") - } - if got, want := ws.ExitStatus(), 140; got != want { - t.Errorf("ws.ExitedStatus got %d, want %d", got, want) + t.Fatalf("wait for exit failed with: %v", err) + } else if got != want { + t.Fatalf("wait for exit returned: %d want: %d", got, want) } } // Test that failure to exec returns proper error message. func TestExecError(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("exec-error-test") + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container. - if err := d.Run("alpine", "sleep", "1000"); err != nil { + if err := d.Spawn(ctx, dockerutil.RunOpts{ + Image: "basic/alpine", + }, "sleep", "1000"); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() - _, err := d.Exec("no_can_find") + // Attempt to exec a binary that doesn't exist. + out, err := d.Exec(ctx, dockerutil.ExecOpts{}, "no_can_find") if err == nil { t.Fatalf("docker exec didn't fail") } - if want := `error finding executable "no_can_find" in PATH`; !strings.Contains(err.Error(), want) { - t.Fatalf("docker exec wrong error, got: %s, want: .*%s.*", err.Error(), want) + if want := `error finding executable "no_can_find" in PATH`; !strings.Contains(out, want) { + t.Fatalf("docker exec wrong error, got: %s, want: .*%s.*", out, want) } } // Test that exec inherits environment from run. func TestExecEnv(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("exec-env-test") + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container with env FOO=BAR. - if err := d.Run("-e", "FOO=BAR", "alpine", "sleep", "1000"); err != nil { + if err := d.Spawn(ctx, dockerutil.RunOpts{ + Image: "basic/alpine", + Env: []string{"FOO=BAR"}, + }, "sleep", "1000"); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() // Exec "echo $FOO". - got, err := d.Exec("/bin/sh", "-c", "echo $FOO") + got, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", "echo $FOO") if err != nil { t.Fatalf("docker exec failed: %v", err) } @@ -216,17 +206,20 @@ func TestExecEnv(t *testing.T) { // TestRunEnvHasHome tests that run always has HOME environment set. func TestRunEnvHasHome(t *testing.T) { // Base alpine image does not have any environment variables set. - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("run-env-test") + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Exec "echo $HOME". The 'bin' user's home dir is '/bin'. - got, err := d.RunFg("--user", "bin", "alpine", "/bin/sh", "-c", "echo $HOME") + got, err := d.Run(ctx, dockerutil.RunOpts{ + Image: "basic/alpine", + User: "bin", + }, "/bin/sh", "-c", "echo $HOME") if err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() + + // Check that the directory matches. if got, want := strings.TrimSpace(got), "/bin"; got != want { t.Errorf("bad output from 'docker run'. Got %q; Want %q.", got, want) } @@ -235,28 +228,18 @@ func TestRunEnvHasHome(t *testing.T) { // Test that exec always has HOME environment set, even when not set in run. func TestExecEnvHasHome(t *testing.T) { // Base alpine image does not have any environment variables set. - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("exec-env-home-test") + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) - // We will check that HOME is set for root user, and also for a new - // non-root user we will create. - newUID := 1234 - newHome := "/foo/bar" - - // Create a new user with a home directory, and then sleep. - script := fmt.Sprintf(` - mkdir -p -m 777 %s && \ - adduser foo -D -u %d -h %s && \ - sleep 1000`, newHome, newUID, newHome) - if err := d.Run("alpine", "/bin/sh", "-c", script); err != nil { + if err := d.Spawn(ctx, dockerutil.RunOpts{ + Image: "basic/alpine", + }, "sleep", "1000"); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() // Exec "echo $HOME", and expect to see "/root". - got, err := d.Exec("/bin/sh", "-c", "echo $HOME") + got, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", "echo $HOME") if err != nil { t.Fatalf("docker exec failed: %v", err) } @@ -264,8 +247,18 @@ func TestExecEnvHasHome(t *testing.T) { t.Errorf("wanted exec output to contain %q, got %q", want, got) } - // Execute the same as uid 123 and expect newHome. - got, err = d.ExecAsUser(strconv.Itoa(newUID), "/bin/sh", "-c", "echo $HOME") + // Create a new user with a home directory. + newUID := 1234 + newHome := "/foo/bar" + cmd := fmt.Sprintf("mkdir -p -m 777 %q && adduser foo -D -u %d -h %q", newHome, newUID, newHome) + if _, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", cmd); err != nil { + t.Fatalf("docker exec failed: %v", err) + } + + // Execute the same as the new user and expect newHome. + got, err = d.Exec(ctx, dockerutil.ExecOpts{ + User: strconv.Itoa(newUID), + }, "/bin/sh", "-c", "echo $HOME") if err != nil { t.Fatalf("docker exec failed: %v", err) } diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go index 7cc0de129..809244bab 100644 --- a/test/e2e/integration_test.go +++ b/test/e2e/integration_test.go @@ -22,21 +22,27 @@ package integration import ( + "context" "flag" "fmt" + "io/ioutil" "net" "net/http" "os" + "path/filepath" "strconv" "strings" - "syscall" "testing" "time" - "gvisor.dev/gvisor/runsc/dockerutil" - "gvisor.dev/gvisor/runsc/testutil" + "github.com/docker/docker/api/types/mount" + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/pkg/test/testutil" ) +// defaultWait is the default wait time used for tests. +const defaultWait = time.Minute + // httpRequestSucceeds sends a request to a given url and checks that the status is OK. func httpRequestSucceeds(client http.Client, server string, port int) error { url := fmt.Sprintf("http://%s:%d", server, port) @@ -53,78 +59,82 @@ func httpRequestSucceeds(client http.Client, server string, port int) error { // TestLifeCycle tests a basic Create/Start/Stop docker container life cycle. func TestLifeCycle(t *testing.T) { - if err := dockerutil.Pull("nginx"); err != nil { - t.Fatal("docker pull failed:", err) - } - d := dockerutil.MakeDocker("lifecycle-test") - if err := d.Create("-p", "80", "nginx"); err != nil { - t.Fatal("docker create failed:", err) + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) + + // Start the container. + if err := d.Create(ctx, dockerutil.RunOpts{ + Image: "basic/nginx", + Ports: []int{80}, + }); err != nil { + t.Fatalf("docker create failed: %v", err) } - if err := d.Start(); err != nil { - d.CleanUp() - t.Fatal("docker start failed:", err) + if err := d.Start(ctx); err != nil { + t.Fatalf("docker start failed: %v", err) } - // Test that container is working - port, err := d.FindPort(80) + // Test that container is working. + port, err := d.FindPort(ctx, 80) if err != nil { - t.Fatal("docker.FindPort(80) failed: ", err) + t.Fatalf("docker.FindPort(80) failed: %v", err) } - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { - t.Fatal("WaitForHTTP() timeout:", err) + if err := testutil.WaitForHTTP(port, defaultWait); err != nil { + t.Fatalf("WaitForHTTP() timeout: %v", err) } - client := http.Client{Timeout: time.Duration(2 * time.Second)} + client := http.Client{Timeout: defaultWait} if err := httpRequestSucceeds(client, "localhost", port); err != nil { - t.Error("http request failed:", err) + t.Errorf("http request failed: %v", err) } - if err := d.Stop(); err != nil { - d.CleanUp() - t.Fatal("docker stop failed:", err) + if err := d.Stop(ctx); err != nil { + t.Fatalf("docker stop failed: %v", err) } - if err := d.Remove(); err != nil { - t.Fatal("docker rm failed:", err) + if err := d.Remove(ctx); err != nil { + t.Fatalf("docker rm failed: %v", err) } } func TestPauseResume(t *testing.T) { - const img = "gcr.io/gvisor-presubmit/python-hello" if !testutil.IsCheckpointSupported() { - t.Log("Checkpoint is not supported, skipping test.") - return + t.Skip("Checkpoint is not supported.") } - if err := dockerutil.Pull(img); err != nil { - t.Fatal("docker pull failed:", err) - } - d := dockerutil.MakeDocker("pause-resume-test") - if err := d.Run("-p", "8080", img); err != nil { + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) + + // Start the container. + if err := d.Spawn(ctx, dockerutil.RunOpts{ + Image: "basic/python", + Ports: []int{8080}, // See Dockerfile. + }); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() // Find where port 8080 is mapped to. - port, err := d.FindPort(8080) + port, err := d.FindPort(ctx, 8080) if err != nil { - t.Fatal("docker.FindPort(8080) failed:", err) + t.Fatalf("docker.FindPort(8080) failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { - t.Fatal("WaitForHTTP() timeout:", err) + if err := testutil.WaitForHTTP(port, defaultWait); err != nil { + t.Fatalf("WaitForHTTP() timeout: %v", err) } // Check that container is working. - client := http.Client{Timeout: time.Duration(2 * time.Second)} + client := http.Client{Timeout: defaultWait} if err := httpRequestSucceeds(client, "localhost", port); err != nil { t.Error("http request failed:", err) } - if err := d.Pause(); err != nil { - t.Fatal("docker pause failed:", err) + if err := d.Pause(ctx); err != nil { + t.Fatalf("docker pause failed: %v", err) } // Check if container is paused. + client = http.Client{Timeout: 10 * time.Millisecond} // Don't wait a minute. switch _, err := client.Get(fmt.Sprintf("http://localhost:%d", port)); v := err.(type) { case nil: t.Errorf("http req expected to fail but it succeeded") @@ -136,62 +146,72 @@ func TestPauseResume(t *testing.T) { t.Errorf("http req got unexpected error %v", v) } - if err := d.Unpause(); err != nil { - t.Fatal("docker unpause failed:", err) + if err := d.Unpause(ctx); err != nil { + t.Fatalf("docker unpause failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { - t.Fatal("WaitForHTTP() timeout:", err) + if err := testutil.WaitForHTTP(port, defaultWait); err != nil { + t.Fatalf("WaitForHTTP() timeout: %v", err) } // Check if container is working again. + client = http.Client{Timeout: defaultWait} if err := httpRequestSucceeds(client, "localhost", port); err != nil { t.Error("http request failed:", err) } } func TestCheckpointRestore(t *testing.T) { - const img = "gcr.io/gvisor-presubmit/python-hello" if !testutil.IsCheckpointSupported() { - t.Log("Pause/resume is not supported, skipping test.") - return + t.Skip("Pause/resume is not supported.") } - if err := dockerutil.Pull(img); err != nil { - t.Fatal("docker pull failed:", err) + // TODO(gvisor.dev/issue/3373): Remove after implementing. + if usingVFS2, err := dockerutil.UsingVFS2(); usingVFS2 { + t.Skip("CheckpointRestore not implemented in VFS2.") + } else if err != nil { + t.Fatalf("failed to read config for runtime %s: %v", dockerutil.Runtime(), err) } - d := dockerutil.MakeDocker("save-restore-test") - if err := d.Run("-p", "8080", img); err != nil { + + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) + + // Start the container. + if err := d.Spawn(ctx, dockerutil.RunOpts{ + Image: "basic/python", + Ports: []int{8080}, // See Dockerfile. + }); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() - if err := d.Checkpoint("test"); err != nil { - t.Fatal("docker checkpoint failed:", err) + // Create a snapshot. + if err := d.Checkpoint(ctx, "test"); err != nil { + t.Fatalf("docker checkpoint failed: %v", err) } - - if _, err := d.Wait(30 * time.Second); err != nil { - t.Fatal(err) + if err := d.WaitTimeout(ctx, defaultWait); err != nil { + t.Fatalf("wait failed: %v", err) } - if err := d.Restore("test"); err != nil { - t.Fatal("docker restore failed:", err) + // TODO(b/143498576): Remove Poll after github.com/moby/moby/issues/38963 is fixed. + if err := testutil.Poll(func() error { return d.Restore(ctx, "test") }, defaultWait); err != nil { + t.Fatalf("docker restore failed: %v", err) } // Find where port 8080 is mapped to. - port, err := d.FindPort(8080) + port, err := d.FindPort(ctx, 8080) if err != nil { - t.Fatal("docker.FindPort(8080) failed:", err) + t.Fatalf("docker.FindPort(8080) failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { - t.Fatal("WaitForHTTP() timeout:", err) + if err := testutil.WaitForHTTP(port, defaultWait); err != nil { + t.Fatalf("WaitForHTTP() timeout: %v", err) } // Check if container is working again. - client := http.Client{Timeout: time.Duration(2 * time.Second)} + client := http.Client{Timeout: defaultWait} if err := httpRequestSucceeds(client, "localhost", port); err != nil { t.Error("http request failed:", err) } @@ -199,48 +219,55 @@ func TestCheckpointRestore(t *testing.T) { // Create client and server that talk to each other using the local IP. func TestConnectToSelf(t *testing.T) { - d := dockerutil.MakeDocker("connect-to-self-test") + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Creates server that replies "server" and exists. Sleeps at the end because // 'docker exec' gets killed if the init process exists before it can finish. - if err := d.Run("ubuntu:trusty", "/bin/sh", "-c", "echo server | nc -l -p 8080 && sleep 1"); err != nil { - t.Fatal("docker run failed:", err) + if err := d.Spawn(ctx, dockerutil.RunOpts{ + Image: "basic/ubuntu", + }, "/bin/sh", "-c", "echo server | nc -l -p 8080 && sleep 1"); err != nil { + t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() // Finds IP address for host. - ip, err := d.Exec("/bin/sh", "-c", "cat /etc/hosts | grep ${HOSTNAME} | awk '{print $1}'") + ip, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", "cat /etc/hosts | grep ${HOSTNAME} | awk '{print $1}'") if err != nil { - t.Fatal("docker exec failed:", err) + t.Fatalf("docker exec failed: %v", err) } ip = strings.TrimRight(ip, "\n") // Runs client that sends "client" to the server and exits. - reply, err := d.Exec("/bin/sh", "-c", fmt.Sprintf("echo client | nc %s 8080", ip)) + reply, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", fmt.Sprintf("echo client | nc %s 8080", ip)) if err != nil { - t.Fatal("docker exec failed:", err) + t.Fatalf("docker exec failed: %v", err) } // Ensure both client and server got the message from each other. if want := "server\n"; reply != want { t.Errorf("Error on server, want: %q, got: %q", want, reply) } - if _, err := d.WaitForOutput("^client\n$", 1*time.Second); err != nil { - t.Fatal("docker.WaitForOutput(client) timeout:", err) + if _, err := d.WaitForOutput(ctx, "^client\n$", defaultWait); err != nil { + t.Fatalf("docker.WaitForOutput(client) timeout: %v", err) } } func TestMemLimit(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatal("docker pull failed:", err) - } - d := dockerutil.MakeDocker("cgroup-test") - cmd := "cat /proc/meminfo | grep MemTotal: | awk '{print $2}'" - out, err := d.RunFg("--memory=500MB", "alpine", "sh", "-c", cmd) + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) + + // N.B. Because the size of the memory file may grow in large chunks, + // there is a minimum threshold of 1GB for the MemTotal figure. + allocMemory := 1024 * 1024 // In kb. + out, err := d.Run(ctx, dockerutil.RunOpts{ + Image: "basic/alpine", + Memory: allocMemory * 1024, // In bytes. + }, "sh", "-c", "cat /proc/meminfo | grep MemTotal: | awk '{print $2}'") if err != nil { - t.Fatal("docker run failed:", err) + t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() // Remove warning message that swap isn't present. if strings.HasPrefix(out, "WARNING") { @@ -251,27 +278,31 @@ func TestMemLimit(t *testing.T) { out = lines[1] } + // Ensure the memory matches what we want. got, err := strconv.ParseUint(strings.TrimSpace(out), 10, 64) if err != nil { t.Fatalf("failed to parse %q: %v", out, err) } - if want := uint64(500 * 1024); got != want { + if want := uint64(allocMemory); got != want { t.Errorf("MemTotal got: %d, want: %d", got, want) } } func TestNumCPU(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatal("docker pull failed:", err) - } - d := dockerutil.MakeDocker("cgroup-test") - cmd := "cat /proc/cpuinfo | grep 'processor.*:' | wc -l" - out, err := d.RunFg("--cpuset-cpus=0", "alpine", "sh", "-c", cmd) + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) + + // Read how many cores are in the container. + out, err := d.Run(ctx, dockerutil.RunOpts{ + Image: "basic/alpine", + CpusetCpus: "0", + }, "sh", "-c", "cat /proc/cpuinfo | grep 'processor.*:' | wc -l") if err != nil { - t.Fatal("docker run failed:", err) + t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() + // Ensure it matches what we want. got, err := strconv.Atoi(strings.TrimSpace(out)) if err != nil { t.Fatalf("failed to parse %q: %v", out, err) @@ -283,62 +314,182 @@ func TestNumCPU(t *testing.T) { // TestJobControl tests that job control characters are handled properly. func TestJobControl(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("job-control-test") + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container with an attached PTY. - _, ptmx, err := d.RunWithPty("alpine", "sh") + p, err := d.SpawnProcess(ctx, dockerutil.RunOpts{ + Image: "basic/alpine", + }, "sh", "-c", "sleep 100 | cat") if err != nil { t.Fatalf("docker run failed: %v", err) } - defer ptmx.Close() - defer d.CleanUp() - - // Call "sleep 100" in the shell. - if _, err := ptmx.Write([]byte("sleep 100\n")); err != nil { - t.Fatalf("error writing to pty: %v", err) - } - // Give shell a few seconds to start executing the sleep. time.Sleep(2 * time.Second) - // Send a ^C to the pty, which should kill sleep, but not the shell. - // \x03 is ASCII "end of text", which is the same as ^C. - if _, err := ptmx.Write([]byte{'\x03'}); err != nil { - t.Fatalf("error writing to pty: %v", err) + if _, err := p.Write(time.Second, []byte{0x03}); err != nil { + t.Fatalf("error exit: %v", err) } - // The shell should still be alive at this point. Sleep should have - // exited with code 2+128=130. We'll exit with 10 plus that number, so - // that we can be sure that the shell did not get signalled. - if _, err := ptmx.Write([]byte("exit $(expr $? + 10)\n")); err != nil { - t.Fatalf("error writing to pty: %v", err) + if err := d.WaitTimeout(ctx, 3*time.Second); err != nil { + t.Fatalf("WaitTimeout failed: %v", err) } - // Wait for the container to exit. - got, err := d.Wait(5 * time.Second) + want := 130 + got, err := p.WaitExitStatus(ctx) if err != nil { - t.Fatalf("error getting exit code: %v", err) + t.Fatalf("wait for exit failed with: %v", err) + } else if got != want { + t.Fatalf("got: %d want: %d", got, want) } - // Container should exit with code 10+130=140. - if want := syscall.WaitStatus(140); got != want { - t.Errorf("container exited with code %d want %d", got, want) +} + +// TestWorkingDirCreation checks that working dir is created if it doesn't exit. +func TestWorkingDirCreation(t *testing.T) { + for _, tc := range []struct { + name string + workingDir string + }{ + {name: "root", workingDir: "/foo"}, + {name: "tmp", workingDir: "/tmp/foo"}, + } { + for _, readonly := range []bool{true, false} { + name := tc.name + if readonly { + name += "-readonly" + } + t.Run(name, func(t *testing.T) { + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) + + opts := dockerutil.RunOpts{ + Image: "basic/alpine", + WorkDir: tc.workingDir, + ReadOnly: readonly, + } + got, err := d.Run(ctx, opts, "sh", "-c", "echo ${PWD}") + if err != nil { + t.Fatalf("docker run failed: %v", err) + } + if want := tc.workingDir + "\n"; want != got { + t.Errorf("invalid working dir, want: %q, got: %q", want, got) + } + }) + } } } -// TestTmpFile checks that files inside '/tmp' are not overridden. In addition, -// it checks that working dir is created if it doesn't exit. +// TestTmpFile checks that files inside '/tmp' are not overridden. func TestTmpFile(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatal("docker pull failed:", err) + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) + + opts := dockerutil.RunOpts{Image: "basic/tmpfile"} + got, err := d.Run(ctx, opts, "cat", "/tmp/foo/file.txt") + if err != nil { + t.Fatalf("docker run failed: %v", err) + } + if want := "123\n"; want != got { + t.Errorf("invalid file content, want: %q, got: %q", want, got) + } +} + +// TestTmpMount checks that mounts inside '/tmp' are not overridden. +func TestTmpMount(t *testing.T) { + ctx := context.Background() + dir, err := ioutil.TempDir(testutil.TmpDir(), "tmp-mount") + if err != nil { + t.Fatalf("TempDir(): %v", err) + } + want := "123" + if err := ioutil.WriteFile(filepath.Join(dir, "file.txt"), []byte("123"), 0666); err != nil { + t.Fatalf("WriteFile(): %v", err) + } + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) + + opts := dockerutil.RunOpts{ + Image: "basic/alpine", + Mounts: []mount.Mount{ + { + Type: mount.TypeBind, + Source: dir, + Target: "/tmp/foo", + }, + }, + } + got, err := d.Run(ctx, opts, "cat", "/tmp/foo/file.txt") + if err != nil { + t.Fatalf("docker run failed: %v", err) + } + if want != got { + t.Errorf("invalid file content, want: %q, got: %q", want, got) + } +} + +// TestHostOverlayfsCopyUp tests that the --overlayfs-stale-read option causes +// runsc to hide the incoherence of FDs opened before and after overlayfs +// copy-up on the host. +func TestHostOverlayfsCopyUp(t *testing.T) { + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) + + if got, err := d.Run(ctx, dockerutil.RunOpts{ + Image: "basic/hostoverlaytest", + WorkDir: "/root", + }, "./test_copy_up"); err != nil { + t.Fatalf("docker run failed: %v", err) + } else if got != "" { + t.Errorf("test failed:\n%s", got) } - d := dockerutil.MakeDocker("tmp-file-test") - if err := d.Run("-w=/tmp/foo/bar", "--read-only", "alpine", "touch", "/tmp/foo/bar/file"); err != nil { - t.Fatal("docker run failed:", err) +} + +// TestHostOverlayfsRewindDir tests that rewinddir() "causes the directory +// stream to refer to the current state of the corresponding directory, as a +// call to opendir() would have done" as required by POSIX, when the directory +// in question is host overlayfs. +// +// This test specifically targets host overlayfs because, per POSIX, "if a file +// is removed from or added to the directory after the most recent call to +// opendir() or rewinddir(), whether a subsequent call to readdir() returns an +// entry for that file is unspecified"; the host filesystems used by other +// automated tests yield newly-added files from readdir() even if the fsgofer +// does not explicitly rewinddir(), but overlayfs does not. +func TestHostOverlayfsRewindDir(t *testing.T) { + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) + + if got, err := d.Run(ctx, dockerutil.RunOpts{ + Image: "basic/hostoverlaytest", + WorkDir: "/root", + }, "./test_rewinddir"); err != nil { + t.Fatalf("docker run failed: %v", err) + } else if got != "" { + t.Errorf("test failed:\n%s", got) + } +} + +// Basic test for linkat(2). Syscall tests requires CAP_DAC_READ_SEARCH and it +// cannot use tricks like userns as root. For this reason, run a basic link test +// to ensure some coverage. +func TestLink(t *testing.T) { + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) + + if got, err := d.Run(ctx, dockerutil.RunOpts{ + Image: "basic/linktest", + WorkDir: "/root", + }, "./link_test"); err != nil { + t.Fatalf("docker run failed: %v", err) + } else if got != "" { + t.Errorf("test failed:\n%s", got) } - defer d.CleanUp() } func TestMain(m *testing.M) { diff --git a/test/e2e/regression_test.go b/test/e2e/regression_test.go index 2488be383..70bbe5121 100644 --- a/test/e2e/regression_test.go +++ b/test/e2e/regression_test.go @@ -15,10 +15,11 @@ package integration import ( + "context" "strings" "testing" - "gvisor.dev/gvisor/runsc/dockerutil" + "gvisor.dev/gvisor/pkg/test/dockerutil" ) // Test that UDS can be created using overlay when parent directory is in lower @@ -27,19 +28,20 @@ import ( // Prerequisite: the directory where the socket file is created must not have // been open for write before bind(2) is called. func TestBindOverlay(t *testing.T) { - if err := dockerutil.Pull("ubuntu:trusty"); err != nil { - t.Fatal("docker pull failed:", err) - } - d := dockerutil.MakeDocker("bind-overlay-test") + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) - cmd := "nc -l -U /var/run/sock & p=$! && sleep 1 && echo foobar-asdf | nc -U /var/run/sock && wait $p" - got, err := d.RunFg("ubuntu:trusty", "bash", "-c", cmd) + // Run the container. + got, err := d.Run(ctx, dockerutil.RunOpts{ + Image: "basic/ubuntu", + }, "bash", "-c", "nc -l -U /var/run/sock & p=$! && sleep 1 && echo foobar-asdf | nc -U /var/run/sock && wait $p") if err != nil { - t.Fatal("docker run failed:", err) + t.Fatalf("docker run failed: %v", err) } + // Check the output contains what we want. if want := "foobar-asdf"; !strings.Contains(got, want) { t.Fatalf("docker run output is missing %q: %s", want, got) } - defer d.CleanUp() } diff --git a/test/fuse/BUILD b/test/fuse/BUILD new file mode 100644 index 000000000..56157c96b --- /dev/null +++ b/test/fuse/BUILD @@ -0,0 +1,9 @@ +load("//test/runner:defs.bzl", "syscall_test") + +package(licenses = ["notice"]) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:stat_test", + vfs2 = "True", +) diff --git a/test/fuse/README.md b/test/fuse/README.md new file mode 100644 index 000000000..734c3a4e3 --- /dev/null +++ b/test/fuse/README.md @@ -0,0 +1,103 @@ +# gVisor FUSE Test Suite + +This is an integration test suite for fuse(4) filesystem. It runs under both +gVisor and Linux, and ensures compatibility between the two. This test suite is +based on system calls test. + +This document describes the framework of fuse integration test and the +guidelines that should be followed when adding new fuse tests. + +## Integration Test Framework + +Please refer to the figure below. `>` is entering the function, `<` is leaving +the function, and `=` indicates sequentially entering and leaving. + +``` + | Client (Test Main Process) | Server (FUSE Daemon) + | | + | >TEST_F() | + | >SetUp() | + | =MountFuse() | + | >SetUpFuseServer() | + | [create communication pipes] | + | =fork() | =fork() + | >WaitCompleted() | + | [wait for MarkDone()] | + | | =ConsumeFuseInit() + | | =MarkDone() + | <WaitCompleted() | + | <SetUpFuseServer() | + | <SetUp() | + | >SetExpected() | + | [construct expected reaction] | + | | >FuseLoop() + | | >ReceiveExpected() + | | [wait data from pipe] + | [write data to pipe] | + | [wait for MarkDone()] | + | | [save data to memory] + | | =MarkDone() + | <SetExpected() | + | | <ReceiveExpected() + | | >read() + | | [wait for fs operation] + | >[Do fs operation] | + | [wait for fs response] | + | | <read() + | | =CompareRequest() + | | =write() [write fs response] + | <[Do fs operation] | + | =[Test fs operation result] | + | =[wait for MarkDone()] | + | | =MarkDone() + | >TearDown() | + | =UnmountFuse() | + | <TearDown() | + | <TEST_F() | +``` + +## Running the tests + +Based on syscall tests, fuse tests can run in different environments. To enable +fuse testing environment, the test targets should be appended with `_fuse`. + +For example, to run fuse test in `stat_test.cc`: + +```bash +$ bazel test //test/fuse:stat_test_runsc_ptrace_vfs2_fuse +``` + +Test all targets tagged with fuse: + +```bash +$ bazel test --test_tag_filters=fuse //test/fuse/... +``` + +## Writing a new FUSE test + +1. Add test targets in `BUILD` and `linux/BUILD`. +2. Inherit your test from `FuseTest` base class. It allows you to: + - Run a fake FUSE server in background during each test setup. + - Create pipes for communication and provide utility functions. + - Stop FUSE server after test completes. +3. Customize your comparison function for request assessment in FUSE server. +4. Add the mapping of the size of structs if you are working on new FUSE + opcode. + - Please update `FuseTest::GetPayloadSize()` for each new FUSE opcode. +5. Build the expected request-response pair of your FUSE operation. +6. Call `SetExpected()` function to inject the expected reaction. +7. Check the response and/or errors. +8. Finally call `WaitCompleted()` to ensure the FUSE server acts correctly. + +A few customized matchers used in syscalls test are encouraged to test the +outcome of filesystem operations. Such as: + +```cc +SyscallSucceeds() +SyscallSucceedsWithValue(...) +SyscallFails() +SyscallFailsWithErrno(...) +``` + +Please refer to [test/syscalls/README.md](../syscalls/README.md) for further +details. diff --git a/test/fuse/linux/BUILD b/test/fuse/linux/BUILD new file mode 100644 index 000000000..4871bb531 --- /dev/null +++ b/test/fuse/linux/BUILD @@ -0,0 +1,32 @@ +load("//tools:defs.bzl", "cc_binary", "cc_library", "gtest") + +package( + default_visibility = ["//:sandbox"], + licenses = ["notice"], +) + +cc_binary( + name = "stat_test", + testonly = 1, + srcs = ["stat_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_library( + name = "fuse_base", + testonly = 1, + srcs = ["fuse_base.cc"], + hdrs = ["fuse_base.h"], + deps = [ + gtest, + "//test/util:posix_error", + "//test/util:temp_path", + "//test/util:test_util", + "@com_google_absl//absl/strings:str_format", + ], +) diff --git a/test/fuse/linux/fuse_base.cc b/test/fuse/linux/fuse_base.cc new file mode 100644 index 000000000..9c3124472 --- /dev/null +++ b/test/fuse/linux/fuse_base.cc @@ -0,0 +1,208 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/fuse/linux/fuse_base.h" + +#include <fcntl.h> +#include <linux/fuse.h> +#include <string.h> +#include <sys/mount.h> +#include <sys/stat.h> +#include <sys/types.h> +#include <sys/uio.h> +#include <unistd.h> + +#include <iostream> + +#include "gtest/gtest.h" +#include "absl/strings/str_format.h" +#include "test/util/posix_error.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +void FuseTest::SetUp() { + MountFuse(); + SetUpFuseServer(); +} + +void FuseTest::TearDown() { UnmountFuse(); } + +// Since CompareRequest is running in background thread, gTest assertions and +// expectations won't directly reflect the test result. However, the FUSE +// background server still connects to the same standard I/O as testing main +// thread. So EXPECT_XX can still be used to show different results. To +// ensure failed testing result is observable, return false and the result +// will be sent to test main thread via pipe. +bool FuseTest::CompareRequest(void* expected_mem, size_t expected_len, + void* real_mem, size_t real_len) { + if (expected_len != real_len) return false; + return memcmp(expected_mem, real_mem, expected_len) == 0; +} + +// SetExpected is called by the testing main thread to set expected request- +// response pair of a single FUSE operation. +void FuseTest::SetExpected(struct iovec* iov_in, int iov_in_cnt, + struct iovec* iov_out, int iov_out_cnt) { + EXPECT_THAT(RetryEINTR(writev)(set_expected_[1], iov_in, iov_in_cnt), + SyscallSucceedsWithValue(::testing::Gt(0))); + WaitCompleted(); + + EXPECT_THAT(RetryEINTR(writev)(set_expected_[1], iov_out, iov_out_cnt), + SyscallSucceedsWithValue(::testing::Gt(0))); + WaitCompleted(); +} + +// WaitCompleted waits for the FUSE server to finish its job and check if it +// completes without errors. +void FuseTest::WaitCompleted() { + char success; + EXPECT_THAT(RetryEINTR(read)(done_[0], &success, sizeof(success)), + SyscallSucceedsWithValue(1)); +} + +void FuseTest::MountFuse() { + EXPECT_THAT(dev_fd_ = open("/dev/fuse", O_RDWR), SyscallSucceeds()); + + std::string mount_opts = absl::StrFormat("fd=%d,%s", dev_fd_, kMountOpts); + mount_point_ = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + EXPECT_THAT(mount("fuse", mount_point_.path().c_str(), "fuse", + MS_NODEV | MS_NOSUID, mount_opts.c_str()), + SyscallSucceedsWithValue(0)); +} + +void FuseTest::UnmountFuse() { + EXPECT_THAT(umount(mount_point_.path().c_str()), SyscallSucceeds()); + // TODO(gvisor.dev/issue/3330): ensure the process is terminated successfully. +} + +// ConsumeFuseInit consumes the first FUSE request and returns the +// corresponding PosixError. +PosixError FuseTest::ConsumeFuseInit() { + RETURN_ERROR_IF_SYSCALL_FAIL( + RetryEINTR(read)(dev_fd_, buf_.data(), buf_.size())); + + struct iovec iov_out[2]; + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_init_out), + .error = 0, + .unique = 2, + }; + // Returns a fake fuse_init_out with 7.0 version to avoid ECONNREFUSED + // error in the initialization of FUSE connection. + struct fuse_init_out out_payload = { + .major = 7, + }; + iov_out[0].iov_len = sizeof(out_header); + iov_out[0].iov_base = &out_header; + iov_out[1].iov_len = sizeof(out_payload); + iov_out[1].iov_base = &out_payload; + + RETURN_ERROR_IF_SYSCALL_FAIL(RetryEINTR(writev)(dev_fd_, iov_out, 2)); + return NoError(); +} + +// ReceiveExpected reads 1 pair of expected fuse request-response `iovec`s +// from pipe and save them into member variables of this testing instance. +void FuseTest::ReceiveExpected() { + // Set expected fuse_in request. + EXPECT_THAT(len_in_ = RetryEINTR(read)(set_expected_[0], mem_in_.data(), + mem_in_.size()), + SyscallSucceedsWithValue(::testing::Gt(0))); + MarkDone(len_in_ > 0); + + // Set expected fuse_out response. + EXPECT_THAT(len_out_ = RetryEINTR(read)(set_expected_[0], mem_out_.data(), + mem_out_.size()), + SyscallSucceedsWithValue(::testing::Gt(0))); + MarkDone(len_out_ > 0); +} + +// MarkDone writes 1 byte of success indicator through pipe. +void FuseTest::MarkDone(bool success) { + char data = success ? 1 : 0; + EXPECT_THAT(RetryEINTR(write)(done_[1], &data, sizeof(data)), + SyscallSucceedsWithValue(1)); +} + +// FuseLoop is the implementation of the fake FUSE server. Read from /dev/fuse, +// compare the request by CompareRequest (use derived function if specified), +// and write the expected response to /dev/fuse. +void FuseTest::FuseLoop() { + bool success = true; + ssize_t len = 0; + while (true) { + ReceiveExpected(); + + EXPECT_THAT(len = RetryEINTR(read)(dev_fd_, buf_.data(), buf_.size()), + SyscallSucceedsWithValue(len_in_)); + if (len != len_in_) success = false; + + if (!CompareRequest(buf_.data(), len_in_, mem_in_.data(), len_in_)) { + std::cerr << "the FUSE request is not expected" << std::endl; + success = false; + } + + EXPECT_THAT(len = RetryEINTR(write)(dev_fd_, mem_out_.data(), len_out_), + SyscallSucceedsWithValue(len_out_)); + if (len != len_out_) success = false; + MarkDone(success); + } +} + +// SetUpFuseServer creates 2 pipes. First is for testing client to send the +// expected request-response pair, and the other acts as a checkpoint for the +// FUSE server to notify the client that it can proceed. +void FuseTest::SetUpFuseServer() { + ASSERT_THAT(pipe(set_expected_), SyscallSucceedsWithValue(0)); + ASSERT_THAT(pipe(done_), SyscallSucceedsWithValue(0)); + + switch (fork()) { + case -1: + GTEST_FAIL(); + return; + case 0: + break; + default: + ASSERT_THAT(close(set_expected_[0]), SyscallSucceedsWithValue(0)); + ASSERT_THAT(close(done_[1]), SyscallSucceedsWithValue(0)); + WaitCompleted(); + return; + } + + ASSERT_THAT(close(set_expected_[1]), SyscallSucceedsWithValue(0)); + ASSERT_THAT(close(done_[0]), SyscallSucceedsWithValue(0)); + + MarkDone(ConsumeFuseInit().ok()); + + FuseLoop(); + _exit(0); +} + +// GetPayloadSize is a helper function to get the number of bytes of a +// specific FUSE operation struct. +size_t FuseTest::GetPayloadSize(uint32_t opcode, bool in) { + switch (opcode) { + case FUSE_INIT: + return in ? sizeof(struct fuse_init_in) : sizeof(struct fuse_init_out); + default: + break; + } + return 0; +} + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/fuse_base.h b/test/fuse/linux/fuse_base.h new file mode 100644 index 000000000..3a2f255a9 --- /dev/null +++ b/test/fuse/linux/fuse_base.h @@ -0,0 +1,99 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_FUSE_FUSE_BASE_H_ +#define GVISOR_TEST_FUSE_FUSE_BASE_H_ + +#include <linux/fuse.h> +#include <sys/uio.h> + +#include <vector> + +#include "gtest/gtest.h" +#include "test/util/posix_error.h" +#include "test/util/temp_path.h" + +namespace gvisor { +namespace testing { + +constexpr char kMountOpts[] = "rootmode=755,user_id=0,group_id=0"; + +class FuseTest : public ::testing::Test { + public: + FuseTest() { + buf_.resize(FUSE_MIN_READ_BUFFER); + mem_in_.resize(FUSE_MIN_READ_BUFFER); + mem_out_.resize(FUSE_MIN_READ_BUFFER); + } + void SetUp() override; + void TearDown() override; + + // CompareRequest is used by the FUSE server and should be implemented to + // compare different FUSE operations. It compares the actual FUSE input + // request with the expected one set by `SetExpected()`. + virtual bool CompareRequest(void* expected_mem, size_t expected_len, + void* real_mem, size_t real_len); + + // SetExpected is called by the testing main thread. Writes a request- + // response pair into FUSE server's member variables via pipe. + void SetExpected(struct iovec* iov_in, int iov_in_cnt, struct iovec* iov_out, + int iov_out_cnt); + + // WaitCompleted waits for FUSE server to complete its processing. It + // complains if the FUSE server responds failure during tests. + void WaitCompleted(); + + protected: + TempPath mount_point_; + + private: + void MountFuse(); + void UnmountFuse(); + + // ConsumeFuseInit is only used during FUSE server setup. + PosixError ConsumeFuseInit(); + + // ReceiveExpected is the FUSE server side's corresponding code of + // `SetExpected()`. Save the request-response pair into its memory. + void ReceiveExpected(); + + // MarkDone is used by the FUSE server to tell testing main if it's OK to + // proceed next command. + void MarkDone(bool success); + + // FuseLoop is where the FUSE server stay until it is terminated. + void FuseLoop(); + + // SetUpFuseServer creates 2 pipes for communication and forks FUSE server. + void SetUpFuseServer(); + + // GetPayloadSize is a helper function to get the number of bytes of a + // specific FUSE operation struct. + size_t GetPayloadSize(uint32_t opcode, bool in); + + int dev_fd_; + int set_expected_[2]; + int done_[2]; + + std::vector<char> buf_; + std::vector<char> mem_in_; + std::vector<char> mem_out_; + ssize_t len_in_; + ssize_t len_out_; +}; + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_FUSE_FUSE_BASE_H_ diff --git a/test/fuse/linux/stat_test.cc b/test/fuse/linux/stat_test.cc new file mode 100644 index 000000000..172e09867 --- /dev/null +++ b/test/fuse/linux/stat_test.cc @@ -0,0 +1,169 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class StatTest : public FuseTest { + public: + bool CompareRequest(void* expected_mem, size_t expected_len, void* real_mem, + size_t real_len) override { + if (expected_len != real_len) return false; + struct fuse_in_header* real_header = + reinterpret_cast<fuse_in_header*>(real_mem); + + if (real_header->opcode != FUSE_GETATTR) { + std::cerr << "expect header opcode " << FUSE_GETATTR << " but got " + << real_header->opcode << std::endl; + return false; + } + return true; + } + + bool StatsAreEqual(struct stat expected, struct stat actual) { + // device number will be dynamically allocated by kernel, we cannot know + // in advance + actual.st_dev = expected.st_dev; + return memcmp(&expected, &actual, sizeof(struct stat)) == 0; + } +}; + +TEST_F(StatTest, StatNormal) { + struct iovec iov_in[2]; + struct iovec iov_out[2]; + + struct fuse_in_header in_header = { + .len = sizeof(struct fuse_in_header) + sizeof(struct fuse_getattr_in), + .opcode = FUSE_GETATTR, + .unique = 4, + .nodeid = 1, + .uid = 0, + .gid = 0, + .pid = 4, + .padding = 0, + }; + struct fuse_getattr_in in_payload = {0}; + iov_in[0].iov_len = sizeof(in_header); + iov_in[0].iov_base = &in_header; + iov_in[1].iov_len = sizeof(in_payload); + iov_in[1].iov_base = &in_payload; + + mode_t expected_mode = S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH; + struct timespec atime = {.tv_sec = 1595436289, .tv_nsec = 134150844}; + struct timespec mtime = {.tv_sec = 1595436290, .tv_nsec = 134150845}; + struct timespec ctime = {.tv_sec = 1595436291, .tv_nsec = 134150846}; + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + .error = 0, + .unique = 4, + }; + struct fuse_attr attr = { + .ino = 1, + .size = 512, + .blocks = 4, + .atime = static_cast<uint64_t>(atime.tv_sec), + .mtime = static_cast<uint64_t>(mtime.tv_sec), + .ctime = static_cast<uint64_t>(ctime.tv_sec), + .atimensec = static_cast<uint32_t>(atime.tv_nsec), + .mtimensec = static_cast<uint32_t>(mtime.tv_nsec), + .ctimensec = static_cast<uint32_t>(ctime.tv_nsec), + .mode = expected_mode, + .nlink = 2, + .uid = 1234, + .gid = 4321, + .rdev = 12, + .blksize = 4096, + }; + struct fuse_attr_out out_payload = { + .attr = attr, + }; + iov_out[0].iov_len = sizeof(out_header); + iov_out[0].iov_base = &out_header; + iov_out[1].iov_len = sizeof(out_payload); + iov_out[1].iov_base = &out_payload; + + SetExpected(iov_in, 2, iov_out, 2); + + struct stat stat_buf; + EXPECT_THAT(stat(mount_point_.path().c_str(), &stat_buf), SyscallSucceeds()); + + struct stat expected_stat = { + .st_ino = attr.ino, + .st_nlink = attr.nlink, + .st_mode = expected_mode, + .st_uid = attr.uid, + .st_gid = attr.gid, + .st_rdev = attr.rdev, + .st_size = static_cast<off_t>(attr.size), + .st_blksize = attr.blksize, + .st_blocks = static_cast<blkcnt_t>(attr.blocks), + .st_atim = atime, + .st_mtim = mtime, + .st_ctim = ctime, + }; + EXPECT_TRUE(StatsAreEqual(stat_buf, expected_stat)); + WaitCompleted(); +} + +TEST_F(StatTest, StatNotFound) { + struct iovec iov_in[2]; + struct iovec iov_out[2]; + + struct fuse_in_header in_header = { + .len = sizeof(struct fuse_in_header) + sizeof(struct fuse_getattr_in), + .opcode = FUSE_GETATTR, + .unique = 4, + }; + struct fuse_getattr_in in_payload = {0}; + iov_in[0].iov_len = sizeof(in_header); + iov_in[0].iov_base = &in_header; + iov_in[1].iov_len = sizeof(in_payload); + iov_in[1].iov_base = &in_payload; + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header), + .error = -ENOENT, + .unique = 4, + }; + iov_out[0].iov_len = sizeof(out_header); + iov_out[0].iov_base = &out_header; + + SetExpected(iov_in, 2, iov_out, 1); + + struct stat stat_buf; + EXPECT_THAT(stat(mount_point_.path().c_str(), &stat_buf), + SyscallFailsWithErrno(ENOENT)); + WaitCompleted(); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/image/BUILD b/test/image/BUILD index 09b0a0ad5..e749e47d4 100644 --- a/test/image/BUILD +++ b/test/image/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -14,7 +14,7 @@ go_test( "ruby.rb", "ruby.sh", ], - embed = [":image"], + library = ":image", tags = [ # Requires docker and runsc to be configured before the test runs. "manual", @@ -22,13 +22,12 @@ go_test( ], visibility = ["//:sandbox"], deps = [ - "//runsc/dockerutil", - "//runsc/testutil", + "//pkg/test/dockerutil", + "//pkg/test/testutil", ], ) go_library( name = "image", srcs = ["image.go"], - importpath = "gvisor.dev/gvisor/test/image", ) diff --git a/test/image/image_test.go b/test/image/image_test.go index d0dcb1861..ac6186688 100644 --- a/test/image/image_test.go +++ b/test/image/image_test.go @@ -22,30 +22,44 @@ package image import ( + "context" "flag" "fmt" "io/ioutil" "log" "net/http" "os" - "path/filepath" "strings" "testing" "time" - "gvisor.dev/gvisor/runsc/dockerutil" - "gvisor.dev/gvisor/runsc/testutil" + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/pkg/test/testutil" ) +// defaultWait defines how long to wait for progress. +// +// See BUILD: This is at least a "large" test, so allow up to 1 minute for any +// given "wait" step. Note that all tests are run in parallel, which may cause +// individual slow-downs (but a huge speed-up in aggregate). +const defaultWait = time.Minute + func TestHelloWorld(t *testing.T) { - d := dockerutil.MakeDocker("hello-test") - if err := d.Run("hello-world"); err != nil { + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) + + // Run the basic container. + out, err := d.Run(ctx, dockerutil.RunOpts{ + Image: "basic/alpine", + }, "echo", "Hello world!") + if err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() - if _, err := d.WaitForOutput("Hello from Docker!", 5*time.Second); err != nil { - t.Fatalf("docker didn't say hello: %v", err) + // Check the output. + if !strings.Contains(out, "Hello world!") { + t.Fatalf("docker didn't say hello: got %s", out) } } @@ -102,31 +116,28 @@ func testHTTPServer(t *testing.T, port int) { } func TestHttpd(t *testing.T) { - if err := dockerutil.Pull("httpd"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("http-test") - - dir, err := dockerutil.PrepareFiles("latin10k.txt") - if err != nil { - t.Fatalf("PrepareFiles() failed: %v", err) - } + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container. - mountArg := dockerutil.MountArg(dir, "/usr/local/apache2/htdocs", dockerutil.ReadOnly) - if err := d.Run("-p", "80", mountArg, "httpd"); err != nil { + opts := dockerutil.RunOpts{ + Image: "basic/httpd", + Ports: []int{80}, + } + d.CopyFiles(&opts, "/usr/local/apache2/htdocs", "test/image/latin10k.txt") + if err := d.Spawn(ctx, opts); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() // Find where port 80 is mapped to. - port, err := d.FindPort(80) + port, err := d.FindPort(ctx, 80) if err != nil { - t.Fatalf("docker.FindPort(80) failed: %v", err) + t.Fatalf("FindPort(80) failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { + if err := testutil.WaitForHTTP(port, defaultWait); err != nil { t.Errorf("WaitForHTTP() timeout: %v", err) } @@ -134,31 +145,28 @@ func TestHttpd(t *testing.T) { } func TestNginx(t *testing.T) { - if err := dockerutil.Pull("nginx"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("net-test") - - dir, err := dockerutil.PrepareFiles("latin10k.txt") - if err != nil { - t.Fatalf("PrepareFiles() failed: %v", err) - } + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container. - mountArg := dockerutil.MountArg(dir, "/usr/share/nginx/html", dockerutil.ReadOnly) - if err := d.Run("-p", "80", mountArg, "nginx"); err != nil { + opts := dockerutil.RunOpts{ + Image: "basic/nginx", + Ports: []int{80}, + } + d.CopyFiles(&opts, "/usr/share/nginx/html", "test/image/latin10k.txt") + if err := d.Spawn(ctx, opts); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() // Find where port 80 is mapped to. - port, err := d.FindPort(80) + port, err := d.FindPort(ctx, 80) if err != nil { - t.Fatalf("docker.FindPort(80) failed: %v", err) + t.Fatalf("FindPort(80) failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { + if err := testutil.WaitForHTTP(port, defaultWait); err != nil { t.Errorf("WaitForHTTP() timeout: %v", err) } @@ -166,103 +174,65 @@ func TestNginx(t *testing.T) { } func TestMysql(t *testing.T) { - if err := dockerutil.Pull("mysql"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("mysql-test") + ctx := context.Background() + server := dockerutil.MakeContainer(ctx, t) + defer server.CleanUp(ctx) // Start the container. - if err := d.Run("-e", "MYSQL_ROOT_PASSWORD=foobar123", "mysql"); err != nil { + if err := server.Spawn(ctx, dockerutil.RunOpts{ + Image: "basic/mysql", + Env: []string{"MYSQL_ROOT_PASSWORD=foobar123"}, + }); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() // Wait until it's up and running. - if _, err := d.WaitForOutput("port: 3306 MySQL Community Server", 3*time.Minute); err != nil { - t.Fatalf("docker.WaitForOutput() timeout: %v", err) + if _, err := server.WaitForOutput(ctx, "port: 3306 MySQL Community Server", defaultWait); err != nil { + t.Fatalf("WaitForOutput() timeout: %v", err) } - client := dockerutil.MakeDocker("mysql-client-test") - dir, err := dockerutil.PrepareFiles("mysql.sql") - if err != nil { - t.Fatalf("PrepareFiles() failed: %v", err) - } + // Generate the client and copy in the SQL payload. + client := dockerutil.MakeContainer(ctx, t) + defer client.CleanUp(ctx) - // Tell mysql client to connect to the server and execute the file in verbose - // mode to verify the output. - args := []string{ - dockerutil.LinkArg(&d, "mysql"), - dockerutil.MountArg(dir, "/sql", dockerutil.ReadWrite), - "mysql", - "mysql", "-hmysql", "-uroot", "-pfoobar123", "-v", "-e", "source /sql/mysql.sql", + // Tell mysql client to connect to the server and execute the file in + // verbose mode to verify the output. + opts := dockerutil.RunOpts{ + Image: "basic/mysql", + Links: []string{server.MakeLink("mysql")}, } - if err := client.Run(args...); err != nil { + client.CopyFiles(&opts, "/sql", "test/image/mysql.sql") + if _, err := client.Run(ctx, opts, "mysql", "-hmysql", "-uroot", "-pfoobar123", "-v", "-e", "source /sql/mysql.sql"); err != nil { t.Fatalf("docker run failed: %v", err) } - defer client.CleanUp() // Ensure file executed to the end and shutdown mysql. - if _, err := client.WaitForOutput("--------------\nshutdown\n--------------", 15*time.Second); err != nil { - t.Fatalf("docker.WaitForOutput() timeout: %v", err) - } - if _, err := d.WaitForOutput("mysqld: Shutdown complete", 30*time.Second); err != nil { - t.Fatalf("docker.WaitForOutput() timeout: %v", err) - } -} - -func TestPythonHello(t *testing.T) { - // TODO(b/136503277): Once we have more complete python runtime tests, - // we can drop this one. - const img = "gcr.io/gvisor-presubmit/python-hello" - if err := dockerutil.Pull(img); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("python-hello-test") - if err := d.Run("-p", "8080", img); err != nil { - t.Fatalf("docker run failed: %v", err) - } - defer d.CleanUp() - - // Find where port 8080 is mapped to. - port, err := d.FindPort(8080) - if err != nil { - t.Fatalf("docker.FindPort(8080) failed: %v", err) - } - - // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { - t.Fatalf("WaitForHTTP() timeout: %v", err) - } - - // Ensure that content is being served. - url := fmt.Sprintf("http://localhost:%d", port) - resp, err := http.Get(url) - if err != nil { - t.Errorf("Error reaching http server: %v", err) - } - if want := http.StatusOK; resp.StatusCode != want { - t.Errorf("Wrong response code, got: %d, want: %d", resp.StatusCode, want) + if _, err := server.WaitForOutput(ctx, "mysqld: Shutdown complete", defaultWait); err != nil { + t.Fatalf("WaitForOutput() timeout: %v", err) } } func TestTomcat(t *testing.T) { - if err := dockerutil.Pull("tomcat:8.0"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("tomcat-test") - if err := d.Run("-p", "8080", "tomcat:8.0"); err != nil { + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) + + // Start the server. + if err := d.Spawn(ctx, dockerutil.RunOpts{ + Image: "basic/tomcat", + Ports: []int{8080}, + }); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() // Find where port 8080 is mapped to. - port, err := d.FindPort(8080) + port, err := d.FindPort(ctx, 8080) if err != nil { - t.Fatalf("docker.FindPort(8080) failed: %v", err) + t.Fatalf("FindPort(8080) failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { + if err := testutil.WaitForHTTP(port, defaultWait); err != nil { t.Fatalf("WaitForHTTP() timeout: %v", err) } @@ -278,32 +248,28 @@ func TestTomcat(t *testing.T) { } func TestRuby(t *testing.T) { - if err := dockerutil.Pull("ruby"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("ruby-test") + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) - dir, err := dockerutil.PrepareFiles("ruby.rb", "ruby.sh") - if err != nil { - t.Fatalf("PrepareFiles() failed: %v", err) - } - if err := os.Chmod(filepath.Join(dir, "ruby.sh"), 0333); err != nil { - t.Fatalf("os.Chmod(%q, 0333) failed: %v", dir, err) + // Execute the ruby workload. + opts := dockerutil.RunOpts{ + Image: "basic/ruby", + Ports: []int{8080}, } - - if err := d.Run("-p", "8080", dockerutil.MountArg(dir, "/src", dockerutil.ReadOnly), "ruby", "/src/ruby.sh"); err != nil { + d.CopyFiles(&opts, "/src", "test/image/ruby.rb", "test/image/ruby.sh") + if err := d.Spawn(ctx, opts, "/src/ruby.sh"); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() // Find where port 8080 is mapped to. - port, err := d.FindPort(8080) + port, err := d.FindPort(ctx, 8080) if err != nil { - t.Fatalf("docker.FindPort(8080) failed: %v", err) + t.Fatalf("FindPort(8080) failed: %v", err) } // Wait until it's up and running, 'gem install' can take some time. - if err := testutil.WaitForHTTP(port, 1*time.Minute); err != nil { + if err := testutil.WaitForHTTP(port, time.Minute); err != nil { t.Fatalf("WaitForHTTP() timeout: %v", err) } @@ -326,21 +292,21 @@ func TestRuby(t *testing.T) { } func TestStdio(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("stdio-test") + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) wantStdout := "hello stdout" wantStderr := "bonjour stderr" cmd := fmt.Sprintf("echo %q; echo %q 1>&2;", wantStdout, wantStderr) - if err := d.Run("alpine", "/bin/sh", "-c", cmd); err != nil { + if err := d.Spawn(ctx, dockerutil.RunOpts{ + Image: "basic/alpine", + }, "/bin/sh", "-c", cmd); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() for _, want := range []string{wantStdout, wantStderr} { - if _, err := d.WaitForOutput(want, 5*time.Second); err != nil { + if _, err := d.WaitForOutput(ctx, want, defaultWait); err != nil { t.Fatalf("docker didn't get output %q : %v", want, err) } } diff --git a/test/image/ruby.sh b/test/image/ruby.sh index ebe8d5b0e..ebe8d5b0e 100644..100755 --- a/test/image/ruby.sh +++ b/test/image/ruby.sh diff --git a/test/iptables/BUILD b/test/iptables/BUILD new file mode 100644 index 000000000..66453772a --- /dev/null +++ b/test/iptables/BUILD @@ -0,0 +1,38 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "iptables", + testonly = 1, + srcs = [ + "filter_input.go", + "filter_output.go", + "iptables.go", + "iptables_unsafe.go", + "iptables_util.go", + "nat.go", + ], + visibility = ["//test/iptables:__subpackages__"], + deps = [ + "//pkg/test/testutil", + ], +) + +go_test( + name = "iptables_test", + size = "large", + srcs = [ + "iptables_test.go", + ], + data = ["//test/iptables/runner"], + library = ":iptables", + tags = [ + "local", + "manual", + ], + deps = [ + "//pkg/test/dockerutil", + "//pkg/test/testutil", + ], +) diff --git a/test/iptables/README.md b/test/iptables/README.md new file mode 100644 index 000000000..b9f44bd40 --- /dev/null +++ b/test/iptables/README.md @@ -0,0 +1,54 @@ +# iptables Tests + +iptables tests are run via `scripts/iptables_test.sh`. + +iptables requires raw socket support, so you must add the `--net-raw=true` flag +to `/etc/docker/daemon.json` in order to use it. + +## Test Structure + +Each test implements `TestCase`, providing (1) a function to run inside the +container and (2) a function to run locally. Those processes are given each +others' IP addresses. The test succeeds when both functions succeed. + +The function inside the container (`ContainerAction`) typically sets some +iptables rules and then tries to send or receive packets. The local function +(`LocalAction`) will typically just send or receive packets. + +### Adding Tests + +1) Add your test to the `iptables` package. + +2) Register the test in an `init` function via `RegisterTestCase` (see +`filter_input.go` as an example). + +3) Add it to `iptables_test.go` (see the other tests in that file). + +Your test is now runnable with bazel! + +## Run individual tests + +Build and install `runsc`. Re-run this when you modify gVisor: + +```bash +$ bazel build //runsc && sudo cp bazel-bin/runsc/linux_amd64_pure_stripped/runsc $(which runsc) +``` + +Build the testing Docker container. Re-run this when you modify the test code in +this directory: + +```bash +$ make load-iptables +``` + +Run an individual test via: + +```bash +$ bazel test //test/iptables:iptables_test --test_filter=<TESTNAME> +``` + +To run an individual test with `runc`: + +```bash +$ bazel test //test/iptables:iptables_test --test_filter=<TESTNAME> --test_arg=--runtime=runc +``` diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go new file mode 100644 index 000000000..b45d448b8 --- /dev/null +++ b/test/iptables/filter_input.go @@ -0,0 +1,745 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package iptables + +import ( + "context" + "errors" + "fmt" + "net" + "time" +) + +const ( + dropPort = 2401 + acceptPort = 2402 + sendloopDuration = 2 * time.Second + chainName = "foochain" +) + +func init() { + RegisterTestCase(FilterInputDropAll{}) + RegisterTestCase(FilterInputDropDifferentUDPPort{}) + RegisterTestCase(FilterInputDropOnlyUDP{}) + RegisterTestCase(FilterInputDropTCPDestPort{}) + RegisterTestCase(FilterInputDropTCPSrcPort{}) + RegisterTestCase(FilterInputDropUDPPort{}) + RegisterTestCase(FilterInputDropUDP{}) + RegisterTestCase(FilterInputCreateUserChain{}) + RegisterTestCase(FilterInputDefaultPolicyAccept{}) + RegisterTestCase(FilterInputDefaultPolicyDrop{}) + RegisterTestCase(FilterInputReturnUnderflow{}) + RegisterTestCase(FilterInputSerializeJump{}) + RegisterTestCase(FilterInputJumpBasic{}) + RegisterTestCase(FilterInputJumpReturn{}) + RegisterTestCase(FilterInputJumpReturnDrop{}) + RegisterTestCase(FilterInputJumpBuiltin{}) + RegisterTestCase(FilterInputJumpTwice{}) + RegisterTestCase(FilterInputDestination{}) + RegisterTestCase(FilterInputInvertDestination{}) + RegisterTestCase(FilterInputSource{}) + RegisterTestCase(FilterInputInvertSource{}) +} + +// FilterInputDropUDP tests that we can drop UDP traffic. +type FilterInputDropUDP struct{ containerCase } + +// Name implements TestCase.Name. +func (FilterInputDropUDP) Name() string { + return "FilterInputDropUDP" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputDropUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-j", "DROP"); err != nil { + return err + } + + // Listen for UDP packets on dropPort. + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, dropPort); err == nil { + return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) + } + + // At this point we know that reading timed out and never received a + // packet. + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputDropUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, dropPort) +} + +// FilterInputDropOnlyUDP tests that "-p udp -j DROP" only affects UDP traffic. +type FilterInputDropOnlyUDP struct{ baseCase } + +// Name implements TestCase.Name. +func (FilterInputDropOnlyUDP) Name() string { + return "FilterInputDropOnlyUDP" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputDropOnlyUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-j", "DROP"); err != nil { + return err + } + + // Listen for a TCP connection, which should be allowed. + if err := listenTCP(ctx, acceptPort); err != nil { + return fmt.Errorf("failed to establish a connection %v", err) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputDropOnlyUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Try to establish a TCP connection with the container, which should + // succeed. + return connectTCP(ctx, ip, acceptPort) +} + +// FilterInputDropUDPPort tests that we can drop UDP traffic by port. +type FilterInputDropUDPPort struct{ containerCase } + +// Name implements TestCase.Name. +func (FilterInputDropUDPPort) Name() string { + return "FilterInputDropUDPPort" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputDropUDPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { + return err + } + + // Listen for UDP packets on dropPort. + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, dropPort); err == nil { + return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) + } + + // At this point we know that reading timed out and never received a + // packet. + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputDropUDPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, dropPort) +} + +// FilterInputDropDifferentUDPPort tests that dropping traffic for a single UDP port +// doesn't drop packets on other ports. +type FilterInputDropDifferentUDPPort struct{ containerCase } + +// Name implements TestCase.Name. +func (FilterInputDropDifferentUDPPort) Name() string { + return "FilterInputDropDifferentUDPPort" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputDropDifferentUDPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { + return err + } + + // Listen for UDP packets on another port. + if err := listenUDP(ctx, acceptPort); err != nil { + return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", acceptPort, err) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputDropDifferentUDPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) +} + +// FilterInputDropTCPDestPort tests that connections are not accepted on specified source ports. +type FilterInputDropTCPDestPort struct{ baseCase } + +// Name implements TestCase.Name. +func (FilterInputDropTCPDestPort) Name() string { + return "FilterInputDropTCPDestPort" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputDropTCPDestPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { + return err + } + + // Listen for TCP packets on drop port. + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, dropPort); err == nil { + return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputDropTCPDestPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Ensure we cannot connect to the container. + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, dropPort); err == nil { + return fmt.Errorf("expected not to connect, but was able to connect on port %d", dropPort) + } + return nil +} + +// FilterInputDropTCPSrcPort tests that connections are not accepted on specified source ports. +type FilterInputDropTCPSrcPort struct{ baseCase } + +// Name implements TestCase.Name. +func (FilterInputDropTCPSrcPort) Name() string { + return "FilterInputDropTCPSrcPort" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputDropTCPSrcPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Drop anything from an ephemeral port. + if err := filterTable(ipv6, "-A", "INPUT", "-p", "tcp", "-m", "tcp", "--sport", "1024:65535", "-j", "DROP"); err != nil { + return err + } + + // Listen for TCP packets on accept port. + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { + return fmt.Errorf("connection destined to port %d should not be accepted, but was", dropPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputDropTCPSrcPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Ensure we cannot connect to the container. + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, dropPort); err == nil { + return fmt.Errorf("expected not to connect, but was able to connect on port %d", acceptPort) + } + return nil +} + +// FilterInputDropAll tests that we can drop all traffic to the INPUT chain. +type FilterInputDropAll struct{ containerCase } + +// Name implements TestCase.Name. +func (FilterInputDropAll) Name() string { + return "FilterInputDropAll" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputDropAll) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-j", "DROP"); err != nil { + return err + } + + // Listen for all packets on dropPort. + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, dropPort); err == nil { + return fmt.Errorf("packets should have been dropped, but got a packet") + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) + } + + // At this point we know that reading timed out and never received a + // packet. + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputDropAll) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, dropPort) +} + +// FilterInputMultiUDPRules verifies that multiple UDP rules are applied +// correctly. This has the added benefit of testing whether we're serializing +// rules correctly -- if we do it incorrectly, the iptables tool will +// misunderstand and save the wrong tables. +type FilterInputMultiUDPRules struct{ baseCase } + +// Name implements TestCase.Name. +func (FilterInputMultiUDPRules) Name() string { + return "FilterInputMultiUDPRules" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputMultiUDPRules) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + rules := [][]string{ + {"-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"}, + {"-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", acceptPort), "-j", "ACCEPT"}, + {"-L"}, + } + return filterTableRules(ipv6, rules) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputMultiUDPRules) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // No-op. + return nil +} + +// FilterInputRequireProtocolUDP checks that "-m udp" requires "-p udp" to be +// specified. +type FilterInputRequireProtocolUDP struct{ baseCase } + +// Name implements TestCase.Name. +func (FilterInputRequireProtocolUDP) Name() string { + return "FilterInputRequireProtocolUDP" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputRequireProtocolUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err == nil { + return errors.New("expected iptables to fail with out \"-p udp\", but succeeded") + } + return nil +} + +func (FilterInputRequireProtocolUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // No-op. + return nil +} + +// FilterInputCreateUserChain tests chain creation. +type FilterInputCreateUserChain struct{ baseCase } + +// Name implements TestCase.Name. +func (FilterInputCreateUserChain) Name() string { + return "FilterInputCreateUserChain" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputCreateUserChain) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + rules := [][]string{ + // Create a chain. + {"-N", chainName}, + // Add a simple rule to the chain. + {"-A", chainName, "-j", "DROP"}, + } + return filterTableRules(ipv6, rules) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputCreateUserChain) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // No-op. + return nil +} + +// FilterInputDefaultPolicyAccept tests the default ACCEPT policy. +type FilterInputDefaultPolicyAccept struct{ containerCase } + +// Name implements TestCase.Name. +func (FilterInputDefaultPolicyAccept) Name() string { + return "FilterInputDefaultPolicyAccept" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputDefaultPolicyAccept) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Set the default policy to accept, then receive a packet. + if err := filterTable(ipv6, "-P", "INPUT", "ACCEPT"); err != nil { + return err + } + return listenUDP(ctx, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputDefaultPolicyAccept) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) +} + +// FilterInputDefaultPolicyDrop tests the default DROP policy. +type FilterInputDefaultPolicyDrop struct{ containerCase } + +// Name implements TestCase.Name. +func (FilterInputDefaultPolicyDrop) Name() string { + return "FilterInputDefaultPolicyDrop" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputDefaultPolicyDrop) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-P", "INPUT", "DROP"); err != nil { + return err + } + + // Listen for UDP packets on dropPort. + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, dropPort); err == nil { + return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) + } + + // At this point we know that reading timed out and never received a + // packet. + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputDefaultPolicyDrop) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) +} + +// FilterInputReturnUnderflow tests that -j RETURN in a built-in chain causes +// the underflow rule (i.e. default policy) to be executed. +type FilterInputReturnUnderflow struct{ containerCase } + +// Name implements TestCase.Name. +func (FilterInputReturnUnderflow) Name() string { + return "FilterInputReturnUnderflow" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputReturnUnderflow) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Add a RETURN rule followed by an unconditional accept, and set the + // default policy to DROP. + rules := [][]string{ + {"-A", "INPUT", "-j", "RETURN"}, + {"-A", "INPUT", "-j", "DROP"}, + {"-P", "INPUT", "ACCEPT"}, + } + if err := filterTableRules(ipv6, rules); err != nil { + return err + } + + // We should receive packets, as the RETURN rule will trigger the default + // ACCEPT policy. + return listenUDP(ctx, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputReturnUnderflow) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) +} + +// FilterInputSerializeJump verifies that we can serialize jumps. +type FilterInputSerializeJump struct{ baseCase } + +// Name implements TestCase.Name. +func (FilterInputSerializeJump) Name() string { + return "FilterInputSerializeJump" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputSerializeJump) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Write a JUMP rule, the serialize it with `-L`. + rules := [][]string{ + {"-N", chainName}, + {"-A", "INPUT", "-j", chainName}, + {"-L"}, + } + return filterTableRules(ipv6, rules) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputSerializeJump) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // No-op. + return nil +} + +// FilterInputJumpBasic jumps to a chain and executes a rule there. +type FilterInputJumpBasic struct{ containerCase } + +// Name implements TestCase.Name. +func (FilterInputJumpBasic) Name() string { + return "FilterInputJumpBasic" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputJumpBasic) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + rules := [][]string{ + {"-P", "INPUT", "DROP"}, + {"-N", chainName}, + {"-A", "INPUT", "-j", chainName}, + {"-A", chainName, "-j", "ACCEPT"}, + } + if err := filterTableRules(ipv6, rules); err != nil { + return err + } + + // Listen for UDP packets on acceptPort. + return listenUDP(ctx, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputJumpBasic) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) +} + +// FilterInputJumpReturn jumps, returns, and executes a rule. +type FilterInputJumpReturn struct{ containerCase } + +// Name implements TestCase.Name. +func (FilterInputJumpReturn) Name() string { + return "FilterInputJumpReturn" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputJumpReturn) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + rules := [][]string{ + {"-N", chainName}, + {"-P", "INPUT", "ACCEPT"}, + {"-A", "INPUT", "-j", chainName}, + {"-A", chainName, "-j", "RETURN"}, + {"-A", chainName, "-j", "DROP"}, + } + if err := filterTableRules(ipv6, rules); err != nil { + return err + } + + // Listen for UDP packets on acceptPort. + return listenUDP(ctx, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputJumpReturn) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) +} + +// FilterInputJumpReturnDrop jumps to a chain, returns, and DROPs packets. +type FilterInputJumpReturnDrop struct{ containerCase } + +// Name implements TestCase.Name. +func (FilterInputJumpReturnDrop) Name() string { + return "FilterInputJumpReturnDrop" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputJumpReturnDrop) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + rules := [][]string{ + {"-N", chainName}, + {"-A", "INPUT", "-j", chainName}, + {"-A", "INPUT", "-j", "DROP"}, + {"-A", chainName, "-j", "RETURN"}, + } + if err := filterTableRules(ipv6, rules); err != nil { + return err + } + + // Listen for UDP packets on dropPort. + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, dropPort); err == nil { + return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) + } + + // At this point we know that reading timed out and never received a + // packet. + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputJumpReturnDrop) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, dropPort) +} + +// FilterInputJumpBuiltin verifies that jumping to a top-levl chain is illegal. +type FilterInputJumpBuiltin struct{ baseCase } + +// Name implements TestCase.Name. +func (FilterInputJumpBuiltin) Name() string { + return "FilterInputJumpBuiltin" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputJumpBuiltin) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-j", "OUTPUT"); err == nil { + return fmt.Errorf("iptables should be unable to jump to a built-in chain") + } + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputJumpBuiltin) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // No-op. + return nil +} + +// FilterInputJumpTwice jumps twice, then returns twice and executes a rule. +type FilterInputJumpTwice struct{ containerCase } + +// Name implements TestCase.Name. +func (FilterInputJumpTwice) Name() string { + return "FilterInputJumpTwice" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputJumpTwice) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + const chainName2 = chainName + "2" + rules := [][]string{ + {"-P", "INPUT", "DROP"}, + {"-N", chainName}, + {"-N", chainName2}, + {"-A", "INPUT", "-j", chainName}, + {"-A", chainName, "-j", chainName2}, + {"-A", "INPUT", "-j", "ACCEPT"}, + } + if err := filterTableRules(ipv6, rules); err != nil { + return err + } + + // UDP packets should jump and return twice, eventually hitting the + // ACCEPT rule. + return listenUDP(ctx, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputJumpTwice) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) +} + +// FilterInputDestination verifies that we can filter packets via `-d +// <ipaddr>`. +type FilterInputDestination struct{ containerCase } + +// Name implements TestCase.Name. +func (FilterInputDestination) Name() string { + return "FilterInputDestination" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputDestination) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + addrs, err := localAddrs(ipv6) + if err != nil { + return err + } + + // Make INPUT's default action DROP, then ACCEPT all packets bound for + // this machine. + rules := [][]string{{"-P", "INPUT", "DROP"}} + for _, addr := range addrs { + rules = append(rules, []string{"-A", "INPUT", "-d", addr, "-j", "ACCEPT"}) + } + if err := filterTableRules(ipv6, rules); err != nil { + return err + } + + return listenUDP(ctx, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputDestination) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) +} + +// FilterInputInvertDestination verifies that we can filter packets via `! -d +// <ipaddr>`. +type FilterInputInvertDestination struct{ containerCase } + +// Name implements TestCase.Name. +func (FilterInputInvertDestination) Name() string { + return "FilterInputInvertDestination" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputInvertDestination) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Make INPUT's default action DROP, then ACCEPT all packets not bound + // for 127.0.0.1. + rules := [][]string{ + {"-P", "INPUT", "DROP"}, + {"-A", "INPUT", "!", "-d", localIP(ipv6), "-j", "ACCEPT"}, + } + if err := filterTableRules(ipv6, rules); err != nil { + return err + } + + return listenUDP(ctx, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputInvertDestination) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) +} + +// FilterInputSource verifies that we can filter packets via `-s +// <ipaddr>`. +type FilterInputSource struct{ containerCase } + +// Name implements TestCase.Name. +func (FilterInputSource) Name() string { + return "FilterInputSource" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputSource) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Make INPUT's default action DROP, then ACCEPT all packets from this + // machine. + rules := [][]string{ + {"-P", "INPUT", "DROP"}, + {"-A", "INPUT", "-s", fmt.Sprintf("%v", ip), "-j", "ACCEPT"}, + } + if err := filterTableRules(ipv6, rules); err != nil { + return err + } + + return listenUDP(ctx, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputSource) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) +} + +// FilterInputInvertSource verifies that we can filter packets via `! -s +// <ipaddr>`. +type FilterInputInvertSource struct{ containerCase } + +// Name implements TestCase.Name. +func (FilterInputInvertSource) Name() string { + return "FilterInputInvertSource" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputInvertSource) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Make INPUT's default action DROP, then ACCEPT all packets not bound + // for 127.0.0.1. + rules := [][]string{ + {"-P", "INPUT", "DROP"}, + {"-A", "INPUT", "!", "-s", localIP(ipv6), "-j", "ACCEPT"}, + } + if err := filterTableRules(ipv6, rules); err != nil { + return err + } + + return listenUDP(ctx, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputInvertSource) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) +} diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go new file mode 100644 index 000000000..32bf2a992 --- /dev/null +++ b/test/iptables/filter_output.go @@ -0,0 +1,663 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package iptables + +import ( + "context" + "errors" + "fmt" + "net" +) + +func init() { + RegisterTestCase(FilterOutputDropTCPDestPort{}) + RegisterTestCase(FilterOutputDropTCPSrcPort{}) + RegisterTestCase(FilterOutputDestination{}) + RegisterTestCase(FilterOutputInvertDestination{}) + RegisterTestCase(FilterOutputAcceptTCPOwner{}) + RegisterTestCase(FilterOutputDropTCPOwner{}) + RegisterTestCase(FilterOutputAcceptUDPOwner{}) + RegisterTestCase(FilterOutputDropUDPOwner{}) + RegisterTestCase(FilterOutputOwnerFail{}) + RegisterTestCase(FilterOutputAcceptGIDOwner{}) + RegisterTestCase(FilterOutputDropGIDOwner{}) + RegisterTestCase(FilterOutputInvertGIDOwner{}) + RegisterTestCase(FilterOutputInvertUIDOwner{}) + RegisterTestCase(FilterOutputInvertUIDAndGIDOwner{}) + RegisterTestCase(FilterOutputInterfaceAccept{}) + RegisterTestCase(FilterOutputInterfaceDrop{}) + RegisterTestCase(FilterOutputInterface{}) + RegisterTestCase(FilterOutputInterfaceBeginsWith{}) + RegisterTestCase(FilterOutputInterfaceInvertDrop{}) + RegisterTestCase(FilterOutputInterfaceInvertAccept{}) +} + +// FilterOutputDropTCPDestPort tests that connections are not accepted on +// specified source ports. +type FilterOutputDropTCPDestPort struct{ baseCase } + +// Name implements TestCase.Name. +func (FilterOutputDropTCPDestPort) Name() string { + return "FilterOutputDropTCPDestPort" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputDropTCPDestPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", "1024:65535", "-j", "DROP"); err != nil { + return err + } + + // Listen for TCP packets on accept port. + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { + return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputDropTCPDestPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, acceptPort); err == nil { + return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort) + } + + return nil +} + +// FilterOutputDropTCPSrcPort tests that connections are not accepted on +// specified source ports. +type FilterOutputDropTCPSrcPort struct{ baseCase } + +// Name implements TestCase.Name. +func (FilterOutputDropTCPSrcPort) Name() string { + return "FilterOutputDropTCPSrcPort" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputDropTCPSrcPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--sport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { + return err + } + + // Listen for TCP packets on drop port. + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, dropPort); err == nil { + return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputDropTCPSrcPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, dropPort); err == nil { + return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort) + } + + return nil +} + +// FilterOutputAcceptTCPOwner tests that TCP connections from uid owner are accepted. +type FilterOutputAcceptTCPOwner struct{ baseCase } + +// Name implements TestCase.Name. +func (FilterOutputAcceptTCPOwner) Name() string { + return "FilterOutputAcceptTCPOwner" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputAcceptTCPOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil { + return err + } + + // Listen for TCP packets on accept port. + return listenTCP(ctx, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputAcceptTCPOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, acceptPort) +} + +// FilterOutputDropTCPOwner tests that TCP connections from uid owner are dropped. +type FilterOutputDropTCPOwner struct{ baseCase } + +// Name implements TestCase.Name. +func (FilterOutputDropTCPOwner) Name() string { + return "FilterOutputDropTCPOwner" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputDropTCPOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil { + return err + } + + // Listen for TCP packets on accept port. + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { + return fmt.Errorf("connection on port %d should be dropped, but got accepted", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputDropTCPOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, acceptPort); err == nil { + return fmt.Errorf("connection destined to port %d should be dropped, but got accepted", acceptPort) + } + + return nil +} + +// FilterOutputAcceptUDPOwner tests that UDP packets from uid owner are accepted. +type FilterOutputAcceptUDPOwner struct{ localCase } + +// Name implements TestCase.Name. +func (FilterOutputAcceptUDPOwner) Name() string { + return "FilterOutputAcceptUDPOwner" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputAcceptUDPOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil { + return err + } + + // Send UDP packets on acceptPort. + return sendUDPLoop(ctx, ip, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputAcceptUDPOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Listen for UDP packets on acceptPort. + return listenUDP(ctx, acceptPort) +} + +// FilterOutputDropUDPOwner tests that UDP packets from uid owner are dropped. +type FilterOutputDropUDPOwner struct{ localCase } + +// Name implements TestCase.Name. +func (FilterOutputDropUDPOwner) Name() string { + return "FilterOutputDropUDPOwner" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputDropUDPOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil { + return err + } + + // Send UDP packets on dropPort. + return sendUDPLoop(ctx, ip, dropPort) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputDropUDPOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Listen for UDP packets on dropPort. + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, dropPort); err == nil { + return fmt.Errorf("packets should not be received") + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) + } + + return nil +} + +// FilterOutputOwnerFail tests that without uid/gid option, owner rule +// will fail. +type FilterOutputOwnerFail struct{ baseCase } + +// Name implements TestCase.Name. +func (FilterOutputOwnerFail) Name() string { + return "FilterOutputOwnerFail" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputOwnerFail) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-m", "owner", "-j", "ACCEPT"); err == nil { + return fmt.Errorf("Invalid argument") + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputOwnerFail) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // no-op. + return nil +} + +// FilterOutputAcceptGIDOwner tests that TCP connections from gid owner are accepted. +type FilterOutputAcceptGIDOwner struct{ baseCase } + +// Name implements TestCase.Name. +func (FilterOutputAcceptGIDOwner) Name() string { + return "FilterOutputAcceptGIDOwner" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputAcceptGIDOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--gid-owner", "root", "-j", "ACCEPT"); err != nil { + return err + } + + // Listen for TCP packets on accept port. + return listenTCP(ctx, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputAcceptGIDOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, acceptPort) +} + +// FilterOutputDropGIDOwner tests that TCP connections from gid owner are dropped. +type FilterOutputDropGIDOwner struct{ baseCase } + +// Name implements TestCase.Name. +func (FilterOutputDropGIDOwner) Name() string { + return "FilterOutputDropGIDOwner" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputDropGIDOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--gid-owner", "root", "-j", "DROP"); err != nil { + return err + } + + // Listen for TCP packets on accept port. + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { + return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputDropGIDOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, acceptPort); err == nil { + return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort) + } + + return nil +} + +// FilterOutputInvertGIDOwner tests that TCP connections from gid owner are dropped. +type FilterOutputInvertGIDOwner struct{ baseCase } + +// Name implements TestCase.Name. +func (FilterOutputInvertGIDOwner) Name() string { + return "FilterOutputInvertGIDOwner" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputInvertGIDOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + rules := [][]string{ + {"-A", "OUTPUT", "-p", "tcp", "-m", "owner", "!", "--gid-owner", "root", "-j", "ACCEPT"}, + {"-A", "OUTPUT", "-p", "tcp", "-j", "DROP"}, + } + if err := filterTableRules(ipv6, rules); err != nil { + return err + } + + // Listen for TCP packets on accept port. + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { + return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputInvertGIDOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, acceptPort); err == nil { + return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort) + } + + return nil +} + +// FilterOutputInvertUIDOwner tests that TCP connections from gid owner are dropped. +type FilterOutputInvertUIDOwner struct{ baseCase } + +// Name implements TestCase.Name. +func (FilterOutputInvertUIDOwner) Name() string { + return "FilterOutputInvertUIDOwner" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputInvertUIDOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + rules := [][]string{ + {"-A", "OUTPUT", "-p", "tcp", "-m", "owner", "!", "--uid-owner", "root", "-j", "DROP"}, + {"-A", "OUTPUT", "-p", "tcp", "-j", "ACCEPT"}, + } + if err := filterTableRules(ipv6, rules); err != nil { + return err + } + + // Listen for TCP packets on accept port. + return listenTCP(ctx, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputInvertUIDOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, acceptPort) +} + +// FilterOutputInvertUIDAndGIDOwner tests that TCP connections from uid and gid +// owner are dropped. +type FilterOutputInvertUIDAndGIDOwner struct{ baseCase } + +// Name implements TestCase.Name. +func (FilterOutputInvertUIDAndGIDOwner) Name() string { + return "FilterOutputInvertUIDAndGIDOwner" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputInvertUIDAndGIDOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + rules := [][]string{ + {"-A", "OUTPUT", "-p", "tcp", "-m", "owner", "!", "--uid-owner", "root", "!", "--gid-owner", "root", "-j", "ACCEPT"}, + {"-A", "OUTPUT", "-p", "tcp", "-j", "DROP"}, + } + if err := filterTableRules(ipv6, rules); err != nil { + return err + } + + // Listen for TCP packets on accept port. + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { + return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputInvertUIDAndGIDOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, acceptPort); err == nil { + return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort) + } + + return nil +} + +// FilterOutputDestination tests that we can selectively allow packets to +// certain destinations. +type FilterOutputDestination struct{ localCase } + +// Name implements TestCase.Name. +func (FilterOutputDestination) Name() string { + return "FilterOutputDestination" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputDestination) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + rules := [][]string{ + {"-A", "OUTPUT", "-d", ip.String(), "-j", "ACCEPT"}, + {"-P", "OUTPUT", "DROP"}, + } + if err := filterTableRules(ipv6, rules); err != nil { + return err + } + + return sendUDPLoop(ctx, ip, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputDestination) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return listenUDP(ctx, acceptPort) +} + +// FilterOutputInvertDestination tests that we can selectively allow packets +// not headed for a particular destination. +type FilterOutputInvertDestination struct{ localCase } + +// Name implements TestCase.Name. +func (FilterOutputInvertDestination) Name() string { + return "FilterOutputInvertDestination" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputInvertDestination) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + rules := [][]string{ + {"-A", "OUTPUT", "!", "-d", localIP(ipv6), "-j", "ACCEPT"}, + {"-P", "OUTPUT", "DROP"}, + } + if err := filterTableRules(ipv6, rules); err != nil { + return err + } + + return sendUDPLoop(ctx, ip, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputInvertDestination) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return listenUDP(ctx, acceptPort) +} + +// FilterOutputInterfaceAccept tests that packets are sent via interface +// matching the iptables rule. +type FilterOutputInterfaceAccept struct{ localCase } + +// Name implements TestCase.Name. +func (FilterOutputInterfaceAccept) Name() string { + return "FilterOutputInterfaceAccept" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputInterfaceAccept) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + ifname, ok := getInterfaceName() + if !ok { + return fmt.Errorf("no interface is present, except loopback") + } + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-o", ifname, "-j", "ACCEPT"); err != nil { + return err + } + + return sendUDPLoop(ctx, ip, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputInterfaceAccept) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return listenUDP(ctx, acceptPort) +} + +// FilterOutputInterfaceDrop tests that packets are not sent via interface +// matching the iptables rule. +type FilterOutputInterfaceDrop struct{ localCase } + +// Name implements TestCase.Name. +func (FilterOutputInterfaceDrop) Name() string { + return "FilterOutputInterfaceDrop" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputInterfaceDrop) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + ifname, ok := getInterfaceName() + if !ok { + return fmt.Errorf("no interface is present, except loopback") + } + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-o", ifname, "-j", "DROP"); err != nil { + return err + } + + return sendUDPLoop(ctx, ip, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputInterfaceDrop) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, acceptPort); err == nil { + return fmt.Errorf("packets should not be received on port %v, but are received", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) + } + + return nil +} + +// FilterOutputInterface tests that packets are sent via interface which is +// not matching the interface name in the iptables rule. +type FilterOutputInterface struct{ localCase } + +// Name implements TestCase.Name. +func (FilterOutputInterface) Name() string { + return "FilterOutputInterface" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputInterface) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-o", "lo", "-j", "DROP"); err != nil { + return err + } + + return sendUDPLoop(ctx, ip, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputInterface) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return listenUDP(ctx, acceptPort) +} + +// FilterOutputInterfaceBeginsWith tests that packets are not sent via an +// interface which begins with the given interface name. +type FilterOutputInterfaceBeginsWith struct{ localCase } + +// Name implements TestCase.Name. +func (FilterOutputInterfaceBeginsWith) Name() string { + return "FilterOutputInterfaceBeginsWith" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputInterfaceBeginsWith) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-o", "e+", "-j", "DROP"); err != nil { + return err + } + + return sendUDPLoop(ctx, ip, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputInterfaceBeginsWith) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, acceptPort); err == nil { + return fmt.Errorf("packets should not be received on port %v, but are received", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) + } + + return nil +} + +// FilterOutputInterfaceInvertDrop tests that we selectively do not send +// packets via interface not matching the interface name. +type FilterOutputInterfaceInvertDrop struct{ baseCase } + +// Name implements TestCase.Name. +func (FilterOutputInterfaceInvertDrop) Name() string { + return "FilterOutputInterfaceInvertDrop" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputInterfaceInvertDrop) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "!", "-o", "lo", "-j", "DROP"); err != nil { + return err + } + + // Listen for TCP packets on accept port. + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { + return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputInterfaceInvertDrop) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, acceptPort); err == nil { + return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort) + } + + return nil +} + +// FilterOutputInterfaceInvertAccept tests that we can selectively send packets +// not matching the specific outgoing interface. +type FilterOutputInterfaceInvertAccept struct{ baseCase } + +// Name implements TestCase.Name. +func (FilterOutputInterfaceInvertAccept) Name() string { + return "FilterOutputInterfaceInvertAccept" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputInterfaceInvertAccept) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "!", "-o", "lo", "-j", "ACCEPT"); err != nil { + return err + } + + // Listen for TCP packets on accept port. + return listenTCP(ctx, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputInterfaceInvertAccept) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, acceptPort) +} diff --git a/test/iptables/iptables.go b/test/iptables/iptables.go new file mode 100644 index 000000000..c2a03f54c --- /dev/null +++ b/test/iptables/iptables.go @@ -0,0 +1,115 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package iptables contains a set of iptables tests implemented as TestCases +package iptables + +import ( + "context" + "fmt" + "net" + "time" +) + +// IPExchangePort is the port the container listens on to receive the IP +// address of the local process. +const IPExchangePort = 2349 + +// TerminalStatement is the last statement in the test runner. +const TerminalStatement = "Finished!" + +// TestTimeout is the timeout used for all tests. +const TestTimeout = 10 * time.Second + +// NegativeTimeout is the time tests should wait to establish the negative +// case, i.e. that connections are not made. +const NegativeTimeout = 2 * time.Second + +// A TestCase contains one action to run in the container and one to run +// locally. The actions run concurrently and each must succeed for the test +// pass. +type TestCase interface { + // Name returns the name of the test. + Name() string + + // ContainerAction runs inside the container. It receives the IP of the + // local process. + ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error + + // LocalAction runs locally. It receives the IP of the container. + LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error + + // ContainerSufficient indicates whether ContainerAction's return value + // alone indicates whether the test succeeded. + ContainerSufficient() bool + + // LocalSufficient indicates whether LocalAction's return value alone + // indicates whether the test succeeded. + LocalSufficient() bool +} + +// baseCase provides defaults for ContainerSufficient and LocalSufficient when +// both actions are required to finish. +type baseCase struct{} + +// ContainerSufficient implements TestCase.ContainerSufficient. +func (baseCase) ContainerSufficient() bool { + return false +} + +// LocalSufficient implements TestCase.LocalSufficient. +func (baseCase) LocalSufficient() bool { + return false +} + +// localCase provides defaults for ContainerSufficient and LocalSufficient when +// only the local action is required to finish. +type localCase struct{} + +// ContainerSufficient implements TestCase.ContainerSufficient. +func (localCase) ContainerSufficient() bool { + return false +} + +// LocalSufficient implements TestCase.LocalSufficient. +func (localCase) LocalSufficient() bool { + return true +} + +// containerCase provides defaults for ContainerSufficient and LocalSufficient +// when only the container action is required to finish. +type containerCase struct{} + +// ContainerSufficient implements TestCase.ContainerSufficient. +func (containerCase) ContainerSufficient() bool { + return true +} + +// LocalSufficient implements TestCase.LocalSufficient. +func (containerCase) LocalSufficient() bool { + return false +} + +// Tests maps test names to TestCase. +// +// New TestCases are added by calling RegisterTestCase in an init function. +var Tests = map[string]TestCase{} + +// RegisterTestCase registers tc so it can be run. +func RegisterTestCase(tc TestCase) { + if _, ok := Tests[tc.Name()]; ok { + panic(fmt.Sprintf("TestCase %s already registered.", tc.Name())) + } + Tests[tc.Name()] = tc +} diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go new file mode 100644 index 000000000..e2beb30d5 --- /dev/null +++ b/test/iptables/iptables_test.go @@ -0,0 +1,427 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package iptables + +import ( + "context" + "errors" + "fmt" + "net" + "reflect" + "sync" + "testing" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/pkg/test/testutil" +) + +// singleTest runs a TestCase. Each test follows a pattern: +// - Create a container. +// - Get the container's IP. +// - Send the container our IP. +// - Start a new goroutine running the local action of the test. +// - Wait for both the container and local actions to finish. +// +// Container output is logged to $TEST_UNDECLARED_OUTPUTS_DIR if it exists, or +// to stderr. +func singleTest(t *testing.T, test TestCase) { + for _, tc := range []bool{false, true} { + subtest := "IPv4" + if tc { + subtest = "IPv6" + } + t.Run(subtest, func(t *testing.T) { + iptablesTest(t, test, tc) + }) + } +} + +func iptablesTest(t *testing.T, test TestCase, ipv6 bool) { + if _, ok := Tests[test.Name()]; !ok { + t.Fatalf("no test found with name %q. Has it been registered?", test.Name()) + } + + // Wait for the local and container goroutines to finish. + var wg sync.WaitGroup + defer wg.Wait() + + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + d := dockerutil.MakeContainer(ctx, t) + defer func() { + if logs, err := d.Logs(context.Background()); err != nil { + t.Logf("Failed to retrieve container logs.") + } else { + t.Logf("=== Container logs: ===\n%s", logs) + } + // Use a new context, as cleanup should run even when we + // timeout. + d.CleanUp(context.Background()) + }() + + // TODO(gvisor.dev/issue/170): Skipping IPv6 gVisor tests. + if ipv6 && dockerutil.Runtime() != "runc" { + t.Skip("gVisor ip6tables not yet implemented") + } + + // Create and start the container. + opts := dockerutil.RunOpts{ + Image: "iptables", + CapAdd: []string{"NET_ADMIN"}, + } + d.CopyFiles(&opts, "/runner", "test/iptables/runner/runner") + args := []string{"/runner/runner", "-name", test.Name()} + if ipv6 { + args = append(args, "-ipv6") + } + if err := d.Spawn(ctx, opts, args...); err != nil { + t.Fatalf("docker run failed: %v", err) + } + + // Get the container IP. + ip, err := d.FindIP(ctx, ipv6) + if err != nil { + t.Fatalf("failed to get container IP: %v", err) + } + + // Give the container our IP. + if err := sendIP(ip); err != nil { + t.Fatalf("failed to send IP to container: %v", err) + } + + // Run our side of the test. + errCh := make(chan error, 2) + wg.Add(1) + go func() { + defer wg.Done() + if err := test.LocalAction(ctx, ip, ipv6); err != nil && !errors.Is(err, context.Canceled) { + errCh <- fmt.Errorf("LocalAction failed: %v", err) + } else { + errCh <- nil + } + if test.LocalSufficient() { + errCh <- nil + } + }() + + // Run the container side. + wg.Add(1) + go func() { + defer wg.Done() + // Wait for the final statement. This structure has the side + // effect that all container logs will appear within the + // individual test context. + if _, err := d.WaitForOutput(ctx, TerminalStatement, TestTimeout); err != nil && !errors.Is(err, context.Canceled) { + errCh <- fmt.Errorf("ContainerAction failed: %v", err) + } else { + errCh <- nil + } + if test.ContainerSufficient() { + errCh <- nil + } + }() + + for i := 0; i < 2; i++ { + select { + case err := <-errCh: + if err != nil { + t.Fatal(err) + } + } + } +} + +func sendIP(ip net.IP) error { + contAddr := net.TCPAddr{ + IP: ip, + Port: IPExchangePort, + } + var conn *net.TCPConn + // The container may not be listening when we first connect, so retry + // upon error. + cb := func() error { + c, err := net.DialTCP("tcp", nil, &contAddr) + conn = c + return err + } + if err := testutil.Poll(cb, TestTimeout); err != nil { + return fmt.Errorf("timed out waiting to send IP, most recent error: %v", err) + } + if _, err := conn.Write([]byte{0}); err != nil { + return fmt.Errorf("error writing to container: %v", err) + } + return nil +} + +func TestFilterInputDropUDP(t *testing.T) { + singleTest(t, FilterInputDropUDP{}) +} + +func TestFilterInputDropUDPPort(t *testing.T) { + singleTest(t, FilterInputDropUDPPort{}) +} + +func TestFilterInputDropDifferentUDPPort(t *testing.T) { + singleTest(t, FilterInputDropDifferentUDPPort{}) +} + +func TestFilterInputDropAll(t *testing.T) { + singleTest(t, FilterInputDropAll{}) +} + +func TestFilterInputDropOnlyUDP(t *testing.T) { + singleTest(t, FilterInputDropOnlyUDP{}) +} + +func TestFilterInputDropTCPDestPort(t *testing.T) { + singleTest(t, FilterInputDropTCPDestPort{}) +} + +func TestFilterInputDropTCPSrcPort(t *testing.T) { + singleTest(t, FilterInputDropTCPSrcPort{}) +} + +func TestFilterInputCreateUserChain(t *testing.T) { + singleTest(t, FilterInputCreateUserChain{}) +} + +func TestFilterInputDefaultPolicyAccept(t *testing.T) { + singleTest(t, FilterInputDefaultPolicyAccept{}) +} + +func TestFilterInputDefaultPolicyDrop(t *testing.T) { + singleTest(t, FilterInputDefaultPolicyDrop{}) +} + +func TestFilterInputReturnUnderflow(t *testing.T) { + singleTest(t, FilterInputReturnUnderflow{}) +} + +func TestFilterOutputDropTCPDestPort(t *testing.T) { + singleTest(t, FilterOutputDropTCPDestPort{}) +} + +func TestFilterOutputDropTCPSrcPort(t *testing.T) { + singleTest(t, FilterOutputDropTCPSrcPort{}) +} + +func TestFilterOutputAcceptTCPOwner(t *testing.T) { + singleTest(t, FilterOutputAcceptTCPOwner{}) +} + +func TestFilterOutputDropTCPOwner(t *testing.T) { + singleTest(t, FilterOutputDropTCPOwner{}) +} + +func TestFilterOutputAcceptUDPOwner(t *testing.T) { + singleTest(t, FilterOutputAcceptUDPOwner{}) +} + +func TestFilterOutputDropUDPOwner(t *testing.T) { + singleTest(t, FilterOutputDropUDPOwner{}) +} + +func TestFilterOutputOwnerFail(t *testing.T) { + singleTest(t, FilterOutputOwnerFail{}) +} + +func TestFilterOutputAcceptGIDOwner(t *testing.T) { + singleTest(t, FilterOutputAcceptGIDOwner{}) +} + +func TestFilterOutputDropGIDOwner(t *testing.T) { + singleTest(t, FilterOutputDropGIDOwner{}) +} + +func TestFilterOutputInvertGIDOwner(t *testing.T) { + singleTest(t, FilterOutputInvertGIDOwner{}) +} + +func TestFilterOutputInvertUIDOwner(t *testing.T) { + singleTest(t, FilterOutputInvertUIDOwner{}) +} + +func TestFilterOutputInvertUIDAndGIDOwner(t *testing.T) { + singleTest(t, FilterOutputInvertUIDAndGIDOwner{}) +} + +func TestFilterOutputInterfaceAccept(t *testing.T) { + singleTest(t, FilterOutputInterfaceAccept{}) +} + +func TestFilterOutputInterfaceDrop(t *testing.T) { + singleTest(t, FilterOutputInterfaceDrop{}) +} + +func TestFilterOutputInterface(t *testing.T) { + singleTest(t, FilterOutputInterface{}) +} + +func TestFilterOutputInterfaceBeginsWith(t *testing.T) { + singleTest(t, FilterOutputInterfaceBeginsWith{}) +} + +func TestFilterOutputInterfaceInvertDrop(t *testing.T) { + singleTest(t, FilterOutputInterfaceInvertDrop{}) +} + +func TestFilterOutputInterfaceInvertAccept(t *testing.T) { + singleTest(t, FilterOutputInterfaceInvertAccept{}) +} + +func TestJumpSerialize(t *testing.T) { + singleTest(t, FilterInputSerializeJump{}) +} + +func TestJumpBasic(t *testing.T) { + singleTest(t, FilterInputJumpBasic{}) +} + +func TestJumpReturn(t *testing.T) { + singleTest(t, FilterInputJumpReturn{}) +} + +func TestJumpReturnDrop(t *testing.T) { + singleTest(t, FilterInputJumpReturnDrop{}) +} + +func TestJumpBuiltin(t *testing.T) { + singleTest(t, FilterInputJumpBuiltin{}) +} + +func TestJumpTwice(t *testing.T) { + singleTest(t, FilterInputJumpTwice{}) +} + +func TestInputDestination(t *testing.T) { + singleTest(t, FilterInputDestination{}) +} + +func TestInputInvertDestination(t *testing.T) { + singleTest(t, FilterInputInvertDestination{}) +} + +func TestOutputDestination(t *testing.T) { + singleTest(t, FilterOutputDestination{}) +} + +func TestOutputInvertDestination(t *testing.T) { + singleTest(t, FilterOutputInvertDestination{}) +} + +func TestNATPreRedirectUDPPort(t *testing.T) { + singleTest(t, NATPreRedirectUDPPort{}) +} + +func TestNATPreRedirectTCPPort(t *testing.T) { + singleTest(t, NATPreRedirectTCPPort{}) +} + +func TestNATPreRedirectTCPOutgoing(t *testing.T) { + singleTest(t, NATPreRedirectTCPOutgoing{}) +} + +func TestNATOutRedirectTCPIncoming(t *testing.T) { + singleTest(t, NATOutRedirectTCPIncoming{}) +} +func TestNATOutRedirectUDPPort(t *testing.T) { + singleTest(t, NATOutRedirectUDPPort{}) +} + +func TestNATOutRedirectTCPPort(t *testing.T) { + singleTest(t, NATOutRedirectTCPPort{}) +} + +func TestNATDropUDP(t *testing.T) { + singleTest(t, NATDropUDP{}) +} + +func TestNATAcceptAll(t *testing.T) { + singleTest(t, NATAcceptAll{}) +} + +func TestNATOutRedirectIP(t *testing.T) { + singleTest(t, NATOutRedirectIP{}) +} + +func TestNATOutDontRedirectIP(t *testing.T) { + singleTest(t, NATOutDontRedirectIP{}) +} + +func TestNATOutRedirectInvert(t *testing.T) { + singleTest(t, NATOutRedirectInvert{}) +} + +func TestNATPreRedirectIP(t *testing.T) { + singleTest(t, NATPreRedirectIP{}) +} + +func TestNATPreDontRedirectIP(t *testing.T) { + singleTest(t, NATPreDontRedirectIP{}) +} + +func TestNATPreRedirectInvert(t *testing.T) { + singleTest(t, NATPreRedirectInvert{}) +} + +func TestNATRedirectRequiresProtocol(t *testing.T) { + singleTest(t, NATRedirectRequiresProtocol{}) +} + +func TestNATLoopbackSkipsPrerouting(t *testing.T) { + singleTest(t, NATLoopbackSkipsPrerouting{}) +} + +func TestInputSource(t *testing.T) { + singleTest(t, FilterInputSource{}) +} + +func TestInputInvertSource(t *testing.T) { + singleTest(t, FilterInputInvertSource{}) +} + +func TestFilterAddrs(t *testing.T) { + tcs := []struct { + ipv6 bool + addrs []string + want []string + }{ + { + ipv6: false, + addrs: []string{"192.168.0.1", "192.168.0.2/24", "::1", "::2/128"}, + want: []string{"192.168.0.1", "192.168.0.2"}, + }, + { + ipv6: true, + addrs: []string{"192.168.0.1", "192.168.0.2/24", "::1", "::2/128"}, + want: []string{"::1", "::2"}, + }, + } + + for _, tc := range tcs { + if got := filterAddrs(tc.addrs, tc.ipv6); !reflect.DeepEqual(got, tc.want) { + t.Errorf("%v with IPv6 %t: got %v, but wanted %v", tc.addrs, tc.ipv6, got, tc.want) + } + } +} + +func TestNATPreOriginalDst(t *testing.T) { + singleTest(t, NATPreOriginalDst{}) +} + +func TestNATOutOriginalDst(t *testing.T) { + singleTest(t, NATOutOriginalDst{}) +} diff --git a/test/iptables/iptables_unsafe.go b/test/iptables/iptables_unsafe.go new file mode 100644 index 000000000..bd85a8fea --- /dev/null +++ b/test/iptables/iptables_unsafe.go @@ -0,0 +1,63 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package iptables + +import ( + "fmt" + "syscall" + "unsafe" +) + +type originalDstError struct { + errno syscall.Errno +} + +func (e originalDstError) Error() string { + return fmt.Sprintf("errno (%d) when calling getsockopt(SO_ORIGINAL_DST): %v", int(e.errno), e.errno.Error()) +} + +// SO_ORIGINAL_DST gets the original destination of a redirected packet via +// getsockopt. +const SO_ORIGINAL_DST = 80 + +func originalDestination4(connfd int) (syscall.RawSockaddrInet4, error) { + var addr syscall.RawSockaddrInet4 + var addrLen uint32 = syscall.SizeofSockaddrInet4 + if errno := originalDestination(connfd, syscall.SOL_IP, unsafe.Pointer(&addr), &addrLen); errno != 0 { + return syscall.RawSockaddrInet4{}, originalDstError{errno} + } + return addr, nil +} + +func originalDestination6(connfd int) (syscall.RawSockaddrInet6, error) { + var addr syscall.RawSockaddrInet6 + var addrLen uint32 = syscall.SizeofSockaddrInet6 + if errno := originalDestination(connfd, syscall.SOL_IPV6, unsafe.Pointer(&addr), &addrLen); errno != 0 { + return syscall.RawSockaddrInet6{}, originalDstError{errno} + } + return addr, nil +} + +func originalDestination(connfd int, level uintptr, optval unsafe.Pointer, optlen *uint32) syscall.Errno { + _, _, errno := syscall.Syscall6( + syscall.SYS_GETSOCKOPT, + uintptr(connfd), + level, + SO_ORIGINAL_DST, + uintptr(optval), + uintptr(unsafe.Pointer(optlen)), + 0) + return errno +} diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go new file mode 100644 index 000000000..a6ec5cca3 --- /dev/null +++ b/test/iptables/iptables_util.go @@ -0,0 +1,282 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package iptables + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "net" + "os/exec" + "strings" + "time" + + "gvisor.dev/gvisor/pkg/test/testutil" +) + +// filterTable calls `ip{6}tables -t filter` with the given args. +func filterTable(ipv6 bool, args ...string) error { + return tableCmd(ipv6, "filter", args) +} + +// natTable calls `ip{6}tables -t nat` with the given args. +func natTable(ipv6 bool, args ...string) error { + return tableCmd(ipv6, "nat", args) +} + +func tableCmd(ipv6 bool, table string, args []string) error { + args = append([]string{"-t", table}, args...) + binary := "iptables" + if ipv6 { + binary = "ip6tables" + } + cmd := exec.Command(binary, args...) + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("error running iptables with args %v\nerror: %v\noutput: %s", args, err, string(out)) + } + return nil +} + +// filterTableRules is like filterTable, but runs multiple iptables commands. +func filterTableRules(ipv6 bool, argsList [][]string) error { + return tableRules(ipv6, "filter", argsList) +} + +// natTableRules is like natTable, but runs multiple iptables commands. +func natTableRules(ipv6 bool, argsList [][]string) error { + return tableRules(ipv6, "nat", argsList) +} + +func tableRules(ipv6 bool, table string, argsList [][]string) error { + for _, args := range argsList { + if err := tableCmd(ipv6, table, args); err != nil { + return err + } + } + return nil +} + +// listenUDP listens on a UDP port and returns the value of net.Conn.Read() for +// the first read on that port. +func listenUDP(ctx context.Context, port int) error { + localAddr := net.UDPAddr{ + Port: port, + } + conn, err := net.ListenUDP("udp", &localAddr) + if err != nil { + return err + } + defer conn.Close() + + ch := make(chan error) + go func() { + _, err = conn.Read([]byte{0}) + ch <- err + }() + + select { + case err := <-ch: + return err + case <-ctx.Done(): + return ctx.Err() + } +} + +// sendUDPLoop sends 1 byte UDP packets repeatedly to the IP and port specified +// over a duration. +func sendUDPLoop(ctx context.Context, ip net.IP, port int) error { + remote := net.UDPAddr{ + IP: ip, + Port: port, + } + conn, err := net.DialUDP("udp", nil, &remote) + if err != nil { + return err + } + defer conn.Close() + + for { + // This may return an error (connection refused) if the remote + // hasn't started listening yet or they're dropping our + // packets. So we ignore Write errors and depend on the remote + // to report a failure if it doesn't get a packet it needs. + conn.Write([]byte{0}) + select { + case <-ctx.Done(): + // Being cancelled or timing out isn't an error, as we + // cannot tell with UDP whether we succeeded. + return nil + // Continue looping. + case <-time.After(200 * time.Millisecond): + } + } +} + +// listenTCP listens for connections on a TCP port. +func listenTCP(ctx context.Context, port int) error { + localAddr := net.TCPAddr{ + Port: port, + } + + // Starts listening on port. + lConn, err := net.ListenTCP("tcp", &localAddr) + if err != nil { + return err + } + defer lConn.Close() + + // Accept connections on port. + ch := make(chan error) + go func() { + conn, err := lConn.AcceptTCP() + ch <- err + conn.Close() + }() + + select { + case err := <-ch: + return err + case <-ctx.Done(): + return fmt.Errorf("timed out waiting for a connection at %#v: %w", localAddr, ctx.Err()) + } +} + +// connectTCP connects to the given IP and port from an ephemeral local address. +func connectTCP(ctx context.Context, ip net.IP, port int) error { + contAddr := net.TCPAddr{ + IP: ip, + Port: port, + } + // The container may not be listening when we first connect, so retry + // upon error. + callback := func() error { + var d net.Dialer + conn, err := d.DialContext(ctx, "tcp", contAddr.String()) + if conn != nil { + conn.Close() + } + return err + } + if err := testutil.PollContext(ctx, callback); err != nil { + return fmt.Errorf("timed out waiting to connect IP on port %v, most recent error: %v", port, err) + } + + return nil +} + +// localAddrs returns a list of local network interface addresses. When ipv6 is +// true, only IPv6 addresses are returned. Otherwise only IPv4 addresses are +// returned. +func localAddrs(ipv6 bool) ([]string, error) { + addrs, err := net.InterfaceAddrs() + if err != nil { + return nil, err + } + addrStrs := make([]string, 0, len(addrs)) + for _, addr := range addrs { + // Add only IPv4 or only IPv6 addresses. + parts := strings.Split(addr.String(), "/") + if len(parts) != 2 { + return nil, fmt.Errorf("bad interface address: %q", addr.String()) + } + if isIPv6 := net.ParseIP(parts[0]).To4() == nil; isIPv6 == ipv6 { + addrStrs = append(addrStrs, addr.String()) + } + } + return filterAddrs(addrStrs, ipv6), nil +} + +func filterAddrs(addrs []string, ipv6 bool) []string { + addrStrs := make([]string, 0, len(addrs)) + for _, addr := range addrs { + // Add only IPv4 or only IPv6 addresses. + parts := strings.Split(addr, "/") + if isIPv6 := net.ParseIP(parts[0]).To4() == nil; isIPv6 == ipv6 { + addrStrs = append(addrStrs, parts[0]) + } + } + return addrStrs +} + +// getInterfaceName returns the name of the interface other than loopback. +func getInterfaceName() (string, bool) { + iface, ok := getNonLoopbackInterface() + if !ok { + return "", false + } + return iface.Name, true +} + +func getInterfaceAddrs(ipv6 bool) ([]net.IP, error) { + iface, ok := getNonLoopbackInterface() + if !ok { + return nil, errors.New("no non-loopback interface found") + } + addrs, err := iface.Addrs() + if err != nil { + return nil, err + } + + // Get only IPv4 or IPv6 addresses. + ips := make([]net.IP, 0, len(addrs)) + for _, addr := range addrs { + parts := strings.Split(addr.String(), "/") + var ip net.IP + // To16() returns IPv4 addresses as IPv4-mapped IPv6 addresses. + // So we check whether To4() returns nil to test whether the + // address is v4 or v6. + if v4 := net.ParseIP(parts[0]).To4(); ipv6 && v4 == nil { + ip = net.ParseIP(parts[0]).To16() + } else { + ip = v4 + } + if ip != nil { + ips = append(ips, ip) + } + } + return ips, nil +} + +func getNonLoopbackInterface() (net.Interface, bool) { + if interfaces, err := net.Interfaces(); err == nil { + for _, intf := range interfaces { + if intf.Name != "lo" { + return intf, true + } + } + } + return net.Interface{}, false +} + +func htons(x uint16) uint16 { + buf := make([]byte, 2) + binary.BigEndian.PutUint16(buf, x) + return binary.LittleEndian.Uint16(buf) +} + +func localIP(ipv6 bool) string { + if ipv6 { + return "::1" + } + return "127.0.0.1" +} + +func nowhereIP(ipv6 bool) string { + if ipv6 { + return "2001:db8::1" + } + return "192.0.2.1" +} diff --git a/test/iptables/nat.go b/test/iptables/nat.go new file mode 100644 index 000000000..dd9a18339 --- /dev/null +++ b/test/iptables/nat.go @@ -0,0 +1,657 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package iptables + +import ( + "context" + "errors" + "fmt" + "net" + "syscall" +) + +const redirectPort = 42 + +func init() { + RegisterTestCase(NATPreRedirectUDPPort{}) + RegisterTestCase(NATPreRedirectTCPPort{}) + RegisterTestCase(NATPreRedirectTCPOutgoing{}) + RegisterTestCase(NATOutRedirectTCPIncoming{}) + RegisterTestCase(NATOutRedirectUDPPort{}) + RegisterTestCase(NATOutRedirectTCPPort{}) + RegisterTestCase(NATDropUDP{}) + RegisterTestCase(NATAcceptAll{}) + RegisterTestCase(NATPreRedirectIP{}) + RegisterTestCase(NATPreDontRedirectIP{}) + RegisterTestCase(NATPreRedirectInvert{}) + RegisterTestCase(NATOutRedirectIP{}) + RegisterTestCase(NATOutDontRedirectIP{}) + RegisterTestCase(NATOutRedirectInvert{}) + RegisterTestCase(NATRedirectRequiresProtocol{}) + RegisterTestCase(NATLoopbackSkipsPrerouting{}) + RegisterTestCase(NATPreOriginalDst{}) + RegisterTestCase(NATOutOriginalDst{}) +} + +// NATPreRedirectUDPPort tests that packets are redirected to different port. +type NATPreRedirectUDPPort struct{ containerCase } + +// Name implements TestCase.Name. +func (NATPreRedirectUDPPort) Name() string { + return "NATPreRedirectUDPPort" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATPreRedirectUDPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { + return err + } + + if err := listenUDP(ctx, redirectPort); err != nil { + return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", redirectPort, err) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (NATPreRedirectUDPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) +} + +// NATPreRedirectTCPPort tests that connections are redirected on specified ports. +type NATPreRedirectTCPPort struct{ baseCase } + +// Name implements TestCase.Name. +func (NATPreRedirectTCPPort) Name() string { + return "NATPreRedirectTCPPort" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATPreRedirectTCPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { + return err + } + + // Listen for TCP packets on redirect port. + return listenTCP(ctx, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (NATPreRedirectTCPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, dropPort) +} + +// NATPreRedirectTCPOutgoing verifies that outgoing TCP connections aren't +// affected by PREROUTING connection tracking. +type NATPreRedirectTCPOutgoing struct{ baseCase } + +// Name implements TestCase.Name. +func (NATPreRedirectTCPOutgoing) Name() string { + return "NATPreRedirectTCPOutgoing" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATPreRedirectTCPOutgoing) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Redirect all incoming TCP traffic to a closed port. + if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { + return err + } + + // Establish a connection to the host process. + return connectTCP(ctx, ip, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (NATPreRedirectTCPOutgoing) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return listenTCP(ctx, acceptPort) +} + +// NATOutRedirectTCPIncoming verifies that incoming TCP connections aren't +// affected by OUTPUT connection tracking. +type NATOutRedirectTCPIncoming struct{ baseCase } + +// Name implements TestCase.Name. +func (NATOutRedirectTCPIncoming) Name() string { + return "NATOutRedirectTCPIncoming" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATOutRedirectTCPIncoming) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Redirect all outgoing TCP traffic to a closed port. + if err := natTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { + return err + } + + // Establish a connection to the host process. + return listenTCP(ctx, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (NATOutRedirectTCPIncoming) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, acceptPort) +} + +// NATOutRedirectUDPPort tests that packets are redirected to different port. +type NATOutRedirectUDPPort struct{ containerCase } + +// Name implements TestCase.Name. +func (NATOutRedirectUDPPort) Name() string { + return "NATOutRedirectUDPPort" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATOutRedirectUDPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return loopbackTest(ctx, ipv6, net.ParseIP(nowhereIP(ipv6)), "-A", "OUTPUT", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)) +} + +// LocalAction implements TestCase.LocalAction. +func (NATOutRedirectUDPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // No-op. + return nil +} + +// NATDropUDP tests that packets are not received in ports other than redirect +// port. +type NATDropUDP struct{ containerCase } + +// Name implements TestCase.Name. +func (NATDropUDP) Name() string { + return "NATDropUDP" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATDropUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { + return err + } + + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, acceptPort); err == nil { + return fmt.Errorf("packets on port %d should have been redirected to port %d", acceptPort, redirectPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (NATDropUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) +} + +// NATAcceptAll tests that all UDP packets are accepted. +type NATAcceptAll struct{ containerCase } + +// Name implements TestCase.Name. +func (NATAcceptAll) Name() string { + return "NATAcceptAll" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATAcceptAll) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-j", "ACCEPT"); err != nil { + return err + } + + if err := listenUDP(ctx, acceptPort); err != nil { + return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", acceptPort, err) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (NATAcceptAll) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) +} + +// NATOutRedirectIP uses iptables to select packets based on destination IP and +// redirects them. +type NATOutRedirectIP struct{ baseCase } + +// Name implements TestCase.Name. +func (NATOutRedirectIP) Name() string { + return "NATOutRedirectIP" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATOutRedirectIP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Redirect OUTPUT packets to a listening localhost port. + return loopbackTest(ctx, ipv6, net.ParseIP(nowhereIP(ipv6)), + "-A", "OUTPUT", + "-d", nowhereIP(ipv6), + "-p", "udp", + "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)) +} + +// LocalAction implements TestCase.LocalAction. +func (NATOutRedirectIP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // No-op. + return nil +} + +// NATOutDontRedirectIP tests that iptables matching with "-d" does not match +// packets it shouldn't. +type NATOutDontRedirectIP struct{ localCase } + +// Name implements TestCase.Name. +func (NATOutDontRedirectIP) Name() string { + return "NATOutDontRedirectIP" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATOutDontRedirectIP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "OUTPUT", "-d", localIP(ipv6), "-p", "udp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil { + return err + } + return sendUDPLoop(ctx, ip, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (NATOutDontRedirectIP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return listenUDP(ctx, acceptPort) +} + +// NATOutRedirectInvert tests that iptables can match with "! -d". +type NATOutRedirectInvert struct{ baseCase } + +// Name implements TestCase.Name. +func (NATOutRedirectInvert) Name() string { + return "NATOutRedirectInvert" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATOutRedirectInvert) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Redirect OUTPUT packets to a listening localhost port. + dest := "192.0.2.2" + if ipv6 { + dest = "2001:db8::2" + } + return loopbackTest(ctx, ipv6, net.ParseIP(nowhereIP(ipv6)), + "-A", "OUTPUT", + "!", "-d", dest, + "-p", "udp", + "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)) +} + +// LocalAction implements TestCase.LocalAction. +func (NATOutRedirectInvert) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // No-op. + return nil +} + +// NATPreRedirectIP tests that we can use iptables to select packets based on +// destination IP and redirect them. +type NATPreRedirectIP struct{ containerCase } + +// Name implements TestCase.Name. +func (NATPreRedirectIP) Name() string { + return "NATPreRedirectIP" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATPreRedirectIP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + addrs, err := localAddrs(ipv6) + if err != nil { + return err + } + + var rules [][]string + for _, addr := range addrs { + rules = append(rules, []string{"-A", "PREROUTING", "-p", "udp", "-d", addr, "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)}) + } + if err := natTableRules(ipv6, rules); err != nil { + return err + } + return listenUDP(ctx, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (NATPreRedirectIP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, dropPort) +} + +// NATPreDontRedirectIP tests that iptables matching with "-d" does not match +// packets it shouldn't. +type NATPreDontRedirectIP struct{ containerCase } + +// Name implements TestCase.Name. +func (NATPreDontRedirectIP) Name() string { + return "NATPreDontRedirectIP" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATPreDontRedirectIP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-d", localIP(ipv6), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { + return err + } + return listenUDP(ctx, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (NATPreDontRedirectIP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) +} + +// NATPreRedirectInvert tests that iptables can match with "! -d". +type NATPreRedirectInvert struct{ containerCase } + +// Name implements TestCase.Name. +func (NATPreRedirectInvert) Name() string { + return "NATPreRedirectInvert" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATPreRedirectInvert) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "!", "-d", localIP(ipv6), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { + return err + } + return listenUDP(ctx, acceptPort) +} + +// LocalAction implements TestCase.LocalAction. +func (NATPreRedirectInvert) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, dropPort) +} + +// NATRedirectRequiresProtocol tests that use of the --to-ports flag requires a +// protocol to be specified with -p. +type NATRedirectRequiresProtocol struct{ baseCase } + +// Name implements TestCase.Name. +func (NATRedirectRequiresProtocol) Name() string { + return "NATRedirectRequiresProtocol" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATRedirectRequiresProtocol) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "PREROUTING", "-d", localIP(ipv6), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err == nil { + return errors.New("expected an error using REDIRECT --to-ports without a protocol") + } + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (NATRedirectRequiresProtocol) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // No-op. + return nil +} + +// NATOutRedirectTCPPort tests that connections are redirected on specified ports. +type NATOutRedirectTCPPort struct{ baseCase } + +// Name implements TestCase.Name. +func (NATOutRedirectTCPPort) Name() string { + return "NATOutRedirectTCPPort" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATOutRedirectTCPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { + return err + } + + localAddr := net.TCPAddr{ + IP: net.ParseIP(localIP(ipv6)), + Port: acceptPort, + } + + // Starts listening on port. + lConn, err := net.ListenTCP("tcp", &localAddr) + if err != nil { + return err + } + defer lConn.Close() + + // Accept connections on port. + if err := connectTCP(ctx, ip, dropPort); err != nil { + return err + } + + conn, err := lConn.AcceptTCP() + if err != nil { + return err + } + conn.Close() + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (NATOutRedirectTCPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return nil +} + +// NATLoopbackSkipsPrerouting tests that packets sent via loopback aren't +// affected by PREROUTING rules. +type NATLoopbackSkipsPrerouting struct{ baseCase } + +// Name implements TestCase.Name. +func (NATLoopbackSkipsPrerouting) Name() string { + return "NATLoopbackSkipsPrerouting" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATLoopbackSkipsPrerouting) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Redirect anything sent to localhost to an unused port. + dest := []byte{127, 0, 0, 1} + if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil { + return err + } + + // Establish a connection via localhost. If the PREROUTING rule did apply to + // loopback traffic, the connection would fail. + sendCh := make(chan error) + go func() { + sendCh <- connectTCP(ctx, dest, acceptPort) + }() + + if err := listenTCP(ctx, acceptPort); err != nil { + return err + } + return <-sendCh +} + +// LocalAction implements TestCase.LocalAction. +func (NATLoopbackSkipsPrerouting) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // No-op. + return nil +} + +// NATPreOriginalDst tests that SO_ORIGINAL_DST returns the pre-NAT destination +// of PREROUTING NATted packets. +type NATPreOriginalDst struct{ baseCase } + +// Name implements TestCase.Name. +func (NATPreOriginalDst) Name() string { + return "NATPreOriginalDst" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATPreOriginalDst) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Redirect incoming TCP connections to acceptPort. + if err := natTable(ipv6, "-A", "PREROUTING", + "-p", "tcp", + "--destination-port", fmt.Sprintf("%d", dropPort), + "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)); err != nil { + return err + } + + addrs, err := getInterfaceAddrs(ipv6) + if err != nil { + return err + } + return listenForRedirectedConn(ctx, ipv6, addrs) +} + +// LocalAction implements TestCase.LocalAction. +func (NATPreOriginalDst) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, dropPort) +} + +// NATOutOriginalDst tests that SO_ORIGINAL_DST returns the pre-NAT destination +// of OUTBOUND NATted packets. +type NATOutOriginalDst struct{ baseCase } + +// Name implements TestCase.Name. +func (NATOutOriginalDst) Name() string { + return "NATOutOriginalDst" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATOutOriginalDst) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // Redirect incoming TCP connections to acceptPort. + if err := natTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)); err != nil { + return err + } + + connCh := make(chan error) + go func() { + connCh <- connectTCP(ctx, ip, dropPort) + }() + + if err := listenForRedirectedConn(ctx, ipv6, []net.IP{ip}); err != nil { + return err + } + return <-connCh +} + +// LocalAction implements TestCase.LocalAction. +func (NATOutOriginalDst) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + // No-op. + return nil +} + +func listenForRedirectedConn(ctx context.Context, ipv6 bool, originalDsts []net.IP) error { + // The net package doesn't give guarantee access to the connection's + // underlying FD, and thus we cannot call getsockopt. We have to use + // traditional syscalls for SO_ORIGINAL_DST. + + // Create the listening socket, bind, listen, and accept. + family := syscall.AF_INET + if ipv6 { + family = syscall.AF_INET6 + } + sockfd, err := syscall.Socket(family, syscall.SOCK_STREAM, 0) + if err != nil { + return err + } + defer syscall.Close(sockfd) + + var bindAddr syscall.Sockaddr + if ipv6 { + bindAddr = &syscall.SockaddrInet6{ + Port: acceptPort, + Addr: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, // in6addr_any + } + } else { + bindAddr = &syscall.SockaddrInet4{ + Port: acceptPort, + Addr: [4]byte{0, 0, 0, 0}, // INADDR_ANY + } + } + if err := syscall.Bind(sockfd, bindAddr); err != nil { + return err + } + + if err := syscall.Listen(sockfd, 1); err != nil { + return err + } + + // Block on accept() in another goroutine. + connCh := make(chan int) + errCh := make(chan error) + go func() { + connFD, _, err := syscall.Accept(sockfd) + if err != nil { + errCh <- err + } + connCh <- connFD + }() + + // Wait for accept() to return or for the context to finish. + var connFD int + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errCh: + return err + case connFD = <-connCh: + } + defer syscall.Close(connFD) + + // Verify that, despite listening on acceptPort, SO_ORIGINAL_DST + // indicates the packet was sent to originalDst:dropPort. + if ipv6 { + got, err := originalDestination6(connFD) + if err != nil { + return err + } + // The original destination could be any of our IPs. + for _, dst := range originalDsts { + want := syscall.RawSockaddrInet6{ + Family: syscall.AF_INET6, + Port: htons(dropPort), + } + copy(want.Addr[:], dst.To16()) + if got == want { + return nil + } + } + return fmt.Errorf("SO_ORIGINAL_DST returned %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, originalDsts) + } else { + got, err := originalDestination4(connFD) + if err != nil { + return err + } + // The original destination could be any of our IPs. + for _, dst := range originalDsts { + want := syscall.RawSockaddrInet4{ + Family: syscall.AF_INET, + Port: htons(dropPort), + } + copy(want.Addr[:], dst.To4()) + if got == want { + return nil + } + } + return fmt.Errorf("SO_ORIGINAL_DST returned %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, originalDsts) + } +} + +// loopbackTests runs an iptables rule and ensures that packets sent to +// dest:dropPort are received by localhost:acceptPort. +func loopbackTest(ctx context.Context, ipv6 bool, dest net.IP, args ...string) error { + if err := natTable(ipv6, args...); err != nil { + return err + } + sendCh := make(chan error, 1) + listenCh := make(chan error, 1) + go func() { + sendCh <- sendUDPLoop(ctx, dest, dropPort) + }() + go func() { + listenCh <- listenUDP(ctx, acceptPort) + }() + select { + case err := <-listenCh: + return err + case err := <-sendCh: + return err + } +} diff --git a/test/iptables/runner/BUILD b/test/iptables/runner/BUILD new file mode 100644 index 000000000..24504a1b9 --- /dev/null +++ b/test/iptables/runner/BUILD @@ -0,0 +1,12 @@ +load("//tools:defs.bzl", "go_binary") + +package(licenses = ["notice"]) + +go_binary( + name = "runner", + testonly = 1, + srcs = ["main.go"], + pure = True, + visibility = ["//test/iptables:__subpackages__"], + deps = ["//test/iptables"], +) diff --git a/test/iptables/runner/main.go b/test/iptables/runner/main.go new file mode 100644 index 000000000..9ae2d1b4d --- /dev/null +++ b/test/iptables/runner/main.go @@ -0,0 +1,79 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package main runs iptables tests from within a docker container. +package main + +import ( + "context" + "flag" + "fmt" + "log" + "net" + + "gvisor.dev/gvisor/test/iptables" +) + +var ( + name = flag.String("name", "", "name of the test to run") + ipv6 = flag.Bool("ipv6", false, "whether the test utilizes ip6tables") +) + +func main() { + flag.Parse() + + // Find out which test we're running. + test, ok := iptables.Tests[*name] + if !ok { + log.Fatalf("No test found named %q", *name) + } + log.Printf("Running test %q", *name) + + // Get the IP of the local process. + ip, err := getIP() + if err != nil { + log.Fatal(err) + } + + // Run the test. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := test.ContainerAction(ctx, ip, *ipv6); err != nil { + log.Fatalf("Failed running test %q: %v", *name, err) + } + + // Emit the final line. + log.Printf("%s", iptables.TerminalStatement) +} + +// getIP listens for a connection from the local process and returns the source +// IP of that connection. +func getIP() (net.IP, error) { + localAddr := net.TCPAddr{ + Port: iptables.IPExchangePort, + } + listener, err := net.ListenTCP("tcp", &localAddr) + if err != nil { + return net.IP{}, fmt.Errorf("failed listening for IP: %v", err) + } + defer listener.Close() + conn, err := listener.AcceptTCP() + if err != nil { + return net.IP{}, fmt.Errorf("failed accepting IP: %v", err) + } + defer conn.Close() + log.Printf("Connected to %v", conn.RemoteAddr()) + + return conn.RemoteAddr().(*net.TCPAddr).IP, nil +} diff --git a/test/packetdrill/BUILD b/test/packetdrill/BUILD new file mode 100644 index 000000000..49642f282 --- /dev/null +++ b/test/packetdrill/BUILD @@ -0,0 +1,45 @@ +load("//tools:defs.bzl", "bzl_library") +load("//test/packetdrill:defs.bzl", "packetdrill_test") + +package(licenses = ["notice"]) + +packetdrill_test( + name = "packetdrill_sanity_test", + scripts = ["sanity_test.pkt"], +) + +packetdrill_test( + name = "accept_ack_drop_test", + scripts = ["accept_ack_drop.pkt"], +) + +packetdrill_test( + name = "fin_wait2_timeout_test", + scripts = ["fin_wait2_timeout.pkt"], +) + +packetdrill_test( + name = "listen_close_before_handshake_complete_test", + scripts = ["listen_close_before_handshake_complete.pkt"], +) + +packetdrill_test( + name = "no_rst_to_rst_test", + scripts = ["no_rst_to_rst.pkt"], +) + +packetdrill_test( + name = "tcp_defer_accept_test", + scripts = ["tcp_defer_accept.pkt"], +) + +packetdrill_test( + name = "tcp_defer_accept_timeout_test", + scripts = ["tcp_defer_accept_timeout.pkt"], +) + +bzl_library( + name = "defs_bzl", + srcs = ["defs.bzl"], + visibility = ["//visibility:private"], +) diff --git a/test/packetdrill/accept_ack_drop.pkt b/test/packetdrill/accept_ack_drop.pkt new file mode 100644 index 000000000..76e638fd4 --- /dev/null +++ b/test/packetdrill/accept_ack_drop.pkt @@ -0,0 +1,27 @@ +// Test that the accept works if the final ACK is dropped and an ack with data +// follows the dropped ack. + +0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 ++0 bind(3, ..., ...) = 0 + +// Set backlog to 1 so that we can easily test. ++0 listen(3, 1) = 0 + +// Establish a connection without timestamps. ++0.0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> ++0.0 > S. 0:0(0) ack 1 <...> + ++0.0 < . 1:5(4) ack 1 win 257 ++0.0 > . 1:1(0) ack 5 <...> + +// This should cause connection to transition to connected state. ++0.000 accept(3, ..., ...) = 4 ++0.000 fcntl(4, F_SETFL, O_RDWR|O_NONBLOCK) = 0 + +// Now read the data and we should get 4 bytes. ++0.000 read(4,..., 4) = 4 ++0.000 close(4) = 0 + ++0.0 > F. 1:1(0) ack 5 <...> ++0.0 < F. 5:5(0) ack 2 win 257 ++0.01 > . 2:2(0) ack 6 <...>
\ No newline at end of file diff --git a/test/packetdrill/defs.bzl b/test/packetdrill/defs.bzl new file mode 100644 index 000000000..fc28ce9ba --- /dev/null +++ b/test/packetdrill/defs.bzl @@ -0,0 +1,91 @@ +"""Defines a rule for packetdrill test targets.""" + +def _packetdrill_test_impl(ctx): + test_runner = ctx.executable._test_runner + runner = ctx.actions.declare_file("%s-runner" % ctx.label.name) + + script_paths = [] + for script in ctx.files.scripts: + script_paths.append(script.short_path) + runner_content = "\n".join([ + "#!/bin/bash", + # This test will run part in a distinct user namespace. This can cause + # permission problems, because all runfiles may not be owned by the + # current user, and no other users will be mapped in that namespace. + # Make sure that everything is readable here. + "find . -type f -exec chmod a+rx {} \\;", + "find . -type d -exec chmod a+rx {} \\;", + "%s %s --init_script %s $@ -- %s\n" % ( + test_runner.short_path, + " ".join(ctx.attr.flags), + ctx.files._init_script[0].short_path, + " ".join(script_paths), + ), + ]) + ctx.actions.write(runner, runner_content, is_executable = True) + + transitive_files = depset() + if hasattr(ctx.attr._test_runner, "data_runfiles"): + transitive_files = ctx.attr._test_runner.data_runfiles.files + runfiles = ctx.runfiles( + files = [test_runner] + ctx.files._init_script + ctx.files.scripts, + transitive_files = transitive_files, + collect_default = True, + collect_data = True, + ) + return [DefaultInfo(executable = runner, runfiles = runfiles)] + +_packetdrill_test = rule( + attrs = { + "_test_runner": attr.label( + executable = True, + cfg = "host", + allow_files = True, + default = "packetdrill_test.sh", + ), + "_init_script": attr.label( + allow_single_file = True, + default = "packetdrill_setup.sh", + ), + "flags": attr.string_list( + mandatory = False, + default = [], + ), + "scripts": attr.label_list( + mandatory = True, + allow_files = True, + ), + }, + test = True, + implementation = _packetdrill_test_impl, +) + +PACKETDRILL_TAGS = [ + "local", + "manual", + "packetdrill", +] + +def packetdrill_linux_test(name, **kwargs): + if "tags" not in kwargs: + kwargs["tags"] = PACKETDRILL_TAGS + _packetdrill_test( + name = name, + flags = ["--dut_platform", "linux"], + **kwargs + ) + +def packetdrill_netstack_test(name, **kwargs): + if "tags" not in kwargs: + kwargs["tags"] = PACKETDRILL_TAGS + _packetdrill_test( + name = name, + # This is the default runtime unless + # "--test_arg=--runtime=OTHER_RUNTIME" is used to override the value. + flags = ["--dut_platform", "netstack", "--runtime", "runsc-d"], + **kwargs + ) + +def packetdrill_test(name, **kwargs): + packetdrill_linux_test(name + "_linux_test", **kwargs) + packetdrill_netstack_test(name + "_netstack_test", **kwargs) diff --git a/test/packetdrill/fin_wait2_timeout.pkt b/test/packetdrill/fin_wait2_timeout.pkt new file mode 100644 index 000000000..93ab08575 --- /dev/null +++ b/test/packetdrill/fin_wait2_timeout.pkt @@ -0,0 +1,23 @@ +// Test that a socket in FIN_WAIT_2 eventually times out and a subsequent +// packet generates a RST. + +0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 ++0 bind(3, ..., ...) = 0 + ++0 listen(3, 1) = 0 + +// Establish a connection without timestamps. ++0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> ++0 > S. 0:0(0) ack 1 <...> ++0 < P. 1:1(0) ack 1 win 257 + ++0.100 accept(3, ..., ...) = 4 +// set FIN_WAIT2 timeout to 1 seconds. ++0.100 setsockopt(4, SOL_TCP, TCP_LINGER2, [1], 4) = 0 ++0 close(4) = 0 + ++0 > F. 1:1(0) ack 1 <...> ++0 < . 1:1(0) ack 2 win 257 + ++2 < . 1:1(0) ack 2 win 257 ++0 > R 2:2(0) win 0 diff --git a/test/packetdrill/listen_close_before_handshake_complete.pkt b/test/packetdrill/listen_close_before_handshake_complete.pkt new file mode 100644 index 000000000..51c3f1a32 --- /dev/null +++ b/test/packetdrill/listen_close_before_handshake_complete.pkt @@ -0,0 +1,31 @@ +// Test that closing a listening socket closes any connections in SYN-RCVD +// state and any packets bound for these connections generate a RESET. + +0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 ++0 bind(3, ..., ...) = 0 + +// Set backlog to 1 so that we can easily test. ++0 listen(3, 1) = 0 + +// Establish a connection without timestamps. ++0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> ++0 > S. 0:0(0) ack 1 <...> + ++0.100 close(3) = 0 ++0.1 < P. 1:1(0) ack 1 win 257 + +// Linux generates a reset with no ack number/bit set. This is contradictory to +// what is specified in Rule 1 under Reset Generation in +// https://tools.ietf.org/html/rfc793#section-3.4. +// "1. If the connection does not exist (CLOSED) then a reset is sent +// in response to any incoming segment except another reset. In +// particular, SYNs addressed to a non-existent connection are rejected +// by this means. +// +// If the incoming segment has an ACK field, the reset takes its +// sequence number from the ACK field of the segment, otherwise the +// reset has sequence number zero and the ACK field is set to the sum +// of the sequence number and segment length of the incoming segment. +// The connection remains in the CLOSED state." + ++0.0 > R 1:1(0) win 0
\ No newline at end of file diff --git a/test/packetdrill/no_rst_to_rst.pkt b/test/packetdrill/no_rst_to_rst.pkt new file mode 100644 index 000000000..612747827 --- /dev/null +++ b/test/packetdrill/no_rst_to_rst.pkt @@ -0,0 +1,36 @@ +// Test a RST is not generated in response to a RST and a RST is correctly +// generated when an accepted endpoint is RST due to an incoming RST. + +0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 ++0 bind(3, ..., ...) = 0 + ++0 listen(3, 1) = 0 + +// Establish a connection without timestamps. ++0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> ++0 > S. 0:0(0) ack 1 <...> ++0 < P. 1:1(0) ack 1 win 257 + ++0.100 accept(3, ..., ...) = 4 + ++0.200 < R 1:1(0) win 0 + ++0.300 read(4,..., 4) = -1 ECONNRESET (Connection Reset by Peer) + ++0.00 < . 1:1(0) ack 1 win 257 + +// Linux generates a reset with no ack number/bit set. This is contradictory to +// what is specified in Rule 1 under Reset Generation in +// https://tools.ietf.org/html/rfc793#section-3.4. +// "1. If the connection does not exist (CLOSED) then a reset is sent +// in response to any incoming segment except another reset. In +// particular, SYNs addressed to a non-existent connection are rejected +// by this means. +// +// If the incoming segment has an ACK field, the reset takes its +// sequence number from the ACK field of the segment, otherwise the +// reset has sequence number zero and the ACK field is set to the sum +// of the sequence number and segment length of the incoming segment. +// The connection remains in the CLOSED state." + ++0.00 > R 1:1(0) win 0
\ No newline at end of file diff --git a/test/runtimes/runner.sh b/test/packetdrill/packetdrill_setup.sh index a8d9a3460..b858072f0 100755 --- a/test/runtimes/runner.sh +++ b/test/packetdrill/packetdrill_setup.sh @@ -14,22 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -set -euf -x -o pipefail - -echo -- "$@" - -# Create outputs dir if it does not exist. -if [[ -n "${TEST_UNDECLARED_OUTPUTS_DIR}" ]]; then - mkdir -p "${TEST_UNDECLARED_OUTPUTS_DIR}" - chmod a+rwx "${TEST_UNDECLARED_OUTPUTS_DIR}" -fi - -# Update the timestamp on the shard status file. Bazel looks for this. -touch "${TEST_SHARD_STATUS_FILE}" - -# Get location of runner binary. -readonly runner=$(find "${TEST_SRCDIR}" -name runner) - -# Pass the arguments of this script directly to the runner. -exec "${runner}" "$@" - +# This script runs both within the sentry context and natively. It should tweak +# TCP parameters to match expectations found in the script files. +sysctl -q net.ipv4.tcp_sack=1 +sysctl -q net.ipv4.tcp_rmem="4096 2097152 $((8*1024*1024))" +sysctl -q net.ipv4.tcp_wmem="4096 2097152 $((8*1024*1024))" + +# There may be errors from the above, but they will show up in the test logs and +# we always want to proceed from this point. It's possible that values were +# already set correctly and the nodes were not available in the namespace. +exit 0 diff --git a/test/packetdrill/packetdrill_test.sh b/test/packetdrill/packetdrill_test.sh new file mode 100755 index 000000000..922547d65 --- /dev/null +++ b/test/packetdrill/packetdrill_test.sh @@ -0,0 +1,226 @@ +#!/bin/bash + +# Copyright 2020 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Run a packetdrill test. Two docker containers are made, one for the +# Device-Under-Test (DUT) and one for the test runner. Each is attached with +# two networks, one for control packets that aid the test and one for test +# packets which are sent as part of the test and observed for correctness. + +set -euxo pipefail + +function failure() { + local lineno=$1 + local msg=$2 + local filename="$0" + echo "FAIL: $filename:$lineno: $msg" +} +trap 'failure ${LINENO} "$BASH_COMMAND"' ERR + +declare -r LONGOPTS="dut_platform:,init_script:,runtime:" + +# Don't use declare below so that the error from getopt will end the script. +PARSED=$(getopt --options "" --longoptions=$LONGOPTS --name "$0" -- "$@") + +eval set -- "$PARSED" + +while true; do + case "$1" in + --dut_platform) + # Either "linux" or "netstack". + declare -r DUT_PLATFORM="$2" + shift 2 + ;; + --init_script) + declare -r INIT_SCRIPT="$2" + shift 2 + ;; + --runtime) + # Not readonly because there might be multiple --runtime arguments and we + # want to use just the last one. Only used if --dut_platform is + # "netstack". + declare RUNTIME="$2" + shift 2 + ;; + --) + shift + break + ;; + *) + echo "Programming error" + exit 3 + esac +done + +# All the other arguments are scripts. +declare -r scripts="$@" + +# Check that the required flags are defined in a way that is safe for "set -u". +if [[ "${DUT_PLATFORM-}" == "netstack" ]]; then + if [[ -z "${RUNTIME-}" ]]; then + echo "FAIL: Missing --runtime argument: ${RUNTIME-}" + exit 2 + fi + declare -r RUNTIME_ARG="--runtime ${RUNTIME}" +elif [[ "${DUT_PLATFORM-}" == "linux" ]]; then + declare -r RUNTIME_ARG="" +else + echo "FAIL: Bad or missing --dut_platform argument: ${DUT_PLATFORM-}" + exit 2 +fi +if [[ ! -x "${INIT_SCRIPT-}" ]]; then + echo "FAIL: Bad or missing --init_script: ${INIT_SCRIPT-}" + exit 2 +fi + +function new_net_prefix() { + # Class C, 192.0.0.0 to 223.255.255.255, transitionally has mask 24. + echo "$(shuf -i 192-223 -n 1).$(shuf -i 0-255 -n 1).$(shuf -i 0-255 -n 1)" +} + +# Variables specific to the control network and interface start with CTRL_. +# Variables specific to the test network and interface start with TEST_. +# Variables specific to the DUT start with DUT_. +# Variables specific to the test runner start with TEST_RUNNER_. +declare -r PACKETDRILL="/packetdrill/gtests/net/packetdrill/packetdrill" +# Use random numbers so that test networks don't collide. +declare CTRL_NET="ctrl_net-$(shuf -i 0-99999999 -n 1)" +declare CTRL_NET_PREFIX=$(new_net_prefix) +declare TEST_NET="test_net-$(shuf -i 0-99999999 -n 1)" +declare TEST_NET_PREFIX=$(new_net_prefix) +declare -r tolerance_usecs=100000 +# On both DUT and test runner, testing packets are on the eth2 interface. +declare -r TEST_DEVICE="eth2" +# Number of bits in the *_NET_PREFIX variables. +declare -r NET_MASK="24" +# Last bits of the DUT's IP address. +declare -r DUT_NET_SUFFIX=".10" +# Control port. +declare -r CTRL_PORT="40000" +# Last bits of the test runner's IP address. +declare -r TEST_RUNNER_NET_SUFFIX=".20" +declare -r TIMEOUT="60" +declare -r IMAGE_TAG="gcr.io/gvisor-presubmit/packetdrill" + +# Make sure that docker is installed. +docker --version + +function finish { + local cleanup_success=1 + for net in "${CTRL_NET}" "${TEST_NET}"; do + # Kill all processes attached to ${net}. + for docker_command in "kill" "rm"; do + (docker network inspect "${net}" \ + --format '{{range $key, $value := .Containers}}{{$key}} {{end}}' \ + | xargs -r docker "${docker_command}") || \ + cleanup_success=0 + done + # Remove the network. + docker network rm "${net}" || \ + cleanup_success=0 + done + + if ((!$cleanup_success)); then + echo "FAIL: Cleanup command failed" + exit 4 + fi +} +trap finish EXIT + +# Subnet for control packets between test runner and DUT. +while ! docker network create \ + "--subnet=${CTRL_NET_PREFIX}.0/${NET_MASK}" "${CTRL_NET}"; do + sleep 0.1 + CTRL_NET_PREFIX=$(new_net_prefix) + CTRL_NET="ctrl_net-$(shuf -i 0-99999999 -n 1)" +done + +# Subnet for the packets that are part of the test. +while ! docker network create \ + "--subnet=${TEST_NET_PREFIX}.0/${NET_MASK}" "${TEST_NET}"; do + sleep 0.1 + TEST_NET_PREFIX=$(new_net_prefix) + TEST_NET="test_net-$(shuf -i 0-99999999 -n 1)" +done + +# Create the DUT container and connect to network. +DUT=$(docker create ${RUNTIME_ARG} --privileged --rm \ + --stop-timeout ${TIMEOUT} -it ${IMAGE_TAG}) +docker network connect "${CTRL_NET}" \ + --ip "${CTRL_NET_PREFIX}${DUT_NET_SUFFIX}" "${DUT}" \ + || (docker kill ${DUT}; docker rm ${DUT}; false) +docker network connect "${TEST_NET}" \ + --ip "${TEST_NET_PREFIX}${DUT_NET_SUFFIX}" "${DUT}" \ + || (docker kill ${DUT}; docker rm ${DUT}; false) +docker start "${DUT}" + +# Create the test runner container and connect to network. +TEST_RUNNER=$(docker create --privileged --rm \ + --stop-timeout ${TIMEOUT} -it ${IMAGE_TAG}) +docker network connect "${CTRL_NET}" \ + --ip "${CTRL_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" "${TEST_RUNNER}" \ + || (docker kill ${TEST_RUNNER}; docker rm ${REST_RUNNER}; false) +docker network connect "${TEST_NET}" \ + --ip "${TEST_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" "${TEST_RUNNER}" \ + || (docker kill ${TEST_RUNNER}; docker rm ${REST_RUNNER}; false) +docker start "${TEST_RUNNER}" + +# Run tcpdump in the test runner unbuffered, without dns resolution, just on the +# interface with the test packets. +docker exec -t ${TEST_RUNNER} tcpdump -U -n -i "${TEST_DEVICE}" & + +# Start a packetdrill server on the test_runner. The packetdrill server sends +# packets and asserts that they are received. +docker exec -d "${TEST_RUNNER}" \ + ${PACKETDRILL} --wire_server --wire_server_dev="${TEST_DEVICE}" \ + --wire_server_ip="${CTRL_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" \ + --wire_server_port="${CTRL_PORT}" \ + --local_ip="${TEST_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" \ + --remote_ip="${TEST_NET_PREFIX}${DUT_NET_SUFFIX}" + +# Because the Linux kernel receives the SYN-ACK but didn't send the SYN it will +# issue a RST. To prevent this IPtables can be used to filter those out. +docker exec "${TEST_RUNNER}" \ + iptables -A OUTPUT -p tcp --tcp-flags RST RST -j DROP + +# Wait for the packetdrill server on the test runner to come. Attempt to +# connect to it from the DUT every 100 milliseconds until success. +while ! docker exec "${DUT}" \ + nc -zv "${CTRL_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" "${CTRL_PORT}"; do + sleep 0.1 +done + +# Copy the packetdrill setup script to the DUT. +docker cp -L "${INIT_SCRIPT}" "${DUT}:packetdrill_setup.sh" + +# Copy the packetdrill scripts to the DUT. +declare -a dut_scripts +for script in $scripts; do + docker cp -L "${script}" "${DUT}:$(basename ${script})" + dut_scripts+=("/$(basename ${script})") +done + +# Start a packetdrill client on the DUT. The packetdrill client runs POSIX +# socket commands and also sends instructions to the server. +docker exec -t "${DUT}" \ + ${PACKETDRILL} --wire_client --wire_client_dev="${TEST_DEVICE}" \ + --wire_server_ip="${CTRL_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" \ + --wire_server_port="${CTRL_PORT}" \ + --local_ip="${TEST_NET_PREFIX}${DUT_NET_SUFFIX}" \ + --remote_ip="${TEST_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" \ + --init_scripts=/packetdrill_setup.sh \ + --tolerance_usecs="${tolerance_usecs}" "${dut_scripts[@]}" + +echo PASS: No errors. diff --git a/test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt b/test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt new file mode 100644 index 000000000..a86b90ce6 --- /dev/null +++ b/test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt @@ -0,0 +1,9 @@ +// Test that a listening socket generates a RST when it receives an +// ACK and syn cookies are not in use. + +0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 ++0 bind(3, ..., ...) = 0 + ++0 listen(3, 1) = 0 ++0.1 < . 1:1(0) ack 1 win 32792 ++0 > R 1:1(0) ack 0 win 0
\ No newline at end of file diff --git a/test/packetdrill/sanity_test.pkt b/test/packetdrill/sanity_test.pkt new file mode 100644 index 000000000..b3b58c366 --- /dev/null +++ b/test/packetdrill/sanity_test.pkt @@ -0,0 +1,7 @@ +// Basic sanity test. One system call. +// +// All of the plumbing has to be working however, and the packetdrill wire +// client needs to be able to connect to the wire server and send the script, +// probe local interfaces, run through the test w/ timings, etc. + +0.000 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 diff --git a/test/packetdrill/tcp_defer_accept.pkt b/test/packetdrill/tcp_defer_accept.pkt new file mode 100644 index 000000000..a17f946db --- /dev/null +++ b/test/packetdrill/tcp_defer_accept.pkt @@ -0,0 +1,48 @@ +// Test that a bare ACK does not complete a connection when TCP_DEFER_ACCEPT +// timeout is not hit but an ACK w/ data does complete and deliver the +// connection to the accept queue. + +0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 ++0 setsockopt(3, SOL_TCP, TCP_DEFER_ACCEPT, [5], 4) = 0 ++0.000 fcntl(3, F_SETFL, O_RDWR|O_NONBLOCK) = 0 ++0 bind(3, ..., ...) = 0 + +// Set backlog to 1 so that we can easily test. ++0 listen(3, 1) = 0 + +// Establish a connection without timestamps. ++0.0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> ++0.0 > S. 0:0(0) ack 1 <...> + +// Send a bare ACK this should not complete the connection as we +// set the TCP_DEFER_ACCEPT above. ++0.0 < . 1:1(0) ack 1 win 257 + +// The bare ACK should be dropped and no connection should be delivered +// to the accept queue. ++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block) + +// Send another bare ACK and it should still fail we set TCP_DEFER_ACCEPT +// to 5 seconds above. ++2.5 < . 1:1(0) ack 1 win 257 ++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block) + +// set accept socket back to blocking. ++0.000 fcntl(3, F_SETFL, O_RDWR) = 0 + +// Now send an ACK w/ data. This should complete the connection +// and deliver the socket to the accept queue. ++0.1 < . 1:5(4) ack 1 win 257 ++0.0 > . 1:1(0) ack 5 <...> + +// This should cause connection to transition to connected state. ++0.000 accept(3, ..., ...) = 4 ++0.000 fcntl(4, F_SETFL, O_RDWR|O_NONBLOCK) = 0 + +// Now read the data and we should get 4 bytes. ++0.000 read(4,..., 4) = 4 ++0.000 close(4) = 0 + ++0.0 > F. 1:1(0) ack 5 <...> ++0.0 < F. 5:5(0) ack 2 win 257 ++0.01 > . 2:2(0) ack 6 <...>
\ No newline at end of file diff --git a/test/packetdrill/tcp_defer_accept_timeout.pkt b/test/packetdrill/tcp_defer_accept_timeout.pkt new file mode 100644 index 000000000..201fdeb14 --- /dev/null +++ b/test/packetdrill/tcp_defer_accept_timeout.pkt @@ -0,0 +1,48 @@ +// Test that a bare ACK is accepted after TCP_DEFER_ACCEPT timeout +// is hit and a connection is delivered. + +0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 ++0 setsockopt(3, SOL_TCP, TCP_DEFER_ACCEPT, [3], 4) = 0 ++0.000 fcntl(3, F_SETFL, O_RDWR|O_NONBLOCK) = 0 ++0 bind(3, ..., ...) = 0 + +// Set backlog to 1 so that we can easily test. ++0 listen(3, 1) = 0 + +// Establish a connection without timestamps. ++0.0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> ++0.0 > S. 0:0(0) ack 1 <...> + +// Send a bare ACK this should not complete the connection as we +// set the TCP_DEFER_ACCEPT above. ++0.0 < . 1:1(0) ack 1 win 257 + +// The bare ACK should be dropped and no connection should be delivered +// to the accept queue. ++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block) + +// Send another bare ACK and it should still fail we set TCP_DEFER_ACCEPT +// to 5 seconds above. ++2.5 < . 1:1(0) ack 1 win 257 ++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block) + +// set accept socket back to blocking. ++0.000 fcntl(3, F_SETFL, O_RDWR) = 0 + +// We should see one more retransmit of the SYN-ACK as a last ditch +// attempt when TCP_DEFER_ACCEPT timeout is hit to trigger another +// ACK or a packet with data. ++.35~+2.35 > S. 0:0(0) ack 1 <...> + +// Now send another bare ACK after TCP_DEFER_ACCEPT time has been passed. ++0.0 < . 1:1(0) ack 1 win 257 + +// The ACK above should cause connection to transition to connected state. ++0.000 accept(3, ..., ...) = 4 ++0.000 fcntl(4, F_SETFL, O_RDWR|O_NONBLOCK) = 0 + ++0.000 close(4) = 0 + ++0.0 > F. 1:1(0) ack 1 <...> ++0.0 < F. 1:1(0) ack 2 win 257 ++0.01 > . 2:2(0) ack 2 <...> diff --git a/test/packetimpact/README.md b/test/packetimpact/README.md new file mode 100644 index 000000000..ffa96ba98 --- /dev/null +++ b/test/packetimpact/README.md @@ -0,0 +1,702 @@ +# Packetimpact + +## What is packetimpact? + +Packetimpact is a tool for platform-independent network testing. It is heavily +inspired by [packetdrill](https://github.com/google/packetdrill). It creates two +docker containers connected by a network. One is for the test bench, which +operates the test. The other is for the device-under-test (DUT), which is the +software being tested. The test bench communicates over the network with the DUT +to check correctness of the network. + +### Goals + +Packetimpact aims to provide: + +* A **multi-platform** solution that can test both Linux and gVisor. +* **Conciseness** on par with packetdrill scripts. +* **Control-flow** like for loops, conditionals, and variables. +* **Flexibilty** to specify every byte in a packet or use multiple sockets. + +## How to run packetimpact tests? + +Build the test container image by running the following at the root of the +repository: + +```bash +$ make load-packetimpact +``` + +Run a test, e.g. `fin_wait2_timeout`, against Linux: + +```bash +$ bazel test //test/packetimpact/tests:fin_wait2_timeout_native_test +``` + +Run the same test, but against gVisor: + +```bash +$ bazel test //test/packetimpact/tests:fin_wait2_timeout_netstack_test +``` + +## When to use packetimpact? + +There are a few ways to write networking tests for gVisor currently: + +* [Go unit tests](https://github.com/google/gvisor/tree/master/pkg/tcpip) +* [syscall tests](https://github.com/google/gvisor/tree/master/test/syscalls/linux) +* [packetdrill tests](https://github.com/google/gvisor/tree/master/test/packetdrill) +* packetimpact tests + +The right choice depends on the needs of the test. + +Feature | Go unit test | syscall test | packetdrill | packetimpact +-------------- | ------------ | ------------ | ----------- | ------------ +Multi-platform | no | **YES** | **YES** | **YES** +Concise | no | somewhat | somewhat | **VERY** +Control-flow | **YES** | **YES** | no | **YES** +Flexible | **VERY** | no | somewhat | **VERY** + +### Go unit tests + +If the test depends on the internals of gVisor and doesn't need to run on Linux +or other platforms for comparison purposes, a Go unit test can be appropriate. +They can observe internals of gVisor networking. The downside is that they are +**not concise** and **not multi-platform**. If you require insight on gVisor +internals, this is the right choice. + +### Syscall tests + +Syscall tests are **multi-platform** but cannot examine the internals of gVisor +networking. They are **concise**. They can use **control-flow** structures like +conditionals, for loops, and variables. However, they are limited to only what +the POSIX interface provides so they are **not flexible**. For example, you +would have difficulty writing a syscall test that intentionally sends a bad IP +checksum. Or if you did write that test with raw sockets, it would be very +**verbose** to write a test that intentionally send wrong checksums, wrong +protocols, wrong sequence numbers, etc. + +### Packetdrill tests + +Packetdrill tests are **multi-platform** and can run against both Linux and +gVisor. They are **concise** and use a special packetdrill scripting language. +They are **more flexible** than a syscall test in that they can send packets +that a syscall test would have difficulty sending, like a packet with a +calcuated ACK number. But they are also somewhat limimted in flexibiilty in that +they can't do tests with multiple sockets. They have **no control-flow** ability +like variables or conditionals. For example, it isn't possible to send a packet +that depends on the window size of a previous packet because the packetdrill +language can't express that. Nor could you branch based on whether or not the +other side supports window scaling, for example. + +### Packetimpact tests + +Packetimpact tests are similar to Packetdrill tests except that they are written +in Go instead of the packetdrill scripting language. That gives them all the +**control-flow** abilities of Go (loops, functions, variables, etc). They are +**multi-platform** in the same way as packetdrill tests but even more +**flexible** because Go is more expressive than the scripting language of +packetdrill. However, Go is **not as concise** as the packetdrill language. Many +design decisions below are made to mitigate that. + +## How it works + +``` + Testbench Device-Under-Test (DUT) + +-------------------+ +------------------------+ + | | TEST NET | | + | rawsockets.go <-->| <===========> | <---+ | + | ^ | | | | + | | | | | | + | v | | | | + | unittest | | | | + | ^ | | | | + | | | | | | + | v | | v | + | dut.go <========gRPC========> posix server | + | | CONTROL NET | | + +-------------------+ +------------------------+ +``` + +Two docker containers are created by a "runner" script, one for the testbench +and the other for the device under test (DUT). The script connects the two +containers with a control network and test network. It also does some other +tasks like waiting until the DUT is ready before starting the test and disabling +Linux networking that would interfere with the test bench. + +### DUT + +The DUT container runs a program called the "posix_server". The posix_server is +written in c++ for maximum portability. It is compiled on the host. The script +that starts the containers copies it into the DUT's container and runs it. It's +job is to receive directions from the test bench on what actions to take. For +this, the posix_server does three steps in a loop: + +1. Listen for a request from the test bench. +2. Execute a command. +3. Send the response back to the test bench. + +The requests and responses are +[protobufs](https://developers.google.com/protocol-buffers) and the +communication is done with [gRPC](https://grpc.io/). The commands run are +[POSIX socket commands](https://en.wikipedia.org/wiki/Berkeley_sockets#Socket_API_functions), +with the inputs and outputs converted into protobuf requests and responses. All +communication is on the control network, so that the test network is unaffected +by extra packets. + +For example, this is the request and response pair to call +[`socket()`](http://man7.org/linux/man-pages/man2/socket.2.html): + +```protocol-buffer +message SocketRequest { + int32 domain = 1; + int32 type = 2; + int32 protocol = 3; +} + +message SocketResponse { + int32 fd = 1; + int32 errno_ = 2; +} +``` + +##### Alternatives considered + +* We could have use JSON for communication instead. It would have been a + lighter-touch than protobuf but protobuf handles all the data type and has + strict typing to prevent a class of errors. The test bench could be written + in other languages, too. +* Instead of mimicking the POSIX interfaces, arguments could have had a more + natural form, like the `bind()` getting a string IP address instead of bytes + in a `sockaddr_t`. However, conforming to the existing structures keeps more + of the complexity in Go and keeps the posix_server simpler and thus more + likely to compile everywhere. + +### Test Bench + +The test bench does most of the work in a test. It is a Go program that compiles +on the host and is copied by the script into test bench's container. It is a +regular [go unit test](https://golang.org/pkg/testing/) that imports the test +bench framework. The test bench framwork is based on three basic utilities: + +* Commanding the DUT to run POSIX commands and return responses. +* Sending raw packets to the DUT on the test network. +* Listening for raw packets from the DUT on the test network. + +#### DUT commands + +To keep the interface to the DUT consistent and easy-to-use, each POSIX command +supported by the posix_server is wrapped in functions with signatures similar to +the ones in the [Go unix package](https://godoc.org/golang.org/x/sys/unix). This +way all the details of endianess and (un)marshalling of go structs such as +[unix.Timeval](https://godoc.org/golang.org/x/sys/unix#Timeval) is handled in +one place. This also makes it straight-forward to convert tests that use `unix.` +or `syscall.` calls to `dut.` calls. + +For example, creating a connection to the DUT and commanding it to make a socket +looks like this: + +```go +dut := testbench.NewDut(t) +fd, err := dut.SocketWithErrno(unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_IP) +if fd < 0 { + t.Fatalf(...) +} +``` + +Because the usual case is to fail the test when the DUT fails to create a +socket, there is a concise version of each of the `...WithErrno` functions that +does that: + +```go +dut := testbench.NewDut(t) +fd := dut.Socket(unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_IP) +``` + +The DUT and other structs in the code store a `*testing.T` so that they can +provide versions of functions that call `t.Fatalf(...)`. This helps keep tests +concise. + +##### Alternatives considered + +* Instead of mimicking the `unix.` go interface, we could have invented a more + natural one, like using `float64` instead of `Timeval`. However, using the + same function signatures that `unix.` has makes it easier to convert code to + `dut.`. Also, using an existing interface ensures that we don't invent an + interface that isn't extensible. For example, if we invented a function for + `bind()` that didn't support IPv6 and later we had to add a second `bind6()` + function. + +#### Sending/Receiving Raw Packets + +The framework wraps POSIX sockets for sending and receiving raw frames. Both +send and receive are synchronous commands. +[SO_RCVTIMEO](http://man7.org/linux/man-pages/man7/socket.7.html) is used to set +a timeout on the receive commands. For ease of use, these are wrapped in an +`Injector` and a `Sniffer`. They have functions: + +```go +func (s *Sniffer) Recv(timeout time.Duration) []byte {...} +func (i *Injector) Send(b []byte) {...} +``` + +##### Alternatives considered + +* [gopacket](https://github.com/google/gopacket) pcap has raw socket support + but requires cgo. cgo is not guaranteed to be portable from the host to the + container and in practice, the container doesn't recognize binaries built on + the host if they use cgo. +* Both gVisor and gopacket have the ability to read and write pcap files + without cgo but that is insufficient here because we can't just replay pcap + files, we need a more dynamic solution. +* The sniffer and injector can't share a socket because they need to be bound + differently. +* Sniffing could have been done asynchronously with channels, obviating the + need for `SO_RCVTIMEO`. But that would introduce asynchronous complication. + `SO_RCVTIMEO` is well supported on the test bench. + +#### `Layer` struct + +A large part of packetimpact tests is creating packets to send and comparing +received packets against expectations. To keep tests concise, it is useful to be +able to specify just the important parts of packets that need to be set. For +example, sending a packet with default values except for TCP Flags. And for +packets received, it's useful to be able to compare just the necessary parts of +received packets and ignore the rest. + +To aid in both of those, Go structs with optional fields are created for each +encapsulation type, such as IPv4, TCP, and Ethernet. This is inspired by +[scapy](https://scapy.readthedocs.io/en/latest/). For example, here is the +struct for Ethernet: + +```go +type Ether struct { + LayerBase + SrcAddr *tcpip.LinkAddress + DstAddr *tcpip.LinkAddress + Type *tcpip.NetworkProtocolNumber +} +``` + +Each struct has the same fields as those in the +[gVisor headers](https://github.com/google/gvisor/tree/master/pkg/tcpip/header) +but with a pointer for each field that may be `nil`. + +##### Alternatives considered + +* Just use []byte like gVisor headers do. The drawback is that it makes the + tests more verbose. + * For example, there would be no way to call `Send(myBytes)` concisely and + indicate if the checksum should be calculated automatically versus + overridden. The only way would be to add lines to the test to calculate + it before each Send, which is wordy. Or make multiple versions of Send: + one that checksums IP, one that doesn't, one that checksums TCP, one + that does both, etc. That would be many combinations. + * Filtering inputs would become verbose. Either: + * large conditionals that need to be repeated many places: + `h[FlagOffset] == SYN && h[LengthOffset:LengthOffset+2] == ...` or + * Many functions, one per field, like: `filterByFlag(myBytes, SYN)`, + `filterByLength(myBytes, 20)`, `filterByNextProto(myBytes, 0x8000)`, + etc. + * Using pointers allows us to combine `Layer`s with reflection. So the + default `Layers` can be overridden by a `Layers` with just the TCP + conection's src/dst which can be overridden by one with just a test + specific TCP window size. + * It's a proven way to separate the details of a packet from the byte + format as shown by scapy's success. +* Use packetgo. It's more general than parsing packets with gVisor. However: + * packetgo doesn't have optional fields so many of the above problems + still apply. + * It would be yet another dependency. + * It's not as well known to engineers that are already writing gVisor + code. + * It might be a good candidate for replacing the parsing of packets into + `Layer`s if all that parsing turns out to be more work than parsing by + packetgo and converting *that* to `Layer`. packetgo has easier to use + getters for the layers. This could be done later in a way that doesn't + break tests. + +#### `Layer` methods + +The `Layer` structs provide a way to partially specify an encapsulation. They +also need methods for using those partially specified encapsulation, for example +to marshal them to bytes or compare them. For those, each encapsulation +implements the `Layer` interface: + +```go +// Layer is the interface that all encapsulations must implement. +// +// A Layer is an encapsulation in a packet, such as TCP, IPv4, IPv6, etc. A +// Layer contains all the fields of the encapsulation. Each field is a pointer +// and may be nil. +type Layer interface { + // toBytes converts the Layer into bytes. In places where the Layer's field + // isn't nil, the value that is pointed to is used. When the field is nil, a + // reasonable default for the Layer is used. For example, "64" for IPv4 TTL + // and a calculated checksum for TCP or IP. Some layers require information + // from the previous or next layers in order to compute a default, such as + // TCP's checksum or Ethernet's type, so each Layer has a doubly-linked list + // to the layer's neighbors. + toBytes() ([]byte, error) + + // match checks if the current Layer matches the provided Layer. If either + // Layer has a nil in a given field, that field is considered matching. + // Otherwise, the values pointed to by the fields must match. + match(Layer) bool + + // length in bytes of the current encapsulation + length() int + + // next gets a pointer to the encapsulated Layer. + next() Layer + + // prev gets a pointer to the Layer encapsulating this one. + prev() Layer + + // setNext sets the pointer to the encapsulated Layer. + setNext(Layer) + + // setPrev sets the pointer to the Layer encapsulating this one. + setPrev(Layer) +} +``` + +The `next` and `prev` make up a link listed so that each layer can get at the +information in the layer around it. This is necessary for some protocols, like +TCP that needs the layer before and payload after to compute the checksum. Any +sequence of `Layer` structs is valid so long as the parser and `toBytes` +functions can map from type to protool number and vice-versa. When the mapping +fails, an error is emitted explaining what functionality is missing. The +solution is either to fix the ordering or implement the missing protocol. + +For each `Layer` there is also a parsing function. For example, this one is for +Ethernet: + +``` +func ParseEther(b []byte) (Layers, error) +``` + +The parsing function converts bytes received on the wire into a `Layer` +(actually `Layers`, see below) which has no `nil`s in it. By using +`match(Layer)` to compare against another `Layer` that *does* have `nil`s in it, +the received bytes can be partially compared. The `nil`s behave as +"don't-cares". + +##### Alternatives considered + +* Matching against `[]byte` instead of converting to `Layer` first. + * The downside is that it precludes the use of a `cmp.Equal` one-liner to + do comparisons. + * It creates confusion in the code to deal with both representations at + different times. For example, is the checksum calculated on `[]byte` or + `Layer` when sending? What about when checking received packets? + +#### `Layers` + +``` +type Layers []Layer + +func (ls *Layers) match(other Layers) bool {...} +func (ls *Layers) toBytes() ([]byte, error) {...} +``` + +`Layers` is an array of `Layer`. It represents a stack of encapsulations, such +as `Layers{Ether{},IPv4{},TCP{},Payload{}}`. It also has `toBytes()` and +`match(Layers)`, like `Layer`. The parse functions above actually return +`Layers` and not `Layer` because they know about the headers below and +sequentially call each parser on the remaining, encapsulated bytes. + +All this leads to the ability to write concise packet processing. For example: + +```go +etherType := 0x8000 +flags = uint8(header.TCPFlagSyn|header.TCPFlagAck) +toMatch := Layers{Ether{Type: ðerType}, IPv4{}, TCP{Flags: &flags}} +for { + recvBytes := sniffer.Recv(time.Second) + if recvBytes == nil { + println("Got no packet for 1 second") + } + gotPacket, err := ParseEther(recvBytes) + if err == nil && toMatch.match(gotPacket) { + println("Got a TCP/IPv4/Eth packet with SYNACK") + } +} +``` + +##### Alternatives considered + +* Don't use previous and next pointers. + * Each layer may need to be able to interrogate the layers around it, like + for computing the next protocol number or total length. So *some* + mechanism is needed for a `Layer` to see neighboring layers. + * We could pass the entire array `Layers` to the `toBytes()` function. + Passing an array to a method that includes in the array the function + receiver itself seems wrong. + +#### `layerState` + +`Layers` represents the different headers of a packet but a connection includes +more state. For example, a TCP connection needs to keep track of the next +expected sequence number and also the next sequence number to send. This is +stored in a `layerState` struct. This is the `layerState` for TCP: + +```go +// tcpState maintains state about a TCP connection. +type tcpState struct { + out, in TCP + localSeqNum, remoteSeqNum *seqnum.Value + synAck *TCP + portPickerFD int + finSent bool +} +``` + +The next sequence numbers for each side of the connection are stored. `out` and +`in` have defaults for the TCP header, such as the expected source and +destination ports for outgoing packets and incoming packets. + +##### `layerState` interface + +```go +// layerState stores the state of a layer of a connection. +type layerState interface { + // outgoing returns an outgoing layer to be sent in a frame. + outgoing() Layer + + // incoming creates an expected Layer for comparing against a received Layer. + // Because the expectation can depend on values in the received Layer, it is + // an input to incoming. For example, the ACK number needs to be checked in a + // TCP packet but only if the ACK flag is set in the received packet. + incoming(received Layer) Layer + + // sent updates the layerState based on the Layer that was sent. The input is + // a Layer with all prev and next pointers populated so that the entire frame + // as it was sent is available. + sent(sent Layer) error + + // received updates the layerState based on a Layer that is receieved. The + // input is a Layer with all prev and next pointers populated so that the + // entire frame as it was receieved is available. + received(received Layer) error + + // close frees associated resources held by the LayerState. + close() error +} +``` + +`outgoing` generates the default Layer for an outgoing packet. For TCP, this +would be a `TCP` with the source and destination ports populated. Because they +are static, they are stored inside the `out` member of `tcpState`. However, the +sequence numbers change frequently so the outgoing sequence number is stored in +the `localSeqNum` and put into the output of outgoing for each call. + +`incoming` does the same functions for packets that arrive but instead of +generating a packet to send, it generates an expect packet for filtering packets +that arrive. For example, if a `TCP` header arrives with the wrong ports, it can +be ignored as belonging to a different connection. `incoming` needs the received +header itself as an input because the filter may depend on the input. For +example, the expected sequence number depends on the flags in the TCP header. + +`sent` and `received` are run for each header that is actually sent or received +and used to update the internal state. `incoming` and `outgoing` should *not* be +used for these purpose. For example, `incoming` is called on every packet that +arrives but only packets that match ought to actually update the state. +`outgoing` is called to created outgoing packets and those packets are always +sent, so unlike `incoming`/`received`, there is one `outgoing` call for each +`sent` call. + +`close` cleans up after the layerState. For example, TCP and UDP need to keep a +port reserved and then release it. + +#### Connections + +Using `layerState` above, we can create connections. + +```go +// Connection holds a collection of layer states for maintaining a connection +// along with sockets for sniffer and injecting packets. +type Connection struct { + layerStates []layerState + injector Injector + sniffer Sniffer + t *testing.T +} +``` + +The connection stores an array of `layerState` in the order that the headers +should be present in the frame to send. For example, Ether then IPv4 then TCP. +The injector and sniffer are for writing and reading frames. A `*testing.T` is +stored so that internal errors can be reported directly without code in the unit +test. + +The `Connection` has some useful functions: + +```go +// Close frees associated resources held by the Connection. +func (conn *Connection) Close() {...} +// CreateFrame builds a frame for the connection with layer overriding defaults +// of the innermost layer and additionalLayers added after it. +func (conn *Connection) CreateFrame(layer Layer, additionalLayers ...Layer) Layers {...} +// SendFrame sends a frame on the wire and updates the state of all layers. +func (conn *Connection) SendFrame(frame Layers) {...} +// Send a packet with reasonable defaults. Potentially override the final layer +// in the connection with the provided layer and add additionLayers. +func (conn *Connection) Send(layer Layer, additionalLayers ...Layer) {...} +// Expect a frame with the final layerStates layer matching the provided Layer +// within the timeout specified. If it doesn't arrive in time, it returns nil. +func (conn *Connection) Expect(layer Layer, timeout time.Duration) (Layer, error) {...} +// ExpectFrame expects a frame that matches the provided Layers within the +// timeout specified. If it doesn't arrive in time, it returns nil. +func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layers, error) {...} +// Drain drains the sniffer's receive buffer by receiving packets until there's +// nothing else to receive. +func (conn *Connection) Drain() {...} +``` + +`CreateFrame` uses the `[]layerState` to create a frame to send. The first +argument is for overriding defaults in the last header of the frame, because +this is the most common need. For a TCPIPv4 connection, this would be the TCP +header. Optional additionalLayers can be specified to add to the frame being +created, such as a `Payload` for `TCP`. + +`SendFrame` sends the frame to the DUT. It is combined with `CreateFrame` to +make `Send`. For unittests with basic sending needs, `Send` can be used. If more +control is needed over the frame, it can be made with `CreateFrame`, modified in +the unit test, and then sent with `SendFrame`. + +On the receiving side, there is `Expect` and `ExpectFrame`. Like with the +sending side, there are two forms of each function, one for just the last header +and one for the whole frame. The expect functions use the `[]layerState` to +create a template for the expected incoming frame. That frame is then overridden +by the values in the first argument. Finally, a loop starts sniffing packets on +the wire for frames. If a matching frame is found before the timeout, it is +returned without error. If not, nil is returned and the error contains text of +all the received frames that didn't match. Exactly one of the outputs will be +non-nil, even if no frames are received at all. + +`Drain` sniffs and discards all the frames that have yet to be received. A +common way to write a test is: + +```go +conn.Drain() // Discard all outstanding frames. +conn.Send(...) // Send a frame with overrides. +// Now expect a frame with a certain header and fail if it doesn't arrive. +if _, err := conn.Expect(...); err != nil { t.Fatal(...) } +``` + +Or for a test where we want to check that no frame arrives: + +```go +if gotOne, _ := conn.Expect(...); gotOne != nil { t.Fatal(...) } +``` + +#### Specializing `Connection` + +Because there are some common combinations of `layerState` into `Connection`, +they are defined: + +```go +// TCPIPv4 maintains the state for all the layers in a TCP/IPv4 connection. +type TCPIPv4 Connection +// UDPIPv4 maintains the state for all the layers in a UDP/IPv4 connection. +type UDPIPv4 Connection +``` + +Each has a `NewXxx` function to create a new connection with reasonable +defaults. They also have functions that call the underlying `Connection` +functions but with specialization and tighter type-checking. For example: + +```go +func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) { + (*Connection)(conn).Send(&tcp, additionalLayers...) +} +func (conn *TCPIPv4) Drain() { + conn.sniffer.Drain() +} +``` + +They may also have some accessors to get or set the internal state of the +connection: + +```go +func (conn *TCPIPv4) state() *tcpState { + state, ok := conn.layerStates[len(conn.layerStates)-1].(*tcpState) + if !ok { + conn.t.Fatalf("expected final state of %v to be tcpState", conn.layerStates) + } + return state +} +func (conn *TCPIPv4) RemoteSeqNum() *seqnum.Value { + return conn.state().remoteSeqNum +} +func (conn *TCPIPv4) LocalSeqNum() *seqnum.Value { + return conn.state().localSeqNum +} +``` + +Unittests will in practice use these functions and not the functions on +`Connection`. For example, `NewTCPIPv4()` and then call `Send` on that rather +than cast is to a `Connection` and call `Send` on that cast result. + +##### Alternatives considered + +* Instead of storing `outgoing` and `incoming`, store values. + * There would be many more things to store instead, like `localMac`, + `remoteMac`, `localIP`, `remoteIP`, `localPort`, and `remotePort`. + * Construction of a packet would be many lines to copy each of these + values into a `[]byte`. And there would be slight variations needed for + each encapsulation stack, like TCPIPv6 and ARP. + * Filtering incoming packets would be a long sequence: + * Compare the MACs, then + * Parse the next header, then + * Compare the IPs, then + * Parse the next header, then + * Compare the TCP ports. Instead it's all just one call to + `cmp.Equal(...)`, for all sequences. + * A TCPIPv6 connection could share most of the code. Only the type of the + IP addresses are different. The types of `outgoing` and `incoming` would + be remain `Layers`. + * An ARP connection could share all the Ethernet parts. The IP `Layer` + could be factored out of `outgoing`. After that, the IPv4 and IPv6 + connections could implement one interface and a single TCP struct could + have either network protocol through composition. + +## Putting it all together + +Here's what te start of a packetimpact unit test looks like. This test creates a +TCP connection with the DUT. There are added comments for explanation in this +document but a real test might not include them in order to stay even more +concise. + +```go +func TestMyTcpTest(t *testing.T) { + // Prepare a DUT for communication. + dut := testbench.NewDUT(t) + + // This does: + // dut.Socket() + // dut.Bind() + // dut.Getsockname() to learn the new port number + // dut.Listen() + listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFD) // Tell the DUT to close the socket at the end of the test. + + // Monitor a new TCP connection with sniffer, injector, sequence number tracking, + // and reasonable outgoing and incoming packet field default IPs, MACs, and port numbers. + conn := testbench.NewTCPIPv4(t, dut, remotePort) + + // Perform a 3-way handshake: send SYN, expect SYNACK, send ACK. + conn.Handshake() + + // Tell the DUT to accept the new connection. + acceptFD := dut.Accept(acceptFd) +} +``` + +## Other notes + +* The time between receiving a SYN-ACK and replying with an ACK in `Handshake` + is about 3ms. This is much slower than the native unix response, which is + about 0.3ms. Packetdrill gets closer to 0.3ms. For tests where timing is + crucial, packetdrill is faster and more precise. diff --git a/test/packetimpact/dut/BUILD b/test/packetimpact/dut/BUILD new file mode 100644 index 000000000..3ce63c2c6 --- /dev/null +++ b/test/packetimpact/dut/BUILD @@ -0,0 +1,18 @@ +load("//tools:defs.bzl", "cc_binary", "grpcpp") + +package( + default_visibility = ["//test/packetimpact:__subpackages__"], + licenses = ["notice"], +) + +cc_binary( + name = "posix_server", + srcs = ["posix_server.cc"], + linkstatic = 1, + static = True, # This is needed for running in a docker container. + deps = [ + grpcpp, + "//test/packetimpact/proto:posix_server_cc_grpc_proto", + "//test/packetimpact/proto:posix_server_cc_proto", + ], +) diff --git a/test/packetimpact/dut/posix_server.cc b/test/packetimpact/dut/posix_server.cc new file mode 100644 index 000000000..29d4cc6fe --- /dev/null +++ b/test/packetimpact/dut/posix_server.cc @@ -0,0 +1,371 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <arpa/inet.h> +#include <fcntl.h> +#include <getopt.h> +#include <netdb.h> +#include <netinet/in.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +#include <iostream> +#include <unordered_map> + +#include "include/grpcpp/security/server_credentials.h" +#include "include/grpcpp/server_builder.h" +#include "test/packetimpact/proto/posix_server.grpc.pb.h" +#include "test/packetimpact/proto/posix_server.pb.h" + +// Converts a sockaddr_storage to a Sockaddr message. +::grpc::Status sockaddr_to_proto(const sockaddr_storage &addr, + socklen_t addrlen, + posix_server::Sockaddr *sockaddr_proto) { + switch (addr.ss_family) { + case AF_INET: { + auto addr_in = reinterpret_cast<const sockaddr_in *>(&addr); + auto response_in = sockaddr_proto->mutable_in(); + response_in->set_family(addr_in->sin_family); + response_in->set_port(ntohs(addr_in->sin_port)); + response_in->mutable_addr()->assign( + reinterpret_cast<const char *>(&addr_in->sin_addr.s_addr), 4); + return ::grpc::Status::OK; + } + case AF_INET6: { + auto addr_in6 = reinterpret_cast<const sockaddr_in6 *>(&addr); + auto response_in6 = sockaddr_proto->mutable_in6(); + response_in6->set_family(addr_in6->sin6_family); + response_in6->set_port(ntohs(addr_in6->sin6_port)); + response_in6->set_flowinfo(ntohl(addr_in6->sin6_flowinfo)); + response_in6->mutable_addr()->assign( + reinterpret_cast<const char *>(&addr_in6->sin6_addr.s6_addr), 16); + // sin6_scope_id is stored in host byte order. + // + // https://www.gnu.org/software/libc/manual/html_node/Internet-Address-Formats.html + response_in6->set_scope_id(addr_in6->sin6_scope_id); + return ::grpc::Status::OK; + } + } + return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Unknown Sockaddr"); +} + +::grpc::Status proto_to_sockaddr(const posix_server::Sockaddr &sockaddr_proto, + sockaddr_storage *addr, socklen_t *addr_len) { + switch (sockaddr_proto.sockaddr_case()) { + case posix_server::Sockaddr::SockaddrCase::kIn: { + auto proto_in = sockaddr_proto.in(); + if (proto_in.addr().size() != 4) { + return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "IPv4 address must be 4 bytes"); + } + auto addr_in = reinterpret_cast<sockaddr_in *>(addr); + addr_in->sin_family = proto_in.family(); + addr_in->sin_port = htons(proto_in.port()); + proto_in.addr().copy(reinterpret_cast<char *>(&addr_in->sin_addr.s_addr), + 4); + *addr_len = sizeof(*addr_in); + break; + } + case posix_server::Sockaddr::SockaddrCase::kIn6: { + auto proto_in6 = sockaddr_proto.in6(); + if (proto_in6.addr().size() != 16) { + return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "IPv6 address must be 16 bytes"); + } + auto addr_in6 = reinterpret_cast<sockaddr_in6 *>(addr); + addr_in6->sin6_family = proto_in6.family(); + addr_in6->sin6_port = htons(proto_in6.port()); + addr_in6->sin6_flowinfo = htonl(proto_in6.flowinfo()); + proto_in6.addr().copy( + reinterpret_cast<char *>(&addr_in6->sin6_addr.s6_addr), 16); + // sin6_scope_id is stored in host byte order. + // + // https://www.gnu.org/software/libc/manual/html_node/Internet-Address-Formats.html + addr_in6->sin6_scope_id = proto_in6.scope_id(); + *addr_len = sizeof(*addr_in6); + break; + } + case posix_server::Sockaddr::SockaddrCase::SOCKADDR_NOT_SET: + default: + return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "Unknown Sockaddr"); + } + return ::grpc::Status::OK; +} + +class PosixImpl final : public posix_server::Posix::Service { + ::grpc::Status Accept(grpc_impl::ServerContext *context, + const ::posix_server::AcceptRequest *request, + ::posix_server::AcceptResponse *response) override { + sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + response->set_fd(accept(request->sockfd(), + reinterpret_cast<sockaddr *>(&addr), &addrlen)); + response->set_errno_(errno); + return sockaddr_to_proto(addr, addrlen, response->mutable_addr()); + } + + ::grpc::Status Bind(grpc_impl::ServerContext *context, + const ::posix_server::BindRequest *request, + ::posix_server::BindResponse *response) override { + if (!request->has_addr()) { + return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "Missing address"); + } + + sockaddr_storage addr; + socklen_t addr_len; + auto err = proto_to_sockaddr(request->addr(), &addr, &addr_len); + if (!err.ok()) { + return err; + } + + response->set_ret( + bind(request->sockfd(), reinterpret_cast<sockaddr *>(&addr), addr_len)); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status Close(grpc_impl::ServerContext *context, + const ::posix_server::CloseRequest *request, + ::posix_server::CloseResponse *response) override { + response->set_ret(close(request->fd())); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status Connect(grpc_impl::ServerContext *context, + const ::posix_server::ConnectRequest *request, + ::posix_server::ConnectResponse *response) override { + if (!request->has_addr()) { + return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "Missing address"); + } + sockaddr_storage addr; + socklen_t addr_len; + auto err = proto_to_sockaddr(request->addr(), &addr, &addr_len); + if (!err.ok()) { + return err; + } + + response->set_ret(connect(request->sockfd(), + reinterpret_cast<sockaddr *>(&addr), addr_len)); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status Fcntl(grpc_impl::ServerContext *context, + const ::posix_server::FcntlRequest *request, + ::posix_server::FcntlResponse *response) override { + response->set_ret(::fcntl(request->fd(), request->cmd(), request->arg())); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status GetSockName( + grpc_impl::ServerContext *context, + const ::posix_server::GetSockNameRequest *request, + ::posix_server::GetSockNameResponse *response) override { + sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + response->set_ret(getsockname( + request->sockfd(), reinterpret_cast<sockaddr *>(&addr), &addrlen)); + response->set_errno_(errno); + return sockaddr_to_proto(addr, addrlen, response->mutable_addr()); + } + + ::grpc::Status GetSockOpt( + grpc_impl::ServerContext *context, + const ::posix_server::GetSockOptRequest *request, + ::posix_server::GetSockOptResponse *response) override { + switch (request->type()) { + case ::posix_server::GetSockOptRequest::BYTES: { + socklen_t optlen = request->optlen(); + std::vector<char> buf(optlen); + response->set_ret(::getsockopt(request->sockfd(), request->level(), + request->optname(), buf.data(), + &optlen)); + if (optlen >= 0) { + response->mutable_optval()->set_bytesval(buf.data(), optlen); + } + break; + } + case ::posix_server::GetSockOptRequest::INT: { + int intval = 0; + socklen_t optlen = sizeof(intval); + response->set_ret(::getsockopt(request->sockfd(), request->level(), + request->optname(), &intval, &optlen)); + response->mutable_optval()->set_intval(intval); + break; + } + case ::posix_server::GetSockOptRequest::TIME: { + timeval tv; + socklen_t optlen = sizeof(tv); + response->set_ret(::getsockopt(request->sockfd(), request->level(), + request->optname(), &tv, &optlen)); + response->mutable_optval()->mutable_timeval()->set_seconds(tv.tv_sec); + response->mutable_optval()->mutable_timeval()->set_microseconds( + tv.tv_usec); + break; + } + default: + return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "Unknown SockOpt Type"); + } + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status Listen(grpc_impl::ServerContext *context, + const ::posix_server::ListenRequest *request, + ::posix_server::ListenResponse *response) override { + response->set_ret(listen(request->sockfd(), request->backlog())); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status Send(::grpc::ServerContext *context, + const ::posix_server::SendRequest *request, + ::posix_server::SendResponse *response) override { + response->set_ret(::send(request->sockfd(), request->buf().data(), + request->buf().size(), request->flags())); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status SendTo(::grpc::ServerContext *context, + const ::posix_server::SendToRequest *request, + ::posix_server::SendToResponse *response) override { + if (!request->has_dest_addr()) { + return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "Missing address"); + } + sockaddr_storage addr; + socklen_t addr_len; + auto err = proto_to_sockaddr(request->dest_addr(), &addr, &addr_len); + if (!err.ok()) { + return err; + } + + response->set_ret(::sendto(request->sockfd(), request->buf().data(), + request->buf().size(), request->flags(), + reinterpret_cast<sockaddr *>(&addr), addr_len)); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status SetSockOpt( + grpc_impl::ServerContext *context, + const ::posix_server::SetSockOptRequest *request, + ::posix_server::SetSockOptResponse *response) override { + switch (request->optval().val_case()) { + case ::posix_server::SockOptVal::kBytesval: + response->set_ret(setsockopt(request->sockfd(), request->level(), + request->optname(), + request->optval().bytesval().c_str(), + request->optval().bytesval().size())); + break; + case ::posix_server::SockOptVal::kIntval: { + int opt = request->optval().intval(); + response->set_ret(::setsockopt(request->sockfd(), request->level(), + request->optname(), &opt, sizeof(opt))); + break; + } + case ::posix_server::SockOptVal::kTimeval: { + timeval tv = {.tv_sec = static_cast<__time_t>( + request->optval().timeval().seconds()), + .tv_usec = static_cast<__suseconds_t>( + request->optval().timeval().microseconds())}; + response->set_ret(setsockopt(request->sockfd(), request->level(), + request->optname(), &tv, sizeof(tv))); + break; + } + default: + return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "Unknown SockOpt Type"); + } + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status Socket(grpc_impl::ServerContext *context, + const ::posix_server::SocketRequest *request, + ::posix_server::SocketResponse *response) override { + response->set_fd( + socket(request->domain(), request->type(), request->protocol())); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status Recv(::grpc::ServerContext *context, + const ::posix_server::RecvRequest *request, + ::posix_server::RecvResponse *response) override { + std::vector<char> buf(request->len()); + response->set_ret( + recv(request->sockfd(), buf.data(), buf.size(), request->flags())); + if (response->ret() >= 0) { + response->set_buf(buf.data(), response->ret()); + } + response->set_errno_(errno); + return ::grpc::Status::OK; + } +}; + +// Parse command line options. Returns a pointer to the first argument beyond +// the options. +void parse_command_line_options(int argc, char *argv[], std::string *ip, + int *port) { + static struct option options[] = {{"ip", required_argument, NULL, 1}, + {"port", required_argument, NULL, 2}, + {0, 0, 0, 0}}; + + // Parse the arguments. + int c; + while ((c = getopt_long(argc, argv, "", options, NULL)) > 0) { + if (c == 1) { + *ip = optarg; + } else if (c == 2) { + *port = std::stoi(std::string(optarg)); + } + } +} + +void run_server(const std::string &ip, int port) { + PosixImpl posix_service; + grpc::ServerBuilder builder; + std::string server_address = ip + ":" + std::to_string(port); + // Set the authentication mechanism. + std::shared_ptr<grpc::ServerCredentials> creds = + grpc::InsecureServerCredentials(); + builder.AddListeningPort(server_address, creds); + builder.RegisterService(&posix_service); + + std::unique_ptr<grpc::Server> server(builder.BuildAndStart()); + std::cerr << "Server listening on " << server_address << std::endl; + server->Wait(); + std::cerr << "posix_server is finished." << std::endl; +} + +int main(int argc, char *argv[]) { + std::cerr << "posix_server is starting." << std::endl; + std::string ip; + int port; + parse_command_line_options(argc, argv, &ip, &port); + + std::cerr << "Got IP " << ip << " and port " << port << "." << std::endl; + run_server(ip, port); +} diff --git a/test/packetimpact/netdevs/BUILD b/test/packetimpact/netdevs/BUILD new file mode 100644 index 000000000..8d1193fed --- /dev/null +++ b/test/packetimpact/netdevs/BUILD @@ -0,0 +1,23 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package( + licenses = ["notice"], +) + +go_library( + name = "netdevs", + srcs = ["netdevs.go"], + visibility = ["//test/packetimpact:__subpackages__"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + ], +) + +go_test( + name = "netdevs_test", + size = "small", + srcs = ["netdevs_test.go"], + library = ":netdevs", + deps = ["@com_github_google_go_cmp//cmp:go_default_library"], +) diff --git a/test/packetimpact/netdevs/netdevs.go b/test/packetimpact/netdevs/netdevs.go new file mode 100644 index 000000000..eecfe0730 --- /dev/null +++ b/test/packetimpact/netdevs/netdevs.go @@ -0,0 +1,115 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package netdevs contains utilities for working with network devices. +package netdevs + +import ( + "fmt" + "net" + "regexp" + "strconv" + "strings" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +// A DeviceInfo represents a network device. +type DeviceInfo struct { + ID uint32 + MAC net.HardwareAddr + IPv4Addr net.IP + IPv4Net *net.IPNet + IPv6Addr net.IP + IPv6Net *net.IPNet +} + +var ( + deviceLine = regexp.MustCompile(`^\s*(\d+): (\w+)`) + linkLine = regexp.MustCompile(`^\s*link/\w+ ([0-9a-fA-F:]+)`) + inetLine = regexp.MustCompile(`^\s*inet ([0-9./]+)`) + inet6Line = regexp.MustCompile(`^\s*inet6 ([0-9a-fA-Z:/]+)`) +) + +// ParseDevices parses the output from `ip addr show` into a map from device +// name to information about the device. +// +// Note: if multiple IPv6 addresses are assigned to a device, the last address +// displayed by `ip addr show` will be used. This is fine for packetimpact +// because we will always only have at most one IPv6 address assigned to each +// device. +func ParseDevices(cmdOutput string) (map[string]DeviceInfo, error) { + var currentDevice string + var currentInfo DeviceInfo + deviceInfos := make(map[string]DeviceInfo) + for _, line := range strings.Split(cmdOutput, "\n") { + if m := deviceLine.FindStringSubmatch(line); m != nil { + if currentDevice != "" { + deviceInfos[currentDevice] = currentInfo + } + id, err := strconv.ParseUint(m[1], 10, 32) + if err != nil { + return nil, fmt.Errorf("parsing device ID %s: %w", m[1], err) + } + currentInfo = DeviceInfo{ID: uint32(id)} + currentDevice = m[2] + } else if m := linkLine.FindStringSubmatch(line); m != nil { + mac, err := net.ParseMAC(m[1]) + if err != nil { + return nil, err + } + currentInfo.MAC = mac + } else if m := inetLine.FindStringSubmatch(line); m != nil { + ipv4Addr, ipv4Net, err := net.ParseCIDR(m[1]) + if err != nil { + return nil, err + } + currentInfo.IPv4Addr = ipv4Addr + currentInfo.IPv4Net = ipv4Net + } else if m := inet6Line.FindStringSubmatch(line); m != nil { + ipv6Addr, ipv6Net, err := net.ParseCIDR(m[1]) + if err != nil { + return nil, err + } + currentInfo.IPv6Addr = ipv6Addr + currentInfo.IPv6Net = ipv6Net + } + } + if currentDevice != "" { + deviceInfos[currentDevice] = currentInfo + } + return deviceInfos, nil +} + +// MACToIP converts the MAC address to an IPv6 link local address as described +// in RFC 4291 page 20: https://tools.ietf.org/html/rfc4291#page-20 +func MACToIP(mac net.HardwareAddr) net.IP { + addr := make([]byte, header.IPv6AddressSize) + addr[0] = 0xfe + addr[1] = 0x80 + header.EthernetAdddressToModifiedEUI64IntoBuf(tcpip.LinkAddress(mac), addr[8:]) + return net.IP(addr) +} + +// FindDeviceByIP finds a DeviceInfo and device name from an IP address in the +// output of ParseDevices. +func FindDeviceByIP(ip net.IP, devices map[string]DeviceInfo) (string, DeviceInfo, error) { + for dev, info := range devices { + if info.IPv4Addr.Equal(ip) { + return dev, info, nil + } + } + return "", DeviceInfo{}, fmt.Errorf("can't find %s on any interface", ip) +} diff --git a/test/packetimpact/netdevs/netdevs_test.go b/test/packetimpact/netdevs/netdevs_test.go new file mode 100644 index 000000000..24ad12198 --- /dev/null +++ b/test/packetimpact/netdevs/netdevs_test.go @@ -0,0 +1,227 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package netdevs + +import ( + "fmt" + "net" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func mustParseMAC(s string) net.HardwareAddr { + mac, err := net.ParseMAC(s) + if err != nil { + panic(fmt.Sprintf("failed to parse test MAC %q: %s", s, err)) + } + return mac +} + +func TestParseDevices(t *testing.T) { + for _, v := range []struct { + desc string + cmdOutput string + want map[string]DeviceInfo + }{ + { + desc: "v4 and v6", + cmdOutput: ` +1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue state UNKNOWN group default qlen 1000 + link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00 + inet 127.0.0.1/8 scope host lo + valid_lft forever preferred_lft forever + inet6 ::1/128 scope host + valid_lft forever preferred_lft forever +2613: eth0@if2614: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:c0:a8:09:02 brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet 192.168.9.2/24 brd 192.168.9.255 scope global eth0 + valid_lft forever preferred_lft forever + inet6 fe80::42:c0ff:fea8:902/64 scope link tentative + valid_lft forever preferred_lft forever +2615: eth2@if2616: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:df:f5:e1:0a brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet 223.245.225.10/24 brd 223.245.225.255 scope global eth2 + valid_lft forever preferred_lft forever + inet6 fe80::42:dfff:fef5:e10a/64 scope link tentative + valid_lft forever preferred_lft forever +2617: eth1@if2618: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:da:33:13:0a brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet 218.51.19.10/24 brd 218.51.19.255 scope global eth1 + valid_lft forever preferred_lft forever + inet6 fe80::42:daff:fe33:130a/64 scope link tentative + valid_lft forever preferred_lft forever`, + want: map[string]DeviceInfo{ + "lo": DeviceInfo{ + ID: 1, + MAC: mustParseMAC("00:00:00:00:00:00"), + IPv4Addr: net.IPv4(127, 0, 0, 1), + IPv4Net: &net.IPNet{ + IP: net.IPv4(127, 0, 0, 0), + Mask: net.CIDRMask(8, 32), + }, + IPv6Addr: net.ParseIP("::1"), + IPv6Net: &net.IPNet{ + IP: net.ParseIP("::1"), + Mask: net.CIDRMask(128, 128), + }, + }, + "eth0": DeviceInfo{ + ID: 2613, + MAC: mustParseMAC("02:42:c0:a8:09:02"), + IPv4Addr: net.IPv4(192, 168, 9, 2), + IPv4Net: &net.IPNet{ + IP: net.IPv4(192, 168, 9, 0), + Mask: net.CIDRMask(24, 32), + }, + IPv6Addr: net.ParseIP("fe80::42:c0ff:fea8:902"), + IPv6Net: &net.IPNet{ + IP: net.ParseIP("fe80::"), + Mask: net.CIDRMask(64, 128), + }, + }, + "eth1": DeviceInfo{ + ID: 2617, + MAC: mustParseMAC("02:42:da:33:13:0a"), + IPv4Addr: net.IPv4(218, 51, 19, 10), + IPv4Net: &net.IPNet{ + IP: net.IPv4(218, 51, 19, 0), + Mask: net.CIDRMask(24, 32), + }, + IPv6Addr: net.ParseIP("fe80::42:daff:fe33:130a"), + IPv6Net: &net.IPNet{ + IP: net.ParseIP("fe80::"), + Mask: net.CIDRMask(64, 128), + }, + }, + "eth2": DeviceInfo{ + ID: 2615, + MAC: mustParseMAC("02:42:df:f5:e1:0a"), + IPv4Addr: net.IPv4(223, 245, 225, 10), + IPv4Net: &net.IPNet{ + IP: net.IPv4(223, 245, 225, 0), + Mask: net.CIDRMask(24, 32), + }, + IPv6Addr: net.ParseIP("fe80::42:dfff:fef5:e10a"), + IPv6Net: &net.IPNet{ + IP: net.ParseIP("fe80::"), + Mask: net.CIDRMask(64, 128), + }, + }, + }, + }, + { + desc: "v4 only", + cmdOutput: ` +2613: eth0@if2614: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:c0:a8:09:02 brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet 192.168.9.2/24 brd 192.168.9.255 scope global eth0 + valid_lft forever preferred_lft forever`, + want: map[string]DeviceInfo{ + "eth0": DeviceInfo{ + ID: 2613, + MAC: mustParseMAC("02:42:c0:a8:09:02"), + IPv4Addr: net.IPv4(192, 168, 9, 2), + IPv4Net: &net.IPNet{ + IP: net.IPv4(192, 168, 9, 0), + Mask: net.CIDRMask(24, 32), + }, + }, + }, + }, + { + desc: "v6 only", + cmdOutput: ` +2615: eth2@if2616: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:df:f5:e1:0a brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet6 fe80::42:dfff:fef5:e10a/64 scope link tentative + valid_lft forever preferred_lft forever`, + want: map[string]DeviceInfo{ + "eth2": DeviceInfo{ + ID: 2615, + MAC: mustParseMAC("02:42:df:f5:e1:0a"), + IPv6Addr: net.ParseIP("fe80::42:dfff:fef5:e10a"), + IPv6Net: &net.IPNet{ + IP: net.ParseIP("fe80::"), + Mask: net.CIDRMask(64, 128), + }, + }, + }, + }, + } { + t.Run(v.desc, func(t *testing.T) { + got, err := ParseDevices(v.cmdOutput) + if err != nil { + t.Errorf("ParseDevices(\n%s\n) got unexpected error: %s", v.cmdOutput, err) + } + if diff := cmp.Diff(v.want, got); diff != "" { + t.Errorf("ParseDevices(\n%s\n) got output diff (-want, +got):\n%s", v.cmdOutput, diff) + } + }) + } +} + +func TestParseDevicesErrors(t *testing.T) { + for _, v := range []struct { + desc string + cmdOutput string + }{ + { + desc: "invalid MAC addr", + cmdOutput: ` +2617: eth1@if2618: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:da:33:13:0a:ffffffff brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet 218.51.19.10/24 brd 218.51.19.255 scope global eth1 + valid_lft forever preferred_lft forever + inet6 fe80::42:daff:fe33:130a/64 scope link tentative + valid_lft forever preferred_lft forever`, + }, + { + desc: "invalid v4 addr", + cmdOutput: ` +2617: eth1@if2618: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:da:33:13:0a brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet 1234.4321.424242.0/24 brd 218.51.19.255 scope global eth1 + valid_lft forever preferred_lft forever + inet6 fe80::42:daff:fe33:130a/64 scope link tentative + valid_lft forever preferred_lft forever`, + }, + { + desc: "invalid v6 addr", + cmdOutput: ` +2617: eth1@if2618: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:da:33:13:0a brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet 218.51.19.10/24 brd 218.51.19.255 scope global eth1 + valid_lft forever preferred_lft forever + inet6 fe80:ffffffff::42:daff:fe33:130a/64 scope link tentative + valid_lft forever preferred_lft forever`, + }, + { + desc: "invalid CIDR missing prefixlen", + cmdOutput: ` +2617: eth1@if2618: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:da:33:13:0a brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet 218.51.19.10 brd 218.51.19.255 scope global eth1 + valid_lft forever preferred_lft forever + inet6 fe80::42:daff:fe33:130a scope link tentative + valid_lft forever preferred_lft forever`, + }, + } { + t.Run(v.desc, func(t *testing.T) { + if _, err := ParseDevices(v.cmdOutput); err == nil { + t.Errorf("ParseDevices(\n%s\n) succeeded unexpectedly, want error", v.cmdOutput) + } + }) + } +} diff --git a/test/packetimpact/proto/BUILD b/test/packetimpact/proto/BUILD new file mode 100644 index 000000000..4a4370f42 --- /dev/null +++ b/test/packetimpact/proto/BUILD @@ -0,0 +1,12 @@ +load("//tools:defs.bzl", "proto_library") + +package( + default_visibility = ["//test/packetimpact:__subpackages__"], + licenses = ["notice"], +) + +proto_library( + name = "posix_server", + srcs = ["posix_server.proto"], + has_services = 1, +) diff --git a/test/packetimpact/proto/posix_server.proto b/test/packetimpact/proto/posix_server.proto new file mode 100644 index 000000000..ccd20b10d --- /dev/null +++ b/test/packetimpact/proto/posix_server.proto @@ -0,0 +1,230 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package posix_server; + +message SockaddrIn { + int32 family = 1; + uint32 port = 2; + bytes addr = 3; +} + +message SockaddrIn6 { + uint32 family = 1; + uint32 port = 2; + uint32 flowinfo = 3; + bytes addr = 4; + uint32 scope_id = 5; +} + +message Sockaddr { + oneof sockaddr { + SockaddrIn in = 1; + SockaddrIn6 in6 = 2; + } +} + +message Timeval { + int64 seconds = 1; + int64 microseconds = 2; +} + +message SockOptVal { + oneof val { + bytes bytesval = 1; + int32 intval = 2; + Timeval timeval = 3; + } +} + +// Request and Response pairs for each Posix service RPC call, sorted. + +message AcceptRequest { + int32 sockfd = 1; +} + +message AcceptResponse { + int32 fd = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. + Sockaddr addr = 3; +} + +message BindRequest { + int32 sockfd = 1; + Sockaddr addr = 2; +} + +message BindResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. +} + +message CloseRequest { + int32 fd = 1; +} + +message CloseResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. +} + +message ConnectRequest { + int32 sockfd = 1; + Sockaddr addr = 2; +} + +message ConnectResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. +} + +message FcntlRequest { + int32 fd = 1; + int32 cmd = 2; + int32 arg = 3; +} + +message FcntlResponse { + int32 ret = 1; + int32 errno_ = 2; +} + +message GetSockNameRequest { + int32 sockfd = 1; +} + +message GetSockNameResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. + Sockaddr addr = 3; +} + +message GetSockOptRequest { + int32 sockfd = 1; + int32 level = 2; + int32 optname = 3; + int32 optlen = 4; + enum SockOptType { + UNSPECIFIED = 0; + BYTES = 1; + INT = 2; + TIME = 3; + } + SockOptType type = 5; +} + +message GetSockOptResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. + SockOptVal optval = 3; +} + +message ListenRequest { + int32 sockfd = 1; + int32 backlog = 2; +} + +message ListenResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. +} + +message SendRequest { + int32 sockfd = 1; + bytes buf = 2; + int32 flags = 3; +} + +message SendResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. +} + +message SendToRequest { + int32 sockfd = 1; + bytes buf = 2; + int32 flags = 3; + Sockaddr dest_addr = 4; +} + +message SendToResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. +} + +message SetSockOptRequest { + int32 sockfd = 1; + int32 level = 2; + int32 optname = 3; + SockOptVal optval = 4; +} + +message SetSockOptResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. +} + +message SocketRequest { + int32 domain = 1; + int32 type = 2; + int32 protocol = 3; +} + +message SocketResponse { + int32 fd = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. +} + +message RecvRequest { + int32 sockfd = 1; + int32 len = 2; + int32 flags = 3; +} + +message RecvResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. + bytes buf = 3; +} + +service Posix { + // Call accept() on the DUT. + rpc Accept(AcceptRequest) returns (AcceptResponse); + // Call bind() on the DUT. + rpc Bind(BindRequest) returns (BindResponse); + // Call close() on the DUT. + rpc Close(CloseRequest) returns (CloseResponse); + // Call connect() on the DUT. + rpc Connect(ConnectRequest) returns (ConnectResponse); + // Call fcntl() on the DUT. + rpc Fcntl(FcntlRequest) returns (FcntlResponse); + // Call getsockname() on the DUT. + rpc GetSockName(GetSockNameRequest) returns (GetSockNameResponse); + // Call getsockopt() on the DUT. + rpc GetSockOpt(GetSockOptRequest) returns (GetSockOptResponse); + // Call listen() on the DUT. + rpc Listen(ListenRequest) returns (ListenResponse); + // Call send() on the DUT. + rpc Send(SendRequest) returns (SendResponse); + // Call sendto() on the DUT. + rpc SendTo(SendToRequest) returns (SendToResponse); + // Call setsockopt() on the DUT. + rpc SetSockOpt(SetSockOptRequest) returns (SetSockOptResponse); + // Call socket() on the DUT. + rpc Socket(SocketRequest) returns (SocketResponse); + // Call recv() on the DUT. + rpc Recv(RecvRequest) returns (RecvResponse); +} diff --git a/test/packetimpact/runner/BUILD b/test/packetimpact/runner/BUILD new file mode 100644 index 000000000..ff2be9b30 --- /dev/null +++ b/test/packetimpact/runner/BUILD @@ -0,0 +1,27 @@ +load("//tools:defs.bzl", "bzl_library", "go_test") + +package( + default_visibility = ["//test/packetimpact:__subpackages__"], + licenses = ["notice"], +) + +go_test( + name = "packetimpact_test", + srcs = ["packetimpact_test.go"], + tags = [ + # Not intended to be run directly. + "local", + "manual", + ], + deps = [ + "//pkg/test/dockerutil", + "//test/packetimpact/netdevs", + "@com_github_docker_docker//api/types/mount:go_default_library", + ], +) + +bzl_library( + name = "defs_bzl", + srcs = ["defs.bzl"], + visibility = ["//visibility:private"], +) diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl new file mode 100644 index 000000000..93a36c6c2 --- /dev/null +++ b/test/packetimpact/runner/defs.bzl @@ -0,0 +1,143 @@ +"""Defines rules for packetimpact test targets.""" + +load("//tools:defs.bzl", "go_test") + +def _packetimpact_test_impl(ctx): + test_runner = ctx.executable._test_runner + bench = ctx.actions.declare_file("%s-bench" % ctx.label.name) + bench_content = "\n".join([ + "#!/bin/bash", + # This test will run part in a distinct user namespace. This can cause + # permission problems, because all runfiles may not be owned by the + # current user, and no other users will be mapped in that namespace. + # Make sure that everything is readable here. + "find . -type f -or -type d -exec chmod a+rx {} \\;", + "%s %s --testbench_binary %s $@\n" % ( + test_runner.short_path, + " ".join(ctx.attr.flags), + ctx.files.testbench_binary[0].short_path, + ), + ]) + ctx.actions.write(bench, bench_content, is_executable = True) + + transitive_files = [] + if hasattr(ctx.attr._test_runner, "data_runfiles"): + transitive_files.append(ctx.attr._test_runner.data_runfiles.files) + runfiles = ctx.runfiles( + files = [test_runner] + ctx.files.testbench_binary + ctx.files._posix_server_binary, + transitive_files = depset(transitive = transitive_files), + collect_default = True, + collect_data = True, + ) + return [DefaultInfo(executable = bench, runfiles = runfiles)] + +_packetimpact_test = rule( + attrs = { + "_test_runner": attr.label( + executable = True, + cfg = "target", + default = ":packetimpact_test", + ), + "_posix_server_binary": attr.label( + cfg = "target", + default = "//test/packetimpact/dut:posix_server", + ), + "testbench_binary": attr.label( + cfg = "target", + mandatory = True, + ), + "flags": attr.string_list( + mandatory = False, + default = [], + ), + }, + test = True, + implementation = _packetimpact_test_impl, +) + +PACKETIMPACT_TAGS = [ + "local", + "manual", + "packetimpact", +] + +def packetimpact_native_test( + name, + testbench_binary, + expect_failure = False, + **kwargs): + """Add a native packetimpact test. + + Args: + name: name of the test + testbench_binary: the testbench binary + expect_failure: the test must fail + **kwargs: all the other args, forwarded to _packetimpact_test + """ + expect_failure_flag = ["--expect_failure"] if expect_failure else [] + _packetimpact_test( + name = name + "_native_test", + testbench_binary = testbench_binary, + flags = ["--native"] + expect_failure_flag, + tags = PACKETIMPACT_TAGS, + **kwargs + ) + +def packetimpact_netstack_test( + name, + testbench_binary, + expect_failure = False, + **kwargs): + """Add a packetimpact test on netstack. + + Args: + name: name of the test + testbench_binary: the testbench binary + expect_failure: the test must fail + **kwargs: all the other args, forwarded to _packetimpact_test + """ + expect_failure_flag = [] + if expect_failure: + expect_failure_flag = ["--expect_failure"] + _packetimpact_test( + name = name + "_netstack_test", + testbench_binary = testbench_binary, + # Note that a distinct runtime must be provided in the form + # --test_arg=--runtime=other when invoking bazel. + flags = expect_failure_flag, + tags = PACKETIMPACT_TAGS, + **kwargs + ) + +def packetimpact_go_test(name, size = "small", pure = True, expect_native_failure = False, expect_netstack_failure = False, **kwargs): + """Add packetimpact tests written in go. + + Args: + name: name of the test + size: size of the test + pure: make a static go binary + expect_native_failure: the test must fail natively + expect_netstack_failure: the test must fail for Netstack + **kwargs: all the other args, forwarded to go_test + """ + testbench_binary = name + "_test" + go_test( + name = testbench_binary, + size = size, + pure = pure, + tags = [ + "local", + "manual", + ], + **kwargs + ) + packetimpact_native_test( + name = name, + expect_failure = expect_native_failure, + testbench_binary = testbench_binary, + ) + packetimpact_netstack_test( + name = name, + expect_failure = expect_netstack_failure, + testbench_binary = testbench_binary, + ) diff --git a/test/packetimpact/runner/packetimpact_test.go b/test/packetimpact/runner/packetimpact_test.go new file mode 100644 index 000000000..e8c183977 --- /dev/null +++ b/test/packetimpact/runner/packetimpact_test.go @@ -0,0 +1,383 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The runner starts docker containers and networking for a packetimpact test. +package packetimpact_test + +import ( + "context" + "flag" + "fmt" + "io/ioutil" + "log" + "math/rand" + "net" + "os" + "os/exec" + "path" + "strings" + "testing" + "time" + + "github.com/docker/docker/api/types/mount" + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/packetimpact/netdevs" +) + +// stringList implements flag.Value. +type stringList []string + +// String implements flag.Value.String. +func (l *stringList) String() string { + return strings.Join(*l, ",") +} + +// Set implements flag.Value.Set. +func (l *stringList) Set(value string) error { + *l = append(*l, value) + return nil +} + +var ( + native = flag.Bool("native", false, "whether the test should be run natively") + testbenchBinary = flag.String("testbench_binary", "", "path to the testbench binary") + tshark = flag.Bool("tshark", false, "use more verbose tshark in logs instead of tcpdump") + extraTestArgs = stringList{} + expectFailure = flag.Bool("expect_failure", false, "expect that the test will fail when run") + + dutAddr = net.IPv4(0, 0, 0, 10) + testbenchAddr = net.IPv4(0, 0, 0, 20) +) + +const ctrlPort = "40000" + +// logger implements testutil.Logger. +// +// Labels logs based on their source and formats multi-line logs. +type logger string + +// Name implements testutil.Logger.Name. +func (l logger) Name() string { + return string(l) +} + +// Logf implements testutil.Logger.Logf. +func (l logger) Logf(format string, args ...interface{}) { + lines := strings.Split(fmt.Sprintf(format, args...), "\n") + log.Printf("%s: %s", l, lines[0]) + for _, line := range lines[1:] { + log.Printf("%*s %s", len(l), "", line) + } +} + +func TestOne(t *testing.T) { + flag.Var(&extraTestArgs, "extra_test_arg", "extra arguments to pass to the testbench") + flag.Parse() + if *testbenchBinary == "" { + t.Fatal("--testbench_binary is missing") + } + dockerutil.EnsureSupportedDockerVersion() + ctx := context.Background() + + // Create the networks needed for the test. One control network is needed for + // the gRPC control packets and one test network on which to transmit the test + // packets. + ctrlNet := dockerutil.NewNetwork(ctx, logger("ctrlNet")) + testNet := dockerutil.NewNetwork(ctx, logger("testNet")) + for _, dn := range []*dockerutil.Network{ctrlNet, testNet} { + for { + if err := createDockerNetwork(ctx, dn); err != nil { + t.Log("creating docker network:", err) + const wait = 100 * time.Millisecond + t.Logf("sleeping %s and will try creating docker network again", wait) + // This can fail if another docker network claimed the same IP so we'll + // just try again. + time.Sleep(wait) + continue + } + break + } + defer func(dn *dockerutil.Network) { + if err := dn.Cleanup(ctx); err != nil { + t.Errorf("unable to cleanup container %s: %s", dn.Name, err) + } + }(dn) + // Sanity check. + inspect, err := dn.Inspect(ctx) + if err != nil { + t.Fatalf("failed to inspect network %s: %v", dn.Name, err) + } else if inspect.Name != dn.Name { + t.Fatalf("name mismatch for network want: %s got: %s", dn.Name, inspect.Name) + } + + } + + tmpDir, err := ioutil.TempDir("", "container-output") + if err != nil { + t.Fatal("creating temp dir:", err) + } + defer os.RemoveAll(tmpDir) + + const testOutputDir = "/tmp/testoutput" + + // Create the Docker container for the DUT. + var dut *dockerutil.Container + if *native { + dut = dockerutil.MakeNativeContainer(ctx, logger("dut")) + } else { + dut = dockerutil.MakeContainer(ctx, logger("dut")) + } + + runOpts := dockerutil.RunOpts{ + Image: "packetimpact", + CapAdd: []string{"NET_ADMIN"}, + Mounts: []mount.Mount{mount.Mount{ + Type: mount.TypeBind, + Source: tmpDir, + Target: testOutputDir, + ReadOnly: false, + }}, + } + + const containerPosixServerBinary = "/packetimpact/posix_server" + dut.CopyFiles(&runOpts, "/packetimpact", "/test/packetimpact/dut/posix_server") + + conf, hostconf, _ := dut.ConfigsFrom(runOpts, containerPosixServerBinary, "--ip=0.0.0.0", "--port="+ctrlPort) + hostconf.AutoRemove = true + hostconf.Sysctls = map[string]string{"net.ipv6.conf.all.disable_ipv6": "0"} + + if err := dut.CreateFrom(ctx, conf, hostconf, nil); err != nil { + t.Fatalf("unable to create container %s: %v", dut.Name, err) + } + + defer dut.CleanUp(ctx) + + // Add ctrlNet as eth1 and testNet as eth2. + const testNetDev = "eth2" + if err := addNetworks(ctx, dut, dutAddr, []*dockerutil.Network{ctrlNet, testNet}); err != nil { + t.Fatal(err) + } + + if err := dut.Start(ctx); err != nil { + t.Fatalf("unable to start container %s: %s", dut.Name, err) + } + + if _, err := dut.WaitForOutput(ctx, "Server listening.*\n", 60*time.Second); err != nil { + t.Fatalf("%s on container %s never listened: %s", containerPosixServerBinary, dut.Name, err) + } + + dutTestDevice, dutDeviceInfo, err := deviceByIP(ctx, dut, addressInSubnet(dutAddr, *testNet.Subnet)) + if err != nil { + t.Fatal(err) + } + + remoteMAC := dutDeviceInfo.MAC + remoteIPv6 := dutDeviceInfo.IPv6Addr + // Netstack as DUT doesn't assign IPv6 addresses automatically so do it if + // needed. + if remoteIPv6 == nil { + if _, err := dut.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "add", netdevs.MACToIP(remoteMAC).String(), "scope", "link", "dev", dutTestDevice); err != nil { + t.Fatalf("unable to ip addr add on container %s: %s", dut.Name, err) + } + // Now try again, to make sure that it worked. + _, dutDeviceInfo, err = deviceByIP(ctx, dut, addressInSubnet(dutAddr, *testNet.Subnet)) + if err != nil { + t.Fatal(err) + } + remoteIPv6 = dutDeviceInfo.IPv6Addr + if remoteIPv6 == nil { + t.Fatal("unable to set IPv6 address on container", dut.Name) + } + } + + // Create the Docker container for the testbench. + testbench := dockerutil.MakeNativeContainer(ctx, logger("testbench")) + + tbb := path.Base(*testbenchBinary) + containerTestbenchBinary := "/packetimpact/" + tbb + runOpts = dockerutil.RunOpts{ + Image: "packetimpact", + CapAdd: []string{"NET_ADMIN"}, + Mounts: []mount.Mount{mount.Mount{ + Type: mount.TypeBind, + Source: tmpDir, + Target: testOutputDir, + ReadOnly: false, + }}, + } + testbench.CopyFiles(&runOpts, "/packetimpact", "/test/packetimpact/tests/"+tbb) + + // Run tcpdump in the test bench unbuffered, without DNS resolution, just on + // the interface with the test packets. + snifferArgs := []string{ + "tcpdump", + "-S", "-vvv", "-U", "-n", + "-i", testNetDev, + "-w", testOutputDir + "/dump.pcap", + } + snifferRegex := "tcpdump: listening.*\n" + if *tshark { + // Run tshark in the test bench unbuffered, without DNS resolution, just on + // the interface with the test packets. + snifferArgs = []string{ + "tshark", "-V", "-l", "-n", "-i", testNetDev, + "-o", "tcp.check_checksum:TRUE", + "-o", "udp.check_checksum:TRUE", + } + snifferRegex = "Capturing on.*\n" + } + + defer func() { + if err := exec.Command("/bin/cp", "-r", tmpDir, os.Getenv("TEST_UNDECLARED_OUTPUTS_DIR")).Run(); err != nil { + t.Error("unable to copy container output files:", err) + } + }() + + conf, hostconf, _ = testbench.ConfigsFrom(runOpts, snifferArgs...) + hostconf.AutoRemove = true + hostconf.Sysctls = map[string]string{"net.ipv6.conf.all.disable_ipv6": "0"} + + if err := testbench.CreateFrom(ctx, conf, hostconf, nil); err != nil { + t.Fatalf("unable to create container %s: %s", testbench.Name, err) + } + defer testbench.CleanUp(ctx) + + // Add ctrlNet as eth1 and testNet as eth2. + if err := addNetworks(ctx, testbench, testbenchAddr, []*dockerutil.Network{ctrlNet, testNet}); err != nil { + t.Fatal(err) + } + + if err := testbench.Start(ctx); err != nil { + t.Fatalf("unable to start container %s: %s", testbench.Name, err) + } + + // Kill so that it will flush output. + defer func() { + time.Sleep(1 * time.Second) + testbench.Exec(ctx, dockerutil.ExecOpts{}, "killall", snifferArgs[0]) + }() + + if _, err := testbench.WaitForOutput(ctx, snifferRegex, 60*time.Second); err != nil { + t.Fatalf("sniffer on %s never listened: %s", dut.Name, err) + } + + // Because the Linux kernel receives the SYN-ACK but didn't send the SYN it + // will issue an RST. To prevent this IPtables can be used to filter out all + // incoming packets. The raw socket that packetimpact tests use will still see + // everything. + for _, bin := range []string{"iptables", "ip6tables"} { + if logs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, bin, "-A", "INPUT", "-i", testNetDev, "-p", "tcp", "-j", "DROP"); err != nil { + t.Fatalf("unable to Exec %s on container %s: %s, logs from testbench:\n%s", bin, testbench.Name, err, logs) + } + } + + // FIXME(b/156449515): Some piece of the system has a race. The old + // bash script version had a sleep, so we have one too. The race should + // be fixed and this sleep removed. + time.Sleep(time.Second) + + // Start a packetimpact test on the test bench. The packetimpact test sends + // and receives packets and also sends POSIX socket commands to the + // posix_server to be executed on the DUT. + testArgs := []string{containerTestbenchBinary} + testArgs = append(testArgs, extraTestArgs...) + testArgs = append(testArgs, + "--posix_server_ip", addressInSubnet(dutAddr, *ctrlNet.Subnet).String(), + "--posix_server_port", ctrlPort, + "--remote_ipv4", addressInSubnet(dutAddr, *testNet.Subnet).String(), + "--local_ipv4", addressInSubnet(testbenchAddr, *testNet.Subnet).String(), + "--remote_ipv6", remoteIPv6.String(), + "--remote_mac", remoteMAC.String(), + "--remote_interface_id", fmt.Sprintf("%d", dutDeviceInfo.ID), + "--device", testNetDev, + fmt.Sprintf("--native=%t", *native), + ) + testbenchLogs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, testArgs...) + if (err != nil) != *expectFailure { + var dutLogs string + if logs, err := dut.Logs(ctx); err != nil { + dutLogs = fmt.Sprintf("failed to fetch DUT logs: %s", err) + } else { + dutLogs = logs + } + + t.Errorf(`test error: %v, expect failure: %t + +====== Begin of DUT Logs ====== + +%s + +====== End of DUT Logs ====== + +====== Begin of Testbench Logs ====== + +%s + +====== End of Testbench Logs ======`, + err, *expectFailure, dutLogs, testbenchLogs) + } +} + +func addNetworks(ctx context.Context, d *dockerutil.Container, addr net.IP, networks []*dockerutil.Network) error { + for _, dn := range networks { + ip := addressInSubnet(addr, *dn.Subnet) + // Connect to the network with the specified IP address. + if err := dn.Connect(ctx, d, ip.String(), ""); err != nil { + return fmt.Errorf("unable to connect container %s to network %s: %w", d.Name, dn.Name, err) + } + } + return nil +} + +// addressInSubnet combines the subnet provided with the address and returns a +// new address. The return address bits come from the subnet where the mask is 1 +// and from the ip address where the mask is 0. +func addressInSubnet(addr net.IP, subnet net.IPNet) net.IP { + var octets []byte + for i := 0; i < 4; i++ { + octets = append(octets, (subnet.IP.To4()[i]&subnet.Mask[i])+(addr.To4()[i]&(^subnet.Mask[i]))) + } + return net.IP(octets) +} + +// createDockerNetwork makes a randomly-named network that will start with the +// namePrefix. The network will be a random /24 subnet. +func createDockerNetwork(ctx context.Context, n *dockerutil.Network) error { + randSource := rand.NewSource(time.Now().UnixNano()) + r1 := rand.New(randSource) + // Class C, 192.0.0.0 to 223.255.255.255, transitionally has mask 24. + ip := net.IPv4(byte(r1.Intn(224-192)+192), byte(r1.Intn(256)), byte(r1.Intn(256)), 0) + n.Subnet = &net.IPNet{ + IP: ip, + Mask: ip.DefaultMask(), + } + return n.Create(ctx) +} + +// deviceByIP finds a deviceInfo and device name from an IP address. +func deviceByIP(ctx context.Context, d *dockerutil.Container, ip net.IP) (string, netdevs.DeviceInfo, error) { + out, err := d.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "show") + if err != nil { + return "", netdevs.DeviceInfo{}, fmt.Errorf("listing devices on %s container: %w", d.Name, err) + } + devs, err := netdevs.ParseDevices(out) + if err != nil { + return "", netdevs.DeviceInfo{}, fmt.Errorf("parsing devices from %s container: %w", d.Name, err) + } + testDevice, deviceInfo, err := netdevs.FindDeviceByIP(ip, devs) + if err != nil { + return "", netdevs.DeviceInfo{}, fmt.Errorf("can't find deviceInfo for container %s: %w", d.Name, err) + } + return testDevice, deviceInfo, nil +} diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD new file mode 100644 index 000000000..5a0ee1367 --- /dev/null +++ b/test/packetimpact/testbench/BUILD @@ -0,0 +1,46 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package( + default_visibility = ["//test/packetimpact:__subpackages__"], + licenses = ["notice"], +) + +go_library( + name = "testbench", + srcs = [ + "connections.go", + "dut.go", + "dut_client.go", + "layers.go", + "rawsockets.go", + "testbench.go", + ], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + "//pkg/usermem", + "//test/packetimpact/netdevs", + "//test/packetimpact/proto:posix_server_go_proto", + "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", + "@com_github_mohae_deepcopy//:go_default_library", + "@org_golang_google_grpc//:go_default_library", + "@org_golang_google_grpc//keepalive:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + "@org_uber_go_multierr//:go_default_library", + ], +) + +go_test( + name = "testbench_test", + size = "small", + srcs = ["layers_test.go"], + library = ":testbench", + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "@com_github_mohae_deepcopy//:go_default_library", + ], +) diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go new file mode 100644 index 000000000..3af5f83fd --- /dev/null +++ b/test/packetimpact/testbench/connections.go @@ -0,0 +1,1205 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package testbench has utilities to send and receive packets and also command +// the DUT to run POSIX functions. +package testbench + +import ( + "fmt" + "math/rand" + "net" + "testing" + "time" + + "github.com/mohae/deepcopy" + "go.uber.org/multierr" + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" +) + +func portFromSockaddr(sa unix.Sockaddr) (uint16, error) { + switch sa := sa.(type) { + case *unix.SockaddrInet4: + return uint16(sa.Port), nil + case *unix.SockaddrInet6: + return uint16(sa.Port), nil + } + return 0, fmt.Errorf("sockaddr type %T does not contain port", sa) +} + +// pickPort makes a new socket and returns the socket FD and port. The domain +// should be AF_INET or AF_INET6. The caller must close the FD when done with +// the port if there is no error. +func pickPort(domain, typ int) (fd int, port uint16, err error) { + fd, err = unix.Socket(domain, typ, 0) + if err != nil { + return -1, 0, fmt.Errorf("creating socket: %w", err) + } + defer func() { + if err != nil { + if cerr := unix.Close(fd); cerr != nil { + err = multierr.Append(err, fmt.Errorf("failed to close socket %d: %w", fd, cerr)) + } + } + }() + var sa unix.Sockaddr + switch domain { + case unix.AF_INET: + var sa4 unix.SockaddrInet4 + copy(sa4.Addr[:], net.ParseIP(LocalIPv4).To4()) + sa = &sa4 + case unix.AF_INET6: + sa6 := unix.SockaddrInet6{ZoneId: uint32(LocalInterfaceID)} + copy(sa6.Addr[:], net.ParseIP(LocalIPv6).To16()) + sa = &sa6 + default: + return -1, 0, fmt.Errorf("invalid domain %d, it should be one of unix.AF_INET or unix.AF_INET6", domain) + } + if err = unix.Bind(fd, sa); err != nil { + return -1, 0, fmt.Errorf("binding to %+v: %w", sa, err) + } + sa, err = unix.Getsockname(fd) + if err != nil { + return -1, 0, fmt.Errorf("Getsocketname(%d): %w", fd, err) + } + port, err = portFromSockaddr(sa) + if err != nil { + return -1, 0, fmt.Errorf("extracting port from socket address %+v: %w", sa, err) + } + return fd, port, nil +} + +// layerState stores the state of a layer of a connection. +type layerState interface { + // outgoing returns an outgoing layer to be sent in a frame. It should not + // update layerState, that is done in layerState.sent. + outgoing() Layer + + // incoming creates an expected Layer for comparing against a received Layer. + // Because the expectation can depend on values in the received Layer, it is + // an input to incoming. For example, the ACK number needs to be checked in a + // TCP packet but only if the ACK flag is set in the received packet. It + // should not update layerState, that is done in layerState.received. The + // caller takes ownership of the returned Layer. + incoming(received Layer) Layer + + // sent updates the layerState based on the Layer that was sent. The input is + // a Layer with all prev and next pointers populated so that the entire frame + // as it was sent is available. + sent(sent Layer) error + + // received updates the layerState based on a Layer that is receieved. The + // input is a Layer with all prev and next pointers populated so that the + // entire frame as it was receieved is available. + received(received Layer) error + + // close frees associated resources held by the LayerState. + close() error +} + +// etherState maintains state about an Ethernet connection. +type etherState struct { + out, in Ether +} + +var _ layerState = (*etherState)(nil) + +// newEtherState creates a new etherState. +func newEtherState(out, in Ether) (*etherState, error) { + lMAC, err := tcpip.ParseMACAddress(LocalMAC) + if err != nil { + return nil, fmt.Errorf("parsing local MAC: %q: %w", LocalMAC, err) + } + + rMAC, err := tcpip.ParseMACAddress(RemoteMAC) + if err != nil { + return nil, fmt.Errorf("parsing remote MAC: %q: %w", RemoteMAC, err) + } + s := etherState{ + out: Ether{SrcAddr: &lMAC, DstAddr: &rMAC}, + in: Ether{SrcAddr: &rMAC, DstAddr: &lMAC}, + } + if err := s.out.merge(&out); err != nil { + return nil, err + } + if err := s.in.merge(&in); err != nil { + return nil, err + } + return &s, nil +} + +func (s *etherState) outgoing() Layer { + return deepcopy.Copy(&s.out).(Layer) +} + +// incoming implements layerState.incoming. +func (s *etherState) incoming(Layer) Layer { + return deepcopy.Copy(&s.in).(Layer) +} + +func (*etherState) sent(Layer) error { + return nil +} + +func (*etherState) received(Layer) error { + return nil +} + +func (*etherState) close() error { + return nil +} + +// ipv4State maintains state about an IPv4 connection. +type ipv4State struct { + out, in IPv4 +} + +var _ layerState = (*ipv4State)(nil) + +// newIPv4State creates a new ipv4State. +func newIPv4State(out, in IPv4) (*ipv4State, error) { + lIP := tcpip.Address(net.ParseIP(LocalIPv4).To4()) + rIP := tcpip.Address(net.ParseIP(RemoteIPv4).To4()) + s := ipv4State{ + out: IPv4{SrcAddr: &lIP, DstAddr: &rIP}, + in: IPv4{SrcAddr: &rIP, DstAddr: &lIP}, + } + if err := s.out.merge(&out); err != nil { + return nil, err + } + if err := s.in.merge(&in); err != nil { + return nil, err + } + return &s, nil +} + +func (s *ipv4State) outgoing() Layer { + return deepcopy.Copy(&s.out).(Layer) +} + +// incoming implements layerState.incoming. +func (s *ipv4State) incoming(Layer) Layer { + return deepcopy.Copy(&s.in).(Layer) +} + +func (*ipv4State) sent(Layer) error { + return nil +} + +func (*ipv4State) received(Layer) error { + return nil +} + +func (*ipv4State) close() error { + return nil +} + +// ipv6State maintains state about an IPv6 connection. +type ipv6State struct { + out, in IPv6 +} + +var _ layerState = (*ipv6State)(nil) + +// newIPv6State creates a new ipv6State. +func newIPv6State(out, in IPv6) (*ipv6State, error) { + lIP := tcpip.Address(net.ParseIP(LocalIPv6).To16()) + rIP := tcpip.Address(net.ParseIP(RemoteIPv6).To16()) + s := ipv6State{ + out: IPv6{SrcAddr: &lIP, DstAddr: &rIP}, + in: IPv6{SrcAddr: &rIP, DstAddr: &lIP}, + } + if err := s.out.merge(&out); err != nil { + return nil, err + } + if err := s.in.merge(&in); err != nil { + return nil, err + } + return &s, nil +} + +// outgoing returns an outgoing layer to be sent in a frame. +func (s *ipv6State) outgoing() Layer { + return deepcopy.Copy(&s.out).(Layer) +} + +func (s *ipv6State) incoming(Layer) Layer { + return deepcopy.Copy(&s.in).(Layer) +} + +func (s *ipv6State) sent(Layer) error { + // Nothing to do. + return nil +} + +func (s *ipv6State) received(Layer) error { + // Nothing to do. + return nil +} + +// close cleans up any resources held. +func (s *ipv6State) close() error { + return nil +} + +// tcpState maintains state about a TCP connection. +type tcpState struct { + out, in TCP + localSeqNum, remoteSeqNum *seqnum.Value + synAck *TCP + portPickerFD int + finSent bool +} + +var _ layerState = (*tcpState)(nil) + +// SeqNumValue is a helper routine that allocates a new seqnum.Value value to +// store v and returns a pointer to it. +func SeqNumValue(v seqnum.Value) *seqnum.Value { + return &v +} + +// newTCPState creates a new TCPState. +func newTCPState(domain int, out, in TCP) (*tcpState, error) { + portPickerFD, localPort, err := pickPort(domain, unix.SOCK_STREAM) + if err != nil { + return nil, err + } + s := tcpState{ + out: TCP{SrcPort: &localPort}, + in: TCP{DstPort: &localPort}, + localSeqNum: SeqNumValue(seqnum.Value(rand.Uint32())), + portPickerFD: portPickerFD, + finSent: false, + } + if err := s.out.merge(&out); err != nil { + return nil, err + } + if err := s.in.merge(&in); err != nil { + return nil, err + } + return &s, nil +} + +func (s *tcpState) outgoing() Layer { + newOutgoing := deepcopy.Copy(s.out).(TCP) + if s.localSeqNum != nil { + newOutgoing.SeqNum = Uint32(uint32(*s.localSeqNum)) + } + if s.remoteSeqNum != nil { + newOutgoing.AckNum = Uint32(uint32(*s.remoteSeqNum)) + } + return &newOutgoing +} + +// incoming implements layerState.incoming. +func (s *tcpState) incoming(received Layer) Layer { + tcpReceived, ok := received.(*TCP) + if !ok { + return nil + } + newIn := deepcopy.Copy(s.in).(TCP) + if s.remoteSeqNum != nil { + newIn.SeqNum = Uint32(uint32(*s.remoteSeqNum)) + } + if s.localSeqNum != nil && (*tcpReceived.Flags&header.TCPFlagAck) != 0 { + // The caller didn't specify an AckNum so we'll expect the calculated one, + // but only if the ACK flag is set because the AckNum is not valid in a + // header if ACK is not set. + newIn.AckNum = Uint32(uint32(*s.localSeqNum)) + } + return &newIn +} + +func (s *tcpState) sent(sent Layer) error { + tcp, ok := sent.(*TCP) + if !ok { + return fmt.Errorf("can't update tcpState with %T Layer", sent) + } + if !s.finSent { + // update localSeqNum by the payload only when FIN is not yet sent by us + for current := tcp.next(); current != nil; current = current.next() { + s.localSeqNum.UpdateForward(seqnum.Size(current.length())) + } + } + if tcp.Flags != nil && *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 { + s.localSeqNum.UpdateForward(1) + } + if *tcp.Flags&(header.TCPFlagFin) != 0 { + s.finSent = true + } + return nil +} + +func (s *tcpState) received(l Layer) error { + tcp, ok := l.(*TCP) + if !ok { + return fmt.Errorf("can't update tcpState with %T Layer", l) + } + s.remoteSeqNum = SeqNumValue(seqnum.Value(*tcp.SeqNum)) + if *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 { + s.remoteSeqNum.UpdateForward(1) + } + for current := tcp.next(); current != nil; current = current.next() { + s.remoteSeqNum.UpdateForward(seqnum.Size(current.length())) + } + return nil +} + +// close frees the port associated with this connection. +func (s *tcpState) close() error { + if err := unix.Close(s.portPickerFD); err != nil { + return err + } + s.portPickerFD = -1 + return nil +} + +// udpState maintains state about a UDP connection. +type udpState struct { + out, in UDP + portPickerFD int +} + +var _ layerState = (*udpState)(nil) + +// newUDPState creates a new udpState. +func newUDPState(domain int, out, in UDP) (*udpState, error) { + portPickerFD, localPort, err := pickPort(domain, unix.SOCK_DGRAM) + if err != nil { + return nil, fmt.Errorf("picking port: %w", err) + } + s := udpState{ + out: UDP{SrcPort: &localPort}, + in: UDP{DstPort: &localPort}, + portPickerFD: portPickerFD, + } + if err := s.out.merge(&out); err != nil { + return nil, err + } + if err := s.in.merge(&in); err != nil { + return nil, err + } + return &s, nil +} + +func (s *udpState) outgoing() Layer { + return deepcopy.Copy(&s.out).(Layer) +} + +// incoming implements layerState.incoming. +func (s *udpState) incoming(Layer) Layer { + return deepcopy.Copy(&s.in).(Layer) +} + +func (*udpState) sent(l Layer) error { + return nil +} + +func (*udpState) received(l Layer) error { + return nil +} + +// close frees the port associated with this connection. +func (s *udpState) close() error { + if err := unix.Close(s.portPickerFD); err != nil { + return err + } + s.portPickerFD = -1 + return nil +} + +// Connection holds a collection of layer states for maintaining a connection +// along with sockets for sniffer and injecting packets. +type Connection struct { + layerStates []layerState + injector Injector + sniffer Sniffer +} + +// Returns the default incoming frame against which to match. If received is +// longer than layerStates then that may still count as a match. The reverse is +// never a match and nil is returned. +func (conn *Connection) incoming(received Layers) Layers { + if len(received) < len(conn.layerStates) { + return nil + } + in := Layers{} + for i, s := range conn.layerStates { + toMatch := s.incoming(received[i]) + if toMatch == nil { + return nil + } + in = append(in, toMatch) + } + return in +} + +func (conn *Connection) match(override, received Layers) bool { + toMatch := conn.incoming(received) + if toMatch == nil { + return false // Not enough layers in gotLayers for matching. + } + if err := toMatch.merge(override); err != nil { + return false // Failing to merge is not matching. + } + return toMatch.match(received) +} + +// Close frees associated resources held by the Connection. +func (conn *Connection) Close(t *testing.T) { + t.Helper() + + errs := multierr.Combine(conn.sniffer.close(), conn.injector.close()) + for _, s := range conn.layerStates { + if err := s.close(); err != nil { + errs = multierr.Append(errs, fmt.Errorf("unable to close %+v: %s", s, err)) + } + } + if errs != nil { + t.Fatalf("unable to close %+v: %s", conn, errs) + } +} + +// CreateFrame builds a frame for the connection with defaults overriden +// from the innermost layer out, and additionalLayers added after it. +// +// Note that overrideLayers can have a length that is less than the number +// of layers in this connection, and in such cases the innermost layers are +// overriden first. As an example, valid values of overrideLayers for a TCP- +// over-IPv4-over-Ethernet connection are: nil, [TCP], [IPv4, TCP], and +// [Ethernet, IPv4, TCP]. +func (conn *Connection) CreateFrame(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) Layers { + t.Helper() + + var layersToSend Layers + for i, s := range conn.layerStates { + layer := s.outgoing() + // overrideLayers and conn.layerStates have their tails aligned, so + // to find the index we move backwards by the distance i is to the + // end. + if j := len(overrideLayers) - (len(conn.layerStates) - i); j >= 0 { + if err := layer.merge(overrideLayers[j]); err != nil { + t.Fatalf("can't merge %+v into %+v: %s", layer, overrideLayers[j], err) + } + } + layersToSend = append(layersToSend, layer) + } + layersToSend = append(layersToSend, additionalLayers...) + return layersToSend +} + +// SendFrameStateless sends a frame without updating any of the layer states. +// +// This method is useful for sending out-of-band control messages such as +// ICMP packets, where it would not make sense to update the transport layer's +// state using the ICMP header. +func (conn *Connection) SendFrameStateless(t *testing.T, frame Layers) { + t.Helper() + + outBytes, err := frame.ToBytes() + if err != nil { + t.Fatalf("can't build outgoing packet: %s", err) + } + conn.injector.Send(t, outBytes) +} + +// SendFrame sends a frame on the wire and updates the state of all layers. +func (conn *Connection) SendFrame(t *testing.T, frame Layers) { + t.Helper() + + outBytes, err := frame.ToBytes() + if err != nil { + t.Fatalf("can't build outgoing packet: %s", err) + } + conn.injector.Send(t, outBytes) + + // frame might have nil values where the caller wanted to use default values. + // sentFrame will have no nil values in it because it comes from parsing the + // bytes that were actually sent. + sentFrame := parse(parseEther, outBytes) + // Update the state of each layer based on what was sent. + for i, s := range conn.layerStates { + if err := s.sent(sentFrame[i]); err != nil { + t.Fatalf("Unable to update the state of %+v with %s: %s", s, sentFrame[i], err) + } + } +} + +// send sends a packet, possibly with layers of this connection overridden and +// additional layers added. +// +// Types defined with Connection as the underlying type should expose +// type-safe versions of this method. +func (conn *Connection) send(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) { + t.Helper() + + conn.SendFrame(t, conn.CreateFrame(t, overrideLayers, additionalLayers...)) +} + +// recvFrame gets the next successfully parsed frame (of type Layers) within the +// timeout provided. If no parsable frame arrives before the timeout, it returns +// nil. +func (conn *Connection) recvFrame(t *testing.T, timeout time.Duration) Layers { + t.Helper() + + if timeout <= 0 { + return nil + } + b := conn.sniffer.Recv(t, timeout) + if b == nil { + return nil + } + return parse(parseEther, b) +} + +// layersError stores the Layers that we got and the Layers that we wanted to +// match. +type layersError struct { + got, want Layers +} + +func (e *layersError) Error() string { + return e.got.diff(e.want) +} + +// Expect expects a frame with the final layerStates layer matching the +// provided Layer within the timeout specified. If it doesn't arrive in time, +// an error is returned. +func (conn *Connection) Expect(t *testing.T, layer Layer, timeout time.Duration) (Layer, error) { + t.Helper() + + // Make a frame that will ignore all but the final layer. + layers := make([]Layer, len(conn.layerStates)) + layers[len(layers)-1] = layer + + gotFrame, err := conn.ExpectFrame(t, layers, timeout) + if err != nil { + return nil, err + } + if len(conn.layerStates)-1 < len(gotFrame) { + return gotFrame[len(conn.layerStates)-1], nil + } + t.Fatalf("the received frame should be at least as long as the expected layers, got %d layers, want at least %d layers, got frame: %#v", len(gotFrame), len(conn.layerStates), gotFrame) + panic("unreachable") +} + +// ExpectFrame expects a frame that matches the provided Layers within the +// timeout specified. If one arrives in time, the Layers is returned without an +// error. If it doesn't arrive in time, it returns nil and error is non-nil. +func (conn *Connection) ExpectFrame(t *testing.T, layers Layers, timeout time.Duration) (Layers, error) { + t.Helper() + + deadline := time.Now().Add(timeout) + var errs error + for { + var gotLayers Layers + if timeout = time.Until(deadline); timeout > 0 { + gotLayers = conn.recvFrame(t, timeout) + } + if gotLayers == nil { + if errs == nil { + return nil, fmt.Errorf("got no frames matching %v during %s", layers, timeout) + } + return nil, fmt.Errorf("got no frames matching %v during %s: got %w", layers, timeout, errs) + } + if conn.match(layers, gotLayers) { + for i, s := range conn.layerStates { + if err := s.received(gotLayers[i]); err != nil { + t.Fatalf("failed to update test connection's layer states based on received frame: %s", err) + } + } + return gotLayers, nil + } + errs = multierr.Combine(errs, &layersError{got: gotLayers, want: conn.incoming(gotLayers)}) + } +} + +// Drain drains the sniffer's receive buffer by receiving packets until there's +// nothing else to receive. +func (conn *Connection) Drain(t *testing.T) { + t.Helper() + + conn.sniffer.Drain(t) +} + +// TCPIPv4 maintains the state for all the layers in a TCP/IPv4 connection. +type TCPIPv4 Connection + +// NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults. +func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 { + t.Helper() + + etherState, err := newEtherState(Ether{}, Ether{}) + if err != nil { + t.Fatalf("can't make etherState: %s", err) + } + ipv4State, err := newIPv4State(IPv4{}, IPv4{}) + if err != nil { + t.Fatalf("can't make ipv4State: %s", err) + } + tcpState, err := newTCPState(unix.AF_INET, outgoingTCP, incomingTCP) + if err != nil { + t.Fatalf("can't make tcpState: %s", err) + } + injector, err := NewInjector(t) + if err != nil { + t.Fatalf("can't make injector: %s", err) + } + sniffer, err := NewSniffer(t) + if err != nil { + t.Fatalf("can't make sniffer: %s", err) + } + + return TCPIPv4{ + layerStates: []layerState{etherState, ipv4State, tcpState}, + injector: injector, + sniffer: sniffer, + } +} + +// Connect performs a TCP 3-way handshake. The input Connection should have a +// final TCP Layer. +func (conn *TCPIPv4) Connect(t *testing.T) { + t.Helper() + + // Send the SYN. + conn.Send(t, TCP{Flags: Uint8(header.TCPFlagSyn)}) + + // Wait for the SYN-ACK. + synAck, err := conn.Expect(t, TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("didn't get synack during handshake: %s", err) + } + conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck + + // Send an ACK. + conn.Send(t, TCP{Flags: Uint8(header.TCPFlagAck)}) +} + +// ConnectWithOptions performs a TCP 3-way handshake with given TCP options. +// The input Connection should have a final TCP Layer. +func (conn *TCPIPv4) ConnectWithOptions(t *testing.T, options []byte) { + t.Helper() + + // Send the SYN. + conn.Send(t, TCP{Flags: Uint8(header.TCPFlagSyn), Options: options}) + + // Wait for the SYN-ACK. + synAck, err := conn.Expect(t, TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("didn't get synack during handshake: %s", err) + } + conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck + + // Send an ACK. + conn.Send(t, TCP{Flags: Uint8(header.TCPFlagAck)}) +} + +// ExpectData is a convenient method that expects a Layer and the Layer after +// it. If it doens't arrive in time, it returns nil. +func (conn *TCPIPv4) ExpectData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { + t.Helper() + + expected := make([]Layer, len(conn.layerStates)) + expected[len(expected)-1] = tcp + if payload != nil { + expected = append(expected, payload) + } + return (*Connection)(conn).ExpectFrame(t, expected, timeout) +} + +// ExpectNextData attempts to receive the next incoming segment for the +// connection and expects that to match the given layers. +// +// It differs from ExpectData() in that here we are only interested in the next +// received segment, while ExpectData() can receive multiple segments for the +// connection until there is a match with given layers or a timeout. +func (conn *TCPIPv4) ExpectNextData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { + t.Helper() + + // Receive the first incoming TCP segment for this connection. + got, err := conn.ExpectData(t, &TCP{}, nil, timeout) + if err != nil { + return nil, err + } + + expected := make([]Layer, len(conn.layerStates)) + expected[len(expected)-1] = tcp + if payload != nil { + expected = append(expected, payload) + tcp.SeqNum = Uint32(uint32(*conn.RemoteSeqNum(t)) - uint32(payload.Length())) + } + if !(*Connection)(conn).match(expected, got) { + return nil, fmt.Errorf("next frame is not matching %s during %s: got %s", expected, timeout, got) + } + return got, nil +} + +// Send a packet with reasonable defaults. Potentially override the TCP layer in +// the connection with the provided layer and add additionLayers. +func (conn *TCPIPv4) Send(t *testing.T, tcp TCP, additionalLayers ...Layer) { + t.Helper() + + (*Connection)(conn).send(t, Layers{&tcp}, additionalLayers...) +} + +// Close frees associated resources held by the TCPIPv4 connection. +func (conn *TCPIPv4) Close(t *testing.T) { + t.Helper() + + (*Connection)(conn).Close(t) +} + +// Expect expects a frame with the TCP layer matching the provided TCP within +// the timeout specified. If it doesn't arrive in time, an error is returned. +func (conn *TCPIPv4) Expect(t *testing.T, tcp TCP, timeout time.Duration) (*TCP, error) { + t.Helper() + + layer, err := (*Connection)(conn).Expect(t, &tcp, timeout) + if layer == nil { + return nil, err + } + gotTCP, ok := layer.(*TCP) + if !ok { + t.Fatalf("expected %s to be TCP", layer) + } + return gotTCP, err +} + +func (conn *TCPIPv4) tcpState(t *testing.T) *tcpState { + t.Helper() + + state, ok := conn.layerStates[2].(*tcpState) + if !ok { + t.Fatalf("got transport-layer state type=%T, expected tcpState", conn.layerStates[2]) + } + return state +} + +func (conn *TCPIPv4) ipv4State(t *testing.T) *ipv4State { + t.Helper() + + state, ok := conn.layerStates[1].(*ipv4State) + if !ok { + t.Fatalf("expected network-layer state type=%T, expected ipv4State", conn.layerStates[1]) + } + return state +} + +// RemoteSeqNum returns the next expected sequence number from the DUT. +func (conn *TCPIPv4) RemoteSeqNum(t *testing.T) *seqnum.Value { + t.Helper() + + return conn.tcpState(t).remoteSeqNum +} + +// LocalSeqNum returns the next sequence number to send from the testbench. +func (conn *TCPIPv4) LocalSeqNum(t *testing.T) *seqnum.Value { + t.Helper() + + return conn.tcpState(t).localSeqNum +} + +// SynAck returns the SynAck that was part of the handshake. +func (conn *TCPIPv4) SynAck(t *testing.T) *TCP { + t.Helper() + + return conn.tcpState(t).synAck +} + +// LocalAddr gets the local socket address of this connection. +func (conn *TCPIPv4) LocalAddr(t *testing.T) *unix.SockaddrInet4 { + t.Helper() + + sa := &unix.SockaddrInet4{Port: int(*conn.tcpState(t).out.SrcPort)} + copy(sa.Addr[:], *conn.ipv4State(t).out.SrcAddr) + return sa +} + +// Drain drains the sniffer's receive buffer by receiving packets until there's +// nothing else to receive. +func (conn *TCPIPv4) Drain(t *testing.T) { + t.Helper() + + conn.sniffer.Drain(t) +} + +// IPv6Conn maintains the state for all the layers in a IPv6 connection. +type IPv6Conn Connection + +// NewIPv6Conn creates a new IPv6Conn connection with reasonable defaults. +func NewIPv6Conn(t *testing.T, outgoingIPv6, incomingIPv6 IPv6) IPv6Conn { + t.Helper() + + etherState, err := newEtherState(Ether{}, Ether{}) + if err != nil { + t.Fatalf("can't make EtherState: %s", err) + } + ipv6State, err := newIPv6State(outgoingIPv6, incomingIPv6) + if err != nil { + t.Fatalf("can't make IPv6State: %s", err) + } + + injector, err := NewInjector(t) + if err != nil { + t.Fatalf("can't make injector: %s", err) + } + sniffer, err := NewSniffer(t) + if err != nil { + t.Fatalf("can't make sniffer: %s", err) + } + + return IPv6Conn{ + layerStates: []layerState{etherState, ipv6State}, + injector: injector, + sniffer: sniffer, + } +} + +// Send sends a frame with ipv6 overriding the IPv6 layer defaults and +// additionalLayers added after it. +func (conn *IPv6Conn) Send(t *testing.T, ipv6 IPv6, additionalLayers ...Layer) { + t.Helper() + + (*Connection)(conn).send(t, Layers{&ipv6}, additionalLayers...) +} + +// Close to clean up any resources held. +func (conn *IPv6Conn) Close(t *testing.T) { + t.Helper() + + (*Connection)(conn).Close(t) +} + +// ExpectFrame expects a frame that matches the provided Layers within the +// timeout specified. If it doesn't arrive in time, an error is returned. +func (conn *IPv6Conn) ExpectFrame(t *testing.T, frame Layers, timeout time.Duration) (Layers, error) { + t.Helper() + + return (*Connection)(conn).ExpectFrame(t, frame, timeout) +} + +// UDPIPv4 maintains the state for all the layers in a UDP/IPv4 connection. +type UDPIPv4 Connection + +// NewUDPIPv4 creates a new UDPIPv4 connection with reasonable defaults. +func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 { + t.Helper() + + etherState, err := newEtherState(Ether{}, Ether{}) + if err != nil { + t.Fatalf("can't make etherState: %s", err) + } + ipv4State, err := newIPv4State(IPv4{}, IPv4{}) + if err != nil { + t.Fatalf("can't make ipv4State: %s", err) + } + udpState, err := newUDPState(unix.AF_INET, outgoingUDP, incomingUDP) + if err != nil { + t.Fatalf("can't make udpState: %s", err) + } + injector, err := NewInjector(t) + if err != nil { + t.Fatalf("can't make injector: %s", err) + } + sniffer, err := NewSniffer(t) + if err != nil { + t.Fatalf("can't make sniffer: %s", err) + } + + return UDPIPv4{ + layerStates: []layerState{etherState, ipv4State, udpState}, + injector: injector, + sniffer: sniffer, + } +} + +func (conn *UDPIPv4) udpState(t *testing.T) *udpState { + t.Helper() + + state, ok := conn.layerStates[2].(*udpState) + if !ok { + t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2]) + } + return state +} + +func (conn *UDPIPv4) ipv4State(t *testing.T) *ipv4State { + t.Helper() + + state, ok := conn.layerStates[1].(*ipv4State) + if !ok { + t.Fatalf("got network-layer state type=%T, expected ipv4State", conn.layerStates[1]) + } + return state +} + +// LocalAddr gets the local socket address of this connection. +func (conn *UDPIPv4) LocalAddr(t *testing.T) *unix.SockaddrInet4 { + t.Helper() + + sa := &unix.SockaddrInet4{Port: int(*conn.udpState(t).out.SrcPort)} + copy(sa.Addr[:], *conn.ipv4State(t).out.SrcAddr) + return sa +} + +// Send sends a packet with reasonable defaults, potentially overriding the UDP +// layer and adding additionLayers. +func (conn *UDPIPv4) Send(t *testing.T, udp UDP, additionalLayers ...Layer) { + t.Helper() + + (*Connection)(conn).send(t, Layers{&udp}, additionalLayers...) +} + +// SendIP sends a packet with reasonable defaults, potentially overriding the +// UDP and IPv4 headers and adding additionLayers. +func (conn *UDPIPv4) SendIP(t *testing.T, ip IPv4, udp UDP, additionalLayers ...Layer) { + t.Helper() + + (*Connection)(conn).send(t, Layers{&ip, &udp}, additionalLayers...) +} + +// Expect expects a frame with the UDP layer matching the provided UDP within +// the timeout specified. If it doesn't arrive in time, an error is returned. +func (conn *UDPIPv4) Expect(t *testing.T, udp UDP, timeout time.Duration) (*UDP, error) { + t.Helper() + + layer, err := (*Connection)(conn).Expect(t, &udp, timeout) + if err != nil { + return nil, err + } + gotUDP, ok := layer.(*UDP) + if !ok { + t.Fatalf("expected %s to be UDP", layer) + } + return gotUDP, nil +} + +// ExpectData is a convenient method that expects a Layer and the Layer after +// it. If it doens't arrive in time, it returns nil. +func (conn *UDPIPv4) ExpectData(t *testing.T, udp UDP, payload Payload, timeout time.Duration) (Layers, error) { + t.Helper() + + expected := make([]Layer, len(conn.layerStates)) + expected[len(expected)-1] = &udp + if payload.length() != 0 { + expected = append(expected, &payload) + } + return (*Connection)(conn).ExpectFrame(t, expected, timeout) +} + +// Close frees associated resources held by the UDPIPv4 connection. +func (conn *UDPIPv4) Close(t *testing.T) { + t.Helper() + + (*Connection)(conn).Close(t) +} + +// Drain drains the sniffer's receive buffer by receiving packets until there's +// nothing else to receive. +func (conn *UDPIPv4) Drain(t *testing.T) { + t.Helper() + + conn.sniffer.Drain(t) +} + +// UDPIPv6 maintains the state for all the layers in a UDP/IPv6 connection. +type UDPIPv6 Connection + +// NewUDPIPv6 creates a new UDPIPv6 connection with reasonable defaults. +func NewUDPIPv6(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv6 { + t.Helper() + + etherState, err := newEtherState(Ether{}, Ether{}) + if err != nil { + t.Fatalf("can't make etherState: %s", err) + } + ipv6State, err := newIPv6State(IPv6{}, IPv6{}) + if err != nil { + t.Fatalf("can't make IPv6State: %s", err) + } + udpState, err := newUDPState(unix.AF_INET6, outgoingUDP, incomingUDP) + if err != nil { + t.Fatalf("can't make udpState: %s", err) + } + injector, err := NewInjector(t) + if err != nil { + t.Fatalf("can't make injector: %s", err) + } + sniffer, err := NewSniffer(t) + if err != nil { + t.Fatalf("can't make sniffer: %s", err) + } + return UDPIPv6{ + layerStates: []layerState{etherState, ipv6State, udpState}, + injector: injector, + sniffer: sniffer, + } +} + +func (conn *UDPIPv6) udpState(t *testing.T) *udpState { + t.Helper() + + state, ok := conn.layerStates[2].(*udpState) + if !ok { + t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2]) + } + return state +} + +func (conn *UDPIPv6) ipv6State(t *testing.T) *ipv6State { + t.Helper() + + state, ok := conn.layerStates[1].(*ipv6State) + if !ok { + t.Fatalf("got network-layer state type=%T, expected ipv6State", conn.layerStates[1]) + } + return state +} + +// LocalAddr gets the local socket address of this connection. +func (conn *UDPIPv6) LocalAddr(t *testing.T) *unix.SockaddrInet6 { + t.Helper() + + sa := &unix.SockaddrInet6{ + Port: int(*conn.udpState(t).out.SrcPort), + // Local address is in perspective to the remote host, so it's scoped to the + // ID of the remote interface. + ZoneId: uint32(RemoteInterfaceID), + } + copy(sa.Addr[:], *conn.ipv6State(t).out.SrcAddr) + return sa +} + +// Send sends a packet with reasonable defaults, potentially overriding the UDP +// layer and adding additionLayers. +func (conn *UDPIPv6) Send(t *testing.T, udp UDP, additionalLayers ...Layer) { + t.Helper() + + (*Connection)(conn).send(t, Layers{&udp}, additionalLayers...) +} + +// SendIPv6 sends a packet with reasonable defaults, potentially overriding the +// UDP and IPv6 headers and adding additionLayers. +func (conn *UDPIPv6) SendIPv6(t *testing.T, ip IPv6, udp UDP, additionalLayers ...Layer) { + t.Helper() + + (*Connection)(conn).send(t, Layers{&ip, &udp}, additionalLayers...) +} + +// Expect expects a frame with the UDP layer matching the provided UDP within +// the timeout specified. If it doesn't arrive in time, an error is returned. +func (conn *UDPIPv6) Expect(t *testing.T, udp UDP, timeout time.Duration) (*UDP, error) { + t.Helper() + + layer, err := (*Connection)(conn).Expect(t, &udp, timeout) + if err != nil { + return nil, err + } + gotUDP, ok := layer.(*UDP) + if !ok { + t.Fatalf("expected %s to be UDP", layer) + } + return gotUDP, nil +} + +// ExpectData is a convenient method that expects a Layer and the Layer after +// it. If it doens't arrive in time, it returns nil. +func (conn *UDPIPv6) ExpectData(t *testing.T, udp UDP, payload Payload, timeout time.Duration) (Layers, error) { + t.Helper() + + expected := make([]Layer, len(conn.layerStates)) + expected[len(expected)-1] = &udp + if payload.length() != 0 { + expected = append(expected, &payload) + } + return (*Connection)(conn).ExpectFrame(t, expected, timeout) +} + +// Close frees associated resources held by the UDPIPv6 connection. +func (conn *UDPIPv6) Close(t *testing.T) { + t.Helper() + + (*Connection)(conn).Close(t) +} + +// Drain drains the sniffer's receive buffer by receiving packets until there's +// nothing else to receive. +func (conn *UDPIPv6) Drain(t *testing.T) { + t.Helper() + + conn.sniffer.Drain(t) +} + +// TCPIPv6 maintains the state for all the layers in a TCP/IPv6 connection. +type TCPIPv6 Connection + +// NewTCPIPv6 creates a new TCPIPv6 connection with reasonable defaults. +func NewTCPIPv6(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv6 { + etherState, err := newEtherState(Ether{}, Ether{}) + if err != nil { + t.Fatalf("can't make etherState: %s", err) + } + ipv6State, err := newIPv6State(IPv6{}, IPv6{}) + if err != nil { + t.Fatalf("can't make ipv6State: %s", err) + } + tcpState, err := newTCPState(unix.AF_INET6, outgoingTCP, incomingTCP) + if err != nil { + t.Fatalf("can't make tcpState: %s", err) + } + injector, err := NewInjector(t) + if err != nil { + t.Fatalf("can't make injector: %s", err) + } + sniffer, err := NewSniffer(t) + if err != nil { + t.Fatalf("can't make sniffer: %s", err) + } + + return TCPIPv6{ + layerStates: []layerState{etherState, ipv6State, tcpState}, + injector: injector, + sniffer: sniffer, + } +} + +func (conn *TCPIPv6) SrcPort() uint16 { + state := conn.layerStates[2].(*tcpState) + return *state.out.SrcPort +} + +// ExpectData is a convenient method that expects a Layer and the Layer after +// it. If it doens't arrive in time, it returns nil. +func (conn *TCPIPv6) ExpectData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { + t.Helper() + + expected := make([]Layer, len(conn.layerStates)) + expected[len(expected)-1] = tcp + if payload != nil { + expected = append(expected, payload) + } + return (*Connection)(conn).ExpectFrame(t, expected, timeout) +} + +// Close frees associated resources held by the TCPIPv6 connection. +func (conn *TCPIPv6) Close(t *testing.T) { + t.Helper() + + (*Connection)(conn).Close(t) +} diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go new file mode 100644 index 000000000..73c532e75 --- /dev/null +++ b/test/packetimpact/testbench/dut.go @@ -0,0 +1,702 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testbench + +import ( + "context" + "flag" + "net" + "strconv" + "syscall" + "testing" + + pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto" + + "golang.org/x/sys/unix" + "google.golang.org/grpc" + "google.golang.org/grpc/keepalive" +) + +// DUT communicates with the DUT to force it to make POSIX calls. +type DUT struct { + conn *grpc.ClientConn + posixServer POSIXClient +} + +// NewDUT creates a new connection with the DUT over gRPC. +func NewDUT(t *testing.T) DUT { + t.Helper() + + flag.Parse() + if err := genPseudoFlags(); err != nil { + t.Fatal("generating psuedo flags:", err) + } + + posixServerAddress := POSIXServerIP + ":" + strconv.Itoa(POSIXServerPort) + conn, err := grpc.Dial(posixServerAddress, grpc.WithInsecure(), grpc.WithKeepaliveParams(keepalive.ClientParameters{Timeout: RPCKeepalive})) + if err != nil { + t.Fatalf("failed to grpc.Dial(%s): %s", posixServerAddress, err) + } + posixServer := NewPOSIXClient(conn) + return DUT{ + conn: conn, + posixServer: posixServer, + } +} + +// TearDown closes the underlying connection. +func (dut *DUT) TearDown() { + dut.conn.Close() +} + +func (dut *DUT) sockaddrToProto(t *testing.T, sa unix.Sockaddr) *pb.Sockaddr { + t.Helper() + + switch s := sa.(type) { + case *unix.SockaddrInet4: + return &pb.Sockaddr{ + Sockaddr: &pb.Sockaddr_In{ + In: &pb.SockaddrIn{ + Family: unix.AF_INET, + Port: uint32(s.Port), + Addr: s.Addr[:], + }, + }, + } + case *unix.SockaddrInet6: + return &pb.Sockaddr{ + Sockaddr: &pb.Sockaddr_In6{ + In6: &pb.SockaddrIn6{ + Family: unix.AF_INET6, + Port: uint32(s.Port), + Flowinfo: 0, + ScopeId: s.ZoneId, + Addr: s.Addr[:], + }, + }, + } + } + t.Fatalf("can't parse Sockaddr struct: %+v", sa) + return nil +} + +func (dut *DUT) protoToSockaddr(t *testing.T, sa *pb.Sockaddr) unix.Sockaddr { + t.Helper() + + switch s := sa.Sockaddr.(type) { + case *pb.Sockaddr_In: + ret := unix.SockaddrInet4{ + Port: int(s.In.GetPort()), + } + copy(ret.Addr[:], s.In.GetAddr()) + return &ret + case *pb.Sockaddr_In6: + ret := unix.SockaddrInet6{ + Port: int(s.In6.GetPort()), + ZoneId: s.In6.GetScopeId(), + } + copy(ret.Addr[:], s.In6.GetAddr()) + return &ret + } + t.Fatalf("can't parse Sockaddr proto: %#v", sa) + return nil +} + +// CreateBoundSocket makes a new socket on the DUT, with type typ and protocol +// proto, and bound to the IP address addr. Returns the new file descriptor and +// the port that was selected on the DUT. +func (dut *DUT) CreateBoundSocket(t *testing.T, typ, proto int32, addr net.IP) (int32, uint16) { + t.Helper() + + var fd int32 + if addr.To4() != nil { + fd = dut.Socket(t, unix.AF_INET, typ, proto) + sa := unix.SockaddrInet4{} + copy(sa.Addr[:], addr.To4()) + dut.Bind(t, fd, &sa) + } else if addr.To16() != nil { + fd = dut.Socket(t, unix.AF_INET6, typ, proto) + sa := unix.SockaddrInet6{} + copy(sa.Addr[:], addr.To16()) + sa.ZoneId = uint32(RemoteInterfaceID) + dut.Bind(t, fd, &sa) + } else { + t.Fatalf("invalid IP address: %s", addr) + } + sa := dut.GetSockName(t, fd) + var port int + switch s := sa.(type) { + case *unix.SockaddrInet4: + port = s.Port + case *unix.SockaddrInet6: + port = s.Port + default: + t.Fatalf("unknown sockaddr type from getsockname: %T", sa) + } + return fd, uint16(port) +} + +// CreateListener makes a new TCP connection. If it fails, the test ends. +func (dut *DUT) CreateListener(t *testing.T, typ, proto, backlog int32) (int32, uint16) { + t.Helper() + + fd, remotePort := dut.CreateBoundSocket(t, typ, proto, net.ParseIP(RemoteIPv4)) + dut.Listen(t, fd, backlog) + return fd, remotePort +} + +// All the functions that make gRPC calls to the POSIX service are below, sorted +// alphabetically. + +// Accept calls accept on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is needed, use +// AcceptWithErrno. +func (dut *DUT) Accept(t *testing.T, sockfd int32) (int32, unix.Sockaddr) { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + fd, sa, err := dut.AcceptWithErrno(ctx, t, sockfd) + if fd < 0 { + t.Fatalf("failed to accept: %s", err) + } + return fd, sa +} + +// AcceptWithErrno calls accept on the DUT. +func (dut *DUT) AcceptWithErrno(ctx context.Context, t *testing.T, sockfd int32) (int32, unix.Sockaddr, error) { + t.Helper() + + req := pb.AcceptRequest{ + Sockfd: sockfd, + } + resp, err := dut.posixServer.Accept(ctx, &req) + if err != nil { + t.Fatalf("failed to call Accept: %s", err) + } + return resp.GetFd(), dut.protoToSockaddr(t, resp.GetAddr()), syscall.Errno(resp.GetErrno_()) +} + +// Bind calls bind on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is +// needed, use BindWithErrno. +func (dut *DUT) Bind(t *testing.T, fd int32, sa unix.Sockaddr) { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, err := dut.BindWithErrno(ctx, t, fd, sa) + if ret != 0 { + t.Fatalf("failed to bind socket: %s", err) + } +} + +// BindWithErrno calls bind on the DUT. +func (dut *DUT) BindWithErrno(ctx context.Context, t *testing.T, fd int32, sa unix.Sockaddr) (int32, error) { + t.Helper() + + req := pb.BindRequest{ + Sockfd: fd, + Addr: dut.sockaddrToProto(t, sa), + } + resp, err := dut.posixServer.Bind(ctx, &req) + if err != nil { + t.Fatalf("failed to call Bind: %s", err) + } + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) +} + +// Close calls close on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is needed, use +// CloseWithErrno. +func (dut *DUT) Close(t *testing.T, fd int32) { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, err := dut.CloseWithErrno(ctx, t, fd) + if ret != 0 { + t.Fatalf("failed to close: %s", err) + } +} + +// CloseWithErrno calls close on the DUT. +func (dut *DUT) CloseWithErrno(ctx context.Context, t *testing.T, fd int32) (int32, error) { + t.Helper() + + req := pb.CloseRequest{ + Fd: fd, + } + resp, err := dut.posixServer.Close(ctx, &req) + if err != nil { + t.Fatalf("failed to call Close: %s", err) + } + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) +} + +// Connect calls connect on the DUT and causes a fatal test failure if it +// doesn't succeed. If more control over the timeout or error handling is +// needed, use ConnectWithErrno. +func (dut *DUT) Connect(t *testing.T, fd int32, sa unix.Sockaddr) { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, err := dut.ConnectWithErrno(ctx, t, fd, sa) + // Ignore 'operation in progress' error that can be returned when the socket + // is non-blocking. + if err != syscall.Errno(unix.EINPROGRESS) && ret != 0 { + t.Fatalf("failed to connect socket: %s", err) + } +} + +// ConnectWithErrno calls bind on the DUT. +func (dut *DUT) ConnectWithErrno(ctx context.Context, t *testing.T, fd int32, sa unix.Sockaddr) (int32, error) { + t.Helper() + + req := pb.ConnectRequest{ + Sockfd: fd, + Addr: dut.sockaddrToProto(t, sa), + } + resp, err := dut.posixServer.Connect(ctx, &req) + if err != nil { + t.Fatalf("failed to call Connect: %s", err) + } + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) +} + +// Fcntl calls fcntl on the DUT and causes a fatal test failure if it +// doesn't succeed. If more control over the timeout or error handling is +// needed, use FcntlWithErrno. +func (dut *DUT) Fcntl(t *testing.T, fd, cmd, arg int32) int32 { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, err := dut.FcntlWithErrno(ctx, t, fd, cmd, arg) + if ret == -1 { + t.Fatalf("failed to Fcntl: ret=%d, errno=%s", ret, err) + } + return ret +} + +// FcntlWithErrno calls fcntl on the DUT. +func (dut *DUT) FcntlWithErrno(ctx context.Context, t *testing.T, fd, cmd, arg int32) (int32, error) { + t.Helper() + + req := pb.FcntlRequest{ + Fd: fd, + Cmd: cmd, + Arg: arg, + } + resp, err := dut.posixServer.Fcntl(ctx, &req) + if err != nil { + t.Fatalf("failed to call Fcntl: %s", err) + } + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) +} + +// GetSockName calls getsockname on the DUT and causes a fatal test failure if +// it doesn't succeed. If more control over the timeout or error handling is +// needed, use GetSockNameWithErrno. +func (dut *DUT) GetSockName(t *testing.T, sockfd int32) unix.Sockaddr { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, sa, err := dut.GetSockNameWithErrno(ctx, t, sockfd) + if ret != 0 { + t.Fatalf("failed to getsockname: %s", err) + } + return sa +} + +// GetSockNameWithErrno calls getsockname on the DUT. +func (dut *DUT) GetSockNameWithErrno(ctx context.Context, t *testing.T, sockfd int32) (int32, unix.Sockaddr, error) { + t.Helper() + + req := pb.GetSockNameRequest{ + Sockfd: sockfd, + } + resp, err := dut.posixServer.GetSockName(ctx, &req) + if err != nil { + t.Fatalf("failed to call Bind: %s", err) + } + return resp.GetRet(), dut.protoToSockaddr(t, resp.GetAddr()), syscall.Errno(resp.GetErrno_()) +} + +func (dut *DUT) getSockOpt(ctx context.Context, t *testing.T, sockfd, level, optname, optlen int32, typ pb.GetSockOptRequest_SockOptType) (int32, *pb.SockOptVal, error) { + t.Helper() + + req := pb.GetSockOptRequest{ + Sockfd: sockfd, + Level: level, + Optname: optname, + Optlen: optlen, + Type: typ, + } + resp, err := dut.posixServer.GetSockOpt(ctx, &req) + if err != nil { + t.Fatalf("failed to call GetSockOpt: %s", err) + } + optval := resp.GetOptval() + if optval == nil { + t.Fatalf("GetSockOpt response does not contain a value") + } + return resp.GetRet(), optval, syscall.Errno(resp.GetErrno_()) +} + +// GetSockOpt calls getsockopt on the DUT and causes a fatal test failure if it +// doesn't succeed. If more control over the timeout or error handling is +// needed, use GetSockOptWithErrno. Because endianess and the width of values +// might differ between the testbench and DUT architectures, prefer to use a +// more specific GetSockOptXxx function. +func (dut *DUT) GetSockOpt(t *testing.T, sockfd, level, optname, optlen int32) []byte { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, optval, err := dut.GetSockOptWithErrno(ctx, t, sockfd, level, optname, optlen) + if ret != 0 { + t.Fatalf("failed to GetSockOpt: %s", err) + } + return optval +} + +// GetSockOptWithErrno calls getsockopt on the DUT. Because endianess and the +// width of values might differ between the testbench and DUT architectures, +// prefer to use a more specific GetSockOptXxxWithErrno function. +func (dut *DUT) GetSockOptWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname, optlen int32) (int32, []byte, error) { + t.Helper() + + ret, optval, errno := dut.getSockOpt(ctx, t, sockfd, level, optname, optlen, pb.GetSockOptRequest_BYTES) + bytesval, ok := optval.Val.(*pb.SockOptVal_Bytesval) + if !ok { + t.Fatalf("GetSockOpt got value type: %T, want bytes", optval.Val) + } + return ret, bytesval.Bytesval, errno +} + +// GetSockOptInt calls getsockopt on the DUT and causes a fatal test failure +// if it doesn't succeed. If more control over the int optval or error handling +// is needed, use GetSockOptIntWithErrno. +func (dut *DUT) GetSockOptInt(t *testing.T, sockfd, level, optname int32) int32 { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, intval, err := dut.GetSockOptIntWithErrno(ctx, t, sockfd, level, optname) + if ret != 0 { + t.Fatalf("failed to GetSockOptInt: %s", err) + } + return intval +} + +// GetSockOptIntWithErrno calls getsockopt with an integer optval. +func (dut *DUT) GetSockOptIntWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32) (int32, int32, error) { + t.Helper() + + ret, optval, errno := dut.getSockOpt(ctx, t, sockfd, level, optname, 0, pb.GetSockOptRequest_INT) + intval, ok := optval.Val.(*pb.SockOptVal_Intval) + if !ok { + t.Fatalf("GetSockOpt got value type: %T, want int", optval.Val) + } + return ret, intval.Intval, errno +} + +// GetSockOptTimeval calls getsockopt on the DUT and causes a fatal test failure +// if it doesn't succeed. If more control over the timeout or error handling is +// needed, use GetSockOptTimevalWithErrno. +func (dut *DUT) GetSockOptTimeval(t *testing.T, sockfd, level, optname int32) unix.Timeval { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, timeval, err := dut.GetSockOptTimevalWithErrno(ctx, t, sockfd, level, optname) + if ret != 0 { + t.Fatalf("failed to GetSockOptTimeval: %s", err) + } + return timeval +} + +// GetSockOptTimevalWithErrno calls getsockopt and returns a timeval. +func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32) (int32, unix.Timeval, error) { + t.Helper() + + ret, optval, errno := dut.getSockOpt(ctx, t, sockfd, level, optname, 0, pb.GetSockOptRequest_TIME) + tv, ok := optval.Val.(*pb.SockOptVal_Timeval) + if !ok { + t.Fatalf("GetSockOpt got value type: %T, want timeval", optval.Val) + } + timeval := unix.Timeval{ + Sec: tv.Timeval.Seconds, + Usec: tv.Timeval.Microseconds, + } + return ret, timeval, errno +} + +// Listen calls listen on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is needed, use +// ListenWithErrno. +func (dut *DUT) Listen(t *testing.T, sockfd, backlog int32) { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, err := dut.ListenWithErrno(ctx, t, sockfd, backlog) + if ret != 0 { + t.Fatalf("failed to listen: %s", err) + } +} + +// ListenWithErrno calls listen on the DUT. +func (dut *DUT) ListenWithErrno(ctx context.Context, t *testing.T, sockfd, backlog int32) (int32, error) { + t.Helper() + + req := pb.ListenRequest{ + Sockfd: sockfd, + Backlog: backlog, + } + resp, err := dut.posixServer.Listen(ctx, &req) + if err != nil { + t.Fatalf("failed to call Listen: %s", err) + } + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) +} + +// Send calls send on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is needed, use +// SendWithErrno. +func (dut *DUT) Send(t *testing.T, sockfd int32, buf []byte, flags int32) int32 { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, err := dut.SendWithErrno(ctx, t, sockfd, buf, flags) + if ret == -1 { + t.Fatalf("failed to send: %s", err) + } + return ret +} + +// SendWithErrno calls send on the DUT. +func (dut *DUT) SendWithErrno(ctx context.Context, t *testing.T, sockfd int32, buf []byte, flags int32) (int32, error) { + t.Helper() + + req := pb.SendRequest{ + Sockfd: sockfd, + Buf: buf, + Flags: flags, + } + resp, err := dut.posixServer.Send(ctx, &req) + if err != nil { + t.Fatalf("failed to call Send: %s", err) + } + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) +} + +// SendTo calls sendto on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is needed, use +// SendToWithErrno. +func (dut *DUT) SendTo(t *testing.T, sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) int32 { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, err := dut.SendToWithErrno(ctx, t, sockfd, buf, flags, destAddr) + if ret == -1 { + t.Fatalf("failed to sendto: %s", err) + } + return ret +} + +// SendToWithErrno calls sendto on the DUT. +func (dut *DUT) SendToWithErrno(ctx context.Context, t *testing.T, sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) (int32, error) { + t.Helper() + + req := pb.SendToRequest{ + Sockfd: sockfd, + Buf: buf, + Flags: flags, + DestAddr: dut.sockaddrToProto(t, destAddr), + } + resp, err := dut.posixServer.SendTo(ctx, &req) + if err != nil { + t.Fatalf("faled to call SendTo: %s", err) + } + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) +} + +// SetNonBlocking will set O_NONBLOCK flag for fd if nonblocking +// is true, otherwise it will clear the flag. +func (dut *DUT) SetNonBlocking(t *testing.T, fd int32, nonblocking bool) { + t.Helper() + + flags := dut.Fcntl(t, fd, unix.F_GETFL, 0) + if nonblocking { + flags |= unix.O_NONBLOCK + } else { + flags &= ^unix.O_NONBLOCK + } + dut.Fcntl(t, fd, unix.F_SETFL, flags) +} + +func (dut *DUT) setSockOpt(ctx context.Context, t *testing.T, sockfd, level, optname int32, optval *pb.SockOptVal) (int32, error) { + t.Helper() + + req := pb.SetSockOptRequest{ + Sockfd: sockfd, + Level: level, + Optname: optname, + Optval: optval, + } + resp, err := dut.posixServer.SetSockOpt(ctx, &req) + if err != nil { + t.Fatalf("failed to call SetSockOpt: %s", err) + } + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) +} + +// SetSockOpt calls setsockopt on the DUT and causes a fatal test failure if it +// doesn't succeed. If more control over the timeout or error handling is +// needed, use SetSockOptWithErrno. Because endianess and the width of values +// might differ between the testbench and DUT architectures, prefer to use a +// more specific SetSockOptXxx function. +func (dut *DUT) SetSockOpt(t *testing.T, sockfd, level, optname int32, optval []byte) { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, err := dut.SetSockOptWithErrno(ctx, t, sockfd, level, optname, optval) + if ret != 0 { + t.Fatalf("failed to SetSockOpt: %s", err) + } +} + +// SetSockOptWithErrno calls setsockopt on the DUT. Because endianess and the +// width of values might differ between the testbench and DUT architectures, +// prefer to use a more specific SetSockOptXxxWithErrno function. +func (dut *DUT) SetSockOptWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32, optval []byte) (int32, error) { + t.Helper() + + return dut.setSockOpt(ctx, t, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Bytesval{optval}}) +} + +// SetSockOptInt calls setsockopt on the DUT and causes a fatal test failure +// if it doesn't succeed. If more control over the int optval or error handling +// is needed, use SetSockOptIntWithErrno. +func (dut *DUT) SetSockOptInt(t *testing.T, sockfd, level, optname, optval int32) { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, err := dut.SetSockOptIntWithErrno(ctx, t, sockfd, level, optname, optval) + if ret != 0 { + t.Fatalf("failed to SetSockOptInt: %s", err) + } +} + +// SetSockOptIntWithErrno calls setsockopt with an integer optval. +func (dut *DUT) SetSockOptIntWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname, optval int32) (int32, error) { + t.Helper() + + return dut.setSockOpt(ctx, t, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Intval{optval}}) +} + +// SetSockOptTimeval calls setsockopt on the DUT and causes a fatal test failure +// if it doesn't succeed. If more control over the timeout or error handling is +// needed, use SetSockOptTimevalWithErrno. +func (dut *DUT) SetSockOptTimeval(t *testing.T, sockfd, level, optname int32, tv *unix.Timeval) { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, err := dut.SetSockOptTimevalWithErrno(ctx, t, sockfd, level, optname, tv) + if ret != 0 { + t.Fatalf("failed to SetSockOptTimeval: %s", err) + } +} + +// SetSockOptTimevalWithErrno calls setsockopt with the timeval converted to +// bytes. +func (dut *DUT) SetSockOptTimevalWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32, tv *unix.Timeval) (int32, error) { + t.Helper() + + timeval := pb.Timeval{ + Seconds: int64(tv.Sec), + Microseconds: int64(tv.Usec), + } + return dut.setSockOpt(ctx, t, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Timeval{&timeval}}) +} + +// Socket calls socket on the DUT and returns the file descriptor. If socket +// fails on the DUT, the test ends. +func (dut *DUT) Socket(t *testing.T, domain, typ, proto int32) int32 { + t.Helper() + + fd, err := dut.SocketWithErrno(t, domain, typ, proto) + if fd < 0 { + t.Fatalf("failed to create socket: %s", err) + } + return fd +} + +// SocketWithErrno calls socket on the DUT and returns the fd and errno. +func (dut *DUT) SocketWithErrno(t *testing.T, domain, typ, proto int32) (int32, error) { + t.Helper() + + req := pb.SocketRequest{ + Domain: domain, + Type: typ, + Protocol: proto, + } + ctx := context.Background() + resp, err := dut.posixServer.Socket(ctx, &req) + if err != nil { + t.Fatalf("failed to call Socket: %s", err) + } + return resp.GetFd(), syscall.Errno(resp.GetErrno_()) +} + +// Recv calls recv on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is needed, use +// RecvWithErrno. +func (dut *DUT) Recv(t *testing.T, sockfd, len, flags int32) []byte { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, buf, err := dut.RecvWithErrno(ctx, t, sockfd, len, flags) + if ret == -1 { + t.Fatalf("failed to recv: %s", err) + } + return buf +} + +// RecvWithErrno calls recv on the DUT. +func (dut *DUT) RecvWithErrno(ctx context.Context, t *testing.T, sockfd, len, flags int32) (int32, []byte, error) { + t.Helper() + + req := pb.RecvRequest{ + Sockfd: sockfd, + Len: len, + Flags: flags, + } + resp, err := dut.posixServer.Recv(ctx, &req) + if err != nil { + t.Fatalf("failed to call Recv: %s", err) + } + return resp.GetRet(), resp.GetBuf(), syscall.Errno(resp.GetErrno_()) +} diff --git a/test/packetimpact/testbench/dut_client.go b/test/packetimpact/testbench/dut_client.go new file mode 100644 index 000000000..d0e68c5da --- /dev/null +++ b/test/packetimpact/testbench/dut_client.go @@ -0,0 +1,28 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testbench + +import ( + "google.golang.org/grpc" + pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto" +) + +// PosixClient is a gRPC client for the Posix service. +type POSIXClient pb.PosixClient + +// NewPOSIXClient makes a new gRPC client for the POSIX service. +func NewPOSIXClient(c grpc.ClientConnInterface) POSIXClient { + return pb.NewPosixClient(c) +} diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go new file mode 100644 index 000000000..a35562ca8 --- /dev/null +++ b/test/packetimpact/testbench/layers.go @@ -0,0 +1,1506 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testbench + +import ( + "encoding/binary" + "encoding/hex" + "fmt" + "reflect" + "strings" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "go.uber.org/multierr" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +// Layer is the interface that all encapsulations must implement. +// +// A Layer is an encapsulation in a packet, such as TCP, IPv4, IPv6, etc. A +// Layer contains all the fields of the encapsulation. Each field is a pointer +// and may be nil. +type Layer interface { + fmt.Stringer + + // ToBytes converts the Layer into bytes. In places where the Layer's field + // isn't nil, the value that is pointed to is used. When the field is nil, a + // reasonable default for the Layer is used. For example, "64" for IPv4 TTL + // and a calculated checksum for TCP or IP. Some layers require information + // from the previous or next layers in order to compute a default, such as + // TCP's checksum or Ethernet's type, so each Layer has a doubly-linked list + // to the layer's neighbors. + ToBytes() ([]byte, error) + + // match checks if the current Layer matches the provided Layer. If either + // Layer has a nil in a given field, that field is considered matching. + // Otherwise, the values pointed to by the fields must match. The LayerBase is + // ignored. + match(Layer) bool + + // length in bytes of the current encapsulation + length() int + + // next gets a pointer to the encapsulated Layer. + next() Layer + + // prev gets a pointer to the Layer encapsulating this one. + Prev() Layer + + // setNext sets the pointer to the encapsulated Layer. + setNext(Layer) + + // setPrev sets the pointer to the Layer encapsulating this one. + setPrev(Layer) + + // merge overrides the values in the interface with the provided values. + merge(Layer) error +} + +// LayerBase is the common elements of all layers. +type LayerBase struct { + nextLayer Layer + prevLayer Layer +} + +func (lb *LayerBase) next() Layer { + return lb.nextLayer +} + +// Prev returns the previous layer. +func (lb *LayerBase) Prev() Layer { + return lb.prevLayer +} + +func (lb *LayerBase) setNext(l Layer) { + lb.nextLayer = l +} + +func (lb *LayerBase) setPrev(l Layer) { + lb.prevLayer = l +} + +// equalLayer compares that two Layer structs match while ignoring field in +// which either input has a nil and also ignoring the LayerBase of the inputs. +func equalLayer(x, y Layer) bool { + if x == nil || y == nil { + return true + } + // opt ignores comparison pairs where either of the inputs is a nil. + opt := cmp.FilterValues(func(x, y interface{}) bool { + for _, l := range []interface{}{x, y} { + v := reflect.ValueOf(l) + if (v.Kind() == reflect.Ptr || v.Kind() == reflect.Slice) && v.IsNil() { + return true + } + } + return false + }, cmp.Ignore()) + return cmp.Equal(x, y, opt, cmpopts.IgnoreTypes(LayerBase{})) +} + +// mergeLayer merges y into x. Any fields for which y has a non-nil value, that +// value overwrite the corresponding fields in x. +func mergeLayer(x, y Layer) error { + if y == nil { + return nil + } + if reflect.TypeOf(x) != reflect.TypeOf(y) { + return fmt.Errorf("can't merge %T into %T", y, x) + } + vx := reflect.ValueOf(x).Elem() + vy := reflect.ValueOf(y).Elem() + t := vy.Type() + for i := 0; i < vy.NumField(); i++ { + t := t.Field(i) + if t.Anonymous { + // Ignore the LayerBase in the Layer struct. + continue + } + v := vy.Field(i) + if v.IsNil() { + continue + } + vx.Field(i).Set(v) + } + return nil +} + +func stringLayer(l Layer) string { + v := reflect.ValueOf(l).Elem() + t := v.Type() + var ret []string + for i := 0; i < v.NumField(); i++ { + t := t.Field(i) + if t.Anonymous { + // Ignore the LayerBase in the Layer struct. + continue + } + v := v.Field(i) + if v.IsNil() { + continue + } + v = reflect.Indirect(v) + if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 { + ret = append(ret, fmt.Sprintf("%s:\n%v", t.Name, hex.Dump(v.Bytes()))) + } else { + ret = append(ret, fmt.Sprintf("%s:%v", t.Name, v)) + } + } + return fmt.Sprintf("&%s{%s}", t, strings.Join(ret, " ")) +} + +// Ether can construct and match an ethernet encapsulation. +type Ether struct { + LayerBase + SrcAddr *tcpip.LinkAddress + DstAddr *tcpip.LinkAddress + Type *tcpip.NetworkProtocolNumber +} + +func (l *Ether) String() string { + return stringLayer(l) +} + +// ToBytes implements Layer.ToBytes. +func (l *Ether) ToBytes() ([]byte, error) { + b := make([]byte, header.EthernetMinimumSize) + h := header.Ethernet(b) + fields := &header.EthernetFields{} + if l.SrcAddr != nil { + fields.SrcAddr = *l.SrcAddr + } + if l.DstAddr != nil { + fields.DstAddr = *l.DstAddr + } + if l.Type != nil { + fields.Type = *l.Type + } else { + switch n := l.next().(type) { + case *IPv4: + fields.Type = header.IPv4ProtocolNumber + case *IPv6: + fields.Type = header.IPv6ProtocolNumber + default: + return nil, fmt.Errorf("ethernet header's next layer is unrecognized: %#v", n) + } + } + h.Encode(fields) + return h, nil +} + +// LinkAddress is a helper routine that allocates a new tcpip.LinkAddress value +// to store v and returns a pointer to it. +func LinkAddress(v tcpip.LinkAddress) *tcpip.LinkAddress { + return &v +} + +// NetworkProtocolNumber is a helper routine that allocates a new +// tcpip.NetworkProtocolNumber value to store v and returns a pointer to it. +func NetworkProtocolNumber(v tcpip.NetworkProtocolNumber) *tcpip.NetworkProtocolNumber { + return &v +} + +// layerParser parses the input bytes and returns a Layer along with the next +// layerParser to run. If there is no more parsing to do, the returned +// layerParser is nil. +type layerParser func([]byte) (Layer, layerParser) + +// parse parses bytes starting with the first layerParser and using successive +// layerParsers until all the bytes are parsed. +func parse(parser layerParser, b []byte) Layers { + var layers Layers + for { + var layer Layer + layer, parser = parser(b) + layers = append(layers, layer) + if parser == nil { + break + } + b = b[layer.length():] + } + layers.linkLayers() + return layers +} + +// parseEther parses the bytes assuming that they start with an ethernet header +// and continues parsing further encapsulations. +func parseEther(b []byte) (Layer, layerParser) { + h := header.Ethernet(b) + ether := Ether{ + SrcAddr: LinkAddress(h.SourceAddress()), + DstAddr: LinkAddress(h.DestinationAddress()), + Type: NetworkProtocolNumber(h.Type()), + } + var nextParser layerParser + switch h.Type() { + case header.IPv4ProtocolNumber: + nextParser = parseIPv4 + case header.IPv6ProtocolNumber: + nextParser = parseIPv6 + default: + // Assume that the rest is a payload. + nextParser = parsePayload + } + return ðer, nextParser +} + +func (l *Ether) match(other Layer) bool { + return equalLayer(l, other) +} + +func (l *Ether) length() int { + return header.EthernetMinimumSize +} + +// merge implements Layer.merge. +func (l *Ether) merge(other Layer) error { + return mergeLayer(l, other) +} + +// IPv4 can construct and match an IPv4 encapsulation. +type IPv4 struct { + LayerBase + IHL *uint8 + TOS *uint8 + TotalLength *uint16 + ID *uint16 + Flags *uint8 + FragmentOffset *uint16 + TTL *uint8 + Protocol *uint8 + Checksum *uint16 + SrcAddr *tcpip.Address + DstAddr *tcpip.Address +} + +func (l *IPv4) String() string { + return stringLayer(l) +} + +// ToBytes implements Layer.ToBytes. +func (l *IPv4) ToBytes() ([]byte, error) { + b := make([]byte, header.IPv4MinimumSize) + h := header.IPv4(b) + fields := &header.IPv4Fields{ + IHL: 20, + TOS: 0, + TotalLength: 0, + ID: 0, + Flags: 0, + FragmentOffset: 0, + TTL: 64, + Protocol: 0, + Checksum: 0, + SrcAddr: tcpip.Address(""), + DstAddr: tcpip.Address(""), + } + if l.TOS != nil { + fields.TOS = *l.TOS + } + if l.TotalLength != nil { + fields.TotalLength = *l.TotalLength + } else { + fields.TotalLength = uint16(l.length()) + current := l.next() + for current != nil { + fields.TotalLength += uint16(current.length()) + current = current.next() + } + } + if l.ID != nil { + fields.ID = *l.ID + } + if l.Flags != nil { + fields.Flags = *l.Flags + } + if l.FragmentOffset != nil { + fields.FragmentOffset = *l.FragmentOffset + } + if l.TTL != nil { + fields.TTL = *l.TTL + } + if l.Protocol != nil { + fields.Protocol = *l.Protocol + } else { + switch n := l.next().(type) { + case *TCP: + fields.Protocol = uint8(header.TCPProtocolNumber) + case *UDP: + fields.Protocol = uint8(header.UDPProtocolNumber) + case *ICMPv4: + fields.Protocol = uint8(header.ICMPv4ProtocolNumber) + default: + // TODO(b/150301488): Support more protocols as needed. + return nil, fmt.Errorf("ipv4 header's next layer is unrecognized: %#v", n) + } + } + if l.SrcAddr != nil { + fields.SrcAddr = *l.SrcAddr + } + if l.DstAddr != nil { + fields.DstAddr = *l.DstAddr + } + if l.Checksum != nil { + fields.Checksum = *l.Checksum + } + h.Encode(fields) + if l.Checksum == nil { + h.SetChecksum(^h.CalculateChecksum()) + } + return h, nil +} + +// Uint16 is a helper routine that allocates a new +// uint16 value to store v and returns a pointer to it. +func Uint16(v uint16) *uint16 { + return &v +} + +// Uint8 is a helper routine that allocates a new +// uint8 value to store v and returns a pointer to it. +func Uint8(v uint8) *uint8 { + return &v +} + +// Address is a helper routine that allocates a new tcpip.Address value to store +// v and returns a pointer to it. +func Address(v tcpip.Address) *tcpip.Address { + return &v +} + +// parseIPv4 parses the bytes assuming that they start with an ipv4 header and +// continues parsing further encapsulations. +func parseIPv4(b []byte) (Layer, layerParser) { + h := header.IPv4(b) + tos, _ := h.TOS() + ipv4 := IPv4{ + IHL: Uint8(h.HeaderLength()), + TOS: &tos, + TotalLength: Uint16(h.TotalLength()), + ID: Uint16(h.ID()), + Flags: Uint8(h.Flags()), + FragmentOffset: Uint16(h.FragmentOffset()), + TTL: Uint8(h.TTL()), + Protocol: Uint8(h.Protocol()), + Checksum: Uint16(h.Checksum()), + SrcAddr: Address(h.SourceAddress()), + DstAddr: Address(h.DestinationAddress()), + } + var nextParser layerParser + switch h.TransportProtocol() { + case header.TCPProtocolNumber: + nextParser = parseTCP + case header.UDPProtocolNumber: + nextParser = parseUDP + case header.ICMPv4ProtocolNumber: + nextParser = parseICMPv4 + default: + // Assume that the rest is a payload. + nextParser = parsePayload + } + return &ipv4, nextParser +} + +func (l *IPv4) match(other Layer) bool { + return equalLayer(l, other) +} + +func (l *IPv4) length() int { + if l.IHL == nil { + return header.IPv4MinimumSize + } + return int(*l.IHL) +} + +// merge implements Layer.merge. +func (l *IPv4) merge(other Layer) error { + return mergeLayer(l, other) +} + +// IPv6 can construct and match an IPv6 encapsulation. +type IPv6 struct { + LayerBase + TrafficClass *uint8 + FlowLabel *uint32 + PayloadLength *uint16 + NextHeader *uint8 + HopLimit *uint8 + SrcAddr *tcpip.Address + DstAddr *tcpip.Address +} + +func (l *IPv6) String() string { + return stringLayer(l) +} + +// ToBytes implements Layer.ToBytes. +func (l *IPv6) ToBytes() ([]byte, error) { + b := make([]byte, header.IPv6MinimumSize) + h := header.IPv6(b) + fields := &header.IPv6Fields{ + HopLimit: 64, + } + if l.TrafficClass != nil { + fields.TrafficClass = *l.TrafficClass + } + if l.FlowLabel != nil { + fields.FlowLabel = *l.FlowLabel + } + if l.PayloadLength != nil { + fields.PayloadLength = *l.PayloadLength + } else { + for current := l.next(); current != nil; current = current.next() { + fields.PayloadLength += uint16(current.length()) + } + } + if l.NextHeader != nil { + fields.NextHeader = *l.NextHeader + } else { + nh, err := nextHeaderByLayer(l.next()) + if err != nil { + return nil, err + } + fields.NextHeader = nh + } + if l.HopLimit != nil { + fields.HopLimit = *l.HopLimit + } + if l.SrcAddr != nil { + fields.SrcAddr = *l.SrcAddr + } + if l.DstAddr != nil { + fields.DstAddr = *l.DstAddr + } + h.Encode(fields) + return h, nil +} + +// nextIPv6PayloadParser finds the corresponding parser for nextHeader. +func nextIPv6PayloadParser(nextHeader uint8) layerParser { + switch tcpip.TransportProtocolNumber(nextHeader) { + case header.TCPProtocolNumber: + return parseTCP + case header.UDPProtocolNumber: + return parseUDP + case header.ICMPv6ProtocolNumber: + return parseICMPv6 + } + switch header.IPv6ExtensionHeaderIdentifier(nextHeader) { + case header.IPv6HopByHopOptionsExtHdrIdentifier: + return parseIPv6HopByHopOptionsExtHdr + case header.IPv6DestinationOptionsExtHdrIdentifier: + return parseIPv6DestinationOptionsExtHdr + case header.IPv6FragmentExtHdrIdentifier: + return parseIPv6FragmentExtHdr + } + return parsePayload +} + +// parseIPv6 parses the bytes assuming that they start with an ipv6 header and +// continues parsing further encapsulations. +func parseIPv6(b []byte) (Layer, layerParser) { + h := header.IPv6(b) + tos, flowLabel := h.TOS() + ipv6 := IPv6{ + TrafficClass: &tos, + FlowLabel: &flowLabel, + PayloadLength: Uint16(h.PayloadLength()), + NextHeader: Uint8(h.NextHeader()), + HopLimit: Uint8(h.HopLimit()), + SrcAddr: Address(h.SourceAddress()), + DstAddr: Address(h.DestinationAddress()), + } + nextParser := nextIPv6PayloadParser(h.NextHeader()) + return &ipv6, nextParser +} + +func (l *IPv6) match(other Layer) bool { + return equalLayer(l, other) +} + +func (l *IPv6) length() int { + return header.IPv6MinimumSize +} + +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *IPv6) merge(other Layer) error { + return mergeLayer(l, other) +} + +// IPv6HopByHopOptionsExtHdr can construct and match an IPv6HopByHopOptions +// Extension Header. +type IPv6HopByHopOptionsExtHdr struct { + LayerBase + NextHeader *header.IPv6ExtensionHeaderIdentifier + Options []byte +} + +// IPv6DestinationOptionsExtHdr can construct and match an IPv6DestinationOptions +// Extension Header. +type IPv6DestinationOptionsExtHdr struct { + LayerBase + NextHeader *header.IPv6ExtensionHeaderIdentifier + Options []byte +} + +// IPv6FragmentExtHdr can construct and match an IPv6 Fragment Extension Header. +type IPv6FragmentExtHdr struct { + LayerBase + NextHeader *header.IPv6ExtensionHeaderIdentifier + FragmentOffset *uint16 + MoreFragments *bool + Identification *uint32 +} + +// nextHeaderByLayer finds the correct next header protocol value for layer l. +func nextHeaderByLayer(l Layer) (uint8, error) { + if l == nil { + return uint8(header.IPv6NoNextHeaderIdentifier), nil + } + switch l.(type) { + case *TCP: + return uint8(header.TCPProtocolNumber), nil + case *UDP: + return uint8(header.UDPProtocolNumber), nil + case *ICMPv6: + return uint8(header.ICMPv6ProtocolNumber), nil + case *Payload: + return uint8(header.IPv6NoNextHeaderIdentifier), nil + case *IPv6HopByHopOptionsExtHdr: + return uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), nil + case *IPv6DestinationOptionsExtHdr: + return uint8(header.IPv6DestinationOptionsExtHdrIdentifier), nil + case *IPv6FragmentExtHdr: + return uint8(header.IPv6FragmentExtHdrIdentifier), nil + default: + // TODO(b/161005083): Support more protocols as needed. + return 0, fmt.Errorf("failed to deduce the IPv6 header's next protocol: %T", l) + } +} + +// ipv6OptionsExtHdrToBytes serializes an options extension header into bytes. +func ipv6OptionsExtHdrToBytes(nextHeader *header.IPv6ExtensionHeaderIdentifier, nextLayer Layer, options []byte) ([]byte, error) { + length := len(options) + 2 + if length%8 != 0 { + return nil, fmt.Errorf("IPv6 extension headers must be a multiple of 8 octets long, but the length given: %d, options: %s", length, hex.Dump(options)) + } + bytes := make([]byte, length) + if nextHeader != nil { + bytes[0] = byte(*nextHeader) + } else { + nh, err := nextHeaderByLayer(nextLayer) + if err != nil { + return nil, err + } + bytes[0] = nh + } + // ExtHdrLen field is the length of the extension header + // in 8-octet unit, ignoring the first 8 octets. + // https://tools.ietf.org/html/rfc2460#section-4.3 + // https://tools.ietf.org/html/rfc2460#section-4.6 + bytes[1] = uint8((length - 8) / 8) + copy(bytes[2:], options) + return bytes, nil +} + +// IPv6ExtHdrIdent is a helper routine that allocates a new +// header.IPv6ExtensionHeaderIdentifier value to store v and returns a pointer +// to it. +func IPv6ExtHdrIdent(id header.IPv6ExtensionHeaderIdentifier) *header.IPv6ExtensionHeaderIdentifier { + return &id +} + +// ToBytes implements Layer.ToBytes. +func (l *IPv6HopByHopOptionsExtHdr) ToBytes() ([]byte, error) { + return ipv6OptionsExtHdrToBytes(l.NextHeader, l.next(), l.Options) +} + +// ToBytes implements Layer.ToBytes. +func (l *IPv6DestinationOptionsExtHdr) ToBytes() ([]byte, error) { + return ipv6OptionsExtHdrToBytes(l.NextHeader, l.next(), l.Options) +} + +// ToBytes implements Layer.ToBytes. +func (l *IPv6FragmentExtHdr) ToBytes() ([]byte, error) { + var offset, mflag uint16 + var ident uint32 + bytes := make([]byte, header.IPv6FragmentExtHdrLength) + if l.NextHeader != nil { + bytes[0] = byte(*l.NextHeader) + } else { + nh, err := nextHeaderByLayer(l.next()) + if err != nil { + return nil, err + } + bytes[0] = nh + } + bytes[1] = 0 // reserved + if l.MoreFragments != nil && *l.MoreFragments { + mflag = 1 + } + if l.FragmentOffset != nil { + offset = *l.FragmentOffset + } + if l.Identification != nil { + ident = *l.Identification + } + offsetAndMflag := offset<<3 | mflag + binary.BigEndian.PutUint16(bytes[2:], offsetAndMflag) + binary.BigEndian.PutUint32(bytes[4:], ident) + + return bytes, nil +} + +// parseIPv6ExtHdr parses an IPv6 extension header and returns the NextHeader +// field, the rest of the payload and a parser function for the corresponding +// next extension header. +func parseIPv6ExtHdr(b []byte) (header.IPv6ExtensionHeaderIdentifier, []byte, layerParser) { + nextHeader := b[0] + // For HopByHop and Destination options extension headers, + // This field is the length of the extension header in + // 8-octet units, not including the first 8 octets. + // https://tools.ietf.org/html/rfc2460#section-4.3 + // https://tools.ietf.org/html/rfc2460#section-4.6 + length := b[1]*8 + 8 + data := b[2:length] + nextParser := nextIPv6PayloadParser(nextHeader) + return header.IPv6ExtensionHeaderIdentifier(nextHeader), data, nextParser +} + +// parseIPv6HopByHopOptionsExtHdr parses the bytes assuming that they start +// with an IPv6 HopByHop Options Extension Header. +func parseIPv6HopByHopOptionsExtHdr(b []byte) (Layer, layerParser) { + nextHeader, options, nextParser := parseIPv6ExtHdr(b) + return &IPv6HopByHopOptionsExtHdr{NextHeader: &nextHeader, Options: options}, nextParser +} + +// parseIPv6DestinationOptionsExtHdr parses the bytes assuming that they start +// with an IPv6 Destination Options Extension Header. +func parseIPv6DestinationOptionsExtHdr(b []byte) (Layer, layerParser) { + nextHeader, options, nextParser := parseIPv6ExtHdr(b) + return &IPv6DestinationOptionsExtHdr{NextHeader: &nextHeader, Options: options}, nextParser +} + +// Bool is a helper routine that allocates a new +// bool value to store v and returns a pointer to it. +func Bool(v bool) *bool { + return &v +} + +// parseIPv6FragmentExtHdr parses the bytes assuming that they start +// with an IPv6 Fragment Extension Header. +func parseIPv6FragmentExtHdr(b []byte) (Layer, layerParser) { + nextHeader := b[0] + var extHdr header.IPv6FragmentExtHdr + copy(extHdr[:], b[2:]) + return &IPv6FragmentExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6ExtensionHeaderIdentifier(nextHeader)), + FragmentOffset: Uint16(extHdr.FragmentOffset()), + MoreFragments: Bool(extHdr.More()), + Identification: Uint32(extHdr.ID()), + }, nextIPv6PayloadParser(nextHeader) +} + +func (l *IPv6HopByHopOptionsExtHdr) length() int { + return len(l.Options) + 2 +} + +func (l *IPv6HopByHopOptionsExtHdr) match(other Layer) bool { + return equalLayer(l, other) +} + +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *IPv6HopByHopOptionsExtHdr) merge(other Layer) error { + return mergeLayer(l, other) +} + +func (l *IPv6HopByHopOptionsExtHdr) String() string { + return stringLayer(l) +} + +func (l *IPv6DestinationOptionsExtHdr) length() int { + return len(l.Options) + 2 +} + +func (l *IPv6DestinationOptionsExtHdr) match(other Layer) bool { + return equalLayer(l, other) +} + +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *IPv6DestinationOptionsExtHdr) merge(other Layer) error { + return mergeLayer(l, other) +} + +func (l *IPv6DestinationOptionsExtHdr) String() string { + return stringLayer(l) +} + +func (*IPv6FragmentExtHdr) length() int { + return header.IPv6FragmentExtHdrLength +} + +func (l *IPv6FragmentExtHdr) match(other Layer) bool { + return equalLayer(l, other) +} + +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *IPv6FragmentExtHdr) merge(other Layer) error { + return mergeLayer(l, other) +} + +func (l *IPv6FragmentExtHdr) String() string { + return stringLayer(l) +} + +// ICMPv6 can construct and match an ICMPv6 encapsulation. +type ICMPv6 struct { + LayerBase + Type *header.ICMPv6Type + Code *header.ICMPv6Code + Checksum *uint16 + Payload []byte +} + +func (l *ICMPv6) String() string { + // TODO(eyalsoha): Do something smarter here when *l.Type is ParameterProblem? + // We could parse the contents of the Payload as if it were an IPv6 packet. + return stringLayer(l) +} + +// ToBytes implements Layer.ToBytes. +func (l *ICMPv6) ToBytes() ([]byte, error) { + b := make([]byte, header.ICMPv6HeaderSize+len(l.Payload)) + h := header.ICMPv6(b) + if l.Type != nil { + h.SetType(*l.Type) + } + if l.Code != nil { + h.SetCode(*l.Code) + } + copy(h.NDPPayload(), l.Payload) + if l.Checksum != nil { + h.SetChecksum(*l.Checksum) + } else { + // It is possible that the ICMPv6 header does not follow the IPv6 header + // immediately, there could be one or more extension headers in between. + // We need to search forward to find the IPv6 header. + for prev := l.Prev(); prev != nil; prev = prev.Prev() { + if ipv6, ok := prev.(*IPv6); ok { + payload, err := payload(l) + if err != nil { + return nil, err + } + h.SetChecksum(header.ICMPv6Checksum(h, *ipv6.SrcAddr, *ipv6.DstAddr, payload)) + break + } + } + } + return h, nil +} + +// ICMPv6Type is a helper routine that allocates a new ICMPv6Type value to store +// v and returns a pointer to it. +func ICMPv6Type(v header.ICMPv6Type) *header.ICMPv6Type { + return &v +} + +// ICMPv6Code is a helper routine that allocates a new ICMPv6Type value to store +// v and returns a pointer to it. +func ICMPv6Code(v header.ICMPv6Code) *header.ICMPv6Code { + return &v +} + +// Byte is a helper routine that allocates a new byte value to store +// v and returns a pointer to it. +func Byte(v byte) *byte { + return &v +} + +// parseICMPv6 parses the bytes assuming that they start with an ICMPv6 header. +func parseICMPv6(b []byte) (Layer, layerParser) { + h := header.ICMPv6(b) + icmpv6 := ICMPv6{ + Type: ICMPv6Type(h.Type()), + Code: ICMPv6Code(h.Code()), + Checksum: Uint16(h.Checksum()), + Payload: h.NDPPayload(), + } + return &icmpv6, nil +} + +func (l *ICMPv6) match(other Layer) bool { + return equalLayer(l, other) +} + +func (l *ICMPv6) length() int { + return header.ICMPv6HeaderSize + len(l.Payload) +} + +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *ICMPv6) merge(other Layer) error { + return mergeLayer(l, other) +} + +// ICMPv4Type is a helper routine that allocates a new header.ICMPv4Type value +// to store t and returns a pointer to it. +func ICMPv4Type(t header.ICMPv4Type) *header.ICMPv4Type { + return &t +} + +// ICMPv4Code is a helper routine that allocates a new header.ICMPv4Code value +// to store t and returns a pointer to it. +func ICMPv4Code(t header.ICMPv4Code) *header.ICMPv4Code { + return &t +} + +// ICMPv4 can construct and match an ICMPv4 encapsulation. +type ICMPv4 struct { + LayerBase + Type *header.ICMPv4Type + Code *header.ICMPv4Code + Checksum *uint16 +} + +func (l *ICMPv4) String() string { + return stringLayer(l) +} + +// ToBytes implements Layer.ToBytes. +func (l *ICMPv4) ToBytes() ([]byte, error) { + b := make([]byte, header.ICMPv4MinimumSize) + h := header.ICMPv4(b) + if l.Type != nil { + h.SetType(*l.Type) + } + if l.Code != nil { + h.SetCode(*l.Code) + } + if l.Checksum != nil { + h.SetChecksum(*l.Checksum) + return h, nil + } + payload, err := payload(l) + if err != nil { + return nil, err + } + h.SetChecksum(header.ICMPv4Checksum(h, payload)) + return h, nil +} + +// parseICMPv4 parses the bytes as an ICMPv4 header, returning a Layer and a +// parser for the encapsulated payload. +func parseICMPv4(b []byte) (Layer, layerParser) { + h := header.ICMPv4(b) + icmpv4 := ICMPv4{ + Type: ICMPv4Type(h.Type()), + Code: ICMPv4Code(h.Code()), + Checksum: Uint16(h.Checksum()), + } + return &icmpv4, parsePayload +} + +func (l *ICMPv4) match(other Layer) bool { + return equalLayer(l, other) +} + +func (l *ICMPv4) length() int { + return header.ICMPv4MinimumSize +} + +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *ICMPv4) merge(other Layer) error { + return mergeLayer(l, other) +} + +// TCP can construct and match a TCP encapsulation. +type TCP struct { + LayerBase + SrcPort *uint16 + DstPort *uint16 + SeqNum *uint32 + AckNum *uint32 + DataOffset *uint8 + Flags *uint8 + WindowSize *uint16 + Checksum *uint16 + UrgentPointer *uint16 + Options []byte +} + +func (l *TCP) String() string { + return stringLayer(l) +} + +// ToBytes implements Layer.ToBytes. +func (l *TCP) ToBytes() ([]byte, error) { + b := make([]byte, l.length()) + h := header.TCP(b) + if l.SrcPort != nil { + h.SetSourcePort(*l.SrcPort) + } + if l.DstPort != nil { + h.SetDestinationPort(*l.DstPort) + } + if l.SeqNum != nil { + h.SetSequenceNumber(*l.SeqNum) + } + if l.AckNum != nil { + h.SetAckNumber(*l.AckNum) + } + if l.DataOffset != nil { + h.SetDataOffset(*l.DataOffset) + } else { + h.SetDataOffset(uint8(l.length())) + } + if l.Flags != nil { + h.SetFlags(*l.Flags) + } + if l.WindowSize != nil { + h.SetWindowSize(*l.WindowSize) + } else { + h.SetWindowSize(32768) + } + if l.UrgentPointer != nil { + h.SetUrgentPoiner(*l.UrgentPointer) + } + copy(b[header.TCPMinimumSize:], l.Options) + header.AddTCPOptionPadding(b[header.TCPMinimumSize:], len(l.Options)) + if l.Checksum != nil { + h.SetChecksum(*l.Checksum) + return h, nil + } + if err := setTCPChecksum(&h, l); err != nil { + return nil, err + } + return h, nil +} + +// totalLength returns the length of the provided layer and all following +// layers. +func totalLength(l Layer) int { + var totalLength int + for ; l != nil; l = l.next() { + totalLength += l.length() + } + return totalLength +} + +// payload returns a buffer.VectorisedView of l's payload. +func payload(l Layer) (buffer.VectorisedView, error) { + var payloadBytes buffer.VectorisedView + for current := l.next(); current != nil; current = current.next() { + payload, err := current.ToBytes() + if err != nil { + return buffer.VectorisedView{}, fmt.Errorf("can't get bytes for next header: %s", payload) + } + payloadBytes.AppendView(payload) + } + return payloadBytes, nil +} + +// layerChecksum calculates the checksum of the Layer header, including the +// peusdeochecksum of the layer before it and all the bytes after it. +func layerChecksum(l Layer, protoNumber tcpip.TransportProtocolNumber) (uint16, error) { + totalLength := uint16(totalLength(l)) + var xsum uint16 + switch p := l.Prev().(type) { + case *IPv4: + xsum = header.PseudoHeaderChecksum(protoNumber, *p.SrcAddr, *p.DstAddr, totalLength) + case *IPv6: + xsum = header.PseudoHeaderChecksum(protoNumber, *p.SrcAddr, *p.DstAddr, totalLength) + default: + // TODO(b/161246171): Support more protocols. + return 0, fmt.Errorf("checksum for protocol %d is not supported when previous layer is %T", protoNumber, p) + } + payloadBytes, err := payload(l) + if err != nil { + return 0, err + } + xsum = header.ChecksumVV(payloadBytes, xsum) + return xsum, nil +} + +// setTCPChecksum calculates the checksum of the TCP header and sets it in h. +func setTCPChecksum(h *header.TCP, tcp *TCP) error { + h.SetChecksum(0) + xsum, err := layerChecksum(tcp, header.TCPProtocolNumber) + if err != nil { + return err + } + h.SetChecksum(^h.CalculateChecksum(xsum)) + return nil +} + +// Uint32 is a helper routine that allocates a new +// uint32 value to store v and returns a pointer to it. +func Uint32(v uint32) *uint32 { + return &v +} + +// parseTCP parses the bytes assuming that they start with a tcp header and +// continues parsing further encapsulations. +func parseTCP(b []byte) (Layer, layerParser) { + h := header.TCP(b) + tcp := TCP{ + SrcPort: Uint16(h.SourcePort()), + DstPort: Uint16(h.DestinationPort()), + SeqNum: Uint32(h.SequenceNumber()), + AckNum: Uint32(h.AckNumber()), + DataOffset: Uint8(h.DataOffset()), + Flags: Uint8(h.Flags()), + WindowSize: Uint16(h.WindowSize()), + Checksum: Uint16(h.Checksum()), + UrgentPointer: Uint16(h.UrgentPointer()), + Options: b[header.TCPMinimumSize:h.DataOffset()], + } + return &tcp, parsePayload +} + +func (l *TCP) match(other Layer) bool { + return equalLayer(l, other) +} + +func (l *TCP) length() int { + if l.DataOffset == nil { + // TCP header including the options must end on a 32-bit + // boundary; the user could potentially give us a slice + // whose length is not a multiple of 4 bytes, so we have + // to do the alignment here. + optlen := (len(l.Options) + 3) & ^3 + return header.TCPMinimumSize + optlen + } + return int(*l.DataOffset) +} + +// merge implements Layer.merge. +func (l *TCP) merge(other Layer) error { + return mergeLayer(l, other) +} + +// UDP can construct and match a UDP encapsulation. +type UDP struct { + LayerBase + SrcPort *uint16 + DstPort *uint16 + Length *uint16 + Checksum *uint16 +} + +func (l *UDP) String() string { + return stringLayer(l) +} + +// ToBytes implements Layer.ToBytes. +func (l *UDP) ToBytes() ([]byte, error) { + b := make([]byte, header.UDPMinimumSize) + h := header.UDP(b) + if l.SrcPort != nil { + h.SetSourcePort(*l.SrcPort) + } + if l.DstPort != nil { + h.SetDestinationPort(*l.DstPort) + } + if l.Length != nil { + h.SetLength(*l.Length) + } else { + h.SetLength(uint16(totalLength(l))) + } + if l.Checksum != nil { + h.SetChecksum(*l.Checksum) + return h, nil + } + if err := setUDPChecksum(&h, l); err != nil { + return nil, err + } + return h, nil +} + +// setUDPChecksum calculates the checksum of the UDP header and sets it in h. +func setUDPChecksum(h *header.UDP, udp *UDP) error { + h.SetChecksum(0) + xsum, err := layerChecksum(udp, header.UDPProtocolNumber) + if err != nil { + return err + } + h.SetChecksum(^h.CalculateChecksum(xsum)) + return nil +} + +// parseUDP parses the bytes assuming that they start with a udp header and +// returns the parsed layer and the next parser to use. +func parseUDP(b []byte) (Layer, layerParser) { + h := header.UDP(b) + udp := UDP{ + SrcPort: Uint16(h.SourcePort()), + DstPort: Uint16(h.DestinationPort()), + Length: Uint16(h.Length()), + Checksum: Uint16(h.Checksum()), + } + return &udp, parsePayload +} + +func (l *UDP) match(other Layer) bool { + return equalLayer(l, other) +} + +func (l *UDP) length() int { + return header.UDPMinimumSize +} + +// merge implements Layer.merge. +func (l *UDP) merge(other Layer) error { + return mergeLayer(l, other) +} + +// Payload has bytes beyond OSI layer 4. +type Payload struct { + LayerBase + Bytes []byte +} + +func (l *Payload) String() string { + return stringLayer(l) +} + +// parsePayload parses the bytes assuming that they start with a payload and +// continue to the end. There can be no further encapsulations. +func parsePayload(b []byte) (Layer, layerParser) { + payload := Payload{ + Bytes: b, + } + return &payload, nil +} + +// ToBytes implements Layer.ToBytes. +func (l *Payload) ToBytes() ([]byte, error) { + return l.Bytes, nil +} + +// Length returns payload byte length. +func (l *Payload) Length() int { + return l.length() +} + +func (l *Payload) match(other Layer) bool { + return equalLayer(l, other) +} + +func (l *Payload) length() int { + return len(l.Bytes) +} + +// merge implements Layer.merge. +func (l *Payload) merge(other Layer) error { + return mergeLayer(l, other) +} + +// Layers is an array of Layer and supports similar functions to Layer. +type Layers []Layer + +// linkLayers sets the linked-list ponters in ls. +func (ls *Layers) linkLayers() { + for i, l := range *ls { + if i > 0 { + l.setPrev((*ls)[i-1]) + } else { + l.setPrev(nil) + } + if i+1 < len(*ls) { + l.setNext((*ls)[i+1]) + } else { + l.setNext(nil) + } + } +} + +// ToBytes converts the Layers into bytes. It creates a linked list of the Layer +// structs and then concatentates the output of ToBytes on each Layer. +func (ls *Layers) ToBytes() ([]byte, error) { + ls.linkLayers() + outBytes := []byte{} + for _, l := range *ls { + layerBytes, err := l.ToBytes() + if err != nil { + return nil, err + } + outBytes = append(outBytes, layerBytes...) + } + return outBytes, nil +} + +func (ls *Layers) match(other Layers) bool { + if len(*ls) > len(other) { + return false + } + for i, l := range *ls { + if !equalLayer(l, other[i]) { + return false + } + } + return true +} + +// layerDiff stores the diffs for each field along with the label for the Layer. +// If rows is nil, that means that there was no diff. +type layerDiff struct { + label string + rows []layerDiffRow +} + +// layerDiffRow stores the fields and corresponding values for two got and want +// layers. If the value was nil then the string stored is the empty string. +type layerDiffRow struct { + field, got, want string +} + +// diffLayer extracts all differing fields between two layers. +func diffLayer(got, want Layer) []layerDiffRow { + vGot := reflect.ValueOf(got).Elem() + vWant := reflect.ValueOf(want).Elem() + if vGot.Type() != vWant.Type() { + return nil + } + t := vGot.Type() + var result []layerDiffRow + for i := 0; i < t.NumField(); i++ { + t := t.Field(i) + if t.Anonymous { + // Ignore the LayerBase in the Layer struct. + continue + } + vGot := vGot.Field(i) + vWant := vWant.Field(i) + gotString := "" + if !vGot.IsNil() { + gotString = fmt.Sprint(reflect.Indirect(vGot)) + } + wantString := "" + if !vWant.IsNil() { + wantString = fmt.Sprint(reflect.Indirect(vWant)) + } + result = append(result, layerDiffRow{t.Name, gotString, wantString}) + } + return result +} + +// layerType returns a concise string describing the type of the Layer, like +// "TCP", or "IPv6". +func layerType(l Layer) string { + return reflect.TypeOf(l).Elem().Name() +} + +// diff compares Layers and returns a representation of the difference. Each +// Layer in the Layers is pairwise compared. If an element in either is nil, it +// is considered a match with the other Layer. If two Layers have differing +// types, they don't match regardless of the contents. If two Layers have the +// same type then the fields in the Layer are pairwise compared. Fields that are +// nil always match. Two non-nil fields only match if they point to equal +// values. diff returns an empty string if and only if *ls and other match. +func (ls *Layers) diff(other Layers) string { + var allDiffs []layerDiff + // Check the cases where one list is longer than the other, where one or both + // elements are nil, where the sides have different types, and where the sides + // have the same type. + for i := 0; i < len(*ls) || i < len(other); i++ { + if i >= len(*ls) { + // Matching ls against other where other is longer than ls. missing + // matches everything so we just include a label without any rows. Having + // no rows is a sign that there was no diff. + allDiffs = append(allDiffs, layerDiff{ + label: "missing matches " + layerType(other[i]), + }) + continue + } + + if i >= len(other) { + // Matching ls against other where ls is longer than other. missing + // matches everything so we just include a label without any rows. Having + // no rows is a sign that there was no diff. + allDiffs = append(allDiffs, layerDiff{ + label: layerType((*ls)[i]) + " matches missing", + }) + continue + } + + if (*ls)[i] == nil && other[i] == nil { + // Matching ls against other where both elements are nil. nil matches + // everything so we just include a label without any rows. Having no rows + // is a sign that there was no diff. + allDiffs = append(allDiffs, layerDiff{ + label: "nil matches nil", + }) + continue + } + + if (*ls)[i] == nil { + // Matching ls against other where the element in ls is nil. nil matches + // everything so we just include a label without any rows. Having no rows + // is a sign that there was no diff. + allDiffs = append(allDiffs, layerDiff{ + label: "nil matches " + layerType(other[i]), + }) + continue + } + + if other[i] == nil { + // Matching ls against other where the element in other is nil. nil + // matches everything so we just include a label without any rows. Having + // no rows is a sign that there was no diff. + allDiffs = append(allDiffs, layerDiff{ + label: layerType((*ls)[i]) + " matches nil", + }) + continue + } + + if reflect.TypeOf((*ls)[i]) == reflect.TypeOf(other[i]) { + // Matching ls against other where both elements have the same type. Match + // each field pairwise and only report a diff if there is a mismatch, + // which is only when both sides are non-nil and have differring values. + diff := diffLayer((*ls)[i], other[i]) + var layerDiffRows []layerDiffRow + for _, d := range diff { + if d.got == "" || d.want == "" || d.got == d.want { + continue + } + layerDiffRows = append(layerDiffRows, layerDiffRow{ + d.field, + d.got, + d.want, + }) + } + if len(layerDiffRows) > 0 { + allDiffs = append(allDiffs, layerDiff{ + label: layerType((*ls)[i]), + rows: layerDiffRows, + }) + } else { + allDiffs = append(allDiffs, layerDiff{ + label: layerType((*ls)[i]) + " matches " + layerType(other[i]), + // Having no rows is a sign that there was no diff. + }) + } + continue + } + // Neither side is nil and the types are different, so we'll display one + // side then the other. + allDiffs = append(allDiffs, layerDiff{ + label: layerType((*ls)[i]) + " doesn't match " + layerType(other[i]), + }) + diff := diffLayer((*ls)[i], (*ls)[i]) + layerDiffRows := []layerDiffRow{} + for _, d := range diff { + if len(d.got) == 0 { + continue + } + layerDiffRows = append(layerDiffRows, layerDiffRow{ + d.field, + d.got, + "", + }) + } + allDiffs = append(allDiffs, layerDiff{ + label: layerType((*ls)[i]), + rows: layerDiffRows, + }) + + layerDiffRows = []layerDiffRow{} + diff = diffLayer(other[i], other[i]) + for _, d := range diff { + if len(d.want) == 0 { + continue + } + layerDiffRows = append(layerDiffRows, layerDiffRow{ + d.field, + "", + d.want, + }) + } + allDiffs = append(allDiffs, layerDiff{ + label: layerType(other[i]), + rows: layerDiffRows, + }) + } + + output := "" + // These are for output formatting. + maxLabelLen, maxFieldLen, maxGotLen, maxWantLen := 0, 0, 0, 0 + foundOne := false + for _, l := range allDiffs { + if len(l.label) > maxLabelLen && len(l.rows) > 0 { + maxLabelLen = len(l.label) + } + if l.rows != nil { + foundOne = true + } + for _, r := range l.rows { + if len(r.field) > maxFieldLen { + maxFieldLen = len(r.field) + } + if l := len(fmt.Sprint(r.got)); l > maxGotLen { + maxGotLen = l + } + if l := len(fmt.Sprint(r.want)); l > maxWantLen { + maxWantLen = l + } + } + } + if !foundOne { + return "" + } + for _, l := range allDiffs { + if len(l.rows) == 0 { + output += "(" + l.label + ")\n" + continue + } + for i, r := range l.rows { + var label string + if i == 0 { + label = l.label + ":" + } + output += fmt.Sprintf( + "%*s %*s %*v %*v\n", + maxLabelLen+1, label, + maxFieldLen+1, r.field+":", + maxGotLen, r.got, + maxWantLen, r.want, + ) + } + } + return output +} + +// merge merges the other Layers into ls. If the other Layers is longer, those +// additional Layer structs are added to ls. The errors from merging are +// collected and returned. +func (ls *Layers) merge(other Layers) error { + var errs error + for i, o := range other { + if i < len(*ls) { + errs = multierr.Combine(errs, (*ls)[i].merge(o)) + } else { + *ls = append(*ls, o) + } + } + return errs +} diff --git a/test/packetimpact/testbench/layers_test.go b/test/packetimpact/testbench/layers_test.go new file mode 100644 index 000000000..eca0780b5 --- /dev/null +++ b/test/packetimpact/testbench/layers_test.go @@ -0,0 +1,728 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testbench + +import ( + "bytes" + "net" + "testing" + + "github.com/mohae/deepcopy" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +func TestLayerMatch(t *testing.T) { + var nilPayload *Payload + noPayload := &Payload{} + emptyPayload := &Payload{Bytes: []byte{}} + fullPayload := &Payload{Bytes: []byte{1, 2, 3}} + emptyTCP := &TCP{SrcPort: Uint16(1234), LayerBase: LayerBase{nextLayer: emptyPayload}} + fullTCP := &TCP{SrcPort: Uint16(1234), LayerBase: LayerBase{nextLayer: fullPayload}} + for _, tt := range []struct { + a, b Layer + want bool + }{ + {nilPayload, nilPayload, true}, + {nilPayload, noPayload, true}, + {nilPayload, emptyPayload, true}, + {nilPayload, fullPayload, true}, + {noPayload, noPayload, true}, + {noPayload, emptyPayload, true}, + {noPayload, fullPayload, true}, + {emptyPayload, emptyPayload, true}, + {emptyPayload, fullPayload, false}, + {fullPayload, fullPayload, true}, + {emptyTCP, fullTCP, true}, + } { + if got := tt.a.match(tt.b); got != tt.want { + t.Errorf("%s.match(%s) = %t, want %t", tt.a, tt.b, got, tt.want) + } + if got := tt.b.match(tt.a); got != tt.want { + t.Errorf("%s.match(%s) = %t, want %t", tt.b, tt.a, got, tt.want) + } + } +} + +func TestLayerMergeMismatch(t *testing.T) { + tcp := &TCP{} + otherTCP := &TCP{} + ipv4 := &IPv4{} + ether := &Ether{} + for _, tt := range []struct { + a, b Layer + success bool + }{ + {tcp, tcp, true}, + {tcp, otherTCP, true}, + {tcp, ipv4, false}, + {tcp, ether, false}, + {tcp, nil, true}, + + {otherTCP, otherTCP, true}, + {otherTCP, ipv4, false}, + {otherTCP, ether, false}, + {otherTCP, nil, true}, + + {ipv4, ipv4, true}, + {ipv4, ether, false}, + {ipv4, nil, true}, + + {ether, ether, true}, + {ether, nil, true}, + } { + if err := tt.a.merge(tt.b); (err == nil) != tt.success { + t.Errorf("%s.merge(%s) got %s, wanted the opposite", tt.a, tt.b, err) + } + if tt.b != nil { + if err := tt.b.merge(tt.a); (err == nil) != tt.success { + t.Errorf("%s.merge(%s) got %s, wanted the opposite", tt.b, tt.a, err) + } + } + } +} + +func TestLayerMerge(t *testing.T) { + zero := Uint32(0) + one := Uint32(1) + two := Uint32(2) + empty := []byte{} + foo := []byte("foo") + bar := []byte("bar") + for _, tt := range []struct { + a, b Layer + want Layer + }{ + {&TCP{AckNum: nil}, &TCP{AckNum: nil}, &TCP{AckNum: nil}}, + {&TCP{AckNum: nil}, &TCP{AckNum: zero}, &TCP{AckNum: zero}}, + {&TCP{AckNum: nil}, &TCP{AckNum: one}, &TCP{AckNum: one}}, + {&TCP{AckNum: nil}, &TCP{AckNum: two}, &TCP{AckNum: two}}, + {&TCP{AckNum: nil}, nil, &TCP{AckNum: nil}}, + + {&TCP{AckNum: zero}, &TCP{AckNum: nil}, &TCP{AckNum: zero}}, + {&TCP{AckNum: zero}, &TCP{AckNum: zero}, &TCP{AckNum: zero}}, + {&TCP{AckNum: zero}, &TCP{AckNum: one}, &TCP{AckNum: one}}, + {&TCP{AckNum: zero}, &TCP{AckNum: two}, &TCP{AckNum: two}}, + {&TCP{AckNum: zero}, nil, &TCP{AckNum: zero}}, + + {&TCP{AckNum: one}, &TCP{AckNum: nil}, &TCP{AckNum: one}}, + {&TCP{AckNum: one}, &TCP{AckNum: zero}, &TCP{AckNum: zero}}, + {&TCP{AckNum: one}, &TCP{AckNum: one}, &TCP{AckNum: one}}, + {&TCP{AckNum: one}, &TCP{AckNum: two}, &TCP{AckNum: two}}, + {&TCP{AckNum: one}, nil, &TCP{AckNum: one}}, + + {&TCP{AckNum: two}, &TCP{AckNum: nil}, &TCP{AckNum: two}}, + {&TCP{AckNum: two}, &TCP{AckNum: zero}, &TCP{AckNum: zero}}, + {&TCP{AckNum: two}, &TCP{AckNum: one}, &TCP{AckNum: one}}, + {&TCP{AckNum: two}, &TCP{AckNum: two}, &TCP{AckNum: two}}, + {&TCP{AckNum: two}, nil, &TCP{AckNum: two}}, + + {&Payload{Bytes: nil}, &Payload{Bytes: nil}, &Payload{Bytes: nil}}, + {&Payload{Bytes: nil}, &Payload{Bytes: empty}, &Payload{Bytes: empty}}, + {&Payload{Bytes: nil}, &Payload{Bytes: foo}, &Payload{Bytes: foo}}, + {&Payload{Bytes: nil}, &Payload{Bytes: bar}, &Payload{Bytes: bar}}, + {&Payload{Bytes: nil}, nil, &Payload{Bytes: nil}}, + + {&Payload{Bytes: empty}, &Payload{Bytes: nil}, &Payload{Bytes: empty}}, + {&Payload{Bytes: empty}, &Payload{Bytes: empty}, &Payload{Bytes: empty}}, + {&Payload{Bytes: empty}, &Payload{Bytes: foo}, &Payload{Bytes: foo}}, + {&Payload{Bytes: empty}, &Payload{Bytes: bar}, &Payload{Bytes: bar}}, + {&Payload{Bytes: empty}, nil, &Payload{Bytes: empty}}, + + {&Payload{Bytes: foo}, &Payload{Bytes: nil}, &Payload{Bytes: foo}}, + {&Payload{Bytes: foo}, &Payload{Bytes: empty}, &Payload{Bytes: empty}}, + {&Payload{Bytes: foo}, &Payload{Bytes: foo}, &Payload{Bytes: foo}}, + {&Payload{Bytes: foo}, &Payload{Bytes: bar}, &Payload{Bytes: bar}}, + {&Payload{Bytes: foo}, nil, &Payload{Bytes: foo}}, + + {&Payload{Bytes: bar}, &Payload{Bytes: nil}, &Payload{Bytes: bar}}, + {&Payload{Bytes: bar}, &Payload{Bytes: empty}, &Payload{Bytes: empty}}, + {&Payload{Bytes: bar}, &Payload{Bytes: foo}, &Payload{Bytes: foo}}, + {&Payload{Bytes: bar}, &Payload{Bytes: bar}, &Payload{Bytes: bar}}, + {&Payload{Bytes: bar}, nil, &Payload{Bytes: bar}}, + } { + a := deepcopy.Copy(tt.a).(Layer) + if err := a.merge(tt.b); err != nil { + t.Errorf("%s.merge(%s) = %s, wanted nil", tt.a, tt.b, err) + continue + } + if a.String() != tt.want.String() { + t.Errorf("%s.merge(%s) merge result got %s, want %s", tt.a, tt.b, a, tt.want) + } + } +} + +func TestLayerStringFormat(t *testing.T) { + for _, tt := range []struct { + name string + l Layer + want string + }{ + { + name: "TCP", + l: &TCP{ + SrcPort: Uint16(34785), + DstPort: Uint16(47767), + SeqNum: Uint32(3452155723), + AckNum: Uint32(2596996163), + DataOffset: Uint8(5), + Flags: Uint8(20), + WindowSize: Uint16(64240), + Checksum: Uint16(0x2e2b), + }, + want: "&testbench.TCP{" + + "SrcPort:34785 " + + "DstPort:47767 " + + "SeqNum:3452155723 " + + "AckNum:2596996163 " + + "DataOffset:5 " + + "Flags:20 " + + "WindowSize:64240 " + + "Checksum:11819" + + "}", + }, + { + name: "UDP", + l: &UDP{ + SrcPort: Uint16(34785), + DstPort: Uint16(47767), + Length: Uint16(12), + }, + want: "&testbench.UDP{" + + "SrcPort:34785 " + + "DstPort:47767 " + + "Length:12" + + "}", + }, + { + name: "IPv4", + l: &IPv4{ + IHL: Uint8(5), + TOS: Uint8(0), + TotalLength: Uint16(44), + ID: Uint16(0), + Flags: Uint8(2), + FragmentOffset: Uint16(0), + TTL: Uint8(64), + Protocol: Uint8(6), + Checksum: Uint16(0x2e2b), + SrcAddr: Address(tcpip.Address([]byte{197, 34, 63, 10})), + DstAddr: Address(tcpip.Address([]byte{197, 34, 63, 20})), + }, + want: "&testbench.IPv4{" + + "IHL:5 " + + "TOS:0 " + + "TotalLength:44 " + + "ID:0 " + + "Flags:2 " + + "FragmentOffset:0 " + + "TTL:64 " + + "Protocol:6 " + + "Checksum:11819 " + + "SrcAddr:197.34.63.10 " + + "DstAddr:197.34.63.20" + + "}", + }, + { + name: "Ether", + l: &Ether{ + SrcAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x0a})), + DstAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x14})), + Type: NetworkProtocolNumber(4), + }, + want: "&testbench.Ether{" + + "SrcAddr:02:42:c5:22:3f:0a " + + "DstAddr:02:42:c5:22:3f:14 " + + "Type:4" + + "}", + }, + { + name: "Payload", + l: &Payload{ + Bytes: []byte("Hooray for packetimpact."), + }, + want: "&testbench.Payload{Bytes:\n" + + "00000000 48 6f 6f 72 61 79 20 66 6f 72 20 70 61 63 6b 65 |Hooray for packe|\n" + + "00000010 74 69 6d 70 61 63 74 2e |timpact.|\n" + + "}", + }, + } { + t.Run(tt.name, func(t *testing.T) { + if got := tt.l.String(); got != tt.want { + t.Errorf("%s.String() = %s, want: %s", tt.name, got, tt.want) + } + }) + } +} + +func TestConnectionMatch(t *testing.T) { + conn := Connection{ + layerStates: []layerState{ðerState{}}, + } + protoNum0 := tcpip.NetworkProtocolNumber(0) + protoNum1 := tcpip.NetworkProtocolNumber(1) + for _, tt := range []struct { + description string + override, received Layers + wantMatch bool + }{ + { + description: "shorter override", + override: []Layer{&Ether{}}, + received: []Layer{&Ether{}, &Payload{Bytes: []byte("hello")}}, + wantMatch: true, + }, + { + description: "longer override", + override: []Layer{&Ether{}, &Payload{Bytes: []byte("hello")}}, + received: []Layer{&Ether{}}, + wantMatch: false, + }, + { + description: "ether layer mismatch", + override: []Layer{&Ether{Type: &protoNum0}}, + received: []Layer{&Ether{Type: &protoNum1}}, + wantMatch: false, + }, + { + description: "both nil", + override: nil, + received: nil, + wantMatch: false, + }, + { + description: "nil override", + override: nil, + received: []Layer{&Ether{}}, + wantMatch: true, + }, + } { + t.Run(tt.description, func(t *testing.T) { + if gotMatch := conn.match(tt.override, tt.received); gotMatch != tt.wantMatch { + t.Fatalf("conn.match(%s, %s) = %t, want %t", tt.override, tt.received, gotMatch, tt.wantMatch) + } + }) + } +} + +func TestLayersDiff(t *testing.T) { + for _, tt := range []struct { + x, y Layers + want string + }{ + { + Layers{&Ether{Type: NetworkProtocolNumber(12)}, &TCP{DataOffset: Uint8(5), SeqNum: Uint32(5)}}, + Layers{&Ether{Type: NetworkProtocolNumber(13)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}}, + "Ether: Type: 12 13\n" + + " TCP: SeqNum: 5 6\n" + + " DataOffset: 5 7\n", + }, + { + Layers{&Ether{Type: NetworkProtocolNumber(12)}, &UDP{SrcPort: Uint16(123)}}, + Layers{&Ether{Type: NetworkProtocolNumber(13)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}}, + "Ether: Type: 12 13\n" + + "(UDP doesn't match TCP)\n" + + " UDP: SrcPort: 123 \n" + + " TCP: SeqNum: 6\n" + + " DataOffset: 7\n", + }, + { + Layers{&UDP{SrcPort: Uint16(123)}}, + Layers{&Ether{Type: NetworkProtocolNumber(13)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}}, + "(UDP doesn't match Ether)\n" + + " UDP: SrcPort: 123 \n" + + "Ether: Type: 13\n" + + "(missing matches TCP)\n", + }, + { + Layers{nil, &UDP{SrcPort: Uint16(123)}}, + Layers{&Ether{Type: NetworkProtocolNumber(13)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}}, + "(nil matches Ether)\n" + + "(UDP doesn't match TCP)\n" + + "UDP: SrcPort: 123 \n" + + "TCP: SeqNum: 6\n" + + " DataOffset: 7\n", + }, + { + Layers{&Ether{Type: NetworkProtocolNumber(13)}, &IPv4{IHL: Uint8(4)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}}, + Layers{&Ether{Type: NetworkProtocolNumber(13)}, &IPv4{IHL: Uint8(6)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}}, + "(Ether matches Ether)\n" + + "IPv4: IHL: 4 6\n" + + "(TCP matches TCP)\n", + }, + { + Layers{&Payload{Bytes: []byte("foo")}}, + Layers{&Payload{Bytes: []byte("bar")}}, + "Payload: Bytes: [102 111 111] [98 97 114]\n", + }, + { + Layers{&Payload{Bytes: []byte("")}}, + Layers{&Payload{}}, + "", + }, + { + Layers{&Payload{Bytes: []byte("")}}, + Layers{&Payload{Bytes: []byte("")}}, + "", + }, + { + Layers{&UDP{}}, + Layers{&TCP{}}, + "(UDP doesn't match TCP)\n" + + "(UDP)\n" + + "(TCP)\n", + }, + } { + if got := tt.x.diff(tt.y); got != tt.want { + t.Errorf("%s.diff(%s) = %q, want %q", tt.x, tt.y, got, tt.want) + } + if tt.x.match(tt.y) != (tt.x.diff(tt.y) == "") { + t.Errorf("match and diff of %s and %s disagree", tt.x, tt.y) + } + if tt.y.match(tt.x) != (tt.y.diff(tt.x) == "") { + t.Errorf("match and diff of %s and %s disagree", tt.y, tt.x) + } + } +} + +func TestTCPOptions(t *testing.T) { + for _, tt := range []struct { + description string + wantBytes []byte + wantLayers Layers + }{ + { + description: "without payload", + wantBytes: []byte{ + // IPv4 Header + 0x45, 0x00, 0x00, 0x2c, 0x00, 0x01, 0x00, 0x00, 0x40, 0x06, + 0xf9, 0x77, 0xc0, 0xa8, 0x00, 0x02, 0xc0, 0xa8, 0x00, 0x01, + // TCP Header + 0x30, 0x39, 0xd4, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x60, 0x02, 0x20, 0x00, 0xf5, 0x1c, 0x00, 0x00, + // WindowScale Option + 0x03, 0x03, 0x02, + // NOP Option + 0x00, + }, + wantLayers: []Layer{ + &IPv4{ + IHL: Uint8(20), + TOS: Uint8(0), + TotalLength: Uint16(44), + ID: Uint16(1), + Flags: Uint8(0), + FragmentOffset: Uint16(0), + TTL: Uint8(64), + Protocol: Uint8(uint8(header.TCPProtocolNumber)), + Checksum: Uint16(0xf977), + SrcAddr: Address(tcpip.Address(net.ParseIP("192.168.0.2").To4())), + DstAddr: Address(tcpip.Address(net.ParseIP("192.168.0.1").To4())), + }, + &TCP{ + SrcPort: Uint16(12345), + DstPort: Uint16(54321), + SeqNum: Uint32(0), + AckNum: Uint32(0), + Flags: Uint8(header.TCPFlagSyn), + WindowSize: Uint16(8192), + Checksum: Uint16(0xf51c), + UrgentPointer: Uint16(0), + Options: []byte{3, 3, 2, 0}, + }, + &Payload{Bytes: nil}, + }, + }, + { + description: "with payload", + wantBytes: []byte{ + // IPv4 header + 0x45, 0x00, 0x00, 0x37, 0x00, 0x01, 0x00, 0x00, 0x40, 0x06, + 0xf9, 0x6c, 0xc0, 0xa8, 0x00, 0x02, 0xc0, 0xa8, 0x00, 0x01, + // TCP header + 0x30, 0x39, 0xd4, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x60, 0x02, 0x20, 0x00, 0xe5, 0x21, 0x00, 0x00, + // WindowScale Option + 0x03, 0x03, 0x02, + // NOP Option + 0x00, + // Payload: "Sample Data" + 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x44, 0x61, 0x74, 0x61, + }, + wantLayers: []Layer{ + &IPv4{ + IHL: Uint8(20), + TOS: Uint8(0), + TotalLength: Uint16(55), + ID: Uint16(1), + Flags: Uint8(0), + FragmentOffset: Uint16(0), + TTL: Uint8(64), + Protocol: Uint8(uint8(header.TCPProtocolNumber)), + Checksum: Uint16(0xf96c), + SrcAddr: Address(tcpip.Address(net.ParseIP("192.168.0.2").To4())), + DstAddr: Address(tcpip.Address(net.ParseIP("192.168.0.1").To4())), + }, + &TCP{ + SrcPort: Uint16(12345), + DstPort: Uint16(54321), + SeqNum: Uint32(0), + AckNum: Uint32(0), + Flags: Uint8(header.TCPFlagSyn), + WindowSize: Uint16(8192), + Checksum: Uint16(0xe521), + UrgentPointer: Uint16(0), + Options: []byte{3, 3, 2, 0}, + }, + &Payload{Bytes: []byte("Sample Data")}, + }, + }, + } { + t.Run(tt.description, func(t *testing.T) { + layers := parse(parseIPv4, tt.wantBytes) + if !layers.match(tt.wantLayers) { + t.Fatalf("match failed with diff: %s", layers.diff(tt.wantLayers)) + } + gotBytes, err := layers.ToBytes() + if err != nil { + t.Fatalf("ToBytes() failed on %s: %s", &layers, err) + } + if !bytes.Equal(tt.wantBytes, gotBytes) { + t.Fatalf("mismatching bytes, gotBytes: %x, wantBytes: %x", gotBytes, tt.wantBytes) + } + }) + } +} + +func TestIPv6ExtHdrOptions(t *testing.T) { + for _, tt := range []struct { + description string + wantBytes []byte + wantLayers Layers + }{ + { + description: "IPv6/HopByHop", + wantBytes: []byte{ + // IPv6 Header + 0x60, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x40, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef, + // HopByHop Options + 0x3b, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00, + }, + wantLayers: []Layer{ + &IPv6{ + SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))), + DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))), + }, + &IPv6HopByHopOptionsExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier), + Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00}, + }, + &Payload{ + Bytes: nil, + }, + }, + }, + { + description: "IPv6/HopByHop/Payload", + wantBytes: []byte{ + // IPv6 Header + 0x60, 0x00, 0x00, 0x00, 0x00, 0x13, 0x00, 0x40, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef, + // HopByHop Options + 0x3b, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00, + // Sample Data + 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x44, 0x61, 0x74, 0x61, + }, + wantLayers: []Layer{ + &IPv6{ + SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))), + DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))), + }, + &IPv6HopByHopOptionsExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier), + Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00}, + }, + &Payload{ + Bytes: []byte("Sample Data"), + }, + }, + }, + { + description: "IPv6/HopByHop/Destination/ICMPv6", + wantBytes: []byte{ + // IPv6 Header + 0x60, 0x00, 0x00, 0x00, 0x00, 0x18, 0x00, 0x40, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef, + // HopByHop Options + 0x3c, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00, + // Destination Options + 0x3a, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00, + // ICMPv6 Param Problem + 0x04, 0x00, 0x5f, 0x98, 0x00, 0x00, 0x00, 0x06, + }, + wantLayers: []Layer{ + &IPv6{ + SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))), + DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))), + }, + &IPv6HopByHopOptionsExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6DestinationOptionsExtHdrIdentifier), + Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00}, + }, + &IPv6DestinationOptionsExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6ExtensionHeaderIdentifier(header.ICMPv6ProtocolNumber)), + Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00}, + }, + &ICMPv6{ + Type: ICMPv6Type(header.ICMPv6ParamProblem), + Code: ICMPv6Code(header.ICMPv6ErroneousHeader), + Checksum: Uint16(0x5f98), + Payload: []byte{0x00, 0x00, 0x00, 0x06}, + }, + }, + }, + { + description: "IPv6/HopByHop/Fragment", + wantBytes: []byte{ + // IPv6 Header + 0x60, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x40, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef, + // HopByHop Options + 0x2c, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00, + // Fragment ExtHdr + 0x3b, 0x00, 0x03, 0x20, 0x00, 0x00, 0x00, 0x2a, + }, + wantLayers: []Layer{ + &IPv6{ + SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))), + DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))), + }, + &IPv6HopByHopOptionsExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6FragmentExtHdrIdentifier), + Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00}, + }, + &IPv6FragmentExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier), + FragmentOffset: Uint16(100), + MoreFragments: Bool(false), + Identification: Uint32(42), + }, + &Payload{ + Bytes: nil, + }, + }, + }, + { + description: "IPv6/DestOpt/Fragment/Payload", + wantBytes: []byte{ + // IPv6 Header + 0x60, 0x00, 0x00, 0x00, 0x00, 0x1b, 0x3c, 0x40, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef, + // Destination Options + 0x2c, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00, + // Fragment ExtHdr + 0x3b, 0x00, 0x03, 0x21, 0x00, 0x00, 0x00, 0x2a, + // Sample Data + 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x44, 0x61, 0x74, 0x61, + }, + wantLayers: []Layer{ + &IPv6{ + SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))), + DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))), + }, + &IPv6DestinationOptionsExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6FragmentExtHdrIdentifier), + Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00}, + }, + &IPv6FragmentExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier), + FragmentOffset: Uint16(100), + MoreFragments: Bool(true), + Identification: Uint32(42), + }, + &Payload{ + Bytes: []byte("Sample Data"), + }, + }, + }, + { + description: "IPv6/Fragment/Payload", + wantBytes: []byte{ + // IPv6 Header + 0x60, 0x00, 0x00, 0x00, 0x00, 0x13, 0x2c, 0x40, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef, + // Fragment ExtHdr + 0x3b, 0x00, 0x03, 0x21, 0x00, 0x00, 0x00, 0x2a, + // Sample Data + 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x44, 0x61, 0x74, 0x61, + }, + wantLayers: []Layer{ + &IPv6{ + SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))), + DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))), + }, + &IPv6FragmentExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier), + FragmentOffset: Uint16(100), + MoreFragments: Bool(true), + Identification: Uint32(42), + }, + &Payload{ + Bytes: []byte("Sample Data"), + }, + }, + }, + } { + t.Run(tt.description, func(t *testing.T) { + layers := parse(parseIPv6, tt.wantBytes) + if !layers.match(tt.wantLayers) { + t.Fatalf("match failed with diff: %s", layers.diff(tt.wantLayers)) + } + // Make sure we can generate correct next header values and checksums + for _, layer := range layers { + switch layer := layer.(type) { + case *IPv6HopByHopOptionsExtHdr: + layer.NextHeader = nil + case *IPv6DestinationOptionsExtHdr: + layer.NextHeader = nil + case *IPv6FragmentExtHdr: + layer.NextHeader = nil + case *ICMPv6: + layer.Checksum = nil + } + } + gotBytes, err := layers.ToBytes() + if err != nil { + t.Fatalf("ToBytes() failed on %s: %s", &layers, err) + } + if !bytes.Equal(tt.wantBytes, gotBytes) { + t.Fatalf("mismatching bytes, gotBytes: %x, wantBytes: %x", gotBytes, tt.wantBytes) + } + }) + } +} diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go new file mode 100644 index 000000000..57e822725 --- /dev/null +++ b/test/packetimpact/testbench/rawsockets.go @@ -0,0 +1,188 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testbench + +import ( + "encoding/binary" + "fmt" + "math" + "net" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/usermem" +) + +// Sniffer can sniff raw packets on the wire. +type Sniffer struct { + fd int +} + +func htons(x uint16) uint16 { + buf := [2]byte{} + binary.BigEndian.PutUint16(buf[:], x) + return usermem.ByteOrder.Uint16(buf[:]) +} + +// NewSniffer creates a Sniffer connected to *device. +func NewSniffer(t *testing.T) (Sniffer, error) { + t.Helper() + + snifferFd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, int(htons(unix.ETH_P_ALL))) + if err != nil { + return Sniffer{}, err + } + if err := unix.SetsockoptInt(snifferFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, 1); err != nil { + t.Fatalf("can't set sockopt SO_RCVBUFFORCE to 1: %s", err) + } + if err := unix.SetsockoptInt(snifferFd, unix.SOL_SOCKET, unix.SO_RCVBUF, 1e7); err != nil { + t.Fatalf("can't setsockopt SO_RCVBUF to 10M: %s", err) + } + return Sniffer{ + fd: snifferFd, + }, nil +} + +// maxReadSize should be large enough for the maximum frame size in bytes. If a +// packet too large for the buffer arrives, the test will get a fatal error. +const maxReadSize int = 65536 + +// Recv tries to read one frame until the timeout is up. +func (s *Sniffer) Recv(t *testing.T, timeout time.Duration) []byte { + t.Helper() + + deadline := time.Now().Add(timeout) + for { + timeout = deadline.Sub(time.Now()) + if timeout <= 0 { + return nil + } + whole, frac := math.Modf(timeout.Seconds()) + tv := unix.Timeval{ + Sec: int64(whole), + Usec: int64(frac * float64(time.Microsecond/time.Second)), + } + + if err := unix.SetsockoptTimeval(s.fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil { + t.Fatalf("can't setsockopt SO_RCVTIMEO: %s", err) + } + + buf := make([]byte, maxReadSize) + nread, _, err := unix.Recvfrom(s.fd, buf, unix.MSG_TRUNC) + if err == unix.EINTR || err == unix.EAGAIN { + // There was a timeout. + continue + } + if err != nil { + t.Fatalf("can't read: %s", err) + } + if nread > maxReadSize { + t.Fatalf("received a truncated frame of %d bytes, want at most %d bytes", nread, maxReadSize) + } + return buf[:nread] + } +} + +// Drain drains the Sniffer's socket receive buffer by receiving until there's +// nothing else to receive. +func (s *Sniffer) Drain(t *testing.T) { + t.Helper() + + flags, err := unix.FcntlInt(uintptr(s.fd), unix.F_GETFL, 0) + if err != nil { + t.Fatalf("failed to get sniffer socket fd flags: %s", err) + } + nonBlockingFlags := flags | unix.O_NONBLOCK + if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, nonBlockingFlags); err != nil { + t.Fatalf("failed to make sniffer socket non-blocking with flags %b: %s", nonBlockingFlags, err) + } + for { + buf := make([]byte, maxReadSize) + _, _, err := unix.Recvfrom(s.fd, buf, unix.MSG_TRUNC) + if err == unix.EINTR || err == unix.EAGAIN || err == unix.EWOULDBLOCK { + break + } + } + if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags); err != nil { + t.Fatalf("failed to restore sniffer socket fd flags to %b: %s", flags, err) + } +} + +// close the socket that Sniffer is using. +func (s *Sniffer) close() error { + if err := unix.Close(s.fd); err != nil { + return fmt.Errorf("can't close sniffer socket: %w", err) + } + s.fd = -1 + return nil +} + +// Injector can inject raw frames. +type Injector struct { + fd int +} + +// NewInjector creates a new injector on *device. +func NewInjector(t *testing.T) (Injector, error) { + t.Helper() + + ifInfo, err := net.InterfaceByName(Device) + if err != nil { + return Injector{}, err + } + + var haddr [8]byte + copy(haddr[:], ifInfo.HardwareAddr) + sa := unix.SockaddrLinklayer{ + Protocol: unix.ETH_P_IP, + Ifindex: ifInfo.Index, + Halen: uint8(len(ifInfo.HardwareAddr)), + Addr: haddr, + } + + injectFd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, int(htons(unix.ETH_P_ALL))) + if err != nil { + return Injector{}, err + } + if err := unix.Bind(injectFd, &sa); err != nil { + return Injector{}, err + } + return Injector{ + fd: injectFd, + }, nil +} + +// Send a raw frame. +func (i *Injector) Send(t *testing.T, b []byte) { + t.Helper() + + n, err := unix.Write(i.fd, b) + if err != nil { + t.Fatalf("can't write bytes of len %d: %s", len(b), err) + } + if n != len(b) { + t.Fatalf("got %d bytes written, want %d", n, len(b)) + } +} + +// close the underlying socket. +func (i *Injector) close() error { + if err := unix.Close(i.fd); err != nil { + return fmt.Errorf("can't close sniffer socket: %w", err) + } + i.fd = -1 + return nil +} diff --git a/test/packetimpact/testbench/testbench.go b/test/packetimpact/testbench/testbench.go new file mode 100644 index 000000000..e3629e1f3 --- /dev/null +++ b/test/packetimpact/testbench/testbench.go @@ -0,0 +1,128 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testbench + +import ( + "flag" + "fmt" + "math/rand" + "net" + "os/exec" + "testing" + "time" + + "gvisor.dev/gvisor/test/packetimpact/netdevs" +) + +var ( + // Native indicates that the test is being run natively. + Native = false + // Device is the local device on the test network. + Device = "" + + // LocalIPv4 is the local IPv4 address on the test network. + LocalIPv4 = "" + // RemoteIPv4 is the DUT's IPv4 address on the test network. + RemoteIPv4 = "" + // IPv4PrefixLength is the network prefix length of the IPv4 test network. + IPv4PrefixLength = 0 + + // LocalIPv6 is the local IPv6 address on the test network. + LocalIPv6 = "" + // RemoteIPv6 is the DUT's IPv6 address on the test network. + RemoteIPv6 = "" + + // LocalInterfaceID is the ID of the local interface on the test network. + LocalInterfaceID uint32 + // RemoteInterfaceID is the ID of the remote interface on the test network. + // + // Not using uint32 because package flag does not support uint32. + RemoteInterfaceID uint64 + + // LocalMAC is the local MAC address on the test network. + LocalMAC = "" + // RemoteMAC is the DUT's MAC address on the test network. + RemoteMAC = "" + + // POSIXServerIP is the POSIX server's IP address on the control network. + POSIXServerIP = "" + // POSIXServerPort is the UDP port the POSIX server is bound to on the + // control network. + POSIXServerPort = 40000 + + // RPCKeepalive is the gRPC keepalive. + RPCKeepalive = 10 * time.Second + // RPCTimeout is the gRPC timeout. + RPCTimeout = 100 * time.Millisecond +) + +// RegisterFlags defines flags and associates them with the package-level +// exported variables above. It should be called by tests in their init +// functions. +func RegisterFlags(fs *flag.FlagSet) { + fs.StringVar(&POSIXServerIP, "posix_server_ip", POSIXServerIP, "ip address to listen to for UDP commands") + fs.IntVar(&POSIXServerPort, "posix_server_port", POSIXServerPort, "port to listen to for UDP commands") + fs.DurationVar(&RPCTimeout, "rpc_timeout", RPCTimeout, "gRPC timeout") + fs.DurationVar(&RPCKeepalive, "rpc_keepalive", RPCKeepalive, "gRPC keepalive") + fs.StringVar(&LocalIPv4, "local_ipv4", LocalIPv4, "local IPv4 address for test packets") + fs.StringVar(&RemoteIPv4, "remote_ipv4", RemoteIPv4, "remote IPv4 address for test packets") + fs.StringVar(&RemoteIPv6, "remote_ipv6", RemoteIPv6, "remote IPv6 address for test packets") + fs.StringVar(&RemoteMAC, "remote_mac", RemoteMAC, "remote mac address for test packets") + fs.StringVar(&Device, "device", Device, "local device for test packets") + fs.BoolVar(&Native, "native", Native, "whether the test is running natively") + fs.Uint64Var(&RemoteInterfaceID, "remote_interface_id", RemoteInterfaceID, "remote interface ID for test packets") +} + +// genPseudoFlags populates flag-like global config based on real flags. +// +// genPseudoFlags must only be called after flag.Parse. +func genPseudoFlags() error { + out, err := exec.Command("ip", "addr", "show").CombinedOutput() + if err != nil { + return fmt.Errorf("listing devices: %q: %w", string(out), err) + } + devs, err := netdevs.ParseDevices(string(out)) + if err != nil { + return fmt.Errorf("parsing devices: %w", err) + } + + _, deviceInfo, err := netdevs.FindDeviceByIP(net.ParseIP(LocalIPv4), devs) + if err != nil { + return fmt.Errorf("can't find deviceInfo: %w", err) + } + + LocalMAC = deviceInfo.MAC.String() + LocalIPv6 = deviceInfo.IPv6Addr.String() + LocalInterfaceID = deviceInfo.ID + + if deviceInfo.IPv4Net != nil { + IPv4PrefixLength, _ = deviceInfo.IPv4Net.Mask.Size() + } else { + IPv4PrefixLength, _ = net.ParseIP(LocalIPv4).DefaultMask().Size() + } + + return nil +} + +// GenerateRandomPayload generates a random byte slice of the specified length, +// causing a fatal test failure if it is unable to do so. +func GenerateRandomPayload(t *testing.T, n int) []byte { + t.Helper() + buf := make([]byte, n) + if _, err := rand.Read(buf); err != nil { + t.Fatalf("rand.Read(buf) failed: %s", err) + } + return buf +} diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD new file mode 100644 index 000000000..74658fea0 --- /dev/null +++ b/test/packetimpact/tests/BUILD @@ -0,0 +1,310 @@ +load("//test/packetimpact/runner:defs.bzl", "packetimpact_go_test") + +package( + default_visibility = ["//test/packetimpact:__subpackages__"], + licenses = ["notice"], +) + +packetimpact_go_test( + name = "fin_wait2_timeout", + srcs = ["fin_wait2_timeout_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "ipv4_id_uniqueness", + srcs = ["ipv4_id_uniqueness_test.go"], + deps = [ + "//pkg/abi/linux", + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "udp_discard_mcast_source_addr", + srcs = ["udp_discard_mcast_source_addr_test.go"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "udp_recv_mcast_bcast", + srcs = ["udp_recv_mcast_bcast_test.go"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@com_github_google_go_cmp//cmp:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "udp_any_addr_recv_unicast", + srcs = ["udp_any_addr_recv_unicast_test.go"], + deps = [ + "//pkg/tcpip", + "//test/packetimpact/testbench", + "@com_github_google_go_cmp//cmp:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "udp_icmp_error_propagation", + srcs = ["udp_icmp_error_propagation_test.go"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_reordering", + srcs = ["tcp_reordering_test.go"], + # TODO(b/139368047): Fix netstack then remove the line below. + expect_netstack_failure = True, + deps = [ + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_window_shrink", + srcs = ["tcp_window_shrink_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_zero_window_probe", + srcs = ["tcp_zero_window_probe_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_zero_window_probe_retransmit", + srcs = ["tcp_zero_window_probe_retransmit_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_zero_window_probe_usertimeout", + srcs = ["tcp_zero_window_probe_usertimeout_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_retransmits", + srcs = ["tcp_retransmits_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_outside_the_window", + srcs = ["tcp_outside_the_window_test.go"], + deps = [ + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_noaccept_close_rst", + srcs = ["tcp_noaccept_close_rst_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_send_window_sizes_piggyback", + srcs = ["tcp_send_window_sizes_piggyback_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_close_wait_ack", + srcs = ["tcp_close_wait_ack_test.go"], + deps = [ + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_paws_mechanism", + srcs = ["tcp_paws_mechanism_test.go"], + # TODO(b/156682000): Fix netstack then remove the line below. + expect_netstack_failure = True, + deps = [ + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_user_timeout", + srcs = ["tcp_user_timeout_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_queue_receive_in_syn_sent", + srcs = ["tcp_queue_receive_in_syn_sent_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_synsent_reset", + srcs = ["tcp_synsent_reset_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_synrcvd_reset", + srcs = ["tcp_synrcvd_reset_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_network_unreachable", + srcs = ["tcp_network_unreachable_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_cork_mss", + srcs = ["tcp_cork_mss_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_handshake_window_size", + srcs = ["tcp_handshake_window_size_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "icmpv6_param_problem", + srcs = ["icmpv6_param_problem_test.go"], + # TODO(b/153485026): Fix netstack then remove the line below. + expect_netstack_failure = True, + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "ipv6_unknown_options_action", + srcs = ["ipv6_unknown_options_action_test.go"], + # TODO(b/159928940): Fix netstack then remove the line below. + expect_netstack_failure = True, + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "ipv6_fragment_reassembly", + srcs = ["ipv6_fragment_reassembly_test.go"], + # TODO(b/160919104): Fix netstack then remove the line below. + expect_netstack_failure = True, + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "udp_send_recv_dgram", + srcs = ["udp_send_recv_dgram_test.go"], + deps = [ + "//test/packetimpact/testbench", + "@com_github_google_go_cmp//cmp:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/test/packetimpact/tests/fin_wait2_timeout_test.go b/test/packetimpact/tests/fin_wait2_timeout_test.go new file mode 100644 index 000000000..a61054c2c --- /dev/null +++ b/test/packetimpact/tests/fin_wait2_timeout_test.go @@ -0,0 +1,75 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fin_wait2_timeout_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func TestFinWait2Timeout(t *testing.T) { + for _, tt := range []struct { + description string + linger2 bool + }{ + {"WithLinger2", true}, + {"WithoutLinger2", false}, + } { + t.Run(tt.description, func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + conn.Connect(t) + + acceptFd, _ := dut.Accept(t, listenFd) + if tt.linger2 { + tv := unix.Timeval{Sec: 1, Usec: 0} + dut.SetSockOptTimeval(t, acceptFd, unix.SOL_TCP, unix.TCP_LINGER2, &tv) + } + dut.Close(t, acceptFd) + + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected a FIN-ACK within 1 second but got none: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + + time.Sleep(5 * time.Second) + conn.Drain(t) + + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + if tt.linger2 { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil { + t.Fatalf("expected a RST packet within a second but got none: %s", err) + } + } else { + if got, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, 10*time.Second); got != nil || err == nil { + t.Fatalf("expected no RST packets within ten seconds but got one: %s", got) + } + } + }) + } +} diff --git a/test/packetimpact/tests/icmpv6_param_problem_test.go b/test/packetimpact/tests/icmpv6_param_problem_test.go new file mode 100644 index 000000000..2d59d552d --- /dev/null +++ b/test/packetimpact/tests/icmpv6_param_problem_test.go @@ -0,0 +1,78 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package icmpv6_param_problem_test + +import ( + "encoding/binary" + "flag" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestICMPv6ParamProblemTest sends a packet with a bad next header. The DUT +// should respond with an ICMPv6 Parameter Problem message. +func TestICMPv6ParamProblemTest(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{}) + defer conn.Close(t) + ipv6 := testbench.IPv6{ + // 254 is reserved and used for experimentation and testing. This should + // cause an error. + NextHeader: testbench.Uint8(254), + } + icmpv6 := testbench.ICMPv6{ + Type: testbench.ICMPv6Type(header.ICMPv6EchoRequest), + Payload: []byte("hello world"), + } + + toSend := (*testbench.Connection)(&conn).CreateFrame(t, testbench.Layers{&ipv6}, &icmpv6) + (*testbench.Connection)(&conn).SendFrame(t, toSend) + + // Build the expected ICMPv6 payload, which includes an index to the + // problematic byte and also the problematic packet as described in + // https://tools.ietf.org/html/rfc4443#page-12 . + ipv6Sent := toSend[1:] + expectedPayload, err := ipv6Sent.ToBytes() + if err != nil { + t.Fatalf("can't convert %s to bytes: %s", ipv6Sent, err) + } + + // The problematic field is the NextHeader. + b := make([]byte, 4) + binary.BigEndian.PutUint32(b, header.IPv6NextHeaderOffset) + expectedPayload = append(b, expectedPayload...) + expectedICMPv6 := testbench.ICMPv6{ + Type: testbench.ICMPv6Type(header.ICMPv6ParamProblem), + Payload: expectedPayload, + } + + paramProblem := testbench.Layers{ + &testbench.Ether{}, + &testbench.IPv6{}, + &expectedICMPv6, + } + timeout := time.Second + if _, err := conn.ExpectFrame(t, paramProblem, timeout); err != nil { + t.Errorf("expected %s within %s but got none: %s", paramProblem, timeout, err) + } +} diff --git a/test/packetimpact/tests/ipv4_id_uniqueness_test.go b/test/packetimpact/tests/ipv4_id_uniqueness_test.go new file mode 100644 index 000000000..cf881418c --- /dev/null +++ b/test/packetimpact/tests/ipv4_id_uniqueness_test.go @@ -0,0 +1,122 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4_id_uniqueness_test + +import ( + "context" + "flag" + "fmt" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func recvTCPSegment(t *testing.T, conn *testbench.TCPIPv4, expect *testbench.TCP, expectPayload *testbench.Payload) (uint16, error) { + layers, err := conn.ExpectData(t, expect, expectPayload, time.Second) + if err != nil { + return 0, fmt.Errorf("failed to receive TCP segment: %s", err) + } + if len(layers) < 2 { + return 0, fmt.Errorf("got packet with layers: %v, expected to have at least 2 layers (link and network)", layers) + } + ipv4, ok := layers[1].(*testbench.IPv4) + if !ok { + return 0, fmt.Errorf("got network layer: %T, expected: *IPv4", layers[1]) + } + if *ipv4.Flags&header.IPv4FlagDontFragment != 0 { + return 0, fmt.Errorf("got IPv4 DF=1, expected DF=0") + } + return *ipv4.ID, nil +} + +// RFC 6864 section 4.2 states: "The IPv4 ID of non-atomic datagrams MUST NOT +// be reused when sending a copy of an earlier non-atomic datagram." +// +// This test creates a TCP connection, uses the IP_MTU_DISCOVER socket option +// to force the DF bit to be 0, and checks that a retransmitted segment has a +// different IPv4 Identification value than the original segment. +func TestIPv4RetransmitIdentificationUniqueness(t *testing.T) { + for _, tc := range []struct { + name string + payload []byte + }{ + {"SmallPayload", []byte("sample data")}, + // 512 bytes is chosen because sending more than this in a single segment + // causes the retransmission to send less than the original amount. + {"512BytePayload", testbench.GenerateRandomPayload(t, 512)}, + } { + t.Run(tc.name, func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) + + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + conn.Connect(t) + remoteFD, _ := dut.Accept(t, listenFD) + defer dut.Close(t, remoteFD) + + dut.SetSockOptInt(t, remoteFD, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + + // TODO(b/129291778) The following socket option clears the DF bit on + // IP packets sent over the socket, and is currently not supported by + // gVisor. gVisor by default sends packets with DF=0 anyway, so the + // socket option being not supported does not affect the operation of + // this test. Once the socket option is supported, the following call + // can be changed to simply assert success. + ret, errno := dut.SetSockOptIntWithErrno(context.Background(), t, remoteFD, unix.IPPROTO_IP, linux.IP_MTU_DISCOVER, linux.IP_PMTUDISC_DONT) + if ret == -1 && errno != unix.ENOTSUP { + t.Fatalf("failed to set IP_MTU_DISCOVER socket option to IP_PMTUDISC_DONT: %s", errno) + } + + samplePayload := &testbench.Payload{Bytes: tc.payload} + + dut.Send(t, remoteFD, tc.payload, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("failed to receive TCP segment sent for RTT calculation: %s", err) + } + // Let the DUT estimate RTO with RTT from the DATA-ACK. + // TODO(gvisor.dev/issue/2685) Estimate RTO during handshake, after which + // we can skip sending this ACK. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + + dut.Send(t, remoteFD, tc.payload, 0) + expectTCP := &testbench.TCP{SeqNum: testbench.Uint32(uint32(*conn.RemoteSeqNum(t)))} + originalID, err := recvTCPSegment(t, &conn, expectTCP, samplePayload) + if err != nil { + t.Fatalf("failed to receive TCP segment: %s", err) + } + + retransmitID, err := recvTCPSegment(t, &conn, expectTCP, samplePayload) + if err != nil { + t.Fatalf("failed to receive retransmitted TCP segment: %s", err) + } + if originalID == retransmitID { + t.Fatalf("unexpectedly got retransmitted TCP segment with same IPv4 ID field=%d", originalID) + } + }) + } +} diff --git a/test/packetimpact/tests/ipv6_fragment_reassembly_test.go b/test/packetimpact/tests/ipv6_fragment_reassembly_test.go new file mode 100644 index 000000000..a24c85566 --- /dev/null +++ b/test/packetimpact/tests/ipv6_fragment_reassembly_test.go @@ -0,0 +1,168 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6_fragment_reassembly_test + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "flag" + "net" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +const ( + // The payload length for the first fragment we send. This number + // is a multiple of 8 near 750 (half of 1500). + firstPayloadLength = 752 + // The ID field for our outgoing fragments. + fragmentID = 1 + // A node must be able to accept a fragmented packet that, + // after reassembly, is as large as 1500 octets. + reassemblyCap = 1500 +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func TestIPv6FragmentReassembly(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{}) + defer conn.Close(t) + + firstPayloadToSend := make([]byte, firstPayloadLength) + for i := range firstPayloadToSend { + firstPayloadToSend[i] = 'A' + } + + secondPayloadLength := reassemblyCap - firstPayloadLength - header.ICMPv6EchoMinimumSize + secondPayloadToSend := firstPayloadToSend[:secondPayloadLength] + + icmpv6EchoPayload := make([]byte, 4) + binary.BigEndian.PutUint16(icmpv6EchoPayload[0:], 0) + binary.BigEndian.PutUint16(icmpv6EchoPayload[2:], 0) + icmpv6EchoPayload = append(icmpv6EchoPayload, firstPayloadToSend...) + + lIP := tcpip.Address(net.ParseIP(testbench.LocalIPv6).To16()) + rIP := tcpip.Address(net.ParseIP(testbench.RemoteIPv6).To16()) + icmpv6 := testbench.ICMPv6{ + Type: testbench.ICMPv6Type(header.ICMPv6EchoRequest), + Code: testbench.ICMPv6Code(header.ICMPv6UnusedCode), + Payload: icmpv6EchoPayload, + } + icmpv6Bytes, err := icmpv6.ToBytes() + if err != nil { + t.Fatalf("failed to serialize ICMPv6: %s", err) + } + cksum := header.ICMPv6Checksum( + header.ICMPv6(icmpv6Bytes), + lIP, + rIP, + buffer.NewVectorisedView(len(secondPayloadToSend), []buffer.View{secondPayloadToSend}), + ) + + conn.Send(t, testbench.IPv6{}, + &testbench.IPv6FragmentExtHdr{ + FragmentOffset: testbench.Uint16(0), + MoreFragments: testbench.Bool(true), + Identification: testbench.Uint32(fragmentID), + }, + &testbench.ICMPv6{ + Type: testbench.ICMPv6Type(header.ICMPv6EchoRequest), + Code: testbench.ICMPv6Code(header.ICMPv6UnusedCode), + Payload: icmpv6EchoPayload, + Checksum: &cksum, + }) + + icmpv6ProtoNum := header.IPv6ExtensionHeaderIdentifier(header.ICMPv6ProtocolNumber) + + conn.Send(t, testbench.IPv6{}, + &testbench.IPv6FragmentExtHdr{ + NextHeader: &icmpv6ProtoNum, + FragmentOffset: testbench.Uint16((firstPayloadLength + header.ICMPv6EchoMinimumSize) / 8), + MoreFragments: testbench.Bool(false), + Identification: testbench.Uint32(fragmentID), + }, + &testbench.Payload{ + Bytes: secondPayloadToSend, + }) + + gotEchoReplyFirstPart, err := conn.ExpectFrame(t, testbench.Layers{ + &testbench.Ether{}, + &testbench.IPv6{}, + &testbench.IPv6FragmentExtHdr{ + FragmentOffset: testbench.Uint16(0), + MoreFragments: testbench.Bool(true), + }, + &testbench.ICMPv6{ + Type: testbench.ICMPv6Type(header.ICMPv6EchoReply), + Code: testbench.ICMPv6Code(header.ICMPv6UnusedCode), + }, + }, time.Second) + if err != nil { + t.Fatalf("expected a fragmented ICMPv6 Echo Reply, but got none: %s", err) + } + + id := *gotEchoReplyFirstPart[2].(*testbench.IPv6FragmentExtHdr).Identification + gotFirstPayload, err := gotEchoReplyFirstPart[len(gotEchoReplyFirstPart)-1].ToBytes() + if err != nil { + t.Fatalf("failed to serialize ICMPv6: %s", err) + } + icmpPayload := gotFirstPayload[header.ICMPv6EchoMinimumSize:] + receivedLen := len(icmpPayload) + wantSecondPayloadLen := reassemblyCap - header.ICMPv6EchoMinimumSize - receivedLen + wantFirstPayload := make([]byte, receivedLen) + for i := range wantFirstPayload { + wantFirstPayload[i] = 'A' + } + wantSecondPayload := wantFirstPayload[:wantSecondPayloadLen] + if !bytes.Equal(icmpPayload, wantFirstPayload) { + t.Fatalf("received unexpected payload, got: %s, want: %s", + hex.Dump(icmpPayload), + hex.Dump(wantFirstPayload)) + } + + gotEchoReplySecondPart, err := conn.ExpectFrame(t, testbench.Layers{ + &testbench.Ether{}, + &testbench.IPv6{}, + &testbench.IPv6FragmentExtHdr{ + NextHeader: &icmpv6ProtoNum, + FragmentOffset: testbench.Uint16(uint16((receivedLen + header.ICMPv6EchoMinimumSize) / 8)), + MoreFragments: testbench.Bool(false), + Identification: &id, + }, + &testbench.ICMPv6{}, + }, time.Second) + if err != nil { + t.Fatalf("expected the rest of ICMPv6 Echo Reply, but got none: %s", err) + } + secondPayload, err := gotEchoReplySecondPart[len(gotEchoReplySecondPart)-1].ToBytes() + if err != nil { + t.Fatalf("failed to serialize ICMPv6 Echo Reply: %s", err) + } + if !bytes.Equal(secondPayload, wantSecondPayload) { + t.Fatalf("received unexpected payload, got: %s, want: %s", + hex.Dump(secondPayload), + hex.Dump(wantSecondPayload)) + } +} diff --git a/test/packetimpact/tests/ipv6_unknown_options_action_test.go b/test/packetimpact/tests/ipv6_unknown_options_action_test.go new file mode 100644 index 000000000..e79d74476 --- /dev/null +++ b/test/packetimpact/tests/ipv6_unknown_options_action_test.go @@ -0,0 +1,187 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6_unknown_options_action_test + +import ( + "encoding/binary" + "flag" + "net" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func mkHopByHopOptionsExtHdr(optType byte) testbench.Layer { + return &testbench.IPv6HopByHopOptionsExtHdr{ + Options: []byte{optType, 0x04, 0x00, 0x00, 0x00, 0x00}, + } +} + +func mkDestinationOptionsExtHdr(optType byte) testbench.Layer { + return &testbench.IPv6DestinationOptionsExtHdr{ + Options: []byte{optType, 0x04, 0x00, 0x00, 0x00, 0x00}, + } +} + +func optionTypeFromAction(action header.IPv6OptionUnknownAction) byte { + return byte(action << 6) +} + +func TestIPv6UnknownOptionAction(t *testing.T) { + for _, tt := range []struct { + description string + mkExtHdr func(optType byte) testbench.Layer + action header.IPv6OptionUnknownAction + multicastDst bool + wantICMPv6 bool + }{ + { + description: "0b00/hbh", + mkExtHdr: mkHopByHopOptionsExtHdr, + action: header.IPv6OptionUnknownActionSkip, + multicastDst: false, + wantICMPv6: false, + }, + { + description: "0b01/hbh", + mkExtHdr: mkHopByHopOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscard, + multicastDst: false, + wantICMPv6: false, + }, + { + description: "0b10/hbh/unicast", + mkExtHdr: mkHopByHopOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscardSendICMP, + multicastDst: false, + wantICMPv6: true, + }, + { + description: "0b10/hbh/multicast", + mkExtHdr: mkHopByHopOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscardSendICMP, + multicastDst: true, + wantICMPv6: true, + }, + { + description: "0b11/hbh/unicast", + mkExtHdr: mkHopByHopOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest, + multicastDst: false, + wantICMPv6: true, + }, + { + description: "0b11/hbh/multicast", + mkExtHdr: mkHopByHopOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest, + multicastDst: true, + wantICMPv6: false, + }, + { + description: "0b00/destination", + mkExtHdr: mkDestinationOptionsExtHdr, + action: header.IPv6OptionUnknownActionSkip, + multicastDst: false, + wantICMPv6: false, + }, + { + description: "0b01/destination", + mkExtHdr: mkDestinationOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscard, + multicastDst: false, + wantICMPv6: false, + }, + { + description: "0b10/destination/unicast", + mkExtHdr: mkDestinationOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscardSendICMP, + multicastDst: false, + wantICMPv6: true, + }, + { + description: "0b10/destination/multicast", + mkExtHdr: mkDestinationOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscardSendICMP, + multicastDst: true, + wantICMPv6: true, + }, + { + description: "0b11/destination/unicast", + mkExtHdr: mkDestinationOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest, + multicastDst: false, + wantICMPv6: true, + }, + { + description: "0b11/destination/multicast", + mkExtHdr: mkDestinationOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest, + multicastDst: true, + wantICMPv6: false, + }, + } { + t.Run(tt.description, func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + ipv6Conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{}) + conn := (*testbench.Connection)(&ipv6Conn) + defer ipv6Conn.Close(t) + + outgoingOverride := testbench.Layers{} + if tt.multicastDst { + outgoingOverride = testbench.Layers{&testbench.IPv6{ + DstAddr: testbench.Address(tcpip.Address(net.ParseIP("ff02::1"))), + }} + } + + outgoing := conn.CreateFrame(t, outgoingOverride, tt.mkExtHdr(optionTypeFromAction(tt.action))) + conn.SendFrame(t, outgoing) + ipv6Sent := outgoing[1:] + invokingPacket, err := ipv6Sent.ToBytes() + if err != nil { + t.Fatalf("failed to serialize the outgoing packet: %s", err) + } + icmpv6Payload := make([]byte, 4) + // The pointer in the ICMPv6 parameter problem message should point to + // the option type of the unknown option. In our test case, it is the + // first option in the extension header whose option type is 2 bytes + // after the IPv6 header (after NextHeader and ExtHdrLen). + binary.BigEndian.PutUint32(icmpv6Payload, header.IPv6MinimumSize+2) + icmpv6Payload = append(icmpv6Payload, invokingPacket...) + gotICMPv6, err := ipv6Conn.ExpectFrame(t, testbench.Layers{ + &testbench.Ether{}, + &testbench.IPv6{}, + &testbench.ICMPv6{ + Type: testbench.ICMPv6Type(header.ICMPv6ParamProblem), + Code: testbench.ICMPv6Code(header.ICMPv6UnknownOption), + Payload: icmpv6Payload, + }, + }, time.Second) + if tt.wantICMPv6 && err != nil { + t.Fatalf("expected ICMPv6 Parameter Problem but got none: %s", err) + } + if !tt.wantICMPv6 && gotICMPv6 != nil { + t.Fatalf("expected no ICMPv6 Parameter Problem but got one: %s", gotICMPv6) + } + }) + } +} diff --git a/test/packetimpact/tests/tcp_close_wait_ack_test.go b/test/packetimpact/tests/tcp_close_wait_ack_test.go new file mode 100644 index 000000000..e6a96f214 --- /dev/null +++ b/test/packetimpact/tests/tcp_close_wait_ack_test.go @@ -0,0 +1,109 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_close_wait_ack_test + +import ( + "flag" + "fmt" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func TestCloseWaitAck(t *testing.T) { + for _, tt := range []struct { + description string + makeTestingTCP func(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset, windowSize seqnum.Size) testbench.TCP + seqNumOffset seqnum.Size + expectAck bool + }{ + {"OTW", generateOTWSeqSegment, 0, false}, + {"OTW", generateOTWSeqSegment, 1, true}, + {"OTW", generateOTWSeqSegment, 2, true}, + {"ACK", generateUnaccACKSegment, 0, false}, + {"ACK", generateUnaccACKSegment, 1, true}, + {"ACK", generateUnaccACKSegment, 2, true}, + } { + t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + + // Send a FIN to DUT to intiate the active close + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}) + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("expected an ACK for our fin and DUT should enter CLOSE_WAIT: %s", err) + } + windowSize := seqnum.Size(*gotTCP.WindowSize) + + // Send a segment with OTW Seq / unacc ACK and expect an ACK back + conn.Send(t, tt.makeTestingTCP(t, &conn, tt.seqNumOffset, windowSize), &testbench.Payload{Bytes: []byte("Sample Data")}) + gotAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + if tt.expectAck && err != nil { + t.Fatalf("expected an ack but got none: %s", err) + } + if !tt.expectAck && gotAck != nil { + t.Fatalf("expected no ack but got one: %s", gotAck) + } + + // Now let's verify DUT is indeed in CLOSE_WAIT + dut.Close(t, acceptFd) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil { + t.Fatalf("expected DUT to send a FIN: %s", err) + } + // Ack the FIN from DUT + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + // Send some extra data to DUT + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: []byte("Sample Data")}) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil { + t.Fatalf("expected DUT to send an RST: %s", err) + } + }) + } +} + +// generateOTWSeqSegment generates an segment with +// seqnum = RCV.NXT + RCV.WND + seqNumOffset, the generated segment is only +// acceptable when seqNumOffset is 0, otherwise an ACK is expected from the +// receiver. +func generateOTWSeqSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP { + lastAcceptable := conn.LocalSeqNum(t).Add(windowSize) + otwSeq := uint32(lastAcceptable.Add(seqNumOffset)) + return testbench.TCP{SeqNum: testbench.Uint32(otwSeq), Flags: testbench.Uint8(header.TCPFlagAck)} +} + +// generateUnaccACKSegment generates an segment with +// acknum = SND.NXT + seqNumOffset, the generated segment is only acceptable +// when seqNumOffset is 0, otherwise an ACK is expected from the receiver. +func generateUnaccACKSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP { + lastAcceptable := conn.RemoteSeqNum(t) + unaccAck := uint32(lastAcceptable.Add(seqNumOffset)) + return testbench.TCP{AckNum: testbench.Uint32(unaccAck), Flags: testbench.Uint8(header.TCPFlagAck)} +} diff --git a/test/packetimpact/tests/tcp_cork_mss_test.go b/test/packetimpact/tests/tcp_cork_mss_test.go new file mode 100644 index 000000000..8feea4a82 --- /dev/null +++ b/test/packetimpact/tests/tcp_cork_mss_test.go @@ -0,0 +1,84 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_cork_mss_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestTCPCorkMSS tests for segment coalesce and split as per MSS. +func TestTCPCorkMSS(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + const mss = uint32(header.TCPDefaultMSS) + options := make([]byte, header.TCPOptionMSSLength) + header.EncodeMSSOption(mss, options) + conn.ConnectWithOptions(t, options) + + acceptFD, _ := dut.Accept(t, listenFD) + defer dut.Close(t, acceptFD) + + dut.SetSockOptInt(t, acceptFD, unix.IPPROTO_TCP, unix.TCP_CORK, 1) + + // Let the dut application send 2 small segments to be held up and coalesced + // until the application sends a larger segment to fill up to > MSS. + sampleData := []byte("Sample Data") + dut.Send(t, acceptFD, sampleData, 0) + dut.Send(t, acceptFD, sampleData, 0) + + expectedData := sampleData + expectedData = append(expectedData, sampleData...) + largeData := make([]byte, mss+1) + expectedData = append(expectedData, largeData...) + dut.Send(t, acceptFD, largeData, 0) + + // Expect the segments to be coalesced and sent and capped to MSS. + expectedPayload := testbench.Payload{Bytes: expectedData[:mss]} + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &expectedPayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + // Expect the coalesced segment to be split and transmitted. + expectedPayload = testbench.Payload{Bytes: expectedData[mss:]} + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + + // Check for segments to *not* be held up because of TCP_CORK when + // the current send window is less than MSS. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(2 * len(sampleData)))}) + dut.Send(t, acceptFD, sampleData, 0) + dut.Send(t, acceptFD, sampleData, 0) + expectedPayload = testbench.Payload{Bytes: append(sampleData, sampleData...)} + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) +} diff --git a/test/packetimpact/tests/tcp_handshake_window_size_test.go b/test/packetimpact/tests/tcp_handshake_window_size_test.go new file mode 100644 index 000000000..22937d92f --- /dev/null +++ b/test/packetimpact/tests/tcp_handshake_window_size_test.go @@ -0,0 +1,66 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_handshake_window_size_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestTCPHandshakeWindowSize tests if the stack is honoring the window size +// communicated during handshake. +func TestTCPHandshakeWindowSize(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + // Start handshake with zero window size. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), WindowSize: testbench.Uint16(uint16(0))}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { + t.Fatalf("expected SYN-ACK: %s", err) + } + // Update the advertised window size to a non-zero value with the ACK that + // completes the handshake. + // + // Set the window size with MSB set and expect the dut to treat it as + // an unsigned value. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(1 << 15))}) + + acceptFd, _ := dut.Accept(t, listenFD) + defer dut.Close(t, acceptFd) + + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + + // Since we advertised a zero window followed by a non-zero window, + // expect the dut to honor the recently advertised non-zero window + // and actually send out the data instead of probing for zero window. + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectNextData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } +} diff --git a/test/packetimpact/tests/tcp_network_unreachable_test.go b/test/packetimpact/tests/tcp_network_unreachable_test.go new file mode 100644 index 000000000..2f57dff19 --- /dev/null +++ b/test/packetimpact/tests/tcp_network_unreachable_test.go @@ -0,0 +1,141 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_synsent_reset_test + +import ( + "context" + "flag" + "net" + "syscall" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestTCPSynSentUnreachable verifies that TCP connections fail immediately when +// an ICMP destination unreachable message is sent in response to the inital +// SYN. +func TestTCPSynSentUnreachable(t *testing.T) { + // Create the DUT and connection. + dut := testbench.NewDUT(t) + defer dut.TearDown() + clientFD, clientPort := dut.CreateBoundSocket(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4)) + port := uint16(9001) + conn := testbench.NewTCPIPv4(t, testbench.TCP{SrcPort: &port, DstPort: &clientPort}, testbench.TCP{SrcPort: &clientPort, DstPort: &port}) + defer conn.Close(t) + + // Bring the DUT to SYN-SENT state with a non-blocking connect. + ctx, cancel := context.WithTimeout(context.Background(), testbench.RPCTimeout) + defer cancel() + sa := unix.SockaddrInet4{Port: int(port)} + copy(sa.Addr[:], net.IP(net.ParseIP(testbench.LocalIPv4)).To4()) + if _, err := dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.EINPROGRESS) { + t.Errorf("expected connect to fail with EINPROGRESS, but got %v", err) + } + + // Get the SYN. + tcpLayers, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, nil, time.Second) + if err != nil { + t.Fatalf("expected SYN: %s", err) + } + + // Send a host unreachable message. + rawConn := (*testbench.Connection)(&conn) + layers := rawConn.CreateFrame(t, nil) + layers = layers[:len(layers)-1] + const ipLayer = 1 + const tcpLayer = ipLayer + 1 + ip, ok := tcpLayers[ipLayer].(*testbench.IPv4) + if !ok { + t.Fatalf("expected %s to be IPv4", tcpLayers[ipLayer]) + } + tcp, ok := tcpLayers[tcpLayer].(*testbench.TCP) + if !ok { + t.Fatalf("expected %s to be TCP", tcpLayers[tcpLayer]) + } + var icmpv4 testbench.ICMPv4 = testbench.ICMPv4{ + Type: testbench.ICMPv4Type(header.ICMPv4DstUnreachable), + Code: testbench.ICMPv4Code(header.ICMPv4HostUnreachable)} + layers = append(layers, &icmpv4, ip, tcp) + rawConn.SendFrameStateless(t, layers) + + if _, err = dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.EHOSTUNREACH) { + t.Errorf("expected connect to fail with EHOSTUNREACH, but got %v", err) + } +} + +// TestTCPSynSentUnreachable6 verifies that TCP connections fail immediately when +// an ICMP destination unreachable message is sent in response to the inital +// SYN. +func TestTCPSynSentUnreachable6(t *testing.T) { + // Create the DUT and connection. + dut := testbench.NewDUT(t) + defer dut.TearDown() + clientFD, clientPort := dut.CreateBoundSocket(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv6)) + conn := testbench.NewTCPIPv6(t, testbench.TCP{DstPort: &clientPort}, testbench.TCP{SrcPort: &clientPort}) + defer conn.Close(t) + + // Bring the DUT to SYN-SENT state with a non-blocking connect. + ctx, cancel := context.WithTimeout(context.Background(), testbench.RPCTimeout) + defer cancel() + sa := unix.SockaddrInet6{ + Port: int(conn.SrcPort()), + ZoneId: uint32(testbench.RemoteInterfaceID), + } + copy(sa.Addr[:], net.IP(net.ParseIP(testbench.LocalIPv6)).To16()) + if _, err := dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.EINPROGRESS) { + t.Errorf("expected connect to fail with EINPROGRESS, but got %v", err) + } + + // Get the SYN. + tcpLayers, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, nil, time.Second) + if err != nil { + t.Fatalf("expected SYN: %s", err) + } + + // Send a host unreachable message. + rawConn := (*testbench.Connection)(&conn) + layers := rawConn.CreateFrame(t, nil) + layers = layers[:len(layers)-1] + const ipLayer = 1 + const tcpLayer = ipLayer + 1 + ip, ok := tcpLayers[ipLayer].(*testbench.IPv6) + if !ok { + t.Fatalf("expected %s to be IPv6", tcpLayers[ipLayer]) + } + tcp, ok := tcpLayers[tcpLayer].(*testbench.TCP) + if !ok { + t.Fatalf("expected %s to be TCP", tcpLayers[tcpLayer]) + } + var icmpv6 testbench.ICMPv6 = testbench.ICMPv6{ + Type: testbench.ICMPv6Type(header.ICMPv6DstUnreachable), + Code: testbench.ICMPv6Code(header.ICMPv6NetworkUnreachable), + // Per RFC 4443 3.1, the payload contains 4 zeroed bytes. + Payload: []byte{0, 0, 0, 0}, + } + layers = append(layers, &icmpv6, ip, tcp) + rawConn.SendFrameStateless(t, layers) + + if _, err = dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.ENETUNREACH) { + t.Errorf("expected connect to fail with ENETUNREACH, but got %v", err) + } +} diff --git a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go new file mode 100644 index 000000000..82b7a85ff --- /dev/null +++ b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go @@ -0,0 +1,42 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_noaccept_close_rst_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func TestTcpNoAcceptCloseReset(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + conn.Connect(t) + defer conn.Close(t) + dut.Close(t, listenFd) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil { + t.Fatalf("expected a RST-ACK packet but got none: %s", err) + } +} diff --git a/test/packetimpact/tests/tcp_outside_the_window_test.go b/test/packetimpact/tests/tcp_outside_the_window_test.go new file mode 100644 index 000000000..08f759f7c --- /dev/null +++ b/test/packetimpact/tests/tcp_outside_the_window_test.go @@ -0,0 +1,93 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_outside_the_window_test + +import ( + "flag" + "fmt" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestTCPOutsideTheWindows tests the behavior of the DUT when packets arrive +// that are inside or outside the TCP window. Packets that are outside the +// window should force an extra ACK, as described in RFC793 page 69: +// https://tools.ietf.org/html/rfc793#page-69 +func TestTCPOutsideTheWindow(t *testing.T) { + for _, tt := range []struct { + description string + tcpFlags uint8 + payload []testbench.Layer + seqNumOffset seqnum.Size + expectACK bool + }{ + {"SYN", header.TCPFlagSyn, nil, 0, true}, + {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 0, true}, + {"ACK", header.TCPFlagAck, nil, 0, false}, + {"FIN", header.TCPFlagFin, nil, 0, false}, + {"Data", header.TCPFlagAck, []testbench.Layer{&testbench.Payload{Bytes: []byte("abc123")}}, 0, true}, + + {"SYN", header.TCPFlagSyn, nil, 1, true}, + {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 1, true}, + {"ACK", header.TCPFlagAck, nil, 1, true}, + {"FIN", header.TCPFlagFin, nil, 1, false}, + {"Data", header.TCPFlagAck, []testbench.Layer{&testbench.Payload{Bytes: []byte("abc123")}}, 1, true}, + + {"SYN", header.TCPFlagSyn, nil, 2, true}, + {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 2, true}, + {"ACK", header.TCPFlagAck, nil, 2, true}, + {"FIN", header.TCPFlagFin, nil, 2, false}, + {"Data", header.TCPFlagAck, []testbench.Layer{&testbench.Payload{Bytes: []byte("abc123")}}, 2, true}, + } { + t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + conn.Connect(t) + acceptFD, _ := dut.Accept(t, listenFD) + defer dut.Close(t, acceptFD) + + windowSize := seqnum.Size(*conn.SynAck(t).WindowSize) + tt.seqNumOffset + conn.Drain(t) + // Ignore whatever incrementing that this out-of-order packet might cause + // to the AckNum. + localSeqNum := testbench.Uint32(uint32(*conn.LocalSeqNum(t))) + conn.Send(t, testbench.TCP{ + Flags: testbench.Uint8(tt.tcpFlags), + SeqNum: testbench.Uint32(uint32(conn.LocalSeqNum(t).Add(windowSize))), + }, tt.payload...) + timeout := 3 * time.Second + gotACK, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: localSeqNum}, timeout) + if tt.expectACK && err != nil { + t.Fatalf("expected an ACK packet within %s but got none: %s", timeout, err) + } + if !tt.expectACK && gotACK != nil { + t.Fatalf("expected no ACK packet within %s but got one: %s", timeout, gotACK) + } + }) + } +} diff --git a/test/packetimpact/tests/tcp_paws_mechanism_test.go b/test/packetimpact/tests/tcp_paws_mechanism_test.go new file mode 100644 index 000000000..37f3b56dd --- /dev/null +++ b/test/packetimpact/tests/tcp_paws_mechanism_test.go @@ -0,0 +1,109 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_paws_mechanism_test + +import ( + "encoding/hex" + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func TestPAWSMechanism(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + options := make([]byte, header.TCPOptionTSLength) + header.EncodeTSOption(currentTS(), 0, options) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), Options: options}) + synAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("didn't get synack during handshake: %s", err) + } + parsedSynOpts := header.ParseSynOptions(synAck.Options, true) + if !parsedSynOpts.TS { + t.Fatalf("expected TSOpt from DUT, options we got:\n%s", hex.Dump(synAck.Options)) + } + tsecr := parsedSynOpts.TSVal + header.EncodeTSOption(currentTS(), tsecr, options) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}) + acceptFD, _ := dut.Accept(t, listenFD) + defer dut.Close(t, acceptFD) + + sampleData := []byte("Sample Data") + sentTSVal := currentTS() + header.EncodeTSOption(sentTSVal, tsecr, options) + // 3ms here is chosen arbitrarily to make sure we have increasing timestamps + // every time we send one, it should not cause any flakiness because timestamps + // only need to be non-decreasing. + time.Sleep(3 * time.Millisecond) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData}) + + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("expected an ACK but got none: %s", err) + } + + parsedOpts := header.ParseTCPOptions(gotTCP.Options) + if !parsedOpts.TS { + t.Fatalf("expected TS option in response, options we got:\n%s", hex.Dump(gotTCP.Options)) + } + if parsedOpts.TSVal < tsecr { + t.Fatalf("TSVal should be non-decreasing, but %d < %d", parsedOpts.TSVal, tsecr) + } + if parsedOpts.TSEcr != sentTSVal { + t.Fatalf("TSEcr should match our sent TSVal, %d != %d", parsedOpts.TSEcr, sentTSVal) + } + tsecr = parsedOpts.TSVal + lastAckNum := gotTCP.AckNum + + badTSVal := sentTSVal - 100 + header.EncodeTSOption(badTSVal, tsecr, options) + // 3ms here is chosen arbitrarily and this time.Sleep() should not cause flakiness + // due to the exact same reasoning discussed above. + time.Sleep(3 * time.Millisecond) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData}) + + gotTCP, err = conn.Expect(t, testbench.TCP{AckNum: lastAckNum, Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("expected segment with AckNum %d but got none: %s", lastAckNum, err) + } + parsedOpts = header.ParseTCPOptions(gotTCP.Options) + if !parsedOpts.TS { + t.Fatalf("expected TS option in response, options we got:\n%s", hex.Dump(gotTCP.Options)) + } + if parsedOpts.TSVal < tsecr { + t.Fatalf("TSVal should be non-decreasing, but %d < %d", parsedOpts.TSVal, tsecr) + } + if parsedOpts.TSEcr != sentTSVal { + t.Fatalf("TSEcr should match our sent TSVal, %d != %d", parsedOpts.TSEcr, sentTSVal) + } +} + +func currentTS() uint32 { + return uint32(time.Now().UnixNano() / 1e6) +} diff --git a/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go b/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go new file mode 100644 index 000000000..d9f3ea0f2 --- /dev/null +++ b/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go @@ -0,0 +1,132 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_queue_receive_in_syn_sent_test + +import ( + "bytes" + "context" + "encoding/hex" + "errors" + "flag" + "net" + "sync" + "syscall" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestQueueReceiveInSynSent tests receive behavior when the TCP state +// is SYN-SENT. +// It tests for 2 variants where the receive is blocked and: +// (1) we complete handshake and send sample data. +// (2) we send a TCP RST. +func TestQueueReceiveInSynSent(t *testing.T) { + for _, tt := range []struct { + description string + reset bool + }{ + {description: "Send DATA", reset: false}, + {description: "Send RST", reset: true}, + } { + t.Run(tt.description, func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + + socket, remotePort := dut.CreateBoundSocket(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4)) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + sampleData := []byte("Sample Data") + + dut.SetNonBlocking(t, socket, true) + if _, err := dut.ConnectWithErrno(context.Background(), t, socket, conn.LocalAddr(t)); !errors.Is(err, syscall.EINPROGRESS) { + t.Fatalf("failed to bring DUT to SYN-SENT, got: %s, want EINPROGRESS", err) + } + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, time.Second); err != nil { + t.Fatalf("expected a SYN from DUT, but got none: %s", err) + } + + if _, _, err := dut.RecvWithErrno(context.Background(), t, socket, int32(len(sampleData)), 0); err != syscall.Errno(unix.EWOULDBLOCK) { + t.Fatalf("expected error %s, got %s", syscall.Errno(unix.EWOULDBLOCK), err) + } + + // Test blocking read. + dut.SetNonBlocking(t, socket, false) + + var wg sync.WaitGroup + defer wg.Wait() + wg.Add(1) + var block sync.WaitGroup + block.Add(1) + go func() { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) + defer cancel() + + block.Done() + // Issue RECEIVE call in SYN-SENT, this should be queued for + // process until the connection is established. + n, buff, err := dut.RecvWithErrno(ctx, t, socket, int32(len(sampleData)), 0) + if tt.reset { + if err != syscall.Errno(unix.ECONNREFUSED) { + t.Errorf("expected error %s, got %s", syscall.Errno(unix.ECONNREFUSED), err) + } + if n != -1 { + t.Errorf("expected return value %d, got %d", -1, n) + } + return + } + if n == -1 { + t.Errorf("failed to recv on DUT: %s", err) + } + if got := buff[:n]; !bytes.Equal(got, sampleData) { + t.Errorf("received data doesn't match, got:\n%s, want:\n%s", hex.Dump(got), hex.Dump(sampleData)) + } + }() + + // Wait for the goroutine to be scheduled and before it + // blocks on endpoint receive. + block.Wait() + // The following sleep is used to prevent the connection + // from being established before we are blocked on Recv. + time.Sleep(100 * time.Millisecond) + + if tt.reset { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}) + return + } + + // Bring the connection to Established. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected an ACK from DUT, but got none: %s", err) + } + + // Send sample payload and expect an ACK. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData}) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected an ACK from DUT, but got none: %s", err) + } + }) + } +} diff --git a/test/packetimpact/tests/tcp_reordering_test.go b/test/packetimpact/tests/tcp_reordering_test.go new file mode 100644 index 000000000..b4aeaab57 --- /dev/null +++ b/test/packetimpact/tests/tcp_reordering_test.go @@ -0,0 +1,174 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reordering_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + tb.RegisterFlags(flag.CommandLine) +} + +func TestReorderingWindow(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) + conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + // Enable SACK. + opts := make([]byte, 40) + optsOff := 0 + optsOff += header.EncodeNOP(opts[optsOff:]) + optsOff += header.EncodeNOP(opts[optsOff:]) + optsOff += header.EncodeSACKPermittedOption(opts[optsOff:]) + + // Ethernet guarantees that the MTU is at least 1500 bytes. + const minMTU = 1500 + const mss = minMTU - header.IPv4MinimumSize - header.TCPMinimumSize + optsOff += header.EncodeMSSOption(mss, opts[optsOff:]) + + conn.ConnectWithOptions(t, opts[:optsOff]) + + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) + + if tb.Native { + // Linux has changed its handling of reordering, force the old behavior. + dut.SetSockOpt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_CONGESTION, []byte("reno")) + } + + pls := dut.GetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_MAXSEG) + if !tb.Native { + // netstack does not impliment TCP_MAXSEG correctly. Fake it + // here. Netstack uses the max SACK size which is 32. The MSS + // option is 8 bytes, making the total 36 bytes. + pls = mss - 36 + } + + payload := make([]byte, pls) + + seqNum1 := *conn.RemoteSeqNum(t) + const numPkts = 10 + // Send some packets, checking that we receive each. + for i, sn := 0, seqNum1; i < numPkts; i++ { + dut.Send(t, acceptFd, payload, 0) + + gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second) + sn.UpdateForward(seqnum.Size(len(payload))) + if err != nil { + t.Errorf("Expect #%d: %s", i+1, err) + continue + } + if gotOne == nil { + t.Errorf("#%d: expected a packet within a second but got none", i+1) + } + } + + seqNum2 := *conn.RemoteSeqNum(t) + + // SACK packets #2-4. + sackBlock := make([]byte, 40) + sbOff := 0 + sbOff += header.EncodeNOP(sackBlock[sbOff:]) + sbOff += header.EncodeNOP(sackBlock[sbOff:]) + sbOff += header.EncodeSACKBlocks([]header.SACKBlock{{ + seqNum1.Add(seqnum.Size(len(payload))), + seqNum1.Add(seqnum.Size(4 * len(payload))), + }}, sackBlock[sbOff:]) + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]}) + + // ACK first packet. + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum1) + uint32(len(payload)))}) + + // Check for retransmit. + gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(seqNum1))}, time.Second) + if err != nil { + t.Error("Expect for retransmit:", err) + } + if gotOne == nil { + t.Error("expected a retransmitted packet within a second but got none") + } + + // ACK all send packets with a DSACK block for packet #1. This tells + // the other end that we got both the original and retransmit for + // packet #1. + dsackBlock := make([]byte, 40) + dsbOff := 0 + dsbOff += header.EncodeNOP(dsackBlock[dsbOff:]) + dsbOff += header.EncodeNOP(dsackBlock[dsbOff:]) + dsbOff += header.EncodeSACKBlocks([]header.SACKBlock{{ + seqNum1.Add(seqnum.Size(len(payload))), + seqNum1.Add(seqnum.Size(4 * len(payload))), + }}, dsackBlock[dsbOff:]) + + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum2)), Options: dsackBlock[:dsbOff]}) + + // Send half of the original window of packets, checking that we + // received each. + for i, sn := 0, seqNum2; i < numPkts/2; i++ { + dut.Send(t, acceptFd, payload, 0) + + gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second) + sn.UpdateForward(seqnum.Size(len(payload))) + if err != nil { + t.Errorf("Expect #%d: %s", i+1, err) + continue + } + if gotOne == nil { + t.Errorf("#%d: expected a packet within a second but got none", i+1) + } + } + + if !tb.Native { + // The window should now be halved, so we should receive any + // more, even if we send them. + dut.Send(t, acceptFd, payload, 0) + if got, err := conn.Expect(t, tb.TCP{}, 100*time.Millisecond); got != nil || err == nil { + t.Fatalf("expected no packets within 100 millisecond, but got one: %s", got) + } + return + } + + // Linux reduces the window by three. Check that we can receive the rest. + for i, sn := 0, seqNum2.Add(seqnum.Size(numPkts/2*len(payload))); i < 2; i++ { + dut.Send(t, acceptFd, payload, 0) + + gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second) + sn.UpdateForward(seqnum.Size(len(payload))) + if err != nil { + t.Errorf("Expect #%d: %s", i+1, err) + continue + } + if gotOne == nil { + t.Errorf("#%d: expected a packet within a second but got none", i+1) + } + } + + // The window should now be full. + dut.Send(t, acceptFd, payload, 0) + if got, err := conn.Expect(t, tb.TCP{}, 100*time.Millisecond); got != nil || err == nil { + t.Fatalf("expected no packets within 100 millisecond, but got one: %s", got) + } +} diff --git a/test/packetimpact/tests/tcp_retransmits_test.go b/test/packetimpact/tests/tcp_retransmits_test.go new file mode 100644 index 000000000..072014ff8 --- /dev/null +++ b/test/packetimpact/tests/tcp_retransmits_test.go @@ -0,0 +1,84 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_retransmits_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestRetransmits tests retransmits occur at exponentially increasing +// time intervals. +func TestRetransmits(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) + + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + // Give a chance for the dut to estimate RTO with RTT from the DATA-ACK. + // TODO(gvisor.dev/issue/2685) Estimate RTO during handshake, after which + // we can skip sending this ACK. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + + startRTO := time.Second + current := startRTO + first := time.Now() + dut.Send(t, acceptFd, sampleData, 0) + seq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t))) + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: seq}, samplePayload, startRTO); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + // Expect retransmits of the same segment. + for i := 0; i < 5; i++ { + start := time.Now() + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: seq}, samplePayload, 2*current); err != nil { + t.Fatalf("expected payload was not received: %s loop %d", err, i) + } + if i == 0 { + startRTO = time.Now().Sub(first) + current = 2 * startRTO + continue + } + // Check if the probes came at exponentially increasing intervals. + if p := time.Since(start); p < current-startRTO { + t.Fatalf("retransmit came sooner interval %d probe %d", p, i) + } + current *= 2 + } +} diff --git a/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go b/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go new file mode 100644 index 000000000..f91b06ba1 --- /dev/null +++ b/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go @@ -0,0 +1,105 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_send_window_sizes_piggyback_test + +import ( + "flag" + "fmt" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestSendWindowSizesPiggyback tests cases where segment sizes are close to +// sender window size and checks for ACK piggybacking for each of those case. +func TestSendWindowSizesPiggyback(t *testing.T) { + sampleData := []byte("Sample Data") + segmentSize := uint16(len(sampleData)) + // Advertise receive window sizes that are lesser, equal to or greater than + // enqueued segment size and check for segment transmits. The test attempts + // to enqueue a segment on the dut before acknowledging previous segment and + // lets the dut piggyback any ACKs along with the enqueued segment. + for _, tt := range []struct { + description string + windowSize uint16 + expectedPayload1 []byte + expectedPayload2 []byte + enqueue bool + }{ + // Expect the first segment to be split as it cannot be accomodated in + // the sender window. This means we need not enqueue a new segment after + // the first segment. + {"WindowSmallerThanSegment", segmentSize - 1, sampleData[:(segmentSize - 1)], sampleData[(segmentSize - 1):], false /* enqueue */}, + + {"WindowEqualToSegment", segmentSize, sampleData, sampleData, true /* enqueue */}, + + // Expect the second segment to not be split as its size is greater than + // the available sender window size. The segments should not be split + // when there is pending unacknowledged data and the segment-size is + // greater than available sender window. + {"WindowGreaterThanSegment", segmentSize + 1, sampleData, sampleData, true /* enqueue */}, + } { + t.Run(fmt.Sprintf("%s%d", tt.description, tt.windowSize), func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) + + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort, WindowSize: testbench.Uint16(tt.windowSize)}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) + + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + + expectedTCP := testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)} + + dut.Send(t, acceptFd, sampleData, 0) + expectedPayload := testbench.Payload{Bytes: tt.expectedPayload1} + if _, err := conn.ExpectData(t, &expectedTCP, &expectedPayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + + // Expect any enqueued segment to be transmitted by the dut along with + // piggybacked ACK for our data. + + if tt.enqueue { + // Enqueue a segment for the dut to transmit. + dut.Send(t, acceptFd, sampleData, 0) + } + + // Send ACK for the previous segment along with data for the dut to + // receive and ACK back. Sending this ACK would make room for the dut + // to transmit any enqueued segment. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh), WindowSize: testbench.Uint16(tt.windowSize)}, &testbench.Payload{Bytes: sampleData}) + + // Expect the dut to piggyback the ACK for received data along with + // the segment enqueued for transmit. + expectedPayload = testbench.Payload{Bytes: tt.expectedPayload2} + if _, err := conn.ExpectData(t, &expectedTCP, &expectedPayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + }) + } +} diff --git a/test/packetimpact/tests/tcp_synrcvd_reset_test.go b/test/packetimpact/tests/tcp_synrcvd_reset_test.go new file mode 100644 index 000000000..57d034dd1 --- /dev/null +++ b/test/packetimpact/tests/tcp_synrcvd_reset_test.go @@ -0,0 +1,52 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_syn_reset_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestTCPSynRcvdReset tests transition from SYN-RCVD to CLOSED. +func TestTCPSynRcvdReset(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + // Expect dut connection to have transitioned to SYN-RCVD state. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { + t.Fatalf("expected SYN-ACK %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}) + // Expect the connection to have transitioned SYN-RCVD to CLOSED. + // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { + t.Fatalf("expected a TCP RST %s", err) + } +} diff --git a/test/packetimpact/tests/tcp_synsent_reset_test.go b/test/packetimpact/tests/tcp_synsent_reset_test.go new file mode 100644 index 000000000..eac8eb19d --- /dev/null +++ b/test/packetimpact/tests/tcp_synsent_reset_test.go @@ -0,0 +1,90 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_synsent_reset_test + +import ( + "flag" + "net" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + tb.RegisterFlags(flag.CommandLine) +} + +// dutSynSentState sets up the dut connection in SYN-SENT state. +func dutSynSentState(t *testing.T) (*tb.DUT, *tb.TCPIPv4, uint16, uint16) { + t.Helper() + + dut := tb.NewDUT(t) + + clientFD, clientPort := dut.CreateBoundSocket(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(tb.RemoteIPv4)) + port := uint16(9001) + conn := tb.NewTCPIPv4(t, tb.TCP{SrcPort: &port, DstPort: &clientPort}, tb.TCP{SrcPort: &clientPort, DstPort: &port}) + + sa := unix.SockaddrInet4{Port: int(port)} + copy(sa.Addr[:], net.IP(net.ParseIP(tb.LocalIPv4)).To4()) + // Bring the dut to SYN-SENT state with a non-blocking connect. + dut.Connect(t, clientFD, &sa) + if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)}, nil, time.Second); err != nil { + t.Fatalf("expected SYN\n") + } + + return &dut, &conn, port, clientPort +} + +// TestTCPSynSentReset tests RFC793, p67: SYN-SENT to CLOSED transition. +func TestTCPSynSentReset(t *testing.T) { + dut, conn, _, _ := dutSynSentState(t) + defer conn.Close(t) + defer dut.TearDown() + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagRst | header.TCPFlagAck)}) + // Expect the connection to have closed. + // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side. + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) + if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { + t.Fatalf("expected a TCP RST") + } +} + +// TestTCPSynSentRcvdReset tests RFC793, p70, SYN-SENT to SYN-RCVD to CLOSED +// transitions. +func TestTCPSynSentRcvdReset(t *testing.T) { + dut, c, remotePort, clientPort := dutSynSentState(t) + defer dut.TearDown() + defer c.Close(t) + + conn := tb.NewTCPIPv4(t, tb.TCP{SrcPort: &remotePort, DstPort: &clientPort}, tb.TCP{SrcPort: &clientPort, DstPort: &remotePort}) + defer conn.Close(t) + // Initiate new SYN connection with the same port pair + // (simultaneous open case), expect the dut connection to move to + // SYN-RCVD state + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)}) + if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { + t.Fatalf("expected SYN-ACK %s\n", err) + } + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}) + // Expect the connection to have transitioned SYN-RCVD to CLOSED. + // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side. + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) + if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { + t.Fatalf("expected a TCP RST") + } +} diff --git a/test/packetimpact/tests/tcp_user_timeout_test.go b/test/packetimpact/tests/tcp_user_timeout_test.go new file mode 100644 index 000000000..551dc78e7 --- /dev/null +++ b/test/packetimpact/tests/tcp_user_timeout_test.go @@ -0,0 +1,100 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_user_timeout_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func sendPayload(t *testing.T, conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) { + sampleData := make([]byte, 100) + for i := range sampleData { + sampleData[i] = uint8(i) + } + conn.Drain(t) + dut.Send(t, fd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &testbench.Payload{Bytes: sampleData}, time.Second); err != nil { + t.Fatalf("expected data but got none: %w", err) + } +} + +func sendFIN(t *testing.T, conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) { + dut.Close(t, fd) +} + +func TestTCPUserTimeout(t *testing.T) { + for _, tt := range []struct { + description string + userTimeout time.Duration + sendDelay time.Duration + }{ + {"NoUserTimeout", 0, 3 * time.Second}, + {"ACKBeforeUserTimeout", 5 * time.Second, 4 * time.Second}, + {"ACKAfterUserTimeout", 5 * time.Second, 7 * time.Second}, + } { + for _, ttf := range []struct { + description string + f func(_ *testing.T, _ *testbench.TCPIPv4, _ *testbench.DUT, fd int32) + }{ + {"AfterPayload", sendPayload}, + {"AfterFIN", sendFIN}, + } { + t.Run(tt.description+ttf.description, func(t *testing.T) { + // Create a socket, listen, TCP handshake, and accept. + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + conn.Connect(t) + acceptFD, _ := dut.Accept(t, listenFD) + + if tt.userTimeout != 0 { + dut.SetSockOptInt(t, acceptFD, unix.SOL_TCP, unix.TCP_USER_TIMEOUT, int32(tt.userTimeout.Milliseconds())) + } + + ttf.f(t, &conn, &dut, acceptFD) + + time.Sleep(tt.sendDelay) + conn.Drain(t) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + + // If TCP_USER_TIMEOUT was set and the above delay was longer than the + // TCP_USER_TIMEOUT then the DUT should send a RST in response to the + // testbench's packet. + expectRST := tt.userTimeout != 0 && tt.sendDelay > tt.userTimeout + expectTimeout := 5 * time.Second + got, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, expectTimeout) + if expectRST && err != nil { + t.Errorf("expected RST packet within %s but got none: %s", expectTimeout, err) + } + if !expectRST && got != nil { + t.Errorf("expected no RST packet within %s but got one: %s", expectTimeout, got) + } + }) + } + } +} diff --git a/test/packetimpact/tests/tcp_window_shrink_test.go b/test/packetimpact/tests/tcp_window_shrink_test.go new file mode 100644 index 000000000..5b001fbec --- /dev/null +++ b/test/packetimpact/tests/tcp_window_shrink_test.go @@ -0,0 +1,73 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_window_shrink_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func TestWindowShrink(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) + + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + + dut.Send(t, acceptFd, sampleData, 0) + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + // We close our receiving window here + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + + dut.Send(t, acceptFd, []byte("Sample Data"), 0) + // Note: There is another kind of zero-window probing which Windows uses (by sending one + // new byte at `RemoteSeqNum`), if netstack wants to go that way, we may want to change + // the following lines. + expectedRemoteSeqNum := *conn.RemoteSeqNum(t) - 1 + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: testbench.Uint32(uint32(expectedRemoteSeqNum))}, nil, time.Second); err != nil { + t.Fatalf("expected a packet with sequence number %d: %s", expectedRemoteSeqNum, err) + } +} diff --git a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go new file mode 100644 index 000000000..da93267d6 --- /dev/null +++ b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go @@ -0,0 +1,104 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_zero_window_probe_retransmit_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestZeroWindowProbeRetransmit tests retransmits of zero window probes +// to be sent at exponentially inreasing time intervals. +func TestZeroWindowProbeRetransmit(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) + + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + + // Send and receive sample data to the dut. + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { + t.Fatalf("expected packet was not received: %s", err) + } + + // Check for the dut to keep the connection alive as long as the zero window + // probes are acknowledged. Check if the zero window probes are sent at + // exponentially increasing intervals. The timeout intervals are function + // of the recorded first zero probe transmission duration. + // + // Advertize zero receive window again. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1)) + ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum(t))) + + startProbeDuration := time.Second + current := startProbeDuration + first := time.Now() + // Ask the dut to send out data. + dut.Send(t, acceptFd, sampleData, 0) + // Expect the dut to keep the connection alive as long as the remote is + // acknowledging the zero-window probes. + for i := 0; i < 5; i++ { + start := time.Now() + // Expect zero-window probe with a timeout which is a function of the typical + // first retransmission time. The retransmission times is supposed to + // exponentially increase. + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, 2*current); err != nil { + t.Fatalf("expected a probe with sequence number %d: loop %d", probeSeq, i) + } + if i == 0 { + startProbeDuration = time.Now().Sub(first) + current = 2 * startProbeDuration + continue + } + // Check if the probes came at exponentially increasing intervals. + if got, want := time.Since(start), current-startProbeDuration; got < want { + t.Errorf("got zero probe %d after %s, want >= %s", i, got, want) + } + // Acknowledge the zero-window probes from the dut. + conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + current *= 2 + } + // Advertize non-zero window. + conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)}) + // Expect the dut to recover and transmit data. + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } +} diff --git a/test/packetimpact/tests/tcp_zero_window_probe_test.go b/test/packetimpact/tests/tcp_zero_window_probe_test.go new file mode 100644 index 000000000..44cac42f8 --- /dev/null +++ b/test/packetimpact/tests/tcp_zero_window_probe_test.go @@ -0,0 +1,112 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_zero_window_probe_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestZeroWindowProbe tests few cases of zero window probing over the +// same connection. +func TestZeroWindowProbe(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) + + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + + start := time.Now() + // Send and receive sample data to the dut. + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + sendTime := time.Now().Sub(start) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { + t.Fatalf("expected packet was not received: %s", err) + } + + // Test 1: Check for receive of a zero window probe, record the duration for + // probe to be sent. + // + // Advertize zero window to the dut. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + + // Expected sequence number of the zero window probe. + probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1)) + // Expected ack number of the ACK for the probe. + ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum(t))) + + // Expect there are no zero-window probes sent until there is data to be sent out + // from the dut. + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, 2*time.Second); err == nil { + t.Fatalf("unexpected packet with sequence number %d: %s", probeSeq, err) + } + + start = time.Now() + // Ask the dut to send out data. + dut.Send(t, acceptFd, sampleData, 0) + // Expect zero-window probe from the dut. + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil { + t.Fatalf("expected a packet with sequence number %d: %s", probeSeq, err) + } + // Expect the probe to be sent after some time. Compare against the previous + // time recorded when the dut immediately sends out data on receiving the + // send command. + if startProbeDuration := time.Now().Sub(start); startProbeDuration <= sendTime { + t.Fatalf("expected the first probe to be sent out after retransmission interval, got %s want > %s", startProbeDuration, sendTime) + } + + // Test 2: Check if the dut recovers on advertizing non-zero receive window. + // and sends out the sample payload after the send window opens. + // + // Advertize non-zero window to the dut and ack the zero window probe. + conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)}) + // Expect the dut to recover and transmit data. + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + + // Test 3: Sanity check for dut's processing of a similar probe it sent. + // Check if the dut responds as we do for a similar probe sent to it. + // Basically with sequence number to one byte behind the unacknowledged + // sequence number. + p := testbench.Uint32(uint32(*conn.LocalSeqNum(t))) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), SeqNum: testbench.Uint32(uint32(*conn.LocalSeqNum(t) - 1))}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: p}, nil, time.Second); err != nil { + t.Fatalf("expected a packet with ack number: %d: %s", p, err) + } +} diff --git a/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go new file mode 100644 index 000000000..09a1c653f --- /dev/null +++ b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go @@ -0,0 +1,98 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_zero_window_probe_usertimeout_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestZeroWindowProbeUserTimeout sanity tests user timeout when we are +// retransmitting zero window probes. +func TestZeroWindowProbeUserTimeout(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) + + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + + // Send and receive sample data to the dut. + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { + t.Fatalf("expected packet was not received: %s", err) + } + + // Test 1: Check for receive of a zero window probe, record the duration for + // probe to be sent. + // + // Advertize zero window to the dut. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + + // Expected sequence number of the zero window probe. + probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1)) + start := time.Now() + // Ask the dut to send out data. + dut.Send(t, acceptFd, sampleData, 0) + // Expect zero-window probe from the dut. + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil { + t.Fatalf("expected a packet with sequence number %d: %s", probeSeq, err) + } + // Record the duration for first probe, the dut sends the zero window probe after + // a retransmission time interval. + startProbeDuration := time.Now().Sub(start) + + // Test 2: Check if the dut times out the connection by honoring usertimeout + // when the dut is sending zero-window probes. + // + // Reduce the retransmit timeout. + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_USER_TIMEOUT, int32(startProbeDuration.Milliseconds())) + // Advertize zero window again. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + // Ask the dut to send out data that would trigger zero window probe retransmissions. + dut.Send(t, acceptFd, sampleData, 0) + + // Wait for the connection to timeout after multiple zero-window probe retransmissions. + time.Sleep(8 * startProbeDuration) + + // Expect the connection to have timed out and closed which would cause the dut + // to reply with a RST to the ACK we send. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { + t.Fatalf("expected a TCP RST") + } +} diff --git a/test/packetimpact/tests/udp_any_addr_recv_unicast_test.go b/test/packetimpact/tests/udp_any_addr_recv_unicast_test.go new file mode 100644 index 000000000..17f32ef65 --- /dev/null +++ b/test/packetimpact/tests/udp_any_addr_recv_unicast_test.go @@ -0,0 +1,51 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udp_any_addr_recv_unicast_test + +import ( + "flag" + "net" + "testing" + + "github.com/google/go-cmp/cmp" + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func TestAnyRecvUnicastUDP(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4zero) + defer dut.Close(t, boundFD) + conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + defer conn.Close(t) + + payload := testbench.GenerateRandomPayload(t, 1<<10 /* 1 KiB */) + conn.SendIP( + t, + testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(net.ParseIP(testbench.RemoteIPv4).To4()))}, + testbench.UDP{}, + &testbench.Payload{Bytes: payload}, + ) + got, want := dut.Recv(t, boundFD, int32(len(payload)+1), 0), payload + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("received payload does not match sent payload, diff (-want, +got):\n%s", diff) + } +} diff --git a/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go b/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go new file mode 100644 index 000000000..d30177e64 --- /dev/null +++ b/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go @@ -0,0 +1,94 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udp_discard_mcast_source_addr_test + +import ( + "context" + "flag" + "fmt" + "net" + "syscall" + "testing" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +var oneSecond = unix.Timeval{Sec: 1, Usec: 0} + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func TestDiscardsUDPPacketsWithMcastSourceAddressV4(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(testbench.RemoteIPv4)) + defer dut.Close(t, remoteFD) + dut.SetSockOptTimeval(t, remoteFD, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &oneSecond) + conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + defer conn.Close(t) + + for _, mcastAddr := range []net.IP{ + net.IPv4allsys, + net.IPv4allrouter, + net.IPv4(224, 0, 1, 42), + net.IPv4(232, 1, 2, 3), + } { + t.Run(fmt.Sprintf("srcaddr=%s", mcastAddr), func(t *testing.T) { + conn.SendIP( + t, + testbench.IPv4{SrcAddr: testbench.Address(tcpip.Address(mcastAddr.To4()))}, + testbench.UDP{}, + ) + + ret, payload, errno := dut.RecvWithErrno(context.Background(), t, remoteFD, 100, 0) + if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK { + t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, payload, errno) + } + }) + } +} + +func TestDiscardsUDPPacketsWithMcastSourceAddressV6(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(testbench.RemoteIPv6)) + defer dut.Close(t, remoteFD) + dut.SetSockOptTimeval(t, remoteFD, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &oneSecond) + conn := testbench.NewUDPIPv6(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + defer conn.Close(t) + + for _, mcastAddr := range []net.IP{ + net.IPv6interfacelocalallnodes, + net.IPv6linklocalallnodes, + net.IPv6linklocalallrouters, + net.ParseIP("fe01::42"), + net.ParseIP("fe02::4242"), + } { + t.Run(fmt.Sprintf("srcaddr=%s", mcastAddr), func(t *testing.T) { + conn.SendIPv6( + t, + testbench.IPv6{SrcAddr: testbench.Address(tcpip.Address(mcastAddr.To16()))}, + testbench.UDP{}, + ) + ret, payload, errno := dut.RecvWithErrno(context.Background(), t, remoteFD, 100, 0) + if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK { + t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, payload, errno) + } + }) + } +} diff --git a/test/packetimpact/tests/udp_icmp_error_propagation_test.go b/test/packetimpact/tests/udp_icmp_error_propagation_test.go new file mode 100644 index 000000000..df35d16c8 --- /dev/null +++ b/test/packetimpact/tests/udp_icmp_error_propagation_test.go @@ -0,0 +1,363 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udp_icmp_error_propagation_test + +import ( + "context" + "flag" + "fmt" + "net" + "sync" + "syscall" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +type connectionMode bool + +func (c connectionMode) String() string { + if c { + return "Connected" + } + return "Connectionless" +} + +type icmpError int + +const ( + portUnreachable icmpError = iota + timeToLiveExceeded +) + +func (e icmpError) String() string { + switch e { + case portUnreachable: + return "PortUnreachable" + case timeToLiveExceeded: + return "TimeToLiveExpired" + } + return "Unknown ICMP error" +} + +func (e icmpError) ToICMPv4() *testbench.ICMPv4 { + switch e { + case portUnreachable: + return &testbench.ICMPv4{ + Type: testbench.ICMPv4Type(header.ICMPv4DstUnreachable), + Code: testbench.ICMPv4Code(header.ICMPv4PortUnreachable)} + case timeToLiveExceeded: + return &testbench.ICMPv4{ + Type: testbench.ICMPv4Type(header.ICMPv4TimeExceeded), + Code: testbench.ICMPv4Code(header.ICMPv4TTLExceeded)} + } + return nil +} + +type errorDetection struct { + name string + useValidConn bool + f func(context.Context, *testing.T, testData) +} + +type testData struct { + dut *testbench.DUT + conn *testbench.UDPIPv4 + remoteFD int32 + remotePort uint16 + cleanFD int32 + cleanPort uint16 + wantErrno syscall.Errno +} + +// wantErrno computes the errno to expect given the connection mode of a UDP +// socket and the ICMP error it will receive. +func wantErrno(c connectionMode, icmpErr icmpError) syscall.Errno { + if c && icmpErr == portUnreachable { + return syscall.Errno(unix.ECONNREFUSED) + } + return syscall.Errno(0) +} + +// sendICMPError sends an ICMP error message in response to a UDP datagram. +func sendICMPError(t *testing.T, conn *testbench.UDPIPv4, icmpErr icmpError, udp *testbench.UDP) { + t.Helper() + + layers := (*testbench.Connection)(conn).CreateFrame(t, nil) + layers = layers[:len(layers)-1] + ip, ok := udp.Prev().(*testbench.IPv4) + if !ok { + t.Fatalf("expected %s to be IPv4", udp.Prev()) + } + if icmpErr == timeToLiveExceeded { + *ip.TTL = 1 + // Let serialization recalculate the checksum since we set the TTL + // to 1. + ip.Checksum = nil + } + // Note that the ICMP payload is valid in this case because the UDP + // payload is empty. If the UDP payload were not empty, the packet + // length during serialization may not be calculated correctly, + // resulting in a mal-formed packet. + layers = append(layers, icmpErr.ToICMPv4(), ip, udp) + + (*testbench.Connection)(conn).SendFrameStateless(t, layers) +} + +// testRecv tests observing the ICMP error through the recv syscall. A packet +// is sent to the DUT, and if wantErrno is non-zero, then the first recv should +// fail and the second should succeed. Otherwise if wantErrno is zero then the +// first recv should succeed immediately. +func testRecv(ctx context.Context, t *testing.T, d testData) { + t.Helper() + + // Check that receiving on the clean socket works. + d.conn.Send(t, testbench.UDP{DstPort: &d.cleanPort}) + d.dut.Recv(t, d.cleanFD, 100, 0) + + d.conn.Send(t, testbench.UDP{}) + + if d.wantErrno != syscall.Errno(0) { + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + ret, _, err := d.dut.RecvWithErrno(ctx, t, d.remoteFD, 100, 0) + if ret != -1 { + t.Fatalf("recv after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", d.wantErrno) + } + if err != d.wantErrno { + t.Fatalf("recv after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, d.wantErrno) + } + } + + d.dut.Recv(t, d.remoteFD, 100, 0) +} + +// testSendTo tests observing the ICMP error through the send syscall. If +// wantErrno is non-zero, the first send should fail and a subsequent send +// should suceed; while if wantErrno is zero then the first send should just +// succeed. +func testSendTo(ctx context.Context, t *testing.T, d testData) { + // Check that sending on the clean socket works. + d.dut.SendTo(t, d.cleanFD, nil, 0, d.conn.LocalAddr(t)) + if _, err := d.conn.Expect(t, testbench.UDP{SrcPort: &d.cleanPort}, time.Second); err != nil { + t.Fatalf("did not receive UDP packet from clean socket on DUT: %s", err) + } + + if d.wantErrno != syscall.Errno(0) { + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + ret, err := d.dut.SendToWithErrno(ctx, t, d.remoteFD, nil, 0, d.conn.LocalAddr(t)) + + if ret != -1 { + t.Fatalf("sendto after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", d.wantErrno) + } + if err != d.wantErrno { + t.Fatalf("sendto after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, d.wantErrno) + } + } + + d.dut.SendTo(t, d.remoteFD, nil, 0, d.conn.LocalAddr(t)) + if _, err := d.conn.Expect(t, testbench.UDP{}, time.Second); err != nil { + t.Fatalf("did not receive UDP packet as expected: %s", err) + } +} + +func testSockOpt(_ context.Context, t *testing.T, d testData) { + // Check that there's no pending error on the clean socket. + if errno := syscall.Errno(d.dut.GetSockOptInt(t, d.cleanFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != syscall.Errno(0) { + t.Fatalf("unexpected error (%[1]d) %[1]v on clean socket", errno) + } + + if errno := syscall.Errno(d.dut.GetSockOptInt(t, d.remoteFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != d.wantErrno { + t.Fatalf("SO_ERROR sockopt after ICMP error is (%[1]d) %[1]v, expected (%[2]d) %[2]v", errno, d.wantErrno) + } + + // Check that after clearing socket error, sending doesn't fail. + d.dut.SendTo(t, d.remoteFD, nil, 0, d.conn.LocalAddr(t)) + if _, err := d.conn.Expect(t, testbench.UDP{}, time.Second); err != nil { + t.Fatalf("did not receive UDP packet as expected: %s", err) + } +} + +// TestUDPICMPErrorPropagation tests that ICMP error messages in response to +// UDP datagrams are processed correctly. RFC 1122 section 4.1.3.3 states that: +// "UDP MUST pass to the application layer all ICMP error messages that it +// receives from the IP layer." +// +// The test cases are parametrized in 3 dimensions: 1. the UDP socket is either +// put into connection mode or left connectionless, 2. the ICMP message type +// and code, and 3. the method by which the ICMP error is observed on the +// socket: sendto, recv, or getsockopt(SO_ERROR). +// +// Linux's udp(7) man page states: "All fatal errors will be passed to the user +// as an error return even when the socket is not connected. This includes +// asynchronous errors received from the network." In practice, the only +// combination of parameters to the test that causes an error to be observable +// on the UDP socket is receiving a port unreachable message on a connected +// socket. +func TestUDPICMPErrorPropagation(t *testing.T) { + for _, connect := range []connectionMode{true, false} { + for _, icmpErr := range []icmpError{portUnreachable, timeToLiveExceeded} { + wantErrno := wantErrno(connect, icmpErr) + + for _, errDetect := range []errorDetection{ + errorDetection{"SendTo", false, testSendTo}, + // Send to an address that's different from the one that caused an ICMP + // error to be returned. + errorDetection{"SendToValid", true, testSendTo}, + errorDetection{"Recv", false, testRecv}, + errorDetection{"SockOpt", false, testSockOpt}, + } { + t.Run(fmt.Sprintf("%s/%s/%s", connect, icmpErr, errDetect.name), func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + + remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4zero) + defer dut.Close(t, remoteFD) + + // Create a second, clean socket on the DUT to ensure that the ICMP + // error messages only affect the sockets they are intended for. + cleanFD, cleanPort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4zero) + defer dut.Close(t, cleanFD) + + conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + defer conn.Close(t) + + if connect { + dut.Connect(t, remoteFD, conn.LocalAddr(t)) + dut.Connect(t, cleanFD, conn.LocalAddr(t)) + } + + dut.SendTo(t, remoteFD, nil, 0, conn.LocalAddr(t)) + udp, err := conn.Expect(t, testbench.UDP{}, time.Second) + if err != nil { + t.Fatalf("did not receive message from DUT: %s", err) + } + + sendICMPError(t, &conn, icmpErr, udp) + + errDetectConn := &conn + if errDetect.useValidConn { + // connClean is a UDP socket on the test runner that was not + // involved in the generation of the ICMP error. As such, + // interactions between it and the the DUT should be independent of + // the ICMP error at least at the port level. + connClean := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + defer connClean.Close(t) + + errDetectConn = &connClean + } + + errDetect.f(context.Background(), t, testData{&dut, errDetectConn, remoteFD, remotePort, cleanFD, cleanPort, wantErrno}) + }) + } + } + } +} + +// TestICMPErrorDuringUDPRecv tests behavior when a UDP socket is in the middle +// of a blocking recv and receives an ICMP error. +func TestICMPErrorDuringUDPRecv(t *testing.T) { + for _, connect := range []connectionMode{true, false} { + for _, icmpErr := range []icmpError{portUnreachable, timeToLiveExceeded} { + wantErrno := wantErrno(connect, icmpErr) + + t.Run(fmt.Sprintf("%s/%s", connect, icmpErr), func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + + remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4zero) + defer dut.Close(t, remoteFD) + + // Create a second, clean socket on the DUT to ensure that the ICMP + // error messages only affect the sockets they are intended for. + cleanFD, cleanPort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4zero) + defer dut.Close(t, cleanFD) + + conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + defer conn.Close(t) + + if connect { + dut.Connect(t, remoteFD, conn.LocalAddr(t)) + dut.Connect(t, cleanFD, conn.LocalAddr(t)) + } + + dut.SendTo(t, remoteFD, nil, 0, conn.LocalAddr(t)) + udp, err := conn.Expect(t, testbench.UDP{}, time.Second) + if err != nil { + t.Fatalf("did not receive message from DUT: %s", err) + } + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + + if wantErrno != syscall.Errno(0) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + ret, _, err := dut.RecvWithErrno(ctx, t, remoteFD, 100, 0) + if ret != -1 { + t.Errorf("recv during ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", wantErrno) + return + } + if err != wantErrno { + t.Errorf("recv during ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, wantErrno) + return + } + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if ret, _, err := dut.RecvWithErrno(ctx, t, remoteFD, 100, 0); ret == -1 { + t.Errorf("recv after ICMP error failed with (%[1]d) %[1]", err) + } + }() + + go func() { + defer wg.Done() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if ret, _, err := dut.RecvWithErrno(ctx, t, cleanFD, 100, 0); ret == -1 { + t.Errorf("recv on clean socket failed with (%[1]d) %[1]", err) + } + }() + + // TODO(b/155684889) This sleep is to allow time for the DUT to + // actually call recv since we want the ICMP error to arrive during the + // blocking recv, and should be replaced when a better synchronization + // alternative is available. + time.Sleep(2 * time.Second) + + sendICMPError(t, &conn, icmpErr, udp) + + conn.Send(t, testbench.UDP{DstPort: &cleanPort}) + conn.Send(t, testbench.UDP{}) + wg.Wait() + }) + } + } +} diff --git a/test/packetimpact/tests/udp_recv_mcast_bcast_test.go b/test/packetimpact/tests/udp_recv_mcast_bcast_test.go new file mode 100644 index 000000000..526173969 --- /dev/null +++ b/test/packetimpact/tests/udp_recv_mcast_bcast_test.go @@ -0,0 +1,110 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udp_recv_mcast_bcast_test + +import ( + "context" + "flag" + "fmt" + "net" + "syscall" + "testing" + + "github.com/google/go-cmp/cmp" + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func TestUDPRecvMcastBcast(t *testing.T) { + subnetBcastAddr := broadcastAddr(net.ParseIP(testbench.RemoteIPv4), net.CIDRMask(testbench.IPv4PrefixLength, 32)) + + for _, v := range []struct { + bound, to net.IP + }{ + {bound: net.IPv4zero, to: subnetBcastAddr}, + {bound: net.IPv4zero, to: net.IPv4bcast}, + {bound: net.IPv4zero, to: net.IPv4allsys}, + + {bound: subnetBcastAddr, to: subnetBcastAddr}, + {bound: subnetBcastAddr, to: net.IPv4bcast}, + + {bound: net.IPv4bcast, to: net.IPv4bcast}, + {bound: net.IPv4allsys, to: net.IPv4allsys}, + } { + t.Run(fmt.Sprintf("bound=%s,to=%s", v.bound, v.to), func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, v.bound) + defer dut.Close(t, boundFD) + conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + defer conn.Close(t) + + payload := testbench.GenerateRandomPayload(t, 1<<10 /* 1 KiB */) + conn.SendIP( + t, + testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(v.to.To4()))}, + testbench.UDP{}, + &testbench.Payload{Bytes: payload}, + ) + got, want := dut.Recv(t, boundFD, int32(len(payload)+1), 0), payload + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("received payload does not match sent payload, diff (-want, +got):\n%s", diff) + } + }) + } +} + +func TestUDPDoesntRecvMcastBcastOnUnicastAddr(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(testbench.RemoteIPv4)) + dut.SetSockOptTimeval(t, boundFD, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &unix.Timeval{Sec: 1, Usec: 0}) + defer dut.Close(t, boundFD) + conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + defer conn.Close(t) + + for _, to := range []net.IP{ + broadcastAddr(net.ParseIP(testbench.RemoteIPv4), net.CIDRMask(testbench.IPv4PrefixLength, 32)), + net.IPv4(255, 255, 255, 255), + net.IPv4(224, 0, 0, 1), + } { + t.Run(fmt.Sprint("to=%s", to), func(t *testing.T) { + payload := testbench.GenerateRandomPayload(t, 1<<10 /* 1 KiB */) + conn.SendIP( + t, + testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(to.To4()))}, + testbench.UDP{}, + &testbench.Payload{Bytes: payload}, + ) + ret, payload, errno := dut.RecvWithErrno(context.Background(), t, boundFD, 100, 0) + if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK { + t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, payload, errno) + } + }) + } +} + +func broadcastAddr(ip net.IP, mask net.IPMask) net.IP { + ip4 := ip.To4() + for i := range ip4 { + ip4[i] |= ^mask[i] + } + return ip4 +} diff --git a/test/packetimpact/tests/udp_send_recv_dgram_test.go b/test/packetimpact/tests/udp_send_recv_dgram_test.go new file mode 100644 index 000000000..91b967400 --- /dev/null +++ b/test/packetimpact/tests/udp_send_recv_dgram_test.go @@ -0,0 +1,104 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udp_send_recv_dgram_test + +import ( + "flag" + "net" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +type udpConn interface { + Send(*testing.T, testbench.UDP, ...testbench.Layer) + ExpectData(*testing.T, testbench.UDP, testbench.Payload, time.Duration) (testbench.Layers, error) + Drain(*testing.T) + Close(*testing.T) +} + +func TestUDP(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + + for _, isIPv4 := range []bool{true, false} { + ipVersionName := "IPv6" + if isIPv4 { + ipVersionName = "IPv4" + } + t.Run(ipVersionName, func(t *testing.T) { + var addr string + if isIPv4 { + addr = testbench.RemoteIPv4 + } else { + addr = testbench.RemoteIPv6 + } + boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(addr)) + defer dut.Close(t, boundFD) + + var conn udpConn + var localAddr unix.Sockaddr + if isIPv4 { + v4Conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + localAddr = v4Conn.LocalAddr(t) + conn = &v4Conn + } else { + v6Conn := testbench.NewUDPIPv6(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + localAddr = v6Conn.LocalAddr(t) + conn = &v6Conn + } + defer conn.Close(t) + + testCases := []struct { + name string + payload []byte + }{ + {"emptypayload", nil}, + {"small payload", []byte("hello world")}, + {"1kPayload", testbench.GenerateRandomPayload(t, 1<<10)}, + // Even though UDP allows larger dgrams we don't test it here as + // they need to be fragmented and written out as individual + // frames. + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Run("Send", func(t *testing.T) { + conn.Send(t, testbench.UDP{}, &testbench.Payload{Bytes: tc.payload}) + got, want := dut.Recv(t, boundFD, int32(len(tc.payload)+1), 0), tc.payload + if diff := cmp.Diff(want, got); diff != "" { + t.Fatalf("received payload does not match sent payload, diff (-want, +got):\n%s", diff) + } + }) + t.Run("Recv", func(t *testing.T) { + conn.Drain(t) + if got, want := int(dut.SendTo(t, boundFD, tc.payload, 0, localAddr)), len(tc.payload); got != want { + t.Fatalf("short write got: %d, want: %d", got, want) + } + if _, err := conn.ExpectData(t, testbench.UDP{SrcPort: &remotePort}, testbench.Payload{Bytes: tc.payload}, time.Second); err != nil { + t.Fatal(err) + } + }) + }) + } + }) + } +} diff --git a/test/perf/BUILD b/test/perf/BUILD new file mode 100644 index 000000000..471d8c2ab --- /dev/null +++ b/test/perf/BUILD @@ -0,0 +1,117 @@ +load("//test/runner:defs.bzl", "syscall_test") + +package(licenses = ["notice"]) + +syscall_test( + test = "//test/perf/linux:clock_getres_benchmark", +) + +syscall_test( + test = "//test/perf/linux:clock_gettime_benchmark", +) + +syscall_test( + test = "//test/perf/linux:death_benchmark", +) + +syscall_test( + test = "//test/perf/linux:epoll_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:fork_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:futex_benchmark", +) + +syscall_test( + size = "enormous", + shard_count = 10, + tags = ["nogotsan"], + test = "//test/perf/linux:getdents_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:getpid_benchmark", +) + +syscall_test( + size = "enormous", + tags = ["nogotsan"], + test = "//test/perf/linux:gettid_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:mapping_benchmark", +) + +syscall_test( + size = "large", + add_overlay = True, + test = "//test/perf/linux:open_benchmark", +) + +syscall_test( + test = "//test/perf/linux:pipe_benchmark", +) + +syscall_test( + size = "large", + add_overlay = True, + test = "//test/perf/linux:randread_benchmark", +) + +syscall_test( + size = "large", + add_overlay = True, + test = "//test/perf/linux:read_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:sched_yield_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:send_recv_benchmark", +) + +syscall_test( + size = "large", + add_overlay = True, + test = "//test/perf/linux:seqwrite_benchmark", +) + +syscall_test( + size = "enormous", + test = "//test/perf/linux:signal_benchmark", +) + +syscall_test( + test = "//test/perf/linux:sleep_benchmark", +) + +syscall_test( + size = "large", + add_overlay = True, + test = "//test/perf/linux:stat_benchmark", +) + +syscall_test( + size = "enormous", + add_overlay = True, + test = "//test/perf/linux:unlink_benchmark", +) + +syscall_test( + size = "large", + add_overlay = True, + test = "//test/perf/linux:write_benchmark", +) diff --git a/test/perf/linux/BUILD b/test/perf/linux/BUILD new file mode 100644 index 000000000..b4e907826 --- /dev/null +++ b/test/perf/linux/BUILD @@ -0,0 +1,356 @@ +load("//tools:defs.bzl", "cc_binary", "gbenchmark", "gtest") + +package( + default_visibility = ["//:sandbox"], + licenses = ["notice"], +) + +cc_binary( + name = "getpid_benchmark", + testonly = 1, + srcs = [ + "getpid_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:test_main", + ], +) + +cc_binary( + name = "send_recv_benchmark", + testonly = 1, + srcs = [ + "send_recv_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/syscalls/linux:socket_test_util", + "//test/util:file_descriptor", + "//test/util:logging", + "//test/util:posix_error", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_absl//absl/synchronization", + ], +) + +cc_binary( + name = "gettid_benchmark", + testonly = 1, + srcs = [ + "gettid_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:test_main", + ], +) + +cc_binary( + name = "sched_yield_benchmark", + testonly = 1, + srcs = [ + "sched_yield_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "clock_getres_benchmark", + testonly = 1, + srcs = [ + "clock_getres_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:test_main", + ], +) + +cc_binary( + name = "clock_gettime_benchmark", + testonly = 1, + srcs = [ + "clock_gettime_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:test_main", + "@com_google_absl//absl/time", + ], +) + +cc_binary( + name = "open_benchmark", + testonly = 1, + srcs = [ + "open_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:fs_util", + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + ], +) + +cc_binary( + name = "read_benchmark", + testonly = 1, + srcs = [ + "read_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:fs_util", + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "randread_benchmark", + testonly = 1, + srcs = [ + "randread_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:file_descriptor", + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/random", + ], +) + +cc_binary( + name = "write_benchmark", + testonly = 1, + srcs = [ + "write_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "seqwrite_benchmark", + testonly = 1, + srcs = [ + "seqwrite_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/random", + ], +) + +cc_binary( + name = "pipe_benchmark", + testonly = 1, + srcs = [ + "pipe_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + ], +) + +cc_binary( + name = "fork_benchmark", + testonly = 1, + srcs = [ + "fork_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:cleanup", + "//test/util:file_descriptor", + "//test/util:logging", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_absl//absl/synchronization", + ], +) + +cc_binary( + name = "futex_benchmark", + testonly = 1, + srcs = [ + "futex_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:test_main", + "//test/util:thread_util", + "@com_google_absl//absl/time", + ], +) + +cc_binary( + name = "epoll_benchmark", + testonly = 1, + srcs = [ + "epoll_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:epoll_util", + "//test/util:file_descriptor", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_absl//absl/time", + ], +) + +cc_binary( + name = "death_benchmark", + testonly = 1, + srcs = [ + "death_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:test_main", + ], +) + +cc_binary( + name = "mapping_benchmark", + testonly = 1, + srcs = [ + "mapping_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:memory_util", + "//test/util:posix_error", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "signal_benchmark", + testonly = 1, + srcs = [ + "signal_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "getdents_benchmark", + testonly = 1, + srcs = [ + "getdents_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:file_descriptor", + "//test/util:fs_util", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "sleep_benchmark", + testonly = 1, + srcs = [ + "sleep_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:test_main", + ], +) + +cc_binary( + name = "stat_benchmark", + testonly = 1, + srcs = [ + "stat_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:fs_util", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/strings", + ], +) + +cc_binary( + name = "unlink_benchmark", + testonly = 1, + srcs = [ + "unlink_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:fs_util", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + ], +) diff --git a/test/perf/linux/clock_getres_benchmark.cc b/test/perf/linux/clock_getres_benchmark.cc new file mode 100644 index 000000000..b051293ad --- /dev/null +++ b/test/perf/linux/clock_getres_benchmark.cc @@ -0,0 +1,39 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <time.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" + +namespace gvisor { +namespace testing { + +namespace { + +// clock_getres(1) is very nearly a no-op syscall, but it does require copying +// out to a userspace struct. It thus provides a nice small copy-out benchmark. +void BM_ClockGetRes(benchmark::State& state) { + struct timespec ts; + for (auto _ : state) { + clock_getres(CLOCK_MONOTONIC, &ts); + } +} + +BENCHMARK(BM_ClockGetRes); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/clock_gettime_benchmark.cc b/test/perf/linux/clock_gettime_benchmark.cc new file mode 100644 index 000000000..6691bebd9 --- /dev/null +++ b/test/perf/linux/clock_gettime_benchmark.cc @@ -0,0 +1,60 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <pthread.h> +#include <time.h> + +#include "gtest/gtest.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "benchmark/benchmark.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_ClockGettimeThreadCPUTime(benchmark::State& state) { + clockid_t clockid; + ASSERT_EQ(0, pthread_getcpuclockid(pthread_self(), &clockid)); + struct timespec tp; + + for (auto _ : state) { + clock_gettime(clockid, &tp); + } +} + +BENCHMARK(BM_ClockGettimeThreadCPUTime); + +void BM_VDSOClockGettime(benchmark::State& state) { + const clockid_t clock = state.range(0); + struct timespec tp; + absl::Time start = absl::Now(); + + // Don't benchmark the calibration phase. + while (absl::Now() < start + absl::Milliseconds(2100)) { + clock_gettime(clock, &tp); + } + + for (auto _ : state) { + clock_gettime(clock, &tp); + } +} + +BENCHMARK(BM_VDSOClockGettime)->Arg(CLOCK_MONOTONIC)->Arg(CLOCK_REALTIME); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/death_benchmark.cc b/test/perf/linux/death_benchmark.cc new file mode 100644 index 000000000..cb2b6fd07 --- /dev/null +++ b/test/perf/linux/death_benchmark.cc @@ -0,0 +1,36 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <signal.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" + +namespace gvisor { +namespace testing { + +namespace { + +// DeathTest is not so much a microbenchmark as a macrobenchmark. It is testing +// the ability of gVisor (on whatever platform) to execute all the related +// stack-dumping routines associated with EXPECT_EXIT / EXPECT_DEATH. +TEST(DeathTest, ZeroEqualsOne) { + EXPECT_EXIT({ TEST_CHECK(0 == 1); }, ::testing::KilledBySignal(SIGABRT), ""); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/epoll_benchmark.cc b/test/perf/linux/epoll_benchmark.cc new file mode 100644 index 000000000..0b121338a --- /dev/null +++ b/test/perf/linux/epoll_benchmark.cc @@ -0,0 +1,99 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/epoll.h> +#include <sys/eventfd.h> + +#include <atomic> +#include <cerrno> +#include <cstdint> +#include <cstdlib> +#include <ctime> +#include <memory> + +#include "gtest/gtest.h" +#include "absl/time/time.h" +#include "benchmark/benchmark.h" +#include "test/util/epoll_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Returns a new eventfd. +PosixErrorOr<FileDescriptor> NewEventFD() { + int fd = eventfd(0, /* flags = */ 0); + MaybeSave(); + if (fd < 0) { + return PosixError(errno, "eventfd"); + } + return FileDescriptor(fd); +} + +// Also stolen from epoll.cc unit tests. +void BM_EpollTimeout(benchmark::State& state) { + constexpr int kFDsPerEpoll = 3; + auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); + + std::vector<FileDescriptor> eventfds; + for (int i = 0; i < kFDsPerEpoll; i++) { + eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD())); + ASSERT_NO_ERRNO( + RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN, 0)); + } + + struct epoll_event result[kFDsPerEpoll]; + int timeout_ms = state.range(0); + + for (auto _ : state) { + EXPECT_EQ(0, epoll_wait(epollfd.get(), result, kFDsPerEpoll, timeout_ms)); + } +} + +BENCHMARK(BM_EpollTimeout)->Range(0, 8); + +// Also stolen from epoll.cc unit tests. +void BM_EpollAllEvents(benchmark::State& state) { + auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); + const int fds_per_epoll = state.range(0); + constexpr uint64_t kEventVal = 5; + + std::vector<FileDescriptor> eventfds; + for (int i = 0; i < fds_per_epoll; i++) { + eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD())); + ASSERT_NO_ERRNO( + RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN, 0)); + + ASSERT_THAT(WriteFd(eventfds[i].get(), &kEventVal, sizeof(kEventVal)), + SyscallSucceedsWithValue(sizeof(kEventVal))); + } + + std::vector<struct epoll_event> result(fds_per_epoll); + + for (auto _ : state) { + EXPECT_EQ(fds_per_epoll, + epoll_wait(epollfd.get(), result.data(), fds_per_epoll, 0)); + } +} + +BENCHMARK(BM_EpollAllEvents)->Range(2, 1024); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/fork_benchmark.cc b/test/perf/linux/fork_benchmark.cc new file mode 100644 index 000000000..84fdbc8a0 --- /dev/null +++ b/test/perf/linux/fork_benchmark.cc @@ -0,0 +1,350 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <unistd.h> + +#include "gtest/gtest.h" +#include "absl/synchronization/barrier.h" +#include "benchmark/benchmark.h" +#include "test/util/cleanup.h" +#include "test/util/file_descriptor.h" +#include "test/util/logging.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +constexpr int kBusyMax = 250; + +// Do some CPU-bound busy-work. +int busy(int max) { + // Prevent the compiler from optimizing this work away, + volatile int count = 0; + + for (int i = 1; i < max; i++) { + for (int j = 2; j < i / 2; j++) { + if (i % j == 0) { + count++; + } + } + } + + return count; +} + +void BM_CPUBoundUniprocess(benchmark::State& state) { + for (auto _ : state) { + busy(kBusyMax); + } +} + +BENCHMARK(BM_CPUBoundUniprocess); + +void BM_CPUBoundAsymmetric(benchmark::State& state) { + const size_t max = state.max_iterations; + pid_t child = fork(); + if (child == 0) { + for (int i = 0; i < max; i++) { + busy(kBusyMax); + } + _exit(0); + } + ASSERT_THAT(child, SyscallSucceeds()); + ASSERT_TRUE(state.KeepRunningBatch(max)); + + int status; + EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds()); + EXPECT_TRUE(WIFEXITED(status)); + EXPECT_EQ(0, WEXITSTATUS(status)); + ASSERT_FALSE(state.KeepRunning()); +} + +BENCHMARK(BM_CPUBoundAsymmetric)->UseRealTime(); + +void BM_CPUBoundSymmetric(benchmark::State& state) { + std::vector<pid_t> children; + auto child_cleanup = Cleanup([&] { + for (const pid_t child : children) { + int status; + EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds()); + EXPECT_TRUE(WIFEXITED(status)); + EXPECT_EQ(0, WEXITSTATUS(status)); + } + ASSERT_FALSE(state.KeepRunning()); + }); + + const int processes = state.range(0); + for (int i = 0; i < processes; i++) { + size_t cur = (state.max_iterations + (processes - 1)) / processes; + if ((state.iterations() + cur) >= state.max_iterations) { + cur = state.max_iterations - state.iterations(); + } + pid_t child = fork(); + if (child == 0) { + for (int i = 0; i < cur; i++) { + busy(kBusyMax); + } + _exit(0); + } + ASSERT_THAT(child, SyscallSucceeds()); + if (cur > 0) { + // We can have a zero cur here, depending. + ASSERT_TRUE(state.KeepRunningBatch(cur)); + } + children.push_back(child); + } +} + +BENCHMARK(BM_CPUBoundSymmetric)->Range(2, 16)->UseRealTime(); + +// Child routine for ProcessSwitch/ThreadSwitch. +// Reads from readfd and writes the result to writefd. +void SwitchChild(int readfd, int writefd) { + while (1) { + char buf; + int ret = ReadFd(readfd, &buf, 1); + if (ret == 0) { + break; + } + TEST_CHECK_MSG(ret == 1, "read failed"); + + ret = WriteFd(writefd, &buf, 1); + if (ret == -1) { + TEST_CHECK_MSG(errno == EPIPE, "unexpected write failure"); + break; + } + TEST_CHECK_MSG(ret == 1, "write failed"); + } +} + +// Send bytes in a loop through a series of pipes, each passing through a +// different process. +// +// Proc 0 Proc 1 +// * ----------> * +// ^ Pipe 1 | +// | | +// | Pipe 0 | Pipe 2 +// | | +// | | +// | Pipe 3 v +// * <---------- * +// Proc 3 Proc 2 +// +// This exercises context switching through multiple processes. +void BM_ProcessSwitch(benchmark::State& state) { + // Code below assumes there are at least two processes. + const int num_processes = state.range(0); + ASSERT_GE(num_processes, 2); + + std::vector<pid_t> children; + auto child_cleanup = Cleanup([&] { + for (const pid_t child : children) { + int status; + EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds()); + EXPECT_TRUE(WIFEXITED(status)); + EXPECT_EQ(0, WEXITSTATUS(status)); + } + }); + + // Must come after children, as the FDs must be closed before the children + // will exit. + std::vector<FileDescriptor> read_fds; + std::vector<FileDescriptor> write_fds; + + for (int i = 0; i < num_processes; i++) { + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + read_fds.emplace_back(fds[0]); + write_fds.emplace_back(fds[1]); + } + + // This process is one of the processes in the loop. It will be considered + // index 0. + for (int i = 1; i < num_processes; i++) { + // Read from current pipe index, write to next. + const int read_index = i; + const int read_fd = read_fds[read_index].get(); + + const int write_index = (i + 1) % num_processes; + const int write_fd = write_fds[write_index].get(); + + // std::vector isn't safe to use from the fork child. + FileDescriptor* read_array = read_fds.data(); + FileDescriptor* write_array = write_fds.data(); + + pid_t child = fork(); + if (!child) { + // Close all other FDs. + for (int j = 0; j < num_processes; j++) { + if (j != read_index) { + read_array[j].reset(); + } + if (j != write_index) { + write_array[j].reset(); + } + } + + SwitchChild(read_fd, write_fd); + _exit(0); + } + ASSERT_THAT(child, SyscallSucceeds()); + children.push_back(child); + } + + // Read from current pipe index (0), write to next (1). + const int read_index = 0; + const int read_fd = read_fds[read_index].get(); + + const int write_index = 1; + const int write_fd = write_fds[write_index].get(); + + // Kick start the loop. + char buf = 'a'; + ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1)); + + for (auto _ : state) { + ASSERT_THAT(ReadFd(read_fd, &buf, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1)); + } +} + +BENCHMARK(BM_ProcessSwitch)->Range(2, 16)->UseRealTime(); + +// Equivalent to BM_ThreadSwitch using threads instead of processes. +void BM_ThreadSwitch(benchmark::State& state) { + // Code below assumes there are at least two threads. + const int num_threads = state.range(0); + ASSERT_GE(num_threads, 2); + + // Must come after threads, as the FDs must be closed before the children + // will exit. + std::vector<std::unique_ptr<ScopedThread>> threads; + std::vector<FileDescriptor> read_fds; + std::vector<FileDescriptor> write_fds; + + for (int i = 0; i < num_threads; i++) { + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + read_fds.emplace_back(fds[0]); + write_fds.emplace_back(fds[1]); + } + + // This thread is one of the threads in the loop. It will be considered + // index 0. + for (int i = 1; i < num_threads; i++) { + // Read from current pipe index, write to next. + // + // Transfer ownership of the FDs to the thread. + const int read_index = i; + const int read_fd = read_fds[read_index].release(); + + const int write_index = (i + 1) % num_threads; + const int write_fd = write_fds[write_index].release(); + + threads.emplace_back(std::make_unique<ScopedThread>([read_fd, write_fd] { + FileDescriptor read(read_fd); + FileDescriptor write(write_fd); + SwitchChild(read.get(), write.get()); + })); + } + + // Read from current pipe index (0), write to next (1). + const int read_index = 0; + const int read_fd = read_fds[read_index].get(); + + const int write_index = 1; + const int write_fd = write_fds[write_index].get(); + + // Kick start the loop. + char buf = 'a'; + ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1)); + + for (auto _ : state) { + ASSERT_THAT(ReadFd(read_fd, &buf, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1)); + } + + // The two FDs still owned by this thread are closed, causing the next thread + // to exit its loop and close its FDs, and so on until all threads exit. +} + +BENCHMARK(BM_ThreadSwitch)->Range(2, 16)->UseRealTime(); + +void BM_ThreadStart(benchmark::State& state) { + const int num_threads = state.range(0); + + for (auto _ : state) { + state.PauseTiming(); + + auto barrier = new absl::Barrier(num_threads + 1); + std::vector<std::unique_ptr<ScopedThread>> threads; + + state.ResumeTiming(); + + for (size_t i = 0; i < num_threads; ++i) { + threads.emplace_back(std::make_unique<ScopedThread>([barrier] { + if (barrier->Block()) { + delete barrier; + } + })); + } + + if (barrier->Block()) { + delete barrier; + } + + state.PauseTiming(); + + for (const auto& thread : threads) { + thread->Join(); + } + + state.ResumeTiming(); + } +} + +BENCHMARK(BM_ThreadStart)->Range(1, 2048)->UseRealTime(); + +// Benchmark the complete fork + exit + wait. +void BM_ProcessLifecycle(benchmark::State& state) { + const int num_procs = state.range(0); + + std::vector<pid_t> pids(num_procs); + for (auto _ : state) { + for (size_t i = 0; i < num_procs; ++i) { + int pid = fork(); + if (pid == 0) { + _exit(0); + } + ASSERT_THAT(pid, SyscallSucceeds()); + pids[i] = pid; + } + + for (const int pid : pids) { + ASSERT_THAT(RetryEINTR(waitpid)(pid, nullptr, 0), + SyscallSucceedsWithValue(pid)); + } + } +} + +BENCHMARK(BM_ProcessLifecycle)->Range(1, 512)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/futex_benchmark.cc b/test/perf/linux/futex_benchmark.cc new file mode 100644 index 000000000..e686041c9 --- /dev/null +++ b/test/perf/linux/futex_benchmark.cc @@ -0,0 +1,198 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <linux/futex.h> + +#include <atomic> +#include <cerrno> +#include <cstdint> +#include <cstdlib> +#include <ctime> + +#include "gtest/gtest.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +inline int FutexWait(std::atomic<int32_t>* v, int32_t val) { + return syscall(SYS_futex, v, FUTEX_WAIT_PRIVATE, val, nullptr); +} + +inline int FutexWaitMonotonicTimeout(std::atomic<int32_t>* v, int32_t val, + const struct timespec* timeout) { + return syscall(SYS_futex, v, FUTEX_WAIT_PRIVATE, val, timeout); +} + +inline int FutexWaitMonotonicDeadline(std::atomic<int32_t>* v, int32_t val, + const struct timespec* deadline) { + return syscall(SYS_futex, v, FUTEX_WAIT_BITSET_PRIVATE, val, deadline, + nullptr, FUTEX_BITSET_MATCH_ANY); +} + +inline int FutexWaitRealtimeDeadline(std::atomic<int32_t>* v, int32_t val, + const struct timespec* deadline) { + return syscall(SYS_futex, v, FUTEX_WAIT_BITSET_PRIVATE | FUTEX_CLOCK_REALTIME, + val, deadline, nullptr, FUTEX_BITSET_MATCH_ANY); +} + +inline int FutexWake(std::atomic<int32_t>* v, int32_t count) { + return syscall(SYS_futex, v, FUTEX_WAKE_PRIVATE, count); +} + +// This just uses FUTEX_WAKE on an address with nothing waiting, very simple. +void BM_FutexWakeNop(benchmark::State& state) { + std::atomic<int32_t> v(0); + + for (auto _ : state) { + TEST_PCHECK(FutexWake(&v, 1) == 0); + } +} + +BENCHMARK(BM_FutexWakeNop)->MinTime(5); + +// This just uses FUTEX_WAIT on an address whose value has changed, i.e., the +// syscall won't wait. +void BM_FutexWaitNop(benchmark::State& state) { + std::atomic<int32_t> v(0); + + for (auto _ : state) { + TEST_PCHECK(FutexWait(&v, 1) == -1 && errno == EAGAIN); + } +} + +BENCHMARK(BM_FutexWaitNop)->MinTime(5); + +// This uses FUTEX_WAIT with a timeout on an address whose value never +// changes, such that it always times out. Timeout overhead can be estimated by +// timer overruns for short timeouts. +void BM_FutexWaitMonotonicTimeout(benchmark::State& state) { + const absl::Duration timeout = absl::Nanoseconds(state.range(0)); + std::atomic<int32_t> v(0); + auto ts = absl::ToTimespec(timeout); + + for (auto _ : state) { + TEST_PCHECK(FutexWaitMonotonicTimeout(&v, 0, &ts) == -1 && + errno == ETIMEDOUT); + } +} + +BENCHMARK(BM_FutexWaitMonotonicTimeout) + ->MinTime(5) + ->UseRealTime() + ->Arg(1) + ->Arg(10) + ->Arg(100) + ->Arg(1000) + ->Arg(10000); + +// This uses FUTEX_WAIT_BITSET with a deadline that is in the past. This allows +// estimation of the overhead of setting up a timer for a deadline (as opposed +// to a timeout as specified for FUTEX_WAIT). +void BM_FutexWaitMonotonicDeadline(benchmark::State& state) { + std::atomic<int32_t> v(0); + struct timespec ts = {}; + + for (auto _ : state) { + TEST_PCHECK(FutexWaitMonotonicDeadline(&v, 0, &ts) == -1 && + errno == ETIMEDOUT); + } +} + +BENCHMARK(BM_FutexWaitMonotonicDeadline)->MinTime(5); + +// This is equivalent to BM_FutexWaitMonotonicDeadline, but uses CLOCK_REALTIME +// instead of CLOCK_MONOTONIC for the deadline. +void BM_FutexWaitRealtimeDeadline(benchmark::State& state) { + std::atomic<int32_t> v(0); + struct timespec ts = {}; + + for (auto _ : state) { + TEST_PCHECK(FutexWaitRealtimeDeadline(&v, 0, &ts) == -1 && + errno == ETIMEDOUT); + } +} + +BENCHMARK(BM_FutexWaitRealtimeDeadline)->MinTime(5); + +int64_t GetCurrentMonotonicTimeNanos() { + struct timespec ts; + TEST_CHECK(clock_gettime(CLOCK_MONOTONIC, &ts) != -1); + return ts.tv_sec * 1000000000ULL + ts.tv_nsec; +} + +void SpinNanos(int64_t delay_ns) { + if (delay_ns <= 0) { + return; + } + const int64_t end = GetCurrentMonotonicTimeNanos() + delay_ns; + while (GetCurrentMonotonicTimeNanos() < end) { + // spin + } +} + +// Each iteration of FutexRoundtripDelayed involves a thread sending a futex +// wakeup to another thread, which spins for delay_us and then sends a futex +// wakeup back. The time per iteration is 2 * (delay_us + kBeforeWakeDelayNs + +// futex/scheduling overhead). +void BM_FutexRoundtripDelayed(benchmark::State& state) { + const int delay_us = state.range(0); + const int64_t delay_ns = delay_us * 1000; + // Spin for an extra kBeforeWakeDelayNs before invoking FUTEX_WAKE to reduce + // the probability that the wakeup comes before the wait, preventing the wait + // from ever taking effect and causing the benchmark to underestimate the + // actual wakeup time. + constexpr int64_t kBeforeWakeDelayNs = 500; + std::atomic<int32_t> v(0); + ScopedThread t([&] { + for (int i = 0; i < state.max_iterations; i++) { + SpinNanos(delay_ns); + while (v.load(std::memory_order_acquire) == 0) { + FutexWait(&v, 0); + } + SpinNanos(kBeforeWakeDelayNs + delay_ns); + v.store(0, std::memory_order_release); + FutexWake(&v, 1); + } + }); + for (auto _ : state) { + SpinNanos(kBeforeWakeDelayNs + delay_ns); + v.store(1, std::memory_order_release); + FutexWake(&v, 1); + SpinNanos(delay_ns); + while (v.load(std::memory_order_acquire) == 1) { + FutexWait(&v, 1); + } + } +} + +BENCHMARK(BM_FutexRoundtripDelayed) + ->MinTime(5) + ->UseRealTime() + ->Arg(0) + ->Arg(10) + ->Arg(20) + ->Arg(50) + ->Arg(100); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/getdents_benchmark.cc b/test/perf/linux/getdents_benchmark.cc new file mode 100644 index 000000000..d8e81fa8c --- /dev/null +++ b/test/perf/linux/getdents_benchmark.cc @@ -0,0 +1,149 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/stat.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/file_descriptor.h" +#include "test/util/fs_util.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +#ifndef SYS_getdents64 +#if defined(__x86_64__) +#define SYS_getdents64 217 +#elif defined(__aarch64__) +#define SYS_getdents64 217 +#else +#error "Unknown architecture" +#endif +#endif // SYS_getdents64 + +namespace gvisor { +namespace testing { + +namespace { + +constexpr int kBufferSize = 65536; + +PosixErrorOr<TempPath> CreateDirectory(int count, + std::vector<std::string>* files) { + ASSIGN_OR_RETURN_ERRNO(TempPath dir, TempPath::CreateDir()); + + ASSIGN_OR_RETURN_ERRNO(FileDescriptor dfd, + Open(dir.path(), O_RDONLY | O_DIRECTORY)); + + for (int i = 0; i < count; i++) { + auto file = NewTempRelPath(); + auto res = MknodAt(dfd, file, S_IFREG | 0644, 0); + RETURN_IF_ERRNO(res); + files->push_back(file); + } + + return std::move(dir); +} + +PosixError CleanupDirectory(const TempPath& dir, + std::vector<std::string>* files) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor dfd, + Open(dir.path(), O_RDONLY | O_DIRECTORY)); + + for (auto it = files->begin(); it != files->end(); ++it) { + auto res = UnlinkAt(dfd, *it, 0); + RETURN_IF_ERRNO(res); + } + return NoError(); +} + +// Creates a directory containing `files` files, and reads all the directory +// entries from the directory using a single FD. +void BM_GetdentsSameFD(benchmark::State& state) { + // Create directory with given files. + const int count = state.range(0); + + // Keep a vector of all of the file TempPaths that is destroyed before dir. + // + // Normally, we'd simply allow dir to recursively clean up the contained + // files, but that recursive cleanup uses getdents, which may be very slow in + // extreme benchmarks. + TempPath dir; + std::vector<std::string> files; + dir = ASSERT_NO_ERRNO_AND_VALUE(CreateDirectory(count, &files)); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY)); + char buffer[kBufferSize]; + + // We read all directory entries on each iteration, but report this as a + // "batch" iteration so that reported times are per file. + while (state.KeepRunningBatch(count)) { + ASSERT_THAT(lseek(fd.get(), 0, SEEK_SET), SyscallSucceeds()); + + int ret; + do { + ASSERT_THAT(ret = syscall(SYS_getdents64, fd.get(), buffer, kBufferSize), + SyscallSucceeds()); + } while (ret > 0); + } + + ASSERT_NO_ERRNO(CleanupDirectory(dir, &files)); + + state.SetItemsProcessed(state.iterations()); +} + +BENCHMARK(BM_GetdentsSameFD)->Range(1, 1 << 16)->UseRealTime(); + +// Creates a directory containing `files` files, and reads all the directory +// entries from the directory using a new FD each time. +void BM_GetdentsNewFD(benchmark::State& state) { + // Create directory with given files. + const int count = state.range(0); + + // Keep a vector of all of the file TempPaths that is destroyed before dir. + // + // Normally, we'd simply allow dir to recursively clean up the contained + // files, but that recursive cleanup uses getdents, which may be very slow in + // extreme benchmarks. + TempPath dir; + std::vector<std::string> files; + dir = ASSERT_NO_ERRNO_AND_VALUE(CreateDirectory(count, &files)); + char buffer[kBufferSize]; + + // We read all directory entries on each iteration, but report this as a + // "batch" iteration so that reported times are per file. + while (state.KeepRunningBatch(count)) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY)); + + int ret; + do { + ASSERT_THAT(ret = syscall(SYS_getdents64, fd.get(), buffer, kBufferSize), + SyscallSucceeds()); + } while (ret > 0); + } + + ASSERT_NO_ERRNO(CleanupDirectory(dir, &files)); + + state.SetItemsProcessed(state.iterations()); +} + +BENCHMARK(BM_GetdentsNewFD)->Range(1, 1 << 12)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/getpid_benchmark.cc b/test/perf/linux/getpid_benchmark.cc new file mode 100644 index 000000000..db74cb264 --- /dev/null +++ b/test/perf/linux/getpid_benchmark.cc @@ -0,0 +1,37 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/syscall.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Getpid(benchmark::State& state) { + for (auto _ : state) { + syscall(SYS_getpid); + } +} + +BENCHMARK(BM_Getpid); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/gettid_benchmark.cc b/test/perf/linux/gettid_benchmark.cc new file mode 100644 index 000000000..8f4961f5e --- /dev/null +++ b/test/perf/linux/gettid_benchmark.cc @@ -0,0 +1,38 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/syscall.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Gettid(benchmark::State& state) { + for (auto _ : state) { + syscall(SYS_gettid); + } +} + +BENCHMARK(BM_Gettid)->ThreadRange(1, 4000)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/mapping_benchmark.cc b/test/perf/linux/mapping_benchmark.cc new file mode 100644 index 000000000..39c30fe69 --- /dev/null +++ b/test/perf/linux/mapping_benchmark.cc @@ -0,0 +1,163 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stdlib.h> +#include <sys/mman.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" +#include "test/util/memory_util.h" +#include "test/util/posix_error.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Conservative value for /proc/sys/vm/max_map_count, which limits the number of +// VMAs, minus a safety margin for VMAs that already exist for the test binary. +// The default value for max_map_count is +// include/linux/mm.h:DEFAULT_MAX_MAP_COUNT = 65530. +constexpr size_t kMaxVMAs = 64001; + +// Map then unmap pages without touching them. +void BM_MapUnmap(benchmark::State& state) { + // Number of pages to map. + const int pages = state.range(0); + + while (state.KeepRunning()) { + void* addr = mmap(0, pages * kPageSize, PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed"); + + int ret = munmap(addr, pages * kPageSize); + TEST_CHECK_MSG(ret == 0, "munmap failed"); + } +} + +BENCHMARK(BM_MapUnmap)->Range(1, 1 << 17)->UseRealTime(); + +// Map, touch, then unmap pages. +void BM_MapTouchUnmap(benchmark::State& state) { + // Number of pages to map. + const int pages = state.range(0); + + while (state.KeepRunning()) { + void* addr = mmap(0, pages * kPageSize, PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed"); + + char* c = reinterpret_cast<char*>(addr); + char* end = c + pages * kPageSize; + while (c < end) { + *c = 42; + c += kPageSize; + } + + int ret = munmap(addr, pages * kPageSize); + TEST_CHECK_MSG(ret == 0, "munmap failed"); + } +} + +BENCHMARK(BM_MapTouchUnmap)->Range(1, 1 << 17)->UseRealTime(); + +// Map and touch many pages, unmapping all at once. +// +// NOTE(b/111429208): This is a regression test to ensure performant mapping and +// allocation even with tons of mappings. +void BM_MapTouchMany(benchmark::State& state) { + // Number of pages to map. + const int page_count = state.range(0); + + while (state.KeepRunning()) { + std::vector<void*> pages; + + for (int i = 0; i < page_count; i++) { + void* addr = mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed"); + + char* c = reinterpret_cast<char*>(addr); + *c = 42; + + pages.push_back(addr); + } + + for (void* addr : pages) { + int ret = munmap(addr, kPageSize); + TEST_CHECK_MSG(ret == 0, "munmap failed"); + } + } + + state.SetBytesProcessed(kPageSize * page_count * state.iterations()); +} + +BENCHMARK(BM_MapTouchMany)->Range(1, 1 << 12)->UseRealTime(); + +void BM_PageFault(benchmark::State& state) { + // Map the region in which we will take page faults. To ensure that each page + // fault maps only a single page, each page we touch must correspond to a + // distinct VMA. Thus we need a 1-page gap between each 1-page VMA. However, + // each gap consists of a PROT_NONE VMA, instead of an unmapped hole, so that + // if there are background threads running, they can't inadvertently creating + // mappings in our gaps that are unmapped when the test ends. + size_t test_pages = kMaxVMAs; + // Ensure that test_pages is odd, since we want the test region to both + // begin and end with a mapped page. + if (test_pages % 2 == 0) { + test_pages--; + } + const size_t test_region_bytes = test_pages * kPageSize; + // Use MAP_SHARED here because madvise(MADV_DONTNEED) on private mappings on + // gVisor won't force future sentry page faults (by design). Use MAP_POPULATE + // so that Linux pre-allocates the shmem file used to back the mapping. + Mapping m = ASSERT_NO_ERRNO_AND_VALUE( + MmapAnon(test_region_bytes, PROT_READ, MAP_SHARED | MAP_POPULATE)); + for (size_t i = 0; i < test_pages / 2; i++) { + ASSERT_THAT( + mprotect(reinterpret_cast<void*>(m.addr() + ((2 * i + 1) * kPageSize)), + kPageSize, PROT_NONE), + SyscallSucceeds()); + } + + const size_t mapped_pages = test_pages / 2 + 1; + // "Start" at the end of the mapped region to force the mapped region to be + // reset, since we mapped it with MAP_POPULATE. + size_t cur_page = mapped_pages; + for (auto _ : state) { + if (cur_page >= mapped_pages) { + // We've reached the end of our mapped region and have to reset it to + // incur page faults again. + state.PauseTiming(); + ASSERT_THAT(madvise(m.ptr(), test_region_bytes, MADV_DONTNEED), + SyscallSucceeds()); + cur_page = 0; + state.ResumeTiming(); + } + const uintptr_t addr = m.addr() + (2 * cur_page * kPageSize); + const char c = *reinterpret_cast<volatile char*>(addr); + benchmark::DoNotOptimize(c); + cur_page++; + } +} + +BENCHMARK(BM_PageFault)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/open_benchmark.cc b/test/perf/linux/open_benchmark.cc new file mode 100644 index 000000000..68008f6d5 --- /dev/null +++ b/test/perf/linux/open_benchmark.cc @@ -0,0 +1,56 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <fcntl.h> +#include <stdlib.h> +#include <unistd.h> + +#include <memory> +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/fs_util.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Open(benchmark::State& state) { + const int size = state.range(0); + std::vector<TempPath> cache; + for (int i = 0; i < size; i++) { + auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + cache.emplace_back(std::move(path)); + } + + unsigned int seed = 1; + for (auto _ : state) { + const int chosen = rand_r(&seed) % size; + int fd = open(cache[chosen].path().c_str(), O_RDONLY); + TEST_CHECK(fd != -1); + close(fd); + } +} + +BENCHMARK(BM_Open)->Range(1, 128)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/pipe_benchmark.cc b/test/perf/linux/pipe_benchmark.cc new file mode 100644 index 000000000..8f5f6a2a3 --- /dev/null +++ b/test/perf/linux/pipe_benchmark.cc @@ -0,0 +1,66 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stdlib.h> +#include <sys/stat.h> +#include <unistd.h> + +#include <cerrno> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Pipe(benchmark::State& state) { + int fds[2]; + TEST_CHECK(pipe(fds) == 0); + + const int size = state.range(0); + std::vector<char> wbuf(size); + std::vector<char> rbuf(size); + RandomizeBuffer(wbuf.data(), size); + + ScopedThread t([&] { + auto const fd = fds[1]; + for (int i = 0; i < state.max_iterations; i++) { + TEST_CHECK(WriteFd(fd, wbuf.data(), wbuf.size()) == size); + } + }); + + for (auto _ : state) { + TEST_CHECK(ReadFd(fds[0], rbuf.data(), rbuf.size()) == size); + } + + t.Join(); + + close(fds[0]); + close(fds[1]); + + state.SetBytesProcessed(static_cast<int64_t>(size) * + static_cast<int64_t>(state.iterations())); +} + +BENCHMARK(BM_Pipe)->Range(1, 1 << 20)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/randread_benchmark.cc b/test/perf/linux/randread_benchmark.cc new file mode 100644 index 000000000..b0eb8c24e --- /dev/null +++ b/test/perf/linux/randread_benchmark.cc @@ -0,0 +1,100 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <fcntl.h> +#include <stdlib.h> +#include <sys/stat.h> +#include <sys/uio.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/file_descriptor.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Create a 1GB file that will be read from at random positions. This should +// invalid any performance gains from caching. +const uint64_t kFileSize = 1ULL << 30; + +// How many bytes to write at once to initialize the file used to read from. +const uint32_t kWriteSize = 65536; + +// Largest benchmarked read unit. +const uint32_t kMaxRead = 1UL << 26; + +TempPath CreateFile(uint64_t file_size) { + auto path = TempPath::CreateFile().ValueOrDie(); + FileDescriptor fd = Open(path.path(), O_WRONLY).ValueOrDie(); + + // Try to minimize syscalls by using maximum size writev() requests. + std::vector<char> buffer(kWriteSize); + RandomizeBuffer(buffer.data(), buffer.size()); + const std::vector<std::vector<struct iovec>> iovecs_list = + GenerateIovecs(file_size, buffer.data(), buffer.size()); + for (const auto& iovecs : iovecs_list) { + TEST_CHECK(writev(fd.get(), iovecs.data(), iovecs.size()) >= 0); + } + + return path; +} + +// Global test state, initialized once per process lifetime. +struct GlobalState { + const TempPath tmpfile; + explicit GlobalState(TempPath tfile) : tmpfile(std::move(tfile)) {} +}; + +GlobalState& GetGlobalState() { + // This gets created only once throughout the lifetime of the process. + // Use a dynamically allocated object (that is never deleted) to avoid order + // of destruction of static storage variables issues. + static GlobalState* const state = + // The actual file size is the maximum random seek range (kFileSize) + the + // maximum read size so we can read that number of bytes at the end of the + // file. + new GlobalState(CreateFile(kFileSize + kMaxRead)); + return *state; +} + +void BM_RandRead(benchmark::State& state) { + const int size = state.range(0); + + GlobalState& global_state = GetGlobalState(); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(global_state.tmpfile.path(), O_RDONLY)); + std::vector<char> buf(size); + + unsigned int seed = 1; + for (auto _ : state) { + TEST_CHECK(PreadFd(fd.get(), buf.data(), buf.size(), + rand_r(&seed) % kFileSize) == size); + } + + state.SetBytesProcessed(static_cast<int64_t>(size) * + static_cast<int64_t>(state.iterations())); +} + +BENCHMARK(BM_RandRead)->Range(1, kMaxRead)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/read_benchmark.cc b/test/perf/linux/read_benchmark.cc new file mode 100644 index 000000000..62445867d --- /dev/null +++ b/test/perf/linux/read_benchmark.cc @@ -0,0 +1,53 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <fcntl.h> +#include <stdlib.h> +#include <sys/stat.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/fs_util.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Read(benchmark::State& state) { + const int size = state.range(0); + const std::string contents(size, 0); + auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), contents, TempPath::kDefaultFileMode)); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_RDONLY)); + + std::vector<char> buf(size); + for (auto _ : state) { + TEST_CHECK(PreadFd(fd.get(), buf.data(), buf.size(), 0) == size); + } + + state.SetBytesProcessed(static_cast<int64_t>(size) * + static_cast<int64_t>(state.iterations())); +} + +BENCHMARK(BM_Read)->Range(1, 1 << 26)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/sched_yield_benchmark.cc b/test/perf/linux/sched_yield_benchmark.cc new file mode 100644 index 000000000..6756b5575 --- /dev/null +++ b/test/perf/linux/sched_yield_benchmark.cc @@ -0,0 +1,37 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sched.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Sched_yield(benchmark::State& state) { + for (auto ignored : state) { + TEST_CHECK(sched_yield() == 0); + } +} + +BENCHMARK(BM_Sched_yield)->ThreadRange(1, 2000)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/send_recv_benchmark.cc b/test/perf/linux/send_recv_benchmark.cc new file mode 100644 index 000000000..d73e49523 --- /dev/null +++ b/test/perf/linux/send_recv_benchmark.cc @@ -0,0 +1,372 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <netinet/in.h> +#include <netinet/tcp.h> +#include <poll.h> +#include <sys/ioctl.h> +#include <sys/socket.h> + +#include <cstring> + +#include "gtest/gtest.h" +#include "absl/synchronization/notification.h" +#include "benchmark/benchmark.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/logging.h" +#include "test/util/posix_error.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +constexpr ssize_t kMessageSize = 1024; + +class Message { + public: + explicit Message(int byte = 0) : Message(byte, kMessageSize, 0) {} + + explicit Message(int byte, int sz) : Message(byte, sz, 0) {} + + explicit Message(int byte, int sz, int cmsg_sz) + : buffer_(sz, byte), cmsg_buffer_(cmsg_sz, 0) { + iov_.iov_base = buffer_.data(); + iov_.iov_len = sz; + hdr_.msg_iov = &iov_; + hdr_.msg_iovlen = 1; + hdr_.msg_control = cmsg_buffer_.data(); + hdr_.msg_controllen = cmsg_sz; + } + + struct msghdr* header() { + return &hdr_; + } + + private: + std::vector<char> buffer_; + std::vector<char> cmsg_buffer_; + struct iovec iov_ = {}; + struct msghdr hdr_ = {}; +}; + +void BM_Recvmsg(benchmark::State& state) { + int sockets[2]; + TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0); + FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]); + absl::Notification notification; + Message send_msg('a'), recv_msg; + + ScopedThread t([&send_msg, &send_socket, ¬ification] { + while (!notification.HasBeenNotified()) { + sendmsg(send_socket.get(), send_msg.header(), 0); + } + }); + + int64_t bytes_received = 0; + for (auto ignored : state) { + int n = recvmsg(recv_socket.get(), recv_msg.header(), 0); + TEST_CHECK(n > 0); + bytes_received += n; + } + + notification.Notify(); + recv_socket.reset(); + + state.SetBytesProcessed(bytes_received); +} + +BENCHMARK(BM_Recvmsg)->UseRealTime(); + +void BM_Sendmsg(benchmark::State& state) { + int sockets[2]; + TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0); + FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]); + absl::Notification notification; + Message send_msg('a'), recv_msg; + + ScopedThread t([&recv_msg, &recv_socket, ¬ification] { + while (!notification.HasBeenNotified()) { + recvmsg(recv_socket.get(), recv_msg.header(), 0); + } + }); + + int64_t bytes_sent = 0; + for (auto ignored : state) { + int n = sendmsg(send_socket.get(), send_msg.header(), 0); + TEST_CHECK(n > 0); + bytes_sent += n; + } + + notification.Notify(); + send_socket.reset(); + + state.SetBytesProcessed(bytes_sent); +} + +BENCHMARK(BM_Sendmsg)->UseRealTime(); + +void BM_Recvfrom(benchmark::State& state) { + int sockets[2]; + TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0); + FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]); + absl::Notification notification; + char send_buffer[kMessageSize], recv_buffer[kMessageSize]; + + ScopedThread t([&send_socket, &send_buffer, ¬ification] { + while (!notification.HasBeenNotified()) { + sendto(send_socket.get(), send_buffer, kMessageSize, 0, nullptr, 0); + } + }); + + int bytes_received = 0; + for (auto ignored : state) { + int n = recvfrom(recv_socket.get(), recv_buffer, kMessageSize, 0, nullptr, + nullptr); + TEST_CHECK(n > 0); + bytes_received += n; + } + + notification.Notify(); + recv_socket.reset(); + + state.SetBytesProcessed(bytes_received); +} + +BENCHMARK(BM_Recvfrom)->UseRealTime(); + +void BM_Sendto(benchmark::State& state) { + int sockets[2]; + TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0); + FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]); + absl::Notification notification; + char send_buffer[kMessageSize], recv_buffer[kMessageSize]; + + ScopedThread t([&recv_socket, &recv_buffer, ¬ification] { + while (!notification.HasBeenNotified()) { + recvfrom(recv_socket.get(), recv_buffer, kMessageSize, 0, nullptr, + nullptr); + } + }); + + int64_t bytes_sent = 0; + for (auto ignored : state) { + int n = sendto(send_socket.get(), send_buffer, kMessageSize, 0, nullptr, 0); + TEST_CHECK(n > 0); + bytes_sent += n; + } + + notification.Notify(); + send_socket.reset(); + + state.SetBytesProcessed(bytes_sent); +} + +BENCHMARK(BM_Sendto)->UseRealTime(); + +PosixErrorOr<sockaddr_storage> InetLoopbackAddr(int family) { + struct sockaddr_storage addr; + memset(&addr, 0, sizeof(addr)); + addr.ss_family = family; + switch (family) { + case AF_INET: + reinterpret_cast<struct sockaddr_in*>(&addr)->sin_addr.s_addr = + htonl(INADDR_LOOPBACK); + break; + case AF_INET6: + reinterpret_cast<struct sockaddr_in6*>(&addr)->sin6_addr = + in6addr_loopback; + break; + default: + return PosixError(EINVAL, + absl::StrCat("unknown socket family: ", family)); + } + return addr; +} + +// BM_RecvmsgWithControlBuf measures the performance of recvmsg when we allocate +// space for control messages. Note that we do not expect to receive any. +void BM_RecvmsgWithControlBuf(benchmark::State& state) { + auto listen_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP)); + + // Initialize address to the loopback one. + sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(AF_INET6)); + socklen_t addrlen = sizeof(addr); + + // Bind to some port then start listening. + ASSERT_THAT(bind(listen_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + ASSERT_THAT(listen(listen_socket.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the address we're listening on, then connect to it. We need to do this + // because we're allowing the stack to pick a port for us. + ASSERT_THAT(getsockname(listen_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + auto send_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT( + RetryEINTR(connect)(send_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + // Accept the connection. + auto recv_socket = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_socket.get(), nullptr, nullptr)); + + absl::Notification notification; + Message send_msg('a'); + // Create a msghdr with a buffer allocated for control messages. + Message recv_msg(0, kMessageSize, /*cmsg_sz=*/24); + + ScopedThread t([&send_msg, &send_socket, ¬ification] { + while (!notification.HasBeenNotified()) { + sendmsg(send_socket.get(), send_msg.header(), 0); + } + }); + + int64_t bytes_received = 0; + for (auto ignored : state) { + int n = recvmsg(recv_socket.get(), recv_msg.header(), 0); + TEST_CHECK(n > 0); + bytes_received += n; + } + + notification.Notify(); + recv_socket.reset(); + + state.SetBytesProcessed(bytes_received); +} + +BENCHMARK(BM_RecvmsgWithControlBuf)->UseRealTime(); + +// BM_SendmsgTCP measures the sendmsg throughput with varying payload sizes. +// +// state.Args[0] indicates whether the underlying socket should be blocking or +// non-blocking w/ 0 indicating non-blocking and 1 to indicate blocking. +// state.Args[1] is the size of the payload to be used per sendmsg call. +void BM_SendmsgTCP(benchmark::State& state) { + auto listen_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); + + // Initialize address to the loopback one. + sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(AF_INET)); + socklen_t addrlen = sizeof(addr); + + // Bind to some port then start listening. + ASSERT_THAT(bind(listen_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + ASSERT_THAT(listen(listen_socket.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the address we're listening on, then connect to it. We need to do this + // because we're allowing the stack to pick a port for us. + ASSERT_THAT(getsockname(listen_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + auto send_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT( + RetryEINTR(connect)(send_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + // Accept the connection. + auto recv_socket = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_socket.get(), nullptr, nullptr)); + + // Check if we want to run the test w/ a blocking send socket + // or non-blocking. + const int blocking = state.range(0); + if (!blocking) { + // Set the send FD to O_NONBLOCK. + int opts; + ASSERT_THAT(opts = fcntl(send_socket.get(), F_GETFL), SyscallSucceeds()); + opts |= O_NONBLOCK; + ASSERT_THAT(fcntl(send_socket.get(), F_SETFL, opts), SyscallSucceeds()); + } + + absl::Notification notification; + + // Get the buffer size we should use for this iteration of the test. + const int buf_size = state.range(1); + Message send_msg('a', buf_size), recv_msg(0, buf_size); + + ScopedThread t([&recv_msg, &recv_socket, ¬ification] { + while (!notification.HasBeenNotified()) { + TEST_CHECK(recvmsg(recv_socket.get(), recv_msg.header(), 0) >= 0); + } + }); + + int64_t bytes_sent = 0; + int ncalls = 0; + for (auto ignored : state) { + int sent = 0; + while (true) { + struct msghdr hdr = {}; + struct iovec iov = {}; + struct msghdr* snd_header = send_msg.header(); + iov.iov_base = static_cast<char*>(snd_header->msg_iov->iov_base) + sent; + iov.iov_len = snd_header->msg_iov->iov_len - sent; + hdr.msg_iov = &iov; + hdr.msg_iovlen = 1; + int n = RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0); + ncalls++; + if (n > 0) { + sent += n; + if (sent == buf_size) { + break; + } + // n can be > 0 but less than requested size. In which case we don't + // poll. + continue; + } + // Poll the fd for it to become writable. + struct pollfd poll_fd = {send_socket.get(), POLL_OUT, 0}; + EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10), + SyscallSucceedsWithValue(0)); + } + bytes_sent += static_cast<int64_t>(sent); + } + + notification.Notify(); + send_socket.reset(); + state.SetBytesProcessed(bytes_sent); +} + +void Args(benchmark::internal::Benchmark* benchmark) { + for (int blocking = 0; blocking < 2; blocking++) { + for (int buf_size = 1024; buf_size <= 256 << 20; buf_size *= 2) { + benchmark->Args({blocking, buf_size}); + } + } +} + +BENCHMARK(BM_SendmsgTCP)->Apply(&Args)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/seqwrite_benchmark.cc b/test/perf/linux/seqwrite_benchmark.cc new file mode 100644 index 000000000..af49e4477 --- /dev/null +++ b/test/perf/linux/seqwrite_benchmark.cc @@ -0,0 +1,66 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <fcntl.h> +#include <stdlib.h> +#include <sys/stat.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// The maximum file size of the test file, when writes get beyond this point +// they wrap around. This should be large enough to blow away caches. +const uint64_t kMaxFile = 1 << 30; + +// Perform writes of various sizes sequentially to one file. Wraps around if it +// goes above a certain maximum file size. +void BM_SeqWrite(benchmark::State& state) { + auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_WRONLY)); + + const int size = state.range(0); + std::vector<char> buf(size); + RandomizeBuffer(buf.data(), buf.size()); + + // Start writes at offset 0. + uint64_t offset = 0; + for (auto _ : state) { + TEST_CHECK(PwriteFd(fd.get(), buf.data(), buf.size(), offset) == + buf.size()); + offset += buf.size(); + // Wrap around if going above the maximum file size. + if (offset >= kMaxFile) { + offset = 0; + } + } + + state.SetBytesProcessed(static_cast<int64_t>(size) * + static_cast<int64_t>(state.iterations())); +} + +BENCHMARK(BM_SeqWrite)->Range(1, 1 << 26)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/signal_benchmark.cc b/test/perf/linux/signal_benchmark.cc new file mode 100644 index 000000000..cec679191 --- /dev/null +++ b/test/perf/linux/signal_benchmark.cc @@ -0,0 +1,61 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <signal.h> +#include <string.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +void FixupHandler(int sig, siginfo_t* si, void* void_ctx) { + static unsigned int dataval = 0; + + // Skip the offending instruction. + ucontext_t* ctx = reinterpret_cast<ucontext_t*>(void_ctx); + ctx->uc_mcontext.gregs[REG_RAX] = reinterpret_cast<greg_t>(&dataval); +} + +void BM_FaultSignalFixup(benchmark::State& state) { + // Set up the signal handler. + struct sigaction sa = {}; + sigemptyset(&sa.sa_mask); + sa.sa_sigaction = FixupHandler; + sa.sa_flags = SA_SIGINFO; + TEST_CHECK(sigaction(SIGSEGV, &sa, nullptr) == 0); + + // Fault, fault, fault. + for (auto _ : state) { + // Trigger the segfault. + asm volatile( + "movq $0, %%rax\n" + "movq $0x77777777, (%%rax)\n" + : + : + : "rax"); + } +} + +BENCHMARK(BM_FaultSignalFixup)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/sleep_benchmark.cc b/test/perf/linux/sleep_benchmark.cc new file mode 100644 index 000000000..99ef05117 --- /dev/null +++ b/test/perf/linux/sleep_benchmark.cc @@ -0,0 +1,60 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <sys/syscall.h> +#include <time.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Sleep for 'param' nanoseconds. +void BM_Sleep(benchmark::State& state) { + const int nanoseconds = state.range(0); + + for (auto _ : state) { + struct timespec ts; + ts.tv_sec = 0; + ts.tv_nsec = nanoseconds; + + int ret; + do { + ret = syscall(SYS_nanosleep, &ts, &ts); + if (ret < 0) { + TEST_CHECK(errno == EINTR); + } + } while (ret < 0); + } +} + +BENCHMARK(BM_Sleep) + ->Arg(0) + ->Arg(1) + ->Arg(1000) // 1us + ->Arg(1000 * 1000) // 1ms + ->Arg(10 * 1000 * 1000) // 10ms + ->Arg(50 * 1000 * 1000) // 50ms + ->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/stat_benchmark.cc b/test/perf/linux/stat_benchmark.cc new file mode 100644 index 000000000..f15424482 --- /dev/null +++ b/test/perf/linux/stat_benchmark.cc @@ -0,0 +1,62 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/stat.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "absl/strings/str_cat.h" +#include "benchmark/benchmark.h" +#include "test/util/fs_util.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Creates a file in a nested directory hierarchy at least `depth` directories +// deep, and stats that file multiple times. +void BM_Stat(benchmark::State& state) { + // Create nested directories with given depth. + int depth = state.range(0); + const TempPath top_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + std::string dir_path = top_dir.path(); + + while (depth-- > 0) { + // Don't use TempPath because it will make paths too long to use. + // + // The top_dir destructor will clean up this whole tree. + dir_path = JoinPath(dir_path, absl::StrCat(depth)); + ASSERT_NO_ERRNO(Mkdir(dir_path, 0755)); + } + + // Create the file that will be stat'd. + const TempPath file = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir_path)); + + struct stat st; + for (auto _ : state) { + ASSERT_THAT(stat(file.path().c_str(), &st), SyscallSucceeds()); + } +} + +BENCHMARK(BM_Stat)->Range(1, 100)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/unlink_benchmark.cc b/test/perf/linux/unlink_benchmark.cc new file mode 100644 index 000000000..92243a042 --- /dev/null +++ b/test/perf/linux/unlink_benchmark.cc @@ -0,0 +1,66 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/stat.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/fs_util.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Creates a directory containing `files` files, and unlinks all the files. +void BM_Unlink(benchmark::State& state) { + // Create directory with given files. + const int file_count = state.range(0); + + // We unlink all files on each iteration, but report this as a "batch" + // iteration so that reported times are per file. + TempPath dir; + while (state.KeepRunningBatch(file_count)) { + state.PauseTiming(); + // N.B. dir is declared outside the loop so that destruction of the previous + // iteration's directory occurs here, inside of PauseTiming. + dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + + std::vector<TempPath> files; + for (int i = 0; i < file_count; i++) { + TempPath file = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); + files.push_back(std::move(file)); + } + state.ResumeTiming(); + + while (!files.empty()) { + // Destructor unlinks. + files.pop_back(); + } + } + + state.SetItemsProcessed(state.iterations()); +} + +BENCHMARK(BM_Unlink)->Range(1, 100 * 1000)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/write_benchmark.cc b/test/perf/linux/write_benchmark.cc new file mode 100644 index 000000000..7b060c70e --- /dev/null +++ b/test/perf/linux/write_benchmark.cc @@ -0,0 +1,52 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <fcntl.h> +#include <stdlib.h> +#include <sys/stat.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Write(benchmark::State& state) { + auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_WRONLY)); + + const int size = state.range(0); + std::vector<char> buf(size); + RandomizeBuffer(buf.data(), size); + + for (auto _ : state) { + TEST_CHECK(PwriteFd(fd.get(), buf.data(), size, 0) == size); + } + + state.SetBytesProcessed(static_cast<int64_t>(size) * + static_cast<int64_t>(state.iterations())); +} + +BENCHMARK(BM_Write)->Range(1, 1 << 26)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/root/BUILD b/test/root/BUILD index d5dd9bca2..a9130b34f 100644 --- a/test/root/BUILD +++ b/test/root/BUILD @@ -1,11 +1,11 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/vm:defs.bzl", "vm_test") package(licenses = ["notice"]) go_library( name = "root", srcs = ["root.go"], - importpath = "gvisor.dev/gvisor/test/root", ) go_test( @@ -17,28 +17,39 @@ go_test( "crictl_test.go", "main_test.go", "oom_score_adj_test.go", + "runsc_test.go", ], data = [ "//runsc", ], - embed = [":root"], + library = ":root", tags = [ # Requires docker and runsc to be configured before the test runs. - # Also test only runs as root. + # Also, the test needs to be run as root. Note that below, the + # root_vm_test relies on the default runtime 'runsc' being installed by + # the default installer. "manual", "local", ], visibility = ["//:sandbox"], deps = [ - "//runsc/boot", + "//pkg/cleanup", + "//pkg/test/criutil", + "//pkg/test/dockerutil", + "//pkg/test/testutil", "//runsc/cgroup", "//runsc/container", - "//runsc/criutil", - "//runsc/dockerutil", "//runsc/specutils", - "//runsc/testutil", - "//test/root/testdata", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_cenkalti_backoff//:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", "@com_github_syndtr_gocapability//capability:go_default_library", + "@org_golang_x_sys//unix:go_default_library", ], ) + +vm_test( + name = "root_vm_test", + size = "large", + shard_count = 1, + targets = [":root_test"], +) diff --git a/test/root/cgroup_test.go b/test/root/cgroup_test.go index 76f1e4f2a..a26b83081 100644 --- a/test/root/cgroup_test.go +++ b/test/root/cgroup_test.go @@ -16,6 +16,7 @@ package root import ( "bufio" + "context" "fmt" "io/ioutil" "os" @@ -24,10 +25,11 @@ import ( "strconv" "strings" "testing" + "time" + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/pkg/test/testutil" "gvisor.dev/gvisor/runsc/cgroup" - "gvisor.dev/gvisor/runsc/dockerutil" - "gvisor.dev/gvisor/runsc/testutil" ) func verifyPid(pid int, path string) error { @@ -52,15 +54,82 @@ func verifyPid(pid int, path string) error { if scanner.Err() != nil { return scanner.Err() } - return fmt.Errorf("got: %s, want: %d", gots, pid) + return fmt.Errorf("got: %v, want: %d", gots, pid) +} + +func TestMemCgroup(t *testing.T) { + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) + + // Start a new container and allocate the specified about of memory. + allocMemSize := 128 << 20 + allocMemLimit := 2 * allocMemSize + + if err := d.Spawn(ctx, dockerutil.RunOpts{ + Image: "basic/ubuntu", + Memory: allocMemLimit, // Must be in bytes. + }, "python3", "-c", fmt.Sprintf("import time; s = 'a' * %d; time.sleep(100)", allocMemSize)); err != nil { + t.Fatalf("docker run failed: %v", err) + } + + // Extract the ID to lookup the cgroup. + gid := d.ID() + t.Logf("cgroup ID: %s", gid) + + // Wait when the container will allocate memory. + memUsage := 0 + start := time.Now() + for time.Since(start) < 30*time.Second { + // Sleep for a brief period of time after spawning the + // container (so that Docker can create the cgroup etc. + // or after looping below (so the application can start). + time.Sleep(100 * time.Millisecond) + + // Read the cgroup memory limit. + path := filepath.Join("/sys/fs/cgroup/memory/docker", gid, "memory.limit_in_bytes") + outRaw, err := ioutil.ReadFile(path) + if err != nil { + // It's possible that the container does not exist yet. + continue + } + out := strings.TrimSpace(string(outRaw)) + memLimit, err := strconv.Atoi(out) + if err != nil { + t.Fatalf("Atoi(%v): %v", out, err) + } + if memLimit != allocMemLimit { + // The group may not have had the correct limit set yet. + continue + } + + // Read the cgroup memory usage. + path = filepath.Join("/sys/fs/cgroup/memory/docker", gid, "memory.max_usage_in_bytes") + outRaw, err = ioutil.ReadFile(path) + if err != nil { + t.Fatalf("error reading usage: %v", err) + } + out = strings.TrimSpace(string(outRaw)) + memUsage, err = strconv.Atoi(out) + if err != nil { + t.Fatalf("Atoi(%v): %v", out, err) + } + t.Logf("read usage: %v, wanted: %v", memUsage, allocMemSize) + + // Are we done? + if memUsage >= allocMemSize { + return + } + } + + t.Fatalf("%vMB is less than %vMB", memUsage>>20, allocMemSize>>20) } // TestCgroup sets cgroup options and checks that cgroup was properly configured. func TestCgroup(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatal("docker pull failed:", err) - } - d := dockerutil.MakeDocker("cgroup-test") + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // This is not a comprehensive list of attributes. // @@ -69,84 +138,133 @@ func TestCgroup(t *testing.T) { // are often run on a single core virtual machine, and there is only a single // CPU available in our current set, and every container's set. attrs := []struct { - arg string + field string + value int64 ctrl string file string want string skipIfNotFound bool }{ { - arg: "--cpu-shares=1000", - ctrl: "cpu", - file: "cpu.shares", - want: "1000", + field: "cpu-shares", + value: 1000, + ctrl: "cpu", + file: "cpu.shares", + want: "1000", }, { - arg: "--cpu-period=2000", - ctrl: "cpu", - file: "cpu.cfs_period_us", - want: "2000", + field: "cpu-period", + value: 2000, + ctrl: "cpu", + file: "cpu.cfs_period_us", + want: "2000", }, { - arg: "--cpu-quota=3000", - ctrl: "cpu", - file: "cpu.cfs_quota_us", - want: "3000", + field: "cpu-quota", + value: 3000, + ctrl: "cpu", + file: "cpu.cfs_quota_us", + want: "3000", }, { - arg: "--kernel-memory=100MB", - ctrl: "memory", - file: "memory.kmem.limit_in_bytes", - want: "104857600", + field: "kernel-memory", + value: 100 << 20, + ctrl: "memory", + file: "memory.kmem.limit_in_bytes", + want: "104857600", }, { - arg: "--memory=1GB", - ctrl: "memory", - file: "memory.limit_in_bytes", - want: "1073741824", + field: "memory", + value: 1 << 30, + ctrl: "memory", + file: "memory.limit_in_bytes", + want: "1073741824", }, { - arg: "--memory-reservation=500MB", - ctrl: "memory", - file: "memory.soft_limit_in_bytes", - want: "524288000", + field: "memory-reservation", + value: 500 << 20, + ctrl: "memory", + file: "memory.soft_limit_in_bytes", + want: "524288000", }, { - arg: "--memory-swap=2GB", + field: "memory-swap", + value: 2 << 30, ctrl: "memory", file: "memory.memsw.limit_in_bytes", want: "2147483648", skipIfNotFound: true, // swap may be disabled on the machine. }, { - arg: "--memory-swappiness=5", - ctrl: "memory", - file: "memory.swappiness", - want: "5", + field: "memory-swappiness", + value: 5, + ctrl: "memory", + file: "memory.swappiness", + want: "5", + }, + { + field: "blkio-weight", + value: 750, + ctrl: "blkio", + file: "blkio.weight", + want: "750", + skipIfNotFound: true, // blkio groups may not be available. }, { - arg: "--blkio-weight=750", - ctrl: "blkio", - file: "blkio.weight", - want: "750", + field: "pids-limit", + value: 1000, + ctrl: "pids", + file: "pids.max", + want: "1000", }, } - args := make([]string, 0, len(attrs)) + // Make configs. + conf, hostconf, _ := d.ConfigsFrom(dockerutil.RunOpts{ + Image: "basic/alpine", + }, "sleep", "10000") + + // Add Cgroup arguments to configs. for _, attr := range attrs { - args = append(args, attr.arg) + switch attr.field { + case "cpu-shares": + hostconf.Resources.CPUShares = attr.value + case "cpu-period": + hostconf.Resources.CPUPeriod = attr.value + case "cpu-quota": + hostconf.Resources.CPUQuota = attr.value + case "kernel-memory": + hostconf.Resources.KernelMemory = attr.value + case "memory": + hostconf.Resources.Memory = attr.value + case "memory-reservation": + hostconf.Resources.MemoryReservation = attr.value + case "memory-swap": + hostconf.Resources.MemorySwap = attr.value + case "memory-swappiness": + val := attr.value + hostconf.Resources.MemorySwappiness = &val + case "blkio-weight": + hostconf.Resources.BlkioWeight = uint16(attr.value) + case "pids-limit": + val := attr.value + hostconf.Resources.PidsLimit = &val + + } } - args = append(args, "alpine", "sleep", "10000") - if err := d.Run(args...); err != nil { - t.Fatal("docker create failed:", err) + // Create container. + if err := d.CreateFrom(ctx, conf, hostconf, nil); err != nil { + t.Fatalf("create failed with: %v", err) } - defer d.CleanUp() - gid, err := d.ID() - if err != nil { - t.Fatalf("Docker.ID() failed: %v", err) + // Start container. + if err := d.Start(ctx); err != nil { + t.Fatalf("start failed with: %v", err) } + + // Lookup the relevant cgroup ID. + gid := d.ID() t.Logf("cgroup ID: %s", gid) // Check list of attributes defined above. @@ -161,7 +279,7 @@ func TestCgroup(t *testing.T) { t.Fatalf("failed to read %q: %v", path, err) } if got := strings.TrimSpace(string(out)); got != attr.want { - t.Errorf("arg: %q, cgroup attribute %s/%s, got: %q, want: %q", attr.arg, attr.ctrl, attr.file, got, attr.want) + t.Errorf("field: %q, cgroup attribute %s/%s, got: %q, want: %q", attr.field, attr.ctrl, attr.file, got, attr.want) } } @@ -179,7 +297,7 @@ func TestCgroup(t *testing.T) { "pids", "systemd", } - pid, err := d.SandboxPid() + pid, err := d.SandboxPid(ctx) if err != nil { t.Fatalf("SandboxPid: %v", err) } @@ -191,25 +309,34 @@ func TestCgroup(t *testing.T) { } } +// TestCgroupParent sets the "CgroupParent" option and checks that the child and parent's +// cgroups are created correctly relative to each other. func TestCgroupParent(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatal("docker pull failed:", err) - } - d := dockerutil.MakeDocker("cgroup-test") + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) - parent := testutil.RandomName("runsc") - if err := d.Run("--cgroup-parent", parent, "alpine", "sleep", "10000"); err != nil { - t.Fatal("docker create failed:", err) + // Construct a known cgroup name. + parent := testutil.RandomID("runsc-") + conf, hostconf, _ := d.ConfigsFrom(dockerutil.RunOpts{ + Image: "basic/alpine", + }, "sleep", "10000") + hostconf.Resources.CgroupParent = parent + + if err := d.CreateFrom(ctx, conf, hostconf, nil); err != nil { + t.Fatalf("create failed with: %v", err) } - defer d.CleanUp() - gid, err := d.ID() - if err != nil { - t.Fatalf("Docker.ID() failed: %v", err) + + if err := d.Start(ctx); err != nil { + t.Fatalf("start failed with: %v", err) } + + // Extract the ID to look up the cgroup. + gid := d.ID() t.Logf("cgroup ID: %s", gid) // Check that sandbox is inside cgroup. - pid, err := d.SandboxPid() + pid, err := d.SandboxPid(ctx) if err != nil { t.Fatalf("SandboxPid: %v", err) } diff --git a/test/root/chroot_test.go b/test/root/chroot_test.go index be0f63d18..58fcd6f08 100644 --- a/test/root/chroot_test.go +++ b/test/root/chroot_test.go @@ -16,6 +16,7 @@ package root import ( + "context" "fmt" "io/ioutil" "os/exec" @@ -24,19 +25,23 @@ import ( "strings" "testing" - "gvisor.dev/gvisor/runsc/dockerutil" + "gvisor.dev/gvisor/pkg/test/dockerutil" ) // TestChroot verifies that the sandbox is chroot'd and that mounts are cleaned // up after the sandbox is destroyed. func TestChroot(t *testing.T) { - d := dockerutil.MakeDocker("chroot-test") - if err := d.Run("alpine", "sleep", "10000"); err != nil { + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) + + if err := d.Spawn(ctx, dockerutil.RunOpts{ + Image: "basic/alpine", + }, "sleep", "10000"); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() - pid, err := d.SandboxPid() + pid, err := d.SandboxPid(ctx) if err != nil { t.Fatalf("Docker.SandboxPid(): %v", err) } @@ -72,20 +77,24 @@ func TestChroot(t *testing.T) { t.Errorf("chroot got children %v, want %v", fi[0].Name(), "proc") } - d.CleanUp() + d.CleanUp(ctx) } func TestChrootGofer(t *testing.T) { - d := dockerutil.MakeDocker("chroot-test") - if err := d.Run("alpine", "sleep", "10000"); err != nil { + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) + + if err := d.Spawn(ctx, dockerutil.RunOpts{ + Image: "basic/alpine", + }, "sleep", "10000"); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() // It's tricky to find gofers. Get sandbox PID first, then find parent. From // parent get all immediate children, remove the sandbox, and everything else // are gofers. - sandPID, err := d.SandboxPid() + sandPID, err := d.SandboxPid(ctx) if err != nil { t.Fatalf("Docker.SandboxPid(): %v", err) } diff --git a/test/root/crictl_test.go b/test/root/crictl_test.go index 3f90c4c6a..df91fa0fe 100644 --- a/test/root/crictl_test.go +++ b/test/root/crictl_test.go @@ -16,196 +16,362 @@ package root import ( "bytes" + "encoding/json" "fmt" "io" "io/ioutil" - "log" "net/http" "os" "os/exec" "path" - "path/filepath" + "regexp" + "strconv" "strings" + "sync" "testing" "time" - "gvisor.dev/gvisor/runsc/criutil" - "gvisor.dev/gvisor/runsc/dockerutil" - "gvisor.dev/gvisor/runsc/specutils" - "gvisor.dev/gvisor/runsc/testutil" - "gvisor.dev/gvisor/test/root/testdata" + "gvisor.dev/gvisor/pkg/cleanup" + "gvisor.dev/gvisor/pkg/test/criutil" + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/pkg/test/testutil" ) // Tests for crictl have to be run as root (rather than in a user namespace) // because crictl creates named network namespaces in /var/run/netns/. -// TestCrictlSanity refers to b/112433158. -func TestCrictlSanity(t *testing.T) { - // Setup containerd and crictl. - crictl, cleanup, err := setup(t) - if err != nil { - t.Fatalf("failed to setup crictl: %v", err) +// Sandbox returns a JSON config for a simple sandbox. Sandbox names must be +// unique so different names should be used when running tests on the same +// containerd instance. +func Sandbox(name string) string { + // Sandbox is a default JSON config for a sandbox. + s := map[string]interface{}{ + "metadata": map[string]string{ + "name": name, + "namespace": "default", + "uid": testutil.RandomID(""), + }, + "linux": map[string]string{}, + "log_directory": "/tmp", } - defer cleanup() - podID, contID, err := crictl.StartPodAndContainer("httpd", testdata.Sandbox, testdata.Httpd) + + v, err := json.Marshal(s) if err != nil { - t.Fatal(err) + // This shouldn't happen. + panic(err) } + return string(v) +} - // Look for the httpd page. - if err = httpGet(crictl, podID, "index.html"); err != nil { - t.Fatalf("failed to get page: %v", err) +// SimpleSpec returns a JSON config for a simple container that runs the +// specified command in the specified image. +func SimpleSpec(name, image string, cmd []string, extra map[string]interface{}) string { + s := map[string]interface{}{ + "metadata": map[string]string{ + "name": name, + }, + "image": map[string]string{ + "image": testutil.ImageByName(image), + }, + // Log files are not deleted after root tests are run. Log to random + // paths to ensure logs are fresh. + "log_path": fmt.Sprintf("%s.log", testutil.RandomID(name)), + "stdin": false, + "tty": false, + } + if len(cmd) > 0 { // Omit if empty. + s["command"] = cmd + } + for k, v := range extra { + s[k] = v // Extra settings. + } + v, err := json.Marshal(s) + if err != nil { + // This shouldn't happen. + panic(err) } + return string(v) +} + +// Httpd is a JSON config for an httpd container. +var Httpd = SimpleSpec("httpd", "basic/httpd", nil, nil) + +// TestCrictlSanity refers to b/112433158. +func TestCrictlSanity(t *testing.T) { + for _, version := range allVersions { + t.Run(version, func(t *testing.T) { + // Setup containerd and crictl. + crictl, cleanup, err := setup(t, version) + if err != nil { + t.Fatalf("failed to setup crictl: %v", err) + } + defer cleanup() + podID, contID, err := crictl.StartPodAndContainer(containerdRuntime, "basic/httpd", Sandbox("default"), Httpd) + if err != nil { + t.Fatalf("start failed: %v", err) + } - // Stop everything. - if err := crictl.StopPodAndContainer(podID, contID); err != nil { - t.Fatal(err) + // Look for the httpd page. + if err = httpGet(crictl, podID, "index.html"); err != nil { + t.Fatalf("failed to get page: %v", err) + } + + // Stop everything. + if err := crictl.StopPodAndContainer(podID, contID); err != nil { + t.Fatalf("stop failed: %v", err) + } + }) } } +// HttpdMountPaths is a JSON config for an httpd container with additional +// mounts. +var HttpdMountPaths = SimpleSpec("httpd", "basic/httpd", nil, map[string]interface{}{ + "mounts": []map[string]interface{}{ + map[string]interface{}{ + "container_path": "/var/run/secrets/kubernetes.io/serviceaccount", + "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064/volumes/kubernetes.io~secret/default-token-2rpfx", + "readonly": true, + }, + map[string]interface{}{ + "container_path": "/etc/hosts", + "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064/etc-hosts", + "readonly": false, + }, + map[string]interface{}{ + "container_path": "/dev/termination-log", + "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064/containers/httpd/d1709580", + "readonly": false, + }, + map[string]interface{}{ + "container_path": "/usr/local/apache2/htdocs/test", + "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064", + "readonly": true, + }, + }, + "linux": map[string]interface{}{}, +}) + // TestMountPaths refers to b/117635704. func TestMountPaths(t *testing.T) { - // Setup containerd and crictl. - crictl, cleanup, err := setup(t) - if err != nil { - t.Fatalf("failed to setup crictl: %v", err) - } - defer cleanup() - podID, contID, err := crictl.StartPodAndContainer("httpd", testdata.Sandbox, testdata.HttpdMountPaths) - if err != nil { - t.Fatal(err) - } + for _, version := range allVersions { + t.Run(version, func(t *testing.T) { + // Setup containerd and crictl. + crictl, cleanup, err := setup(t, version) + if err != nil { + t.Fatalf("failed to setup crictl: %v", err) + } + defer cleanup() + podID, contID, err := crictl.StartPodAndContainer(containerdRuntime, "basic/httpd", Sandbox("default"), HttpdMountPaths) + if err != nil { + t.Fatalf("start failed: %v", err) + } - // Look for the directory available at /test. - if err = httpGet(crictl, podID, "test"); err != nil { - t.Fatalf("failed to get page: %v", err) - } + // Look for the directory available at /test. + if err = httpGet(crictl, podID, "test"); err != nil { + t.Fatalf("failed to get page: %v", err) + } - // Stop everything. - if err := crictl.StopPodAndContainer(podID, contID); err != nil { - t.Fatal(err) + // Stop everything. + if err := crictl.StopPodAndContainer(podID, contID); err != nil { + t.Fatalf("stop failed: %v", err) + } + }) } } // TestMountPaths refers to b/118728671. func TestMountOverSymlinks(t *testing.T) { - // Setup containerd and crictl. - crictl, cleanup, err := setup(t) - if err != nil { - t.Fatalf("failed to setup crictl: %v", err) - } - defer cleanup() - podID, contID, err := crictl.StartPodAndContainer("k8s.gcr.io/busybox", testdata.Sandbox, testdata.MountOverSymlink) - if err != nil { - t.Fatal(err) - } + for _, version := range allVersions { + t.Run(version, func(t *testing.T) { + // Setup containerd and crictl. + crictl, cleanup, err := setup(t, version) + if err != nil { + t.Fatalf("failed to setup crictl: %v", err) + } + defer cleanup() - out, err := crictl.Exec(contID, "readlink", "/etc/resolv.conf") - if err != nil { - t.Fatal(err) - } - if want := "/tmp/resolv.conf"; !strings.Contains(string(out), want) { - t.Fatalf("/etc/resolv.conf is not pointing to %q: %q", want, string(out)) - } + spec := SimpleSpec("busybox", "basic/resolv", []string{"sleep", "1000"}, nil) + podID, contID, err := crictl.StartPodAndContainer(containerdRuntime, "basic/resolv", Sandbox("default"), spec) + if err != nil { + t.Fatalf("start failed: %v", err) + } - etc, err := crictl.Exec(contID, "cat", "/etc/resolv.conf") - if err != nil { - t.Fatal(err) - } - tmp, err := crictl.Exec(contID, "cat", "/tmp/resolv.conf") - if err != nil { - t.Fatal(err) - } - if tmp != etc { - t.Fatalf("file content doesn't match:\n\t/etc/resolv.conf: %s\n\t/tmp/resolv.conf: %s", string(etc), string(tmp)) - } + out, err := crictl.Exec(contID, "readlink", "/etc/resolv.conf") + if err != nil { + t.Fatalf("readlink failed: %v, out: %s", err, out) + } + if want := "/tmp/resolv.conf"; !strings.Contains(string(out), want) { + t.Fatalf("/etc/resolv.conf is not pointing to %q: %q", want, string(out)) + } + + etc, err := crictl.Exec(contID, "cat", "/etc/resolv.conf") + if err != nil { + t.Fatalf("cat failed: %v, out: %s", err, etc) + } + tmp, err := crictl.Exec(contID, "cat", "/tmp/resolv.conf") + if err != nil { + t.Fatalf("cat failed: %v, out: %s", err, out) + } + if tmp != etc { + t.Fatalf("file content doesn't match:\n\t/etc/resolv.conf: %s\n\t/tmp/resolv.conf: %s", string(etc), string(tmp)) + } - // Stop everything. - if err := crictl.StopPodAndContainer(podID, contID); err != nil { - t.Fatal(err) + // Stop everything. + if err := crictl.StopPodAndContainer(podID, contID); err != nil { + t.Fatalf("stop failed: %v", err) + } + }) } } // TestHomeDir tests that the HOME environment variable is set for -// multi-containers. +// Pod containers. func TestHomeDir(t *testing.T) { - // Setup containerd and crictl. - crictl, cleanup, err := setup(t) - if err != nil { - t.Fatalf("failed to setup crictl: %v", err) - } - defer cleanup() - contSpec := testdata.SimpleSpec("root", "k8s.gcr.io/busybox", []string{"sleep", "1000"}) - podID, contID, err := crictl.StartPodAndContainer("k8s.gcr.io/busybox", testdata.Sandbox, contSpec) - if err != nil { - t.Fatal(err) - } + for _, version := range allVersions { + t.Run(version, func(t *testing.T) { + // Setup containerd and crictl. + crictl, cleanup, err := setup(t, version) + if err != nil { + t.Fatalf("failed to setup crictl: %v", err) + } + defer cleanup() - t.Run("root container", func(t *testing.T) { - out, err := crictl.Exec(contID, "sh", "-c", "echo $HOME") - if err != nil { - t.Fatal(err) - } - if got, want := strings.TrimSpace(string(out)), "/root"; got != want { - t.Fatalf("Home directory invalid. Got %q, Want : %q", got, want) - } - }) + // Note that container ID returned here is a sub-container. All Pod + // containers are sub-containers. The root container of the sandbox is the + // pause container. + t.Run("sub-container", func(t *testing.T) { + contSpec := SimpleSpec("subcontainer", "basic/busybox", []string{"sh", "-c", "echo $HOME"}, nil) + podID, contID, err := crictl.StartPodAndContainer(containerdRuntime, "basic/busybox", Sandbox("subcont-sandbox"), contSpec) + if err != nil { + t.Fatalf("start failed: %v", err) + } - t.Run("sub-container", func(t *testing.T) { - // Create a sub container in the same pod. - subContSpec := testdata.SimpleSpec("subcontainer", "k8s.gcr.io/busybox", []string{"sleep", "1000"}) - subContID, err := crictl.StartContainer(podID, "k8s.gcr.io/busybox", testdata.Sandbox, subContSpec) - if err != nil { - t.Fatal(err) - } + out, err := crictl.Logs(contID) + if err != nil { + t.Fatalf("failed retrieving container logs: %v, out: %s", err, out) + } + if got, want := strings.TrimSpace(string(out)), "/root"; got != want { + t.Fatalf("Home directory invalid. Got %q, Want : %q", got, want) + } - out, err := crictl.Exec(subContID, "sh", "-c", "echo $HOME") - if err != nil { - t.Fatal(err) - } - if got, want := strings.TrimSpace(string(out)), "/root"; got != want { - t.Fatalf("Home directory invalid. Got %q, Want: %q", got, want) - } + // Stop everything; note that the pod may have already stopped. + crictl.StopPodAndContainer(podID, contID) + }) - if err := crictl.StopContainer(subContID); err != nil { - t.Fatal(err) - } - }) + // Tests that HOME is set for the exec process. + t.Run("exec", func(t *testing.T) { + contSpec := SimpleSpec("exec", "basic/busybox", []string{"sleep", "1000"}, nil) + podID, contID, err := crictl.StartPodAndContainer(containerdRuntime, "basic/busybox", Sandbox("exec-sandbox"), contSpec) + if err != nil { + t.Fatalf("start failed: %v", err) + } - // Stop everything. - if err := crictl.StopPodAndContainer(podID, contID); err != nil { - t.Fatal(err) - } + out, err := crictl.Exec(contID, "sh", "-c", "echo $HOME") + if err != nil { + t.Fatalf("failed retrieving container logs: %v, out: %s", err, out) + } + if got, want := strings.TrimSpace(string(out)), "/root"; got != want { + t.Fatalf("Home directory invalid. Got %q, Want : %q", got, want) + } + // Stop everything. + if err := crictl.StopPodAndContainer(podID, contID); err != nil { + t.Fatalf("stop failed: %v", err) + } + }) + }) + } } +const containerdRuntime = "runsc" + +const v1Template = ` +disabled_plugins = ["restart"] +[plugins.cri] + disable_tcp_service = true +[plugins.linux] + shim = "%s" + shim_debug = true +[plugins.cri.containerd.runtimes.` + containerdRuntime + `] + runtime_type = "io.containerd.runtime.v1.linux" + runtime_engine = "%s" + runtime_root = "%s/root/runsc" +` + +const v2Template = ` +disabled_plugins = ["restart"] +[plugins.cri] + disable_tcp_service = true +[plugins.linux] + shim_debug = true +[plugins.cri.containerd.runtimes.` + containerdRuntime + `] + runtime_type = "io.containerd.` + containerdRuntime + `.v1" +[plugins.cri.containerd.runtimes.` + containerdRuntime + `.options] + TypeUrl = "io.containerd.` + containerdRuntime + `.v1.options" +` + +const ( + // v1 is the containerd API v1. + v1 string = "v1" + + // v1 is the containerd API v21. + v2 string = "v2" +) + +// allVersions is the set of known versions. +var allVersions = []string{v1, v2} + // setup sets up before a test. Specifically it: // * Creates directories and a socket for containerd to utilize. // * Runs containerd and waits for it to reach a "ready" state for testing. // * Returns a cleanup function that should be called at the end of the test. -func setup(t *testing.T) (*criutil.Crictl, func(), error) { - var cleanups []func() - cleanupFunc := func() { - for i := len(cleanups) - 1; i >= 0; i-- { - cleanups[i]() - } - } - cleanup := specutils.MakeCleanup(cleanupFunc) - defer cleanup.Clean() - +func setup(t *testing.T, version string) (*criutil.Crictl, func(), error) { // Create temporary containerd root and state directories, and a socket // via which crictl and containerd communicate. containerdRoot, err := ioutil.TempDir(testutil.TmpDir(), "containerd-root") if err != nil { t.Fatalf("failed to create containerd root: %v", err) } - cleanups = append(cleanups, func() { os.RemoveAll(containerdRoot) }) + cu := cleanup.Make(func() { os.RemoveAll(containerdRoot) }) + defer cu.Clean() + t.Logf("Using containerd root: %s", containerdRoot) + containerdState, err := ioutil.TempDir(testutil.TmpDir(), "containerd-state") if err != nil { t.Fatalf("failed to create containerd state: %v", err) } - cleanups = append(cleanups, func() { os.RemoveAll(containerdState) }) - sockAddr := filepath.Join(testutil.TmpDir(), "containerd-test.sock") + cu.Add(func() { os.RemoveAll(containerdState) }) + t.Logf("Using containerd state: %s", containerdState) + + sockDir, err := ioutil.TempDir(testutil.TmpDir(), "containerd-sock") + if err != nil { + t.Fatalf("failed to create containerd socket directory: %v", err) + } + cu.Add(func() { os.RemoveAll(sockDir) }) + sockAddr := path.Join(sockDir, "test.sock") + t.Logf("Using containerd socket: %s", sockAddr) + + // Extract the containerd version. + versionCmd := exec.Command(getContainerd(), "-v") + out, err := versionCmd.CombinedOutput() + if err != nil { + t.Fatalf("error extracting containerd version: %v (%s)", err, string(out)) + } + r := regexp.MustCompile(" v([0-9]+)\\.([0-9]+)\\.([0-9+])") + vs := r.FindStringSubmatch(string(out)) + if len(vs) != 4 { + t.Fatalf("error unexpected version string: %s", string(out)) + } + major, err := strconv.ParseUint(vs[1], 10, 64) + if err != nil { + t.Fatalf("error parsing containerd major version: %v (%s)", err, string(out)) + } + minor, err := strconv.ParseUint(vs[2], 10, 64) + if err != nil { + t.Fatalf("error parsing containerd minor version: %v (%s)", err, string(out)) + } + t.Logf("Using containerd version: %d.%d", major, minor) // We rewrite a configuration. This is based on the current docker // configuration for the runtime under test. @@ -213,50 +379,125 @@ func setup(t *testing.T) (*criutil.Crictl, func(), error) { if err != nil { t.Fatalf("error discovering runtime path: %v", err) } - config, err := testutil.WriteTmpFile("containerd-config", testdata.ContainerdConfig(runtime)) + t.Logf("Using runtime: %v", runtime) + + // Construct a PATH that includes the runtime directory. This is + // because the shims will be installed there, and containerd may infer + // the binary name and search the PATH. + runtimeDir := path.Dir(runtime) + modifiedPath := os.Getenv("PATH") + if modifiedPath != "" { + modifiedPath = ":" + modifiedPath // We prepend below. + } + modifiedPath = path.Dir(getContainerd()) + modifiedPath + modifiedPath = runtimeDir + ":" + modifiedPath + t.Logf("Using PATH: %v", modifiedPath) + + var ( + config string + runpArgs []string + ) + switch version { + case v1: + // This is only supported less than 1.3. + if major > 1 || (major == 1 && minor >= 3) { + t.Skipf("skipping unsupported containerd (want less than 1.3, got %d.%d)", major, minor) + } + + // We provide the shim, followed by the runtime, and then a + // temporary root directory. + config = fmt.Sprintf(v1Template, criutil.ResolvePath("gvisor-containerd-shim"), runtime, containerdRoot) + case v2: + // This is only supported past 1.2. + if major < 1 || (major == 1 && minor <= 1) { + t.Skipf("skipping incompatible containerd (want at least 1.2, got %d.%d)", major, minor) + } + + // The runtime is provided via parameter. Note that the v2 shim + // binary name is always containerd-shim-* so we don't actually + // care about the docker runtime name. + config = v2Template + default: + t.Fatalf("unknown version: %d", version) + } + t.Logf("Using config: %s", config) + + // Generate the configuration for the test. + configFile, configCleanup, err := testutil.WriteTmpFile("containerd-config", config) if err != nil { t.Fatalf("failed to write containerd config") } - cleanups = append(cleanups, func() { os.RemoveAll(config) }) + cu.Add(configCleanup) // Start containerd. - containerd := exec.Command(getContainerd(), - "--config", config, + args := []string{ + getContainerd(), + "--config", configFile, "--log-level", "debug", "--root", containerdRoot, "--state", containerdState, - "--address", sockAddr) - cleanups = append(cleanups, func() { - if err := testutil.KillCommand(containerd); err != nil { - log.Printf("error killing containerd: %v", err) - } - }) - containerdStderr, err := containerd.StderrPipe() - if err != nil { - t.Fatalf("failed to get containerd stderr: %v", err) + "--address", sockAddr, } - containerdStdout, err := containerd.StdoutPipe() + t.Logf("Using args: %s", strings.Join(args, " ")) + cmd := exec.Command(args[0], args[1:]...) + cmd.Env = append(os.Environ(), "PATH="+modifiedPath) + + // Include output in logs. + stderrPipe, err := cmd.StderrPipe() if err != nil { - t.Fatalf("failed to get containerd stdout: %v", err) + t.Fatalf("failed to create stderr pipe: %v", err) } - if err := containerd.Start(); err != nil { + cu.Add(func() { stderrPipe.Close() }) + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + t.Fatalf("failed to create stdout pipe: %v", err) + } + cu.Add(func() { stdoutPipe.Close() }) + var ( + wg sync.WaitGroup + stderr bytes.Buffer + stdout bytes.Buffer + ) + startupR, startupW := io.Pipe() + wg.Add(2) + go func() { + defer wg.Done() + io.Copy(io.MultiWriter(startupW, &stderr), stderrPipe) + }() + go func() { + defer wg.Done() + io.Copy(io.MultiWriter(startupW, &stdout), stdoutPipe) + }() + cu.Add(func() { + wg.Wait() + t.Logf("containerd stdout: %s", stdout.String()) + t.Logf("containerd stderr: %s", stderr.String()) + }) + + // Start the process. + if err := cmd.Start(); err != nil { t.Fatalf("failed running containerd: %v", err) } - // Wait for containerd to boot. Then put all containerd output into a - // buffer to be logged at the end of the test. - testutil.WaitUntilRead(containerdStderr, "Start streaming server", nil, 10*time.Second) - stdoutBuf := &bytes.Buffer{} - stderrBuf := &bytes.Buffer{} - go func() { io.Copy(stdoutBuf, containerdStdout) }() - go func() { io.Copy(stderrBuf, containerdStderr) }() - cleanups = append(cleanups, func() { - t.Logf("containerd stdout: %s", string(stdoutBuf.Bytes())) - t.Logf("containerd stderr: %s", string(stderrBuf.Bytes())) + // Wait for containerd to boot. + if err := testutil.WaitUntilRead(startupR, "Start streaming server", nil, 10*time.Second); err != nil { + t.Fatalf("failed to start containerd: %v", err) + } + + // Discard all subsequent data. + go io.Copy(ioutil.Discard, startupR) + + // Create the crictl interface. + cc := criutil.NewCrictl(t, sockAddr, runpArgs) + cu.Add(cc.CleanUp) + + // Kill must be the last cleanup (as it will be executed first). + cu.Add(func() { + // Best effort: ignore errors. + testutil.KillCommand(cmd) }) - cleanup.Release() - return criutil.NewCrictl(20*time.Second, sockAddr), cleanupFunc, nil + return cc, cu.Release(), nil } // httpGet GETs the contents of a file served from a pod on port 80. diff --git a/test/root/main_test.go b/test/root/main_test.go index d74dec85f..9fb17e0dd 100644 --- a/test/root/main_test.go +++ b/test/root/main_test.go @@ -21,7 +21,7 @@ import ( "testing" "github.com/syndtr/gocapability/capability" - "gvisor.dev/gvisor/runsc/dockerutil" + "gvisor.dev/gvisor/pkg/test/dockerutil" "gvisor.dev/gvisor/runsc/specutils" ) diff --git a/test/root/oom_score_adj_test.go b/test/root/oom_score_adj_test.go index 126f0975a..4243eb59e 100644 --- a/test/root/oom_score_adj_test.go +++ b/test/root/oom_score_adj_test.go @@ -20,10 +20,10 @@ import ( "testing" specs "github.com/opencontainers/runtime-spec/specs-go" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/pkg/cleanup" + "gvisor.dev/gvisor/pkg/test/testutil" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/specutils" - "gvisor.dev/gvisor/runsc/testutil" ) var ( @@ -40,15 +40,6 @@ var ( // TestOOMScoreAdjSingle tests that oom_score_adj is set properly in a // single container sandbox. func TestOOMScoreAdjSingle(t *testing.T) { - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - - conf := testutil.TestConfig() - conf.RootDir = rootDir - ppid, err := specutils.GetParentPid(os.Getpid()) if err != nil { t.Fatalf("getting parent pid: %v", err) @@ -89,11 +80,11 @@ func TestOOMScoreAdjSingle(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.Name, func(t *testing.T) { - id := testutil.UniqueContainerID() + id := testutil.RandomContainerID() s := testutil.NewSpecWithArgs("sleep", "1000") s.Process.OOMScoreAdj = testCase.OOMScoreAdj - containers, cleanup, err := startContainers(conf, []*specs.Spec{s}, []string{id}) + containers, cleanup, err := startContainers(t, []*specs.Spec{s}, []string{id}) if err != nil { t.Fatalf("error starting containers: %v", err) } @@ -131,15 +122,6 @@ func TestOOMScoreAdjSingle(t *testing.T) { // TestOOMScoreAdjMulti tests that oom_score_adj is set properly in a // multi-container sandbox. func TestOOMScoreAdjMulti(t *testing.T) { - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - - conf := testutil.TestConfig() - conf.RootDir = rootDir - ppid, err := specutils.GetParentPid(os.Getpid()) if err != nil { t.Fatalf("getting parent pid: %v", err) @@ -257,7 +239,7 @@ func TestOOMScoreAdjMulti(t *testing.T) { } } - containers, cleanup, err := startContainers(conf, specs, ids) + containers, cleanup, err := startContainers(t, specs, ids) if err != nil { t.Fatalf("error starting containers: %v", err) } @@ -321,7 +303,7 @@ func TestOOMScoreAdjMulti(t *testing.T) { func createSpecs(cmds ...[]string) ([]*specs.Spec, []string) { var specs []*specs.Spec var ids []string - rootID := testutil.UniqueContainerID() + rootID := testutil.RandomContainerID() for i, cmd := range cmds { spec := testutil.NewSpecWithArgs(cmd...) @@ -335,35 +317,34 @@ func createSpecs(cmds ...[]string) ([]*specs.Spec, []string) { specutils.ContainerdContainerTypeAnnotation: specutils.ContainerdContainerTypeContainer, specutils.ContainerdSandboxIDAnnotation: rootID, } - ids = append(ids, testutil.UniqueContainerID()) + ids = append(ids, testutil.RandomContainerID()) } specs = append(specs, spec) } return specs, ids } -func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*container.Container, func(), error) { - if len(conf.RootDir) == 0 { - panic("conf.RootDir not set. Call testutil.SetupRootDir() to set.") - } - +func startContainers(t *testing.T, specs []*specs.Spec, ids []string) ([]*container.Container, func(), error) { var containers []*container.Container - var bundles []string - cleanup := func() { - for _, c := range containers { - c.Destroy() - } - for _, b := range bundles { - os.RemoveAll(b) - } + + // All containers must share the same root. + rootDir, clean, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) } + cu := cleanup.Make(clean) + defer cu.Clean() + + // Point this to from the configuration. + conf := testutil.TestConfig(t) + conf.RootDir = rootDir + for i, spec := range specs { - bundleDir, err := testutil.SetupBundleDir(spec) + bundleDir, clean, err := testutil.SetupBundleDir(spec) if err != nil { - cleanup() - return nil, nil, fmt.Errorf("error setting up container: %v", err) + return nil, nil, fmt.Errorf("error setting up bundle: %v", err) } - bundles = append(bundles, bundleDir) + cu.Add(clean) args := container.Args{ ID: ids[i], @@ -372,15 +353,14 @@ func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*c } cont, err := container.New(conf, args) if err != nil { - cleanup() return nil, nil, fmt.Errorf("error creating container: %v", err) } containers = append(containers, cont) if err := cont.Start(conf); err != nil { - cleanup() return nil, nil, fmt.Errorf("error starting container: %v", err) } } - return containers, cleanup, nil + + return containers, cu.Release(), nil } diff --git a/test/root/runsc_test.go b/test/root/runsc_test.go new file mode 100644 index 000000000..25204bebb --- /dev/null +++ b/test/root/runsc_test.go @@ -0,0 +1,151 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package root + +import ( + "bytes" + "fmt" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + "testing" + "time" + + "github.com/cenkalti/backoff" + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/test/testutil" + "gvisor.dev/gvisor/runsc/specutils" +) + +// TestDoKill checks that when "runsc do..." is killed, the sandbox process is +// also terminated. This ensures that parent death signal is propagate to the +// sandbox process correctly. +func TestDoKill(t *testing.T) { + // Make the sandbox process be reparented here when it's killed, so we can + // wait for it. + if err := unix.Prctl(unix.PR_SET_CHILD_SUBREAPER, 1, 0, 0, 0); err != nil { + t.Fatalf("prctl(PR_SET_CHILD_SUBREAPER): %v", err) + } + + cmd := exec.Command(specutils.ExePath, "do", "sleep", "10000") + buf := &bytes.Buffer{} + cmd.Stdout = buf + cmd.Stderr = buf + cmd.Start() + + var pid int + findSandbox := func() error { + var err error + pid, err = sandboxPid(cmd.Process.Pid) + if err != nil { + return &backoff.PermanentError{Err: err} + } + if pid == 0 { + return fmt.Errorf("sandbox process not found") + } + return nil + } + if err := testutil.Poll(findSandbox, 10*time.Second); err != nil { + t.Fatalf("failed to find sandbox: %v", err) + } + t.Logf("Found sandbox, pid: %d", pid) + + if err := cmd.Process.Kill(); err != nil { + t.Fatalf("failed to kill run process: %v", err) + } + cmd.Wait() + t.Logf("Parent process killed (%d). Output: %s", cmd.Process.Pid, buf.String()) + + ch := make(chan struct{}) + go func() { + defer func() { ch <- struct{}{} }() + t.Logf("Waiting for sandbox process (%d) termination", pid) + if _, err := unix.Wait4(pid, nil, 0, nil); err != nil { + t.Errorf("error waiting for sandbox process (%d): %v", pid, err) + } + }() + select { + case <-ch: + // Done + case <-time.After(5 * time.Second): + t.Fatalf("timeout waiting for sandbox process (%d) to exit", pid) + } +} + +// sandboxPid looks for the sandbox process inside the process tree starting +// from "pid". It returns 0 and no error if no sandbox process is found. It +// returns error if anything failed. +func sandboxPid(pid int) (int, error) { + cmd := exec.Command("pgrep", "-P", strconv.Itoa(pid)) + buf := &bytes.Buffer{} + cmd.Stdout = buf + if err := cmd.Start(); err != nil { + return 0, err + } + ps, err := cmd.Process.Wait() + if err != nil { + return 0, err + } + if ps.ExitCode() == 1 { + // pgrep returns 1 when no process is found. + return 0, nil + } + + var children []int + for _, line := range strings.Split(buf.String(), "\n") { + if len(line) == 0 { + continue + } + child, err := strconv.Atoi(line) + if err != nil { + return 0, err + } + + cmdline, err := ioutil.ReadFile(filepath.Join("/proc", line, "cmdline")) + if err != nil { + if os.IsNotExist(err) { + // Raced with process exit. + continue + } + return 0, err + } + args := strings.SplitN(string(cmdline), "\x00", 2) + if len(args) == 0 { + return 0, fmt.Errorf("malformed cmdline file: %q", cmdline) + } + // The sandbox process has the first argument set to "runsc-sandbox". + if args[0] == "runsc-sandbox" { + return child, nil + } + + children = append(children, child) + } + + // Sandbox process wasn't found, try another level down. + for _, pid := range children { + sand, err := sandboxPid(pid) + if err != nil { + return 0, err + } + if sand != 0 { + return sand, nil + } + // Not found, continue the search. + } + return 0, nil +} diff --git a/test/root/testdata/BUILD b/test/root/testdata/BUILD deleted file mode 100644 index 125633680..000000000 --- a/test/root/testdata/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "testdata", - srcs = [ - "busybox.go", - "containerd_config.go", - "httpd.go", - "httpd_mount_paths.go", - "sandbox.go", - "simple.go", - ], - importpath = "gvisor.dev/gvisor/test/root/testdata", - visibility = [ - "//visibility:public", - ], -) diff --git a/test/root/testdata/containerd_config.go b/test/root/testdata/containerd_config.go deleted file mode 100644 index e12f1ec88..000000000 --- a/test/root/testdata/containerd_config.go +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package testdata contains data required for root tests. -package testdata - -import "fmt" - -// containerdConfigTemplate is a .toml config for containerd. It contains a -// formatting verb so the runtime field can be set via fmt.Sprintf. -const containerdConfigTemplate = ` -disabled_plugins = ["restart"] -[plugins.linux] - runtime = "%s" - runtime_root = "/tmp/test-containerd/runsc" - shim = "/usr/local/bin/gvisor-containerd-shim" - shim_debug = true - -[plugins.cri.containerd.runtimes.runsc] - runtime_type = "io.containerd.runtime.v1.linux" - runtime_engine = "%s" -` - -// ContainerdConfig returns a containerd config file with the specified -// runtime. -func ContainerdConfig(runtime string) string { - return fmt.Sprintf(containerdConfigTemplate, runtime, runtime) -} diff --git a/test/root/testdata/httpd_mount_paths.go b/test/root/testdata/httpd_mount_paths.go deleted file mode 100644 index ac3f4446a..000000000 --- a/test/root/testdata/httpd_mount_paths.go +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testdata - -// HttpdMountPaths is a JSON config for an httpd container with additional -// mounts. -const HttpdMountPaths = ` -{ - "metadata": { - "name": "httpd" - }, - "image":{ - "image": "httpd" - }, - "mounts": [ - { - "container_path": "/var/run/secrets/kubernetes.io/serviceaccount", - "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064/volumes/kubernetes.io~secret/default-token-2rpfx", - "readonly": true - }, - { - "container_path": "/etc/hosts", - "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064/etc-hosts", - "readonly": false - }, - { - "container_path": "/dev/termination-log", - "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064/containers/httpd/d1709580", - "readonly": false - }, - { - "container_path": "/usr/local/apache2/htdocs/test", - "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064", - "readonly": true - } - ], - "linux": { - }, - "log_path": "httpd.log" -} -` diff --git a/test/runner/BUILD b/test/runner/BUILD new file mode 100644 index 000000000..582d2946d --- /dev/null +++ b/test/runner/BUILD @@ -0,0 +1,29 @@ +load("//tools:defs.bzl", "bzl_library", "go_binary") + +package(licenses = ["notice"]) + +go_binary( + name = "runner", + testonly = 1, + srcs = ["runner.go"], + data = [ + "//runsc", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/log", + "//pkg/test/testutil", + "//runsc/specutils", + "//test/runner/gtest", + "//test/uds", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + "@com_github_syndtr_gocapability//capability:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +bzl_library( + name = "defs_bzl", + srcs = ["defs.bzl"], + visibility = ["//visibility:private"], +) diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl new file mode 100644 index 000000000..2d64934b0 --- /dev/null +++ b/test/runner/defs.bzl @@ -0,0 +1,249 @@ +"""Defines a rule for syscall test targets.""" + +load("//tools:defs.bzl", "default_platform", "loopback", "platforms") + +def _runner_test_impl(ctx): + # Generate a runner binary. + runner = ctx.actions.declare_file("%s-runner" % ctx.label.name) + runner_content = "\n".join([ + "#!/bin/bash", + "set -euf -x -o pipefail", + "if [[ -n \"${TEST_UNDECLARED_OUTPUTS_DIR}\" ]]; then", + " mkdir -p \"${TEST_UNDECLARED_OUTPUTS_DIR}\"", + " chmod a+rwx \"${TEST_UNDECLARED_OUTPUTS_DIR}\"", + "fi", + "exec %s %s %s\n" % ( + ctx.files.runner[0].short_path, + " ".join(ctx.attr.runner_args), + ctx.files.test[0].short_path, + ), + ]) + ctx.actions.write(runner, runner_content, is_executable = True) + + # Return with all transitive files. + runfiles = ctx.runfiles( + transitive_files = depset(transitive = [ + target.data_runfiles.files + for target in (ctx.attr.runner, ctx.attr.test) + if hasattr(target, "data_runfiles") + ]), + files = ctx.files.runner + ctx.files.test, + collect_default = True, + collect_data = True, + ) + return [DefaultInfo(executable = runner, runfiles = runfiles)] + +_runner_test = rule( + attrs = { + "runner": attr.label( + default = "//test/runner:runner", + ), + "test": attr.label( + mandatory = True, + ), + "runner_args": attr.string_list(), + "data": attr.label_list( + allow_files = True, + ), + }, + test = True, + implementation = _runner_test_impl, +) + +def _syscall_test( + test, + shard_count, + size, + platform, + use_tmpfs, + tags, + network = "none", + file_access = "exclusive", + overlay = False, + add_uds_tree = False, + vfs2 = False, + fuse = False): + # Prepend "runsc" to non-native platform names. + full_platform = platform if platform == "native" else "runsc_" + platform + + # Name the test appropriately. + name = test.split(":")[1] + "_" + full_platform + if file_access == "shared": + name += "_shared" + if overlay: + name += "_overlay" + if vfs2: + name += "_vfs2" + if fuse: + name += "_fuse" + if network != "none": + name += "_" + network + "net" + + # Apply all tags. + if tags == None: + tags = [] + + # Add the full_platform and file access in a tag to make it easier to run + # all the tests on a specific flavor. Use --test_tag_filters=ptrace,file_shared. + tags += [full_platform, "file_" + file_access] + + # Hash this target into one of 15 buckets. This can be used to + # randomly split targets between different workflows. + hash15 = hash(native.package_name() + name) % 15 + tags.append("hash15:" + str(hash15)) + + # TODO(b/139838000): Tests using hostinet must be disabled on Guitar until + # we figure out how to request ipv4 sockets on Guitar machines. + if network == "host": + tags.append("noguitar") + tags.append("block-network") + + # Disable off-host networking. + tags.append("requires-net:loopback") + + runner_args = [ + # Arguments are passed directly to runner binary. + "--platform=" + platform, + "--network=" + network, + "--use-tmpfs=" + str(use_tmpfs), + "--file-access=" + file_access, + "--overlay=" + str(overlay), + "--add-uds-tree=" + str(add_uds_tree), + "--vfs2=" + str(vfs2), + "--fuse=" + str(fuse), + ] + + # Call the rule above. + _runner_test( + name = name, + test = test, + runner_args = runner_args, + data = [loopback], + size = size, + tags = tags, + shard_count = shard_count, + ) + +def syscall_test( + test, + shard_count = 5, + size = "small", + use_tmpfs = False, + add_overlay = False, + add_uds_tree = False, + add_hostinet = False, + vfs2 = True, + fuse = False, + tags = None): + """syscall_test is a macro that will create targets for all platforms. + + Args: + test: the test target. + shard_count: shards for defined tests. + size: the defined test size. + use_tmpfs: use tmpfs in the defined tests. + add_overlay: add an overlay test. + add_uds_tree: add a UDS test. + add_hostinet: add a hostinet test. + tags: starting test tags. + """ + if not tags: + tags = [] + + vfs2_tags = list(tags) + if vfs2: + # Add tag to easily run VFS2 tests with --test_tag_filters=vfs2 + vfs2_tags.append("vfs2") + if fuse: + vfs2_tags.append("fuse") + + else: + # Don't automatically run tests tests not yet passing. + vfs2_tags.append("manual") + vfs2_tags.append("noguitar") + vfs2_tags.append("notap") + + _syscall_test( + test = test, + shard_count = shard_count, + size = size, + platform = default_platform, + use_tmpfs = use_tmpfs, + add_uds_tree = add_uds_tree, + tags = platforms[default_platform] + vfs2_tags, + vfs2 = True, + fuse = fuse, + ) + if fuse: + # Only generate *_vfs2_fuse target if fuse parameter is enabled. + return + + _syscall_test( + test = test, + shard_count = shard_count, + size = size, + platform = "native", + use_tmpfs = False, + add_uds_tree = add_uds_tree, + tags = list(tags), + ) + + for (platform, platform_tags) in platforms.items(): + _syscall_test( + test = test, + shard_count = shard_count, + size = size, + platform = platform, + use_tmpfs = use_tmpfs, + add_uds_tree = add_uds_tree, + tags = platform_tags + tags, + ) + + # TODO(gvisor.dev/issue/1487): Enable VFS2 overlay tests. + if add_overlay: + _syscall_test( + test = test, + shard_count = shard_count, + size = size, + platform = default_platform, + use_tmpfs = use_tmpfs, + add_uds_tree = add_uds_tree, + tags = platforms[default_platform] + tags, + overlay = True, + ) + + if add_hostinet: + _syscall_test( + test = test, + shard_count = shard_count, + size = size, + platform = default_platform, + use_tmpfs = use_tmpfs, + network = "host", + add_uds_tree = add_uds_tree, + tags = platforms[default_platform] + tags, + ) + + if not use_tmpfs: + # Also test shared gofer access. + _syscall_test( + test = test, + shard_count = shard_count, + size = size, + platform = default_platform, + use_tmpfs = use_tmpfs, + add_uds_tree = add_uds_tree, + tags = platforms[default_platform] + tags, + file_access = "shared", + ) + _syscall_test( + test = test, + shard_count = shard_count, + size = size, + platform = default_platform, + use_tmpfs = use_tmpfs, + add_uds_tree = add_uds_tree, + tags = platforms[default_platform] + vfs2_tags, + file_access = "shared", + vfs2 = True, + ) diff --git a/test/runner/gtest/BUILD b/test/runner/gtest/BUILD new file mode 100644 index 000000000..de4b2727c --- /dev/null +++ b/test/runner/gtest/BUILD @@ -0,0 +1,9 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "gtest", + srcs = ["gtest.go"], + visibility = ["//:sandbox"], +) diff --git a/test/runner/gtest/gtest.go b/test/runner/gtest/gtest.go new file mode 100644 index 000000000..e4445e01b --- /dev/null +++ b/test/runner/gtest/gtest.go @@ -0,0 +1,170 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package gtest contains helpers for running google-test tests from Go. +package gtest + +import ( + "fmt" + "os/exec" + "strings" +) + +var ( + // listTestFlag is the flag that will list tests in gtest binaries. + listTestFlag = "--gtest_list_tests" + + // filterTestFlag is the flag that will filter tests in gtest binaries. + filterTestFlag = "--gtest_filter" + + // listBechmarkFlag is the flag that will list benchmarks in gtest binaries. + listBenchmarkFlag = "--benchmark_list_tests" + + // filterBenchmarkFlag is the flag that will run specified benchmarks. + filterBenchmarkFlag = "--benchmark_filter" +) + +// TestCase is a single gtest test case. +type TestCase struct { + // Suite is the suite for this test. + Suite string + + // Name is the name of this individual test. + Name string + + // all indicates that this will run without flags. This takes + // precendence over benchmark below. + all bool + + // benchmark indicates that this is a benchmark. In this case, the + // suite will be empty, and we will use the appropriate test and + // benchmark flags. + benchmark bool +} + +// FullName returns the name of the test including the suite. It is suitable to +// pass to "-gtest_filter". +func (tc TestCase) FullName() string { + return fmt.Sprintf("%s.%s", tc.Suite, tc.Name) +} + +// Args returns arguments to be passed when invoking the test. +func (tc TestCase) Args() []string { + if tc.all { + return []string{} // No arguments. + } + if tc.benchmark { + return []string{ + fmt.Sprintf("%s=^%s$", filterBenchmarkFlag, tc.Name), + fmt.Sprintf("%s=", filterTestFlag), + } + } + return []string{ + fmt.Sprintf("%s=%s", filterTestFlag, tc.FullName()), + } +} + +// ParseTestCases calls a gtest test binary to list its test and returns a +// slice with the name and suite of each test. +// +// If benchmarks is true, then benchmarks will be included in the list of test +// cases provided. Note that this requires the binary to support the +// benchmarks_list_tests flag. +func ParseTestCases(testBin string, benchmarks bool, extraArgs ...string) ([]TestCase, error) { + // Run to extract test cases. + args := append([]string{listTestFlag}, extraArgs...) + cmd := exec.Command(testBin, args...) + out, err := cmd.Output() + if err != nil { + // We failed to list tests with the given flags. Just + // return something that will run the binary with no + // flags, which should execute all tests. + return []TestCase{ + TestCase{ + Suite: "Default", + Name: "All", + all: true, + }, + }, nil + } + + // Parse test output. + var t []TestCase + var suite string + for _, line := range strings.Split(string(out), "\n") { + // Strip comments. + line = strings.Split(line, "#")[0] + + // New suite? + if !strings.HasPrefix(line, " ") { + suite = strings.TrimSuffix(strings.TrimSpace(line), ".") + continue + } + + // Individual test. + name := strings.TrimSpace(line) + + // Do we have a suite yet? + if suite == "" { + return nil, fmt.Errorf("test without a suite: %v", name) + } + + // Add this individual test. + t = append(t, TestCase{ + Suite: suite, + Name: name, + }) + } + + // Finished? + if !benchmarks { + return t, nil + } + + // Run again to extract benchmarks. + args = append([]string{listBenchmarkFlag}, extraArgs...) + cmd = exec.Command(testBin, args...) + out, err = cmd.Output() + if err != nil { + // We were able to enumerate tests above, but not benchmarks? + // We requested them, so we return an error in this case. + exitErr, ok := err.(*exec.ExitError) + if !ok { + return nil, fmt.Errorf("could not enumerate gtest benchmarks: %v", err) + } + return nil, fmt.Errorf("could not enumerate gtest benchmarks: %v\nstderr\n%s", err, exitErr.Stderr) + } + + benches := strings.Trim(string(out), "\n") + if len(benches) == 0 { + return t, nil + } + + // Parse benchmark output. + for _, line := range strings.Split(benches, "\n") { + // Strip comments. + line = strings.Split(line, "#")[0] + + // Single benchmark. + name := strings.TrimSpace(line) + + // Add the single benchmark. + t = append(t, TestCase{ + Suite: "Benchmarks", + Name: name, + benchmark: true, + }) + } + return t, nil +} diff --git a/test/syscalls/syscall_test_runner.go b/test/runner/runner.go index 856398994..5ac91310d 100644 --- a/test/syscalls/syscall_test_runner.go +++ b/test/runner/runner.go @@ -30,25 +30,25 @@ import ( "time" specs "github.com/opencontainers/runtime-spec/specs-go" + "github.com/syndtr/gocapability/capability" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/test/testutil" "gvisor.dev/gvisor/runsc/specutils" - "gvisor.dev/gvisor/runsc/testutil" - "gvisor.dev/gvisor/test/syscalls/gtest" + "gvisor.dev/gvisor/test/runner/gtest" "gvisor.dev/gvisor/test/uds" ) -// Location of syscall tests, relative to the repo root. -const testDir = "test/syscalls/linux" - var ( - testName = flag.String("test-name", "", "name of test binary to run") debug = flag.Bool("debug", false, "enable debug logs") strace = flag.Bool("strace", false, "enable strace logs") platform = flag.String("platform", "ptrace", "platform to run on") + network = flag.String("network", "none", "network stack to run on (sandbox, host, none)") useTmpfs = flag.Bool("use-tmpfs", false, "mounts tmpfs for /tmp") fileAccess = flag.String("file-access", "exclusive", "mounts root in exclusive or shared mode") overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable tmpfs overlay") + vfs2 = flag.Bool("vfs2", false, "enable VFS2") + fuse = flag.Bool("fuse", false, "enable FUSE") parallel = flag.Bool("parallel", false, "run tests in parallel") runscPath = flag.String("runsc", "", "path to runsc binary") @@ -102,10 +102,17 @@ func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) { env = append(env, "TEST_UDS_ATTACH_TREE="+socketDir) } - cmd := exec.Command(testBin, gtest.FilterTestFlag+"="+tc.FullName()) + cmd := exec.Command(testBin, tc.Args()...) cmd.Env = env cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr + + if specutils.HasCapabilities(capability.CAP_NET_ADMIN) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + Cloneflags: syscall.CLONE_NEWNET, + } + } + if err := cmd.Run(); err != nil { ws := err.(*exec.ExitError).Sys().(syscall.WaitStatus) t.Errorf("test %q exited with status %d, want 0", tc.FullName(), ws.ExitStatus()) @@ -118,26 +125,26 @@ func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) { // // Returns an error if the sandboxed application exits non-zero. func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { - bundleDir, err := testutil.SetupBundleDir(spec) + bundleDir, cleanup, err := testutil.SetupBundleDir(spec) if err != nil { return fmt.Errorf("SetupBundleDir failed: %v", err) } - defer os.RemoveAll(bundleDir) + defer cleanup() - rootDir, err := testutil.SetupRootDir() + rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { return fmt.Errorf("SetupRootDir failed: %v", err) } - defer os.RemoveAll(rootDir) + defer cleanup() name := tc.FullName() - id := testutil.UniqueContainerID() + id := testutil.RandomContainerID() log.Infof("Running test %q in container %q", name, id) specutils.LogSpec(spec) args := []string{ "-root", rootDir, - "-network=none", + "-network", *network, "-log-format=text", "-TESTONLY-unsafe-nonroot=true", "-net-raw=true", @@ -149,6 +156,12 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { if *overlay { args = append(args, "-overlay") } + if *vfs2 { + args = append(args, "-vfs2") + if *fuse { + args = append(args, "-fuse") + } + } if *debug { args = append(args, "-debug", "-log-packets=true") } @@ -159,12 +172,14 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { args = append(args, "-fsgofer-host-uds") } - if outDir, ok := syscall.Getenv("TEST_UNDECLARED_OUTPUTS_DIR"); ok { - tdir := filepath.Join(outDir, strings.Replace(name, "/", "_", -1)) - if err := os.MkdirAll(tdir, 0755); err != nil { + testLogDir := "" + if undeclaredOutputsDir, ok := syscall.Getenv("TEST_UNDECLARED_OUTPUTS_DIR"); ok { + // Create log directory dedicated for this test. + testLogDir = filepath.Join(undeclaredOutputsDir, strings.Replace(name, "/", "_", -1)) + if err := os.MkdirAll(testLogDir, 0755); err != nil { return fmt.Errorf("could not create test dir: %v", err) } - debugLogDir, err := ioutil.TempDir(tdir, "runsc") + debugLogDir, err := ioutil.TempDir(testLogDir, "runsc") if err != nil { return fmt.Errorf("could not create temp dir: %v", err) } @@ -200,22 +215,25 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr sig := make(chan os.Signal, 1) + defer close(sig) signal.Notify(sig, syscall.SIGTERM) + defer signal.Stop(sig) go func() { s, ok := <-sig if !ok { return } log.Warningf("%s: Got signal: %v", name, s) - done := make(chan bool) - go func() { - dArgs := append(args, "-alsologtostderr=true", "debug", "--stacks", id) - cmd := exec.Command(*runscPath, dArgs...) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - cmd.Run() + done := make(chan bool, 1) + dArgs := append([]string{}, args...) + dArgs = append(dArgs, "-alsologtostderr=true", "debug", "--stacks", id) + go func(dArgs []string) { + debug := exec.Command(*runscPath, dArgs...) + debug.Stdout = os.Stdout + debug.Stderr = os.Stderr + debug.Run() done <- true - }() + }(dArgs) timeout := time.After(3 * time.Second) select { @@ -225,19 +243,21 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { } log.Warningf("Send SIGTERM to the sandbox process") - dArgs := append(args, "debug", + dArgs = append(args, "debug", fmt.Sprintf("--signal=%d", syscall.SIGTERM), id) - cmd = exec.Command(*runscPath, dArgs...) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - cmd.Run() + signal := exec.Command(*runscPath, dArgs...) + signal.Stdout = os.Stdout + signal.Stderr = os.Stderr + signal.Run() }() err = cmd.Run() - - signal.Stop(sig) - close(sig) + if err == nil && len(testLogDir) > 0 { + // If the test passed, then we erase the log directory. This speeds up + // uploading logs in continuous integration & saves on disk space. + os.RemoveAll(testLogDir) + } return err } @@ -294,7 +314,7 @@ func setupUDSTree(spec *specs.Spec) (cleanup func(), err error) { func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { // Run a new container with the test executable and filter for the // given test suite and name. - spec := testutil.NewSpecWithArgs(testBin, gtest.FilterTestFlag+"="+tc.FullName()) + spec := testutil.NewSpecWithArgs(append([]string{testBin}, tc.Args()...)...) // Mark the root as writeable, as some tests attempt to // write to the rootfs, and expect EACCES, not EROFS. @@ -302,6 +322,7 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { // Test spec comes with pre-defined mounts that we don't want. Reset it. spec.Mounts = nil + testTmpDir := "/tmp" if *useTmpfs { // Forces '/tmp' to be mounted as tmpfs, otherwise test that rely on // features only available in gVisor's internal tmpfs may fail. @@ -327,17 +348,38 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { t.Fatalf("could not chmod temp dir: %v", err) } - spec.Mounts = append(spec.Mounts, specs.Mount{ - Type: "bind", - Destination: "/tmp", - Source: tmpDir, - }) + // "/tmp" is not replaced with a tmpfs mount inside the sandbox + // when it's not empty. This ensures that testTmpDir uses gofer + // in exclusive mode. + testTmpDir = tmpDir + if *fileAccess == "shared" { + // All external mounts except the root mount are shared. + spec.Mounts = append(spec.Mounts, specs.Mount{ + Type: "bind", + Destination: "/tmp", + Source: tmpDir, + }) + testTmpDir = "/tmp" + } } - // Set environment variable that indicates we are - // running in gVisor and with the given platform. + // Set environment variables that indicate we are running in gVisor with + // the given platform, network, and filesystem stack. platformVar := "TEST_ON_GVISOR" - env := append(os.Environ(), platformVar+"="+*platform) + networkVar := "GVISOR_NETWORK" + env := append(os.Environ(), platformVar+"="+*platform, networkVar+"="+*network) + vfsVar := "GVISOR_VFS" + if *vfs2 { + env = append(env, vfsVar+"=VFS2") + fuseVar := "FUSE_ENABLED" + if *fuse { + env = append(env, fuseVar+"=TRUE") + } else { + env = append(env, fuseVar+"=FALSE") + } + } else { + env = append(env, vfsVar+"=VFS1") + } // Remove env variables that cause the gunit binary to write output // files, since they will stomp on eachother, and on the output files @@ -350,12 +392,8 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { // Set TEST_TMPDIR to /tmp, as some of the syscall tests require it to // be backed by tmpfs. - for i, kv := range env { - if strings.HasPrefix(kv, "TEST_TMPDIR=") { - env[i] = "TEST_TMPDIR=/tmp" - break - } - } + env = filterEnv(env, []string{"TEST_TMPDIR"}) + env = append(env, fmt.Sprintf("TEST_TMPDIR=%s", testTmpDir)) spec.Process.Env = env @@ -372,12 +410,12 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { } } -// filterEnv returns an environment with the blacklisted variables removed. -func filterEnv(env, blacklist []string) []string { +// filterEnv returns an environment with the excluded variables removed. +func filterEnv(env, exclude []string) []string { var out []string for _, kv := range env { ok := true - for _, k := range blacklist { + for _, k := range exclude { if strings.HasPrefix(kv, k+"=") { ok = false break @@ -401,9 +439,10 @@ func matchString(a, b string) (bool, error) { func main() { flag.Parse() - if *testName == "" { - fatalf("test-name flag must be provided") + if flag.NArg() != 1 { + fatalf("test must be provided") } + testBin := flag.Args()[0] // Only argument. log.SetLevel(log.Info) if *debug { @@ -433,34 +472,31 @@ func main() { } } - // Get path to test binary. - fullTestName := filepath.Join(testDir, *testName) - testBin, err := testutil.FindFile(fullTestName) - if err != nil { - fatalf("FindFile(%q) failed: %v", fullTestName, err) - } - // Get all test cases in each binary. - testCases, err := gtest.ParseTestCases(testBin) + testCases, err := gtest.ParseTestCases(testBin, true) if err != nil { fatalf("ParseTestCases(%q) failed: %v", testBin, err) } // Get subset of tests corresponding to shard. - begin, end, err := testutil.TestBoundsForShard(len(testCases)) + indices, err := testutil.TestIndicesForShard(len(testCases)) if err != nil { fatalf("TestsForShard() failed: %v", err) } - testCases = testCases[begin:end] + + // Resolve the absolute path for the binary. + testBin, err = filepath.Abs(testBin) + if err != nil { + fatalf("Abs() failed: %v", err) + } // Run the tests. var tests []testing.InternalTest - for _, tc := range testCases { + for _, tci := range indices { // Capture tc. - tc := tc - testName := fmt.Sprintf("%s_%s", tc.Suite, tc.Name) + tc := testCases[tci] tests = append(tests, testing.InternalTest{ - Name: testName, + Name: fmt.Sprintf("%s_%s", tc.Suite, tc.Name), F: func(t *testing.T) { if *parallel { t.Parallel() diff --git a/test/runtimes/BUILD b/test/runtimes/BUILD index 2e125525b..066338ee3 100644 --- a/test/runtimes/BUILD +++ b/test/runtimes/BUILD @@ -1,53 +1,46 @@ -# These packages are used to run language runtime tests inside gVisor sandboxes. - -load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test") -load("//test/runtimes:build_defs.bzl", "runtime_test") +load("//tools:defs.bzl", "bzl_library") +load("//test/runtimes:defs.bzl", "runtime_test") package(licenses = ["notice"]) -go_binary( - name = "runner", - testonly = 1, - srcs = ["runner.go"], - deps = [ - "//runsc/dockerutil", - "//runsc/testutil", - ], -) - runtime_test( - blacklist_file = "blacklist_go1.12.csv", - image = "gcr.io/gvisor-presubmit/go1.12", + name = "go1.12", + exclude_file = "exclude_go1.12.csv", lang = "go", + shard_count = 8, ) runtime_test( - blacklist_file = "blacklist_java11.csv", - image = "gcr.io/gvisor-presubmit/java11", + name = "java11", + batch = 100, + exclude_file = "exclude_java11.csv", lang = "java", + shard_count = 16, ) runtime_test( - blacklist_file = "blacklist_nodejs12.4.0.csv", - image = "gcr.io/gvisor-presubmit/nodejs12.4.0", + name = "nodejs12.4.0", + exclude_file = "exclude_nodejs12.4.0.csv", lang = "nodejs", + shard_count = 8, ) runtime_test( - blacklist_file = "blacklist_php7.3.6.csv", - image = "gcr.io/gvisor-presubmit/php7.3.6", + name = "php7.3.6", + exclude_file = "exclude_php7.3.6.csv", lang = "php", + shard_count = 8, ) runtime_test( - blacklist_file = "blacklist_python3.7.3.csv", - image = "gcr.io/gvisor-presubmit/python3.7.3", + name = "python3.7.3", + exclude_file = "exclude_python3.7.3.csv", lang = "python", + shard_count = 8, ) -go_test( - name = "blacklist_test", - size = "small", - srcs = ["blacklist_test.go"], - embed = [":runner"], +bzl_library( + name = "defs_bzl", + srcs = ["defs.bzl"], + visibility = ["//visibility:private"], ) diff --git a/test/runtimes/README.md b/test/runtimes/README.md deleted file mode 100644 index e41e78f77..000000000 --- a/test/runtimes/README.md +++ /dev/null @@ -1,41 +0,0 @@ -# Runtimes Tests Dockerfiles - -The Dockerfiles defined under this path are configured to host the execution of -the runtimes language tests. Each Dockerfile can support the language indicated -by its directory. - -The following runtimes are currently supported: - -- Go 1.12 -- Java 11 -- Node.js 12 -- PHP 7.3 -- Python 3.7 - -#### Prerequisites: - -1) [Install and configure Docker](https://docs.docker.com/install/) - -2) Build each Docker container from the runtimes/images directory: - -```bash -$ cd images -$ docker build -f Dockerfile_$LANG [-t $NAME] . -``` - -### Testing: - -If the prerequisites have been fulfilled, you can run the tests with the -following command: - -```bash -$ docker run --rm -it $NAME [FLAG] -``` - -Running the command with no flags will cause all the available tests to execute. - -Flags can be added for additional functionality: - -- --list: Print a list of all available tests -- --test <name>: Run a single test from the list of available tests -- --v: Print the language version diff --git a/test/runtimes/blacklist_go1.12.csv b/test/runtimes/blacklist_go1.12.csv deleted file mode 100644 index 8c8ae0c5d..000000000 --- a/test/runtimes/blacklist_go1.12.csv +++ /dev/null @@ -1,16 +0,0 @@ -test name,bug id,comment -cgo_errors,,FLAKY -cgo_test,,FLAKY -go_test:cmd/go,,FLAKY -go_test:cmd/vendor/golang.org/x/sys/unix,b/118783622,/dev devices missing -go_test:net,b/118784196,socket: invalid argument. Works as intended: see bug. -go_test:os,b/118780122,we have a pollable filesystem but that's a surprise -go_test:os/signal,b/118780860,/dev/pts not properly supported -go_test:runtime,b/118782341,sigtrap not reported or caught or something -go_test:syscall,b/118781998,bad bytes -- bad mem addr -race,b/118782931,thread sanitizer. Works as intended: b/62219744. -runtime:cpu124,b/118778254,segmentation fault -test:0_1,,FLAKY -testasan,, -testcarchive,b/118782924,no sigpipe -testshared,,FLAKY diff --git a/test/runtimes/blacklist_java11.csv b/test/runtimes/blacklist_java11.csv deleted file mode 100644 index c012e5a56..000000000 --- a/test/runtimes/blacklist_java11.csv +++ /dev/null @@ -1,126 +0,0 @@ -test name,bug id,comment -com/sun/crypto/provider/Cipher/PBE/PKCS12Cipher.java,,Fails in Docker -com/sun/jdi/NashornPopFrameTest.java,, -com/sun/jdi/ProcessAttachTest.java,, -com/sun/management/HotSpotDiagnosticMXBean/CheckOrigin.java,,Fails in Docker -com/sun/management/OperatingSystemMXBean/GetCommittedVirtualMemorySize.java,, -com/sun/management/UnixOperatingSystemMXBean/GetMaxFileDescriptorCount.sh,, -com/sun/tools/attach/AttachSelf.java,, -com/sun/tools/attach/BasicTests.java,, -com/sun/tools/attach/PermissionTest.java,, -com/sun/tools/attach/StartManagementAgent.java,, -com/sun/tools/attach/TempDirTest.java,, -com/sun/tools/attach/modules/Driver.java,, -java/lang/Character/CheckScript.java,,Fails in Docker -java/lang/Character/CheckUnicode.java,,Fails in Docker -java/lang/Class/GetPackageBootLoaderChildLayer.java,, -java/lang/ClassLoader/nativeLibrary/NativeLibraryTest.java,,Fails in Docker -java/lang/String/nativeEncoding/StringPlatformChars.java,, -java/net/DatagramSocket/ReuseAddressTest.java,, -java/net/DatagramSocket/SendDatagramToBadAddress.java,b/78473345, -java/net/Inet4Address/PingThis.java,, -java/net/InterfaceAddress/NetworkPrefixLength.java,b/78507103, -java/net/MulticastSocket/MulticastTTL.java,, -java/net/MulticastSocket/Promiscuous.java,, -java/net/MulticastSocket/SetLoopbackMode.java,, -java/net/MulticastSocket/SetTTLAndGetTTL.java,, -java/net/MulticastSocket/Test.java,, -java/net/MulticastSocket/TestDefaults.java,, -java/net/MulticastSocket/TimeToLive.java,, -java/net/NetworkInterface/NetworkInterfaceStreamTest.java,, -java/net/Socket/SetSoLinger.java,b/78527327,SO_LINGER is not yet supported -java/net/Socket/TrafficClass.java,b/78527818,Not supported on gVisor -java/net/Socket/UrgentDataTest.java,b/111515323, -java/net/Socket/setReuseAddress/Basic.java,b/78519214,SO_REUSEADDR enabled by default -java/net/SocketOption/OptionsTest.java,,Fails in Docker -java/net/SocketOption/TcpKeepAliveTest.java,, -java/net/SocketPermission/SocketPermissionTest.java,, -java/net/URLConnection/6212146/TestDriver.java,,Fails in Docker -java/net/httpclient/RequestBuilderTest.java,,Fails in Docker -java/net/httpclient/ShortResponseBody.java,, -java/net/httpclient/ShortResponseBodyWithRetry.java,, -java/nio/channels/AsyncCloseAndInterrupt.java,, -java/nio/channels/AsynchronousServerSocketChannel/Basic.java,, -java/nio/channels/AsynchronousSocketChannel/Basic.java,b/77921528,SO_KEEPALIVE is not settable -java/nio/channels/DatagramChannel/BasicMulticastTests.java,, -java/nio/channels/DatagramChannel/SocketOptionTests.java,,Fails in Docker -java/nio/channels/DatagramChannel/UseDGWithIPv6.java,, -java/nio/channels/FileChannel/directio/DirectIOTest.java,,Fails in Docker -java/nio/channels/Selector/OutOfBand.java,, -java/nio/channels/Selector/SelectWithConsumer.java,,Flaky -java/nio/channels/ServerSocketChannel/SocketOptionTests.java,, -java/nio/channels/SocketChannel/LingerOnClose.java,, -java/nio/channels/SocketChannel/SocketOptionTests.java,b/77965901, -java/nio/channels/spi/SelectorProvider/inheritedChannel/InheritedChannelTest.java,,Fails in Docker -java/rmi/activation/Activatable/extLoadedImpl/ext.sh,, -java/rmi/transport/checkLeaseInfoLeak/CheckLeaseLeak.java,, -java/text/Format/NumberFormat/CurrencyFormat.java,,Fails in Docker -java/text/Format/NumberFormat/CurrencyFormat.java,,Fails in Docker -java/util/Calendar/JapaneseEraNameTest.java,, -java/util/Currency/CurrencyTest.java,,Fails in Docker -java/util/Currency/ValidateISO4217.java,,Fails in Docker -java/util/Locale/LSRDataTest.java,, -java/util/concurrent/locks/Lock/TimedAcquireLeak.java,, -java/util/jar/JarFile/mrjar/MultiReleaseJarAPI.java,,Fails in Docker -java/util/logging/LogManager/Configuration/updateConfiguration/SimpleUpdateConfigWithInputStreamTest.java,, -java/util/logging/TestLoggerWeakRefLeak.java,, -javax/imageio/AppletResourceTest.java,, -javax/management/security/HashedPasswordFileTest.java,, -javax/net/ssl/SSLSession/JSSERenegotiate.java,,Fails in Docker -javax/sound/sampled/AudioInputStream/FrameLengthAfterConversion.java,, -jdk/jfr/event/runtime/TestNetworkUtilizationEvent.java,, -jdk/jfr/event/runtime/TestThreadParkEvent.java,, -jdk/jfr/event/sampling/TestNative.java,, -jdk/jfr/jcmd/TestJcmdChangeLogLevel.java,, -jdk/jfr/jcmd/TestJcmdConfigure.java,, -jdk/jfr/jcmd/TestJcmdDump.java,, -jdk/jfr/jcmd/TestJcmdDumpGeneratedFilename.java,, -jdk/jfr/jcmd/TestJcmdDumpLimited.java,, -jdk/jfr/jcmd/TestJcmdDumpPathToGCRoots.java,, -jdk/jfr/jcmd/TestJcmdLegacy.java,, -jdk/jfr/jcmd/TestJcmdSaveToFile.java,, -jdk/jfr/jcmd/TestJcmdStartDirNotExist.java,, -jdk/jfr/jcmd/TestJcmdStartInvaldFile.java,, -jdk/jfr/jcmd/TestJcmdStartPathToGCRoots.java,, -jdk/jfr/jcmd/TestJcmdStartStopDefault.java,, -jdk/jfr/jcmd/TestJcmdStartWithOptions.java,, -jdk/jfr/jcmd/TestJcmdStartWithSettings.java,, -jdk/jfr/jcmd/TestJcmdStopInvalidFile.java,, -jdk/jfr/jvm/TestJfrJavaBase.java,, -jdk/jfr/startupargs/TestStartRecording.java,, -jdk/modules/incubator/ImageModules.java,, -jdk/net/Sockets/ExtOptionTest.java,, -jdk/net/Sockets/QuickAckTest.java,, -lib/security/cacerts/VerifyCACerts.java,, -sun/management/jmxremote/bootstrap/CustomLauncherTest.java,, -sun/management/jmxremote/bootstrap/JvmstatCountersTest.java,, -sun/management/jmxremote/bootstrap/LocalManagementTest.java,, -sun/management/jmxremote/bootstrap/RmiRegistrySslTest.java,, -sun/management/jmxremote/bootstrap/RmiSslBootstrapTest.sh,, -sun/management/jmxremote/startstop/JMXStartStopTest.java,, -sun/management/jmxremote/startstop/JMXStatusPerfCountersTest.java,, -sun/management/jmxremote/startstop/JMXStatusTest.java,, -sun/text/resources/LocaleDataTest.java,, -sun/tools/jcmd/TestJcmdSanity.java,, -sun/tools/jhsdb/AlternateHashingTest.java,, -sun/tools/jhsdb/BasicLauncherTest.java,, -sun/tools/jhsdb/HeapDumpTest.java,, -sun/tools/jhsdb/heapconfig/JMapHeapConfigTest.java,, -sun/tools/jinfo/BasicJInfoTest.java,, -sun/tools/jinfo/JInfoTest.java,, -sun/tools/jmap/BasicJMapTest.java,, -sun/tools/jstack/BasicJStackTest.java,, -sun/tools/jstack/DeadlockDetectionTest.java,, -sun/tools/jstatd/TestJstatdExternalRegistry.java,, -sun/tools/jstatd/TestJstatdPort.java,,Flaky -sun/tools/jstatd/TestJstatdPortAndServer.java,,Flaky -sun/util/calendar/zi/TestZoneInfo310.java,, -tools/jar/modularJar/Basic.java,, -tools/jar/multiRelease/Basic.java,, -tools/jimage/JImageExtractTest.java,, -tools/jimage/JImageTest.java,, -tools/jlink/JLinkTest.java,, -tools/jlink/plugins/IncludeLocalesPluginTest.java,, -tools/jmod/hashes/HashesTest.java,, -tools/launcher/BigJar.java,b/111611473, -tools/launcher/modules/patch/systemmodules/PatchSystemModules.java,, diff --git a/test/runtimes/blacklist_nodejs12.4.0.csv b/test/runtimes/blacklist_nodejs12.4.0.csv deleted file mode 100644 index 4ab4e2927..000000000 --- a/test/runtimes/blacklist_nodejs12.4.0.csv +++ /dev/null @@ -1,47 +0,0 @@ -test name,bug id,comment -benchmark/test-benchmark-fs.js,, -benchmark/test-benchmark-module.js,, -benchmark/test-benchmark-napi.js,, -doctool/test-make-doc.js,b/68848110,Expected to fail. -fixtures/test-error-first-line-offset.js,, -fixtures/test-fs-readfile-error.js,, -fixtures/test-fs-stat-sync-overflow.js,, -internet/test-dgram-broadcast-multi-process.js,, -internet/test-dgram-multicast-multi-process.js,, -internet/test-dgram-multicast-set-interface-lo.js,, -parallel/test-cluster-dgram-reuse.js,b/64024294, -parallel/test-dgram-bind-fd.js,b/132447356, -parallel/test-dgram-create-socket-handle-fd.js,b/132447238, -parallel/test-dgram-createSocket-type.js,b/68847739, -parallel/test-dgram-socket-buffer-size.js,b/68847921, -parallel/test-fs-access.js,, -parallel/test-fs-write-stream-double-close.js,, -parallel/test-fs-write-stream-throw-type-error.js,b/110226209, -parallel/test-fs-write-stream.js,, -parallel/test-http2-respond-file-error-pipe-offset.js,, -parallel/test-os.js,, -parallel/test-process-uid-gid.js,, -pseudo-tty/test-assert-colors.js,, -pseudo-tty/test-assert-no-color.js,, -pseudo-tty/test-assert-position-indicator.js,, -pseudo-tty/test-async-wrap-getasyncid-tty.js,, -pseudo-tty/test-fatal-error.js,, -pseudo-tty/test-handle-wrap-isrefed-tty.js,, -pseudo-tty/test-readable-tty-keepalive.js,, -pseudo-tty/test-set-raw-mode-reset-process-exit.js,, -pseudo-tty/test-set-raw-mode-reset-signal.js,, -pseudo-tty/test-set-raw-mode-reset.js,, -pseudo-tty/test-stderr-stdout-handle-sigwinch.js,, -pseudo-tty/test-stdout-read.js,, -pseudo-tty/test-tty-color-support.js,, -pseudo-tty/test-tty-isatty.js,, -pseudo-tty/test-tty-stdin-call-end.js,, -pseudo-tty/test-tty-stdin-end.js,, -pseudo-tty/test-stdin-write.js,, -pseudo-tty/test-tty-stdout-end.js,, -pseudo-tty/test-tty-stdout-resize.js,, -pseudo-tty/test-tty-stream-constructors.js,, -pseudo-tty/test-tty-window-size.js,, -pseudo-tty/test-tty-wrap.js,, -pummel/test-net-pingpong.js,, -pummel/test-vm-memleak.js,, diff --git a/test/runtimes/blacklist_python3.7.3.csv b/test/runtimes/blacklist_python3.7.3.csv deleted file mode 100644 index 2b9947212..000000000 --- a/test/runtimes/blacklist_python3.7.3.csv +++ /dev/null @@ -1,27 +0,0 @@ -test name,bug id,comment -test_asynchat,b/76031995,SO_REUSEADDR -test_asyncio,,Fails on Docker. -test_asyncore,b/76031995,SO_REUSEADDR -test_epoll,, -test_fcntl,,fcntl invalid argument -- artificial test to make sure something works in 64 bit mode. -test_ftplib,,Fails in Docker -test_httplib,b/76031995,SO_REUSEADDR -test_imaplib,, -test_logging,, -test_multiprocessing_fork,,Flaky. Sometimes times out. -test_multiprocessing_forkserver,,Flaky. Sometimes times out. -test_multiprocessing_main_handling,,Flaky. Sometimes times out. -test_multiprocessing_spawn,,Flaky. Sometimes times out. -test_nntplib,b/76031995,tests should not set SO_REUSEADDR -test_poplib,,Fails on Docker -test_posix,b/76174079,posix.sched_get_priority_min not implemented + posix.sched_rr_get_interval not permitted -test_pty,b/76157709,out of pty devices -test_readline,b/76157709,out of pty devices -test_resource,b/76174079, -test_selectors,b/76116849,OSError not raised with epoll -test_smtplib,b/76031995,SO_REUSEADDR and unclosed sockets -test_socket,b/75983380, -test_ssl,b/76031995,SO_REUSEADDR -test_subprocess,, -test_support,b/76031995,SO_REUSEADDR -test_telnetlib,b/76031995,SO_REUSEADDR diff --git a/test/runtimes/build_defs.bzl b/test/runtimes/build_defs.bzl deleted file mode 100644 index 7c11624b4..000000000 --- a/test/runtimes/build_defs.bzl +++ /dev/null @@ -1,57 +0,0 @@ -"""Defines a rule for runtime test targets.""" - -load("@io_bazel_rules_go//go:def.bzl", "go_test") - -# runtime_test is a macro that will create targets to run the given test target -# with different runtime options. -def runtime_test( - lang, - image, - shard_count = 50, - size = "enormous", - blacklist_file = ""): - args = [ - "--lang", - lang, - "--image", - image, - ] - data = [ - ":runner", - ] - if blacklist_file != "": - args += ["--blacklist_file", "test/runtimes/" + blacklist_file] - data += [blacklist_file] - - # Add a test that the blacklist parses correctly. - blacklist_test(lang, blacklist_file) - - sh_test( - name = lang + "_test", - srcs = ["runner.sh"], - args = args, - data = data, - size = size, - shard_count = shard_count, - tags = [ - # Requires docker and runsc to be configured before the test runs. - "manual", - "local", - ], - ) - -def blacklist_test(lang, blacklist_file): - """Test that a blacklist parses correctly.""" - go_test( - name = lang + "_blacklist_test", - embed = [":runner"], - srcs = ["blacklist_test.go"], - args = ["--blacklist_file", "test/runtimes/" + blacklist_file], - data = [blacklist_file], - ) - -def sh_test(**kwargs): - """Wraps the standard sh_test.""" - native.sh_test( - **kwargs - ) diff --git a/test/runtimes/defs.bzl b/test/runtimes/defs.bzl new file mode 100644 index 000000000..702522d86 --- /dev/null +++ b/test/runtimes/defs.bzl @@ -0,0 +1,90 @@ +"""Defines a rule for runtime test targets.""" + +load("//tools:defs.bzl", "go_test") + +def _runtime_test_impl(ctx): + # Construct arguments. + args = [ + "--lang", + ctx.attr.lang, + "--image", + ctx.attr.image, + "--batch", + str(ctx.attr.batch), + ] + if ctx.attr.exclude_file: + args += [ + "--exclude_file", + ctx.files.exclude_file[0].short_path, + ] + + # Build a runner. + runner = ctx.actions.declare_file("%s-executer" % ctx.label.name) + runner_content = "\n".join([ + "#!/bin/bash", + "%s %s $@\n" % (ctx.files._runner[0].short_path, " ".join(args)), + ]) + ctx.actions.write(runner, runner_content, is_executable = True) + + # Return the runner. + return [DefaultInfo( + executable = runner, + runfiles = ctx.runfiles( + files = ctx.files._runner + ctx.files.exclude_file + ctx.files._proctor, + collect_default = True, + collect_data = True, + ), + )] + +_runtime_test = rule( + implementation = _runtime_test_impl, + attrs = { + "image": attr.string( + mandatory = False, + ), + "lang": attr.string( + mandatory = True, + ), + "exclude_file": attr.label( + mandatory = False, + allow_single_file = True, + ), + "batch": attr.int( + default = 50, + mandatory = False, + ), + "_runner": attr.label( + default = "//test/runtimes/runner:runner", + executable = True, + cfg = "target", + ), + "_proctor": attr.label( + default = "//test/runtimes/proctor:proctor", + executable = True, + cfg = "target", + ), + }, + test = True, +) + +def runtime_test(name, **kwargs): + _runtime_test( + name = name, + image = name, # Resolved as images/runtimes/%s. + tags = [ + "local", + "manual", + ], + size = "enormous", + **kwargs + ) + +def exclude_test(name, exclude_file): + """Test that a exclude file parses correctly.""" + go_test( + name = name + "_exclude_test", + library = ":runner", + srcs = ["exclude_test.go"], + args = ["--exclude_file", "test/runtimes/" + exclude_file], + data = [exclude_file], + ) diff --git a/test/runtimes/exclude_go1.12.csv b/test/runtimes/exclude_go1.12.csv new file mode 100644 index 000000000..81e02cf64 --- /dev/null +++ b/test/runtimes/exclude_go1.12.csv @@ -0,0 +1,13 @@ +test name,bug id,comment +cgo_errors,,FLAKY +cgo_test,,FLAKY +go_test:cmd/go,,FLAKY +go_test:net,b/162473575,setsockopt: protocol not available. +go_test:os,b/118780122,we have a pollable filesystem but that's a surprise +go_test:os/signal,b/118780860,/dev/pts not properly supported. Also being tracked in b/29356795. +go_test:runtime,b/118782341,sigtrap not reported or caught or something. Also being tracked in b/33003106. +go_test:syscall,b/118781998,bad bytes -- bad mem addr; FcntlFlock(F_GETLK) not supported. +runtime:cpu124,b/118778254,segmentation fault +test:0_1,,FLAKY +testcarchive,b/118782924,no sigpipe +testshared,,FLAKY diff --git a/test/runtimes/exclude_java11.csv b/test/runtimes/exclude_java11.csv new file mode 100644 index 000000000..997a29cad --- /dev/null +++ b/test/runtimes/exclude_java11.csv @@ -0,0 +1,208 @@ +test name,bug id,comment +com/sun/crypto/provider/Cipher/PBE/PKCS12Cipher.java,,Fails in Docker +com/sun/jdi/NashornPopFrameTest.java,, +com/sun/jdi/ProcessAttachTest.java,, +com/sun/management/HotSpotDiagnosticMXBean/CheckOrigin.java,,Fails in Docker +com/sun/management/OperatingSystemMXBean/GetCommittedVirtualMemorySize.java,, +com/sun/management/UnixOperatingSystemMXBean/GetMaxFileDescriptorCount.sh,, +com/sun/tools/attach/AttachSelf.java,, +com/sun/tools/attach/BasicTests.java,, +com/sun/tools/attach/PermissionTest.java,, +com/sun/tools/attach/StartManagementAgent.java,, +com/sun/tools/attach/TempDirTest.java,, +com/sun/tools/attach/modules/Driver.java,, +java/lang/Character/CheckScript.java,,Fails in Docker +java/lang/Character/CheckUnicode.java,,Fails in Docker +java/lang/Class/GetPackageBootLoaderChildLayer.java,, +java/lang/ClassLoader/nativeLibrary/NativeLibraryTest.java,,Fails in Docker +java/lang/module/ModuleDescriptorTest.java,, +java/lang/String/nativeEncoding/StringPlatformChars.java,, +java/net/CookieHandler/B6791927.java,,java.lang.RuntimeException: Expiration date shouldn't be 0 +java/net/ipv6tests/TcpTest.java,,java.net.ConnectException: Connection timed out (Connection timed out) +java/net/ipv6tests/UdpTest.java,,Times out +java/net/Inet6Address/B6558853.java,,Times out +java/net/InetAddress/CheckJNI.java,,java.net.ConnectException: Connection timed out (Connection timed out) +java/net/InterfaceAddress/NetworkPrefixLength.java,b/78507103, +java/net/MulticastSocket/B6425815.java,,java.net.SocketException: Protocol not available (Error getting socket option) +java/net/MulticastSocket/B6427403.java,,java.net.SocketException: Protocol not available +java/net/MulticastSocket/MulticastTTL.java,, +java/net/MulticastSocket/NetworkInterfaceEmptyGetInetAddressesTest.java,,java.net.SocketException: Protocol not available (Error getting socket option) +java/net/MulticastSocket/NoLoopbackPackets.java,,java.net.SocketException: Protocol not available +java/net/MulticastSocket/Promiscuous.java,, +java/net/MulticastSocket/SetLoopbackMode.java,, +java/net/MulticastSocket/SetTTLAndGetTTL.java,, +java/net/MulticastSocket/Test.java,, +java/net/MulticastSocket/TestDefaults.java,, +java/net/MulticastSocket/TimeToLive.java,, +java/net/NetworkInterface/NetworkInterfaceStreamTest.java,, +java/net/Socket/LinkLocal.java,,java.net.SocketTimeoutException: Receive timed out +java/net/Socket/SetSoLinger.java,b/78527327,SO_LINGER is not yet supported +java/net/Socket/UrgentDataTest.java,b/111515323, +java/net/SocketOption/OptionsTest.java,,Fails in Docker +java/net/SocketPermission/SocketPermissionTest.java,, +java/net/URLConnection/6212146/TestDriver.java,,Fails in Docker +java/net/httpclient/RequestBuilderTest.java,,Fails in Docker +java/nio/channels/DatagramChannel/BasicMulticastTests.java,, +java/nio/channels/DatagramChannel/SocketOptionTests.java,,java.net.SocketException: Invalid argument +java/nio/channels/DatagramChannel/UseDGWithIPv6.java,, +java/nio/channels/FileChannel/directio/DirectIOTest.java,,Fails in Docker +java/nio/channels/FileChannel/directio/PwriteDirect.java,,java.io.IOException: Invalid argument +java/nio/channels/Selector/OutOfBand.java,, +java/nio/channels/Selector/SelectWithConsumer.java,,Flaky +java/nio/channels/ServerSocketChannel/SocketOptionTests.java,, +java/nio/channels/SocketChannel/LingerOnClose.java,, +java/nio/channels/SocketChannel/SocketOptionTests.java,b/77965901, +java/nio/channels/spi/SelectorProvider/inheritedChannel/InheritedChannelTest.java,,Fails in Docker +java/rmi/activation/Activatable/extLoadedImpl/ext.sh,, +java/rmi/transport/checkLeaseInfoLeak/CheckLeaseLeak.java,, +java/text/Format/NumberFormat/CurrencyFormat.java,,Fails in Docker +java/text/Format/NumberFormat/CurrencyFormat.java,,Fails in Docker +java/util/Calendar/JapaneseEraNameTest.java,, +java/util/Currency/CurrencyTest.java,,Fails in Docker +java/util/Currency/ValidateISO4217.java,,Fails in Docker +java/util/EnumSet/BogusEnumSet.java,,"java.io.InvalidClassException: java.util.EnumSet; local class incompatible: stream classdesc serialVersionUID = -2409567991088730183, local class serialVersionUID = 1009687484059888093" +java/util/Locale/Bug8040211.java,,java.lang.RuntimeException: Failed. +java/util/Locale/LSRDataTest.java,, +java/util/Properties/CompatibilityTest.java,,"java.lang.RuntimeException: jdk.internal.org.xml.sax.SAXParseException; Internal DTD subset is not allowed. The Properties XML document must have the following DOCTYPE declaration: <!DOCTYPE properties SYSTEM ""http://java.sun.com/dtd/properties.dtd"">" +java/util/ResourceBundle/Control/XMLResourceBundleTest.java,,java.util.MissingResourceException: Can't find bundle for base name XmlRB locale +java/util/ResourceBundle/modules/xmlformat/xmlformat.sh,,Timeout reached: 60000. Process is not alive! +java/util/TimeZone/TimeZoneTest.java,,Uncaught exception thrown in test method TestShortZoneIDs +java/util/concurrent/locks/Lock/TimedAcquireLeak.java,, +java/util/jar/JarFile/mrjar/MultiReleaseJarAPI.java,,Fails in Docker +java/util/logging/LogManager/Configuration/updateConfiguration/SimpleUpdateConfigWithInputStreamTest.java,, +java/util/logging/TestLoggerWeakRefLeak.java,, +java/util/spi/ResourceBundleControlProvider/UserDefaultControlTest.java,,java.util.MissingResourceException: Can't find bundle for base name com.foo.XmlRB locale +javax/imageio/AppletResourceTest.java,, +javax/imageio/plugins/jpeg/JPEGsNotAcceleratedTest.java,,java.awt.HeadlessException: No X11 DISPLAY variable was set but this program performed an operation which requires it. +javax/management/security/HashedPasswordFileTest.java,, +javax/net/ssl/DTLS/DTLSBufferOverflowUnderflowTest.java,,Compilation failed +javax/net/ssl/DTLS/DTLSDataExchangeTest.java,,Compilation failed +javax/net/ssl/DTLS/DTLSEnginesClosureTest.java,,Compilation failed +javax/net/ssl/DTLS/DTLSHandshakeTest.java,,Compilation failed +javax/net/ssl/DTLS/DTLSHandshakeWithReplicatedPacketsTest.java,,Compilation failed +javax/net/ssl/DTLS/DTLSIncorrectAppDataTest.java,,Compilation failed +javax/net/ssl/DTLS/DTLSMFLNTest.java,,Compilation failed +javax/net/ssl/DTLS/DTLSNotEnabledRC4Test.java,,Compilation failed +javax/net/ssl/DTLS/DTLSRehandshakeTest.java,,Compilation failed +javax/net/ssl/DTLS/DTLSRehandshakeWithCipherChangeTest.java,,Compilation failed +javax/net/ssl/DTLS/DTLSRehandshakeWithDataExTest.java,,Compilation failed +javax/net/ssl/DTLS/DTLSSequenceNumberTest.java,,Compilation failed +javax/net/ssl/DTLS/DTLSUnsupportedCiphersTest.java,,Compilation failed +javax/net/ssl/DTLSv10/DTLSv10BufferOverflowUnderflowTest.java,,Compilation failed +javax/net/ssl/DTLSv10/DTLSv10DataExchangeTest.java,,Compilation failed +javax/net/ssl/DTLSv10/DTLSv10EnginesClosureTest.java,,Compilation failed +javax/net/ssl/DTLSv10/DTLSv10HandshakeTest.java,,Compilation failed +javax/net/ssl/DTLSv10/DTLSv10HandshakeWithReplicatedPacketsTest.java,,Compilation failed +javax/net/ssl/DTLSv10/DTLSv10IncorrectAppDataTest.java,,Compilation failed +javax/net/ssl/DTLSv10/DTLSv10MFLNTest.java,,Compilation failed +javax/net/ssl/DTLSv10/DTLSv10NotEnabledRC4Test.java,,Compilation failed +javax/net/ssl/DTLSv10/DTLSv10RehandshakeTest.java,,Compilation failed +javax/net/ssl/DTLSv10/DTLSv10RehandshakeWithCipherChangeTest.java,,Compilation failed +javax/net/ssl/DTLSv10/DTLSv10RehandshakeWithDataExTest.java,,Compilation failed +javax/net/ssl/DTLSv10/DTLSv10SequenceNumberTest.java,,Compilation failed +javax/net/ssl/DTLSv10/DTLSv10UnsupportedCiphersTest.java,,Compilation failed +javax/net/ssl/SSLSession/JSSERenegotiate.java,,Fails in Docker +javax/net/ssl/TLS/TLSDataExchangeTest.java,,Compilation failed +javax/net/ssl/TLS/TLSEnginesClosureTest.java,,Compilation failed +javax/net/ssl/TLS/TLSHandshakeTest.java,,Compilation failed +javax/net/ssl/TLS/TLSMFLNTest.java,,Compilation failed +javax/net/ssl/TLS/TLSNotEnabledRC4Test.java,,Compilation failed +javax/net/ssl/TLS/TLSRehandshakeTest.java,,Compilation failed +javax/net/ssl/TLS/TLSRehandshakeWithCipherChangeTest.java,,Compilation failed +javax/net/ssl/TLS/TLSRehandshakeWithDataExTest.java,,Compilation failed +javax/net/ssl/TLS/TLSUnsupportedCiphersTest.java,,Compilation failed +javax/net/ssl/TLSv1/TLSDataExchangeTest.java,,Compilation failed +javax/net/ssl/TLSv1/TLSEnginesClosureTest.java,,Compilation failed +javax/net/ssl/TLSv1/TLSHandshakeTest.java,,Compilation failed +javax/net/ssl/TLSv1/TLSMFLNTest.java,,Compilation failed +javax/net/ssl/TLSv1/TLSNotEnabledRC4Test.java,,Compilation failed +javax/net/ssl/TLSv1/TLSRehandshakeTest.java,,Compilation failed +javax/net/ssl/TLSv1/TLSRehandshakeWithCipherChangeTest.java,,Compilation failed +javax/net/ssl/TLSv1/TLSRehandshakeWithDataExTest.java,,Compilation failed +javax/net/ssl/TLSv1/TLSUnsupportedCiphersTest.java,,Compilation failed +javax/net/ssl/TLSv11/TLSDataExchangeTest.java,,Compilation failed +javax/net/ssl/TLSv11/TLSEnginesClosureTest.java,,Compilation failed +javax/net/ssl/TLSv11/TLSHandshakeTest.java,,Compilation failed +javax/net/ssl/TLSv11/TLSMFLNTest.java,,Compilation failed +javax/net/ssl/TLSv11/TLSNotEnabledRC4Test.java,,Compilation failed +javax/net/ssl/TLSv11/TLSRehandshakeTest.java,,Compilation failed +javax/net/ssl/TLSv11/TLSRehandshakeWithCipherChangeTest.java,,Compilation failed +javax/net/ssl/TLSv11/TLSRehandshakeWithDataExTest.java,,Compilation failed +javax/net/ssl/TLSv11/TLSUnsupportedCiphersTest.java,,Compilation failed +javax/net/ssl/TLSv12/TLSEnginesClosureTest.java,,Compilation failed +javax/sound/sampled/AudioInputStream/FrameLengthAfterConversion.java,, +jdk/jfr/cmd/TestHelp.java,,java.lang.RuntimeException: 'Available commands are:' missing from stdout/stderr +jdk/jfr/cmd/TestPrint.java,,Missing file' missing from stdout/stderr +jdk/jfr/cmd/TestPrintDefault.java,,java.lang.RuntimeException: 'JVMInformation' missing from stdout/stderr +jdk/jfr/cmd/TestPrintJSON.java,,javax.script.ScriptException: <eval>:1:17 Expected an operand but found eof var jsonObject = ^ in <eval> at line number 1 at column number 17 +jdk/jfr/cmd/TestPrintXML.java,,org.xml.sax.SAXParseException; lineNumber: 1; columnNumber: 1; Premature end of file. +jdk/jfr/cmd/TestReconstruct.java,,java.lang.RuntimeException: 'Too few arguments' missing from stdout/stderr +jdk/jfr/cmd/TestSplit.java,,java.lang.RuntimeException: 'Missing file' missing from stdout/stderr +jdk/jfr/cmd/TestSummary.java,,java.lang.RuntimeException: 'Missing file' missing from stdout/stderr +jdk/jfr/event/compiler/TestCompilerStats.java,,java.lang.RuntimeException: Field nmetodsSize not in event +jdk/jfr/event/metadata/TestDefaultConfigurations.java,,Setting 'threshold' in event 'jdk.SecurityPropertyModification' was not configured in the configuration 'default' +jdk/jfr/event/runtime/TestActiveSettingEvent.java,,java.lang.Exception: Could not find setting with name jdk.X509Validation#threshold +jdk/jfr/event/runtime/TestModuleEvents.java,,java.lang.RuntimeException: assertEquals: expected jdk.proxy1 to equal java.base +jdk/jfr/event/runtime/TestNetworkUtilizationEvent.java,, +jdk/jfr/event/runtime/TestThreadParkEvent.java,, +jdk/jfr/event/sampling/TestNative.java,, +jdk/jfr/jcmd/TestJcmdChangeLogLevel.java,, +jdk/jfr/jcmd/TestJcmdConfigure.java,, +jdk/jfr/jcmd/TestJcmdDump.java,, +jdk/jfr/jcmd/TestJcmdDumpGeneratedFilename.java,, +jdk/jfr/jcmd/TestJcmdDumpLimited.java,, +jdk/jfr/jcmd/TestJcmdDumpPathToGCRoots.java,, +jdk/jfr/jcmd/TestJcmdLegacy.java,, +jdk/jfr/jcmd/TestJcmdSaveToFile.java,, +jdk/jfr/jcmd/TestJcmdStartDirNotExist.java,, +jdk/jfr/jcmd/TestJcmdStartInvaldFile.java,, +jdk/jfr/jcmd/TestJcmdStartPathToGCRoots.java,, +jdk/jfr/jcmd/TestJcmdStartStopDefault.java,, +jdk/jfr/jcmd/TestJcmdStartWithOptions.java,, +jdk/jfr/jcmd/TestJcmdStartWithSettings.java,, +jdk/jfr/jcmd/TestJcmdStopInvalidFile.java,, +jdk/jfr/jvm/TestGetAllEventClasses.java,,Compilation failed +jdk/jfr/jvm/TestJfrJavaBase.java,, +jdk/jfr/startupargs/TestStartRecording.java,, +jdk/modules/incubator/ImageModules.java,, +jdk/net/Sockets/ExtOptionTest.java,, +jdk/net/Sockets/QuickAckTest.java,, +lib/security/cacerts/VerifyCACerts.java,, +sun/management/jmxremote/bootstrap/CustomLauncherTest.java,, +sun/management/jmxremote/bootstrap/JvmstatCountersTest.java,, +sun/management/jmxremote/bootstrap/LocalManagementTest.java,, +sun/management/jmxremote/bootstrap/RmiRegistrySslTest.java,, +sun/management/jmxremote/bootstrap/RmiSslBootstrapTest.sh,, +sun/management/jmxremote/startstop/JMXStartStopTest.java,, +sun/management/jmxremote/startstop/JMXStatusPerfCountersTest.java,, +sun/management/jmxremote/startstop/JMXStatusTest.java,, +sun/management/jdp/JdpDefaultsTest.java,, +sun/management/jdp/JdpJmxRemoteDynamicPortTest.java,, +sun/management/jdp/JdpOffTest.java,, +sun/management/jdp/JdpSpecificAddressTest.java,, +sun/text/resources/LocaleDataTest.java,, +sun/tools/jcmd/TestJcmdSanity.java,, +sun/tools/jhsdb/AlternateHashingTest.java,, +sun/tools/jhsdb/BasicLauncherTest.java,, +sun/tools/jhsdb/HeapDumpTest.java,, +sun/tools/jhsdb/heapconfig/JMapHeapConfigTest.java,, +sun/tools/jinfo/BasicJInfoTest.java,, +sun/tools/jinfo/JInfoTest.java,, +sun/tools/jmap/BasicJMapTest.java,, +sun/tools/jstack/BasicJStackTest.java,, +sun/tools/jstack/DeadlockDetectionTest.java,, +sun/tools/jstatd/TestJstatdExternalRegistry.java,, +sun/tools/jstatd/TestJstatdPort.java,,Flaky +sun/tools/jstatd/TestJstatdPortAndServer.java,,Flaky +sun/util/calendar/zi/TestZoneInfo310.java,, +tools/jar/modularJar/Basic.java,, +tools/jar/multiRelease/Basic.java,, +tools/jimage/JImageExtractTest.java,, +tools/jimage/JImageTest.java,, +tools/jlink/JLinkTest.java,, +tools/jlink/plugins/IncludeLocalesPluginTest.java,, +tools/jmod/hashes/HashesTest.java,, +tools/launcher/BigJar.java,b/111611473, +tools/launcher/HelpFlagsTest.java,,java.lang.AssertionError: HelpFlagsTest failed: Tool jfr not covered by this test. Add specification to jdkTools array! +tools/launcher/VersionCheck.java,,java.lang.AssertionError: VersionCheck failed: testToolVersion: [jfr]; +tools/launcher/modules/patch/systemmodules/PatchSystemModules.java,, diff --git a/test/runtimes/exclude_nodejs12.4.0.csv b/test/runtimes/exclude_nodejs12.4.0.csv new file mode 100644 index 000000000..1d8e65fd0 --- /dev/null +++ b/test/runtimes/exclude_nodejs12.4.0.csv @@ -0,0 +1,55 @@ +test name,bug id,comment +benchmark/test-benchmark-fs.js,, +benchmark/test-benchmark-napi.js,, +doctool/test-make-doc.js,b/68848110,Expected to fail. +internet/test-dgram-multicast-set-interface-lo.js,b/162798882, +internet/test-doctool-versions.js,, +internet/test-uv-threadpool-schedule.js,, +parallel/test-cluster-dgram-reuse.js,b/64024294, +parallel/test-dgram-bind-fd.js,b/132447356, +parallel/test-dgram-socket-buffer-size.js,b/68847921, +parallel/test-dns-channel-timeout.js,b/161893056, +parallel/test-fs-access.js,, +parallel/test-fs-watchfile.js,,Flaky - File already exists error +parallel/test-fs-write-stream.js,,Flaky +parallel/test-fs-write-stream-throw-type-error.js,b/110226209, +parallel/test-http-writable-true-after-close.js,,Flaky - Mismatched <anonymous> function calls. Expected exactly 1 actual 2 +parallel/test-os.js,b/63997097, +parallel/test-net-server-listen-options.js,,Flaky - EADDRINUSE +parallel/test-process-uid-gid.js,, +parallel/test-tls-cli-min-version-1.0.js,,Flaky - EADDRINUSE +parallel/test-tls-cli-min-version-1.1.js,,Flaky - EADDRINUSE +parallel/test-tls-cli-min-version-1.2.js,,Flaky - EADDRINUSE +parallel/test-tls-cli-min-version-1.3.js,,Flaky - EADDRINUSE +parallel/test-tls-cli-max-version-1.2.js,,Flaky - EADDRINUSE +parallel/test-tls-cli-max-version-1.3.js,,Flaky - EADDRINUSE +parallel/test-tls-min-max-version.js,,Flaky - EADDRINUSE +pseudo-tty/test-assert-colors.js,b/162801321, +pseudo-tty/test-assert-no-color.js,b/162801321, +pseudo-tty/test-assert-position-indicator.js,b/162801321, +pseudo-tty/test-async-wrap-getasyncid-tty.js,b/162801321, +pseudo-tty/test-fatal-error.js,b/162801321, +pseudo-tty/test-handle-wrap-isrefed-tty.js,b/162801321, +pseudo-tty/test-readable-tty-keepalive.js,b/162801321, +pseudo-tty/test-set-raw-mode-reset-process-exit.js,b/162801321, +pseudo-tty/test-set-raw-mode-reset-signal.js,b/162801321, +pseudo-tty/test-set-raw-mode-reset.js,b/162801321, +pseudo-tty/test-stderr-stdout-handle-sigwinch.js,b/162801321, +pseudo-tty/test-stdout-read.js,b/162801321, +pseudo-tty/test-tty-color-support.js,b/162801321, +pseudo-tty/test-tty-isatty.js,b/162801321, +pseudo-tty/test-tty-stdin-call-end.js,b/162801321, +pseudo-tty/test-tty-stdin-end.js,b/162801321, +pseudo-tty/test-stdin-write.js,b/162801321, +pseudo-tty/test-tty-stdout-end.js,b/162801321, +pseudo-tty/test-tty-stdout-resize.js,b/162801321, +pseudo-tty/test-tty-stream-constructors.js,b/162801321, +pseudo-tty/test-tty-window-size.js,b/162801321, +pseudo-tty/test-tty-wrap.js,b/162801321, +pummel/test-heapdump-http2.js,,Flaky +pummel/test-net-pingpong.js,, +pummel/test-vm-memleak.js,b/162799436, +sequential/test-child-process-pass-fd.js,b/63926391,Flaky +sequential/test-https-connect-localport.js,,Flaky - EADDRINUSE +sequential/test-net-bytes-per-incoming-chunk-overhead.js,,flaky - timeout +tick-processor/test-tick-processor-builtin.js,, diff --git a/test/runtimes/blacklist_php7.3.6.csv b/test/runtimes/exclude_php7.3.6.csv index 456bf7487..2ce979dc8 100644 --- a/test/runtimes/blacklist_php7.3.6.csv +++ b/test/runtimes/exclude_php7.3.6.csv @@ -8,22 +8,31 @@ ext/mbstring/tests/bug77165.phpt,, ext/mbstring/tests/bug77454.phpt,, ext/mbstring/tests/mb_convert_encoding_leak.phpt,, ext/mbstring/tests/mb_strrpos_encoding_3rd_param.phpt,, -ext/standard/tests/file/filetype_variation.phpt,, -ext/standard/tests/file/fopen_variation19.phpt,, +ext/session/tests/session_module_name_variation4.phpt,,Flaky +ext/session/tests/session_set_save_handler_class_018.phpt,, +ext/session/tests/session_set_save_handler_iface_003.phpt,, +ext/session/tests/session_set_save_handler_sid_001.phpt,, +ext/session/tests/session_set_save_handler_variation4.phpt,, +ext/standard/tests/file/fopen_variation19.phpt,b/162894964, +ext/standard/tests/file/lstat_stat_variation14.phpt,,Flaky ext/standard/tests/file/php_fd_wrapper_01.phpt,, ext/standard/tests/file/php_fd_wrapper_02.phpt,, ext/standard/tests/file/php_fd_wrapper_03.phpt,, ext/standard/tests/file/php_fd_wrapper_04.phpt,, -ext/standard/tests/file/realpath_bug77484.phpt,, +ext/standard/tests/file/realpath_bug77484.phpt,b/162894969, ext/standard/tests/file/rename_variation.phpt,b/68717309, -ext/standard/tests/file/symlink_link_linkinfo_is_link_variation4.phpt,, -ext/standard/tests/file/symlink_link_linkinfo_is_link_variation8.phpt,, +ext/standard/tests/file/symlink_link_linkinfo_is_link_variation4.phpt,b/162895341, +ext/standard/tests/file/symlink_link_linkinfo_is_link_variation8.phpt,b/162896223, ext/standard/tests/general_functions/escapeshellarg_bug71270.phpt,, ext/standard/tests/general_functions/escapeshellcmd_bug71270.phpt,, -ext/standard/tests/network/bug20134.phpt,, +ext/standard/tests/streams/proc_open_bug69900.phpt,,Flaky +ext/standard/tests/streams/stream_socket_sendto.phpt,, +ext/standard/tests/strings/007.phpt,, +sapi/cli/tests/upload_2G.phpt,, tests/output/stream_isatty_err.phpt,b/68720279, tests/output/stream_isatty_in-err.phpt,b/68720282, tests/output/stream_isatty_in-out-err.phpt,, tests/output/stream_isatty_in-out.phpt,b/68720299, tests/output/stream_isatty_out-err.phpt,b/68720311, tests/output/stream_isatty_out.phpt,b/68720325, +Zend/tests/concat_003.phpt,b/162896021, diff --git a/test/runtimes/exclude_python3.7.3.csv b/test/runtimes/exclude_python3.7.3.csv new file mode 100644 index 000000000..8760f8951 --- /dev/null +++ b/test/runtimes/exclude_python3.7.3.csv @@ -0,0 +1,21 @@ +test name,bug id,comment +test_asyncio,,Fails on Docker. +test_asyncore,b/162973328, +test_epoll,b/162983393, +test_fcntl,b/162978767,fcntl invalid argument -- artificial test to make sure something works in 64 bit mode. +test_httplib,b/163000009,OSError: [Errno 98] Address already in use +test_imaplib,b/162979661, +test_logging,b/162980079, +test_multiprocessing_fork,,Flaky. Sometimes times out. +test_multiprocessing_forkserver,,Flaky. Sometimes times out. +test_multiprocessing_main_handling,,Flaky. Sometimes times out. +test_multiprocessing_spawn,,Flaky. Sometimes times out. +test_posix,b/76174079,posix.sched_get_priority_min not implemented + posix.sched_rr_get_interval not permitted +test_pty,b/162979921, +test_readline,b/162980389,TestReadline hangs forever +test_resource,b/76174079, +test_selectors,b/76116849,OSError not raised with epoll +test_smtplib,b/162980434,unclosed sockets +test_signal,,Flaky - signal: alarm clock +test_socket,b/75983380, +test_subprocess,b/162980831, diff --git a/test/runtimes/images/Dockerfile_go1.12 b/test/runtimes/images/Dockerfile_go1.12 deleted file mode 100644 index ab9d6abf3..000000000 --- a/test/runtimes/images/Dockerfile_go1.12 +++ /dev/null @@ -1,10 +0,0 @@ -# Go is easy, since we already have everything we need to compile the proctor -# binary and run the tests in the golang Docker image. -FROM golang:1.12 -ADD ["proctor/", "/go/src/proctor/"] -RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"] - -# Pre-compile the tests so we don't need to do so in each test run. -RUN ["go", "tool", "dist", "test", "-compile-only"] - -ENTRYPOINT ["/proctor", "--runtime=go"] diff --git a/test/runtimes/images/Dockerfile_java11 b/test/runtimes/images/Dockerfile_java11 deleted file mode 100644 index 9b7c3d5a3..000000000 --- a/test/runtimes/images/Dockerfile_java11 +++ /dev/null @@ -1,30 +0,0 @@ -# Compile the proctor binary. -FROM golang:1.12 AS golang -ADD ["proctor/", "/go/src/proctor/"] -RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"] - -FROM ubuntu:bionic -RUN apt-get update && apt-get install -y \ - autoconf \ - build-essential \ - curl \ - make \ - openjdk-11-jdk \ - unzip \ - zip - -# Download the JDK test library. -WORKDIR /root -RUN set -ex \ - && curl -fsSL --retry 10 -o /tmp/jdktests.tar.gz http://hg.openjdk.java.net/jdk/jdk11/archive/76072a077ee1.tar.gz/test \ - && tar -xzf /tmp/jdktests.tar.gz \ - && mv jdk11-76072a077ee1/test test \ - && rm -f /tmp/jdktests.tar.gz - -# Install jtreg and add to PATH. -RUN curl -o jtreg.tar.gz https://ci.adoptopenjdk.net/view/Dependencies/job/jtreg/lastSuccessfulBuild/artifact/jtreg-4.2.0-tip.tar.gz -RUN tar -xzf jtreg.tar.gz -ENV PATH="/root/jtreg/bin:$PATH" - -COPY --from=golang /proctor /proctor -ENTRYPOINT ["/proctor", "--runtime=java"] diff --git a/test/runtimes/images/Dockerfile_nodejs12.4.0 b/test/runtimes/images/Dockerfile_nodejs12.4.0 deleted file mode 100644 index 26f68b487..000000000 --- a/test/runtimes/images/Dockerfile_nodejs12.4.0 +++ /dev/null @@ -1,28 +0,0 @@ -# Compile the proctor binary. -FROM golang:1.12 AS golang -ADD ["proctor/", "/go/src/proctor/"] -RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"] - -FROM ubuntu:bionic -RUN apt-get update && apt-get install -y \ - curl \ - dumb-init \ - g++ \ - make \ - python - -WORKDIR /root -ARG VERSION=v12.4.0 -RUN curl -o node-${VERSION}.tar.gz https://nodejs.org/dist/${VERSION}/node-${VERSION}.tar.gz -RUN tar -zxf node-${VERSION}.tar.gz - -WORKDIR /root/node-${VERSION} -RUN ./configure -RUN make -RUN make test-build - -COPY --from=golang /proctor /proctor - -# Including dumb-init emulates the Linux "init" process, preventing the failure -# of tests involving worker processes. -ENTRYPOINT ["/usr/bin/dumb-init", "/proctor", "--runtime=nodejs"] diff --git a/test/runtimes/images/Dockerfile_php7.3.6 b/test/runtimes/images/Dockerfile_php7.3.6 deleted file mode 100644 index e6b4c6329..000000000 --- a/test/runtimes/images/Dockerfile_php7.3.6 +++ /dev/null @@ -1,27 +0,0 @@ -# Compile the proctor binary. -FROM golang:1.12 AS golang -ADD ["proctor/", "/go/src/proctor/"] -RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"] - -FROM ubuntu:bionic -RUN apt-get update && apt-get install -y \ - autoconf \ - automake \ - bison \ - build-essential \ - curl \ - libtool \ - libxml2-dev \ - re2c - -WORKDIR /root -ARG VERSION=7.3.6 -RUN curl -o php-${VERSION}.tar.gz https://www.php.net/distributions/php-${VERSION}.tar.gz -RUN tar -zxf php-${VERSION}.tar.gz - -WORKDIR /root/php-${VERSION} -RUN ./configure -RUN make - -COPY --from=golang /proctor /proctor -ENTRYPOINT ["/proctor", "--runtime=php"] diff --git a/test/runtimes/images/Dockerfile_python3.7.3 b/test/runtimes/images/Dockerfile_python3.7.3 deleted file mode 100644 index 905cd22d7..000000000 --- a/test/runtimes/images/Dockerfile_python3.7.3 +++ /dev/null @@ -1,30 +0,0 @@ -# Compile the proctor binary. -FROM golang:1.12 AS golang -ADD ["proctor/", "/go/src/proctor/"] -RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"] - -FROM ubuntu:bionic - -RUN apt-get update && apt-get install -y \ - curl \ - gcc \ - libbz2-dev \ - libffi-dev \ - liblzma-dev \ - libreadline-dev \ - libssl-dev \ - make \ - zlib1g-dev - -# Use flags -LJO to follow the html redirect and download .tar.gz. -WORKDIR /root -ARG VERSION=3.7.3 -RUN curl -LJO https://github.com/python/cpython/archive/v${VERSION}.tar.gz -RUN tar -zxf cpython-${VERSION}.tar.gz - -WORKDIR /root/cpython-${VERSION} -RUN ./configure --with-pydebug -RUN make -s -j2 - -COPY --from=golang /proctor /proctor -ENTRYPOINT ["/proctor", "--runtime=python"] diff --git a/test/runtimes/images/proctor/BUILD b/test/runtimes/proctor/BUILD index 09dc6c42f..f76e2ddc0 100644 --- a/test/runtimes/images/proctor/BUILD +++ b/test/runtimes/proctor/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test") +load("//tools:defs.bzl", "go_binary", "go_test") package(licenses = ["notice"]) @@ -12,15 +12,17 @@ go_binary( "proctor.go", "python.go", ], - visibility = ["//test/runtimes/images:__subpackages__"], + pure = True, + visibility = ["//test/runtimes:__pkg__"], ) go_test( name = "proctor_test", size = "small", srcs = ["proctor_test.go"], - embed = [":proctor"], + library = ":proctor", + pure = True, deps = [ - "//runsc/testutil", + "//pkg/test/testutil", ], ) diff --git a/test/runtimes/images/proctor/go.go b/test/runtimes/proctor/go.go index 3e2d5d8db..d0ae844e6 100644 --- a/test/runtimes/images/proctor/go.go +++ b/test/runtimes/proctor/go.go @@ -74,17 +74,26 @@ func (goRunner) ListTests() ([]string, error) { return append(toolSlice, diskFiltered...), nil } -// TestCmd implements TestRunner.TestCmd. -func (goRunner) TestCmd(test string) *exec.Cmd { - // Check if test exists on disk by searching for file of the same name. - // This will determine whether or not it is a Go test on disk. - if strings.HasSuffix(test, ".go") { - // Test has suffix ".go" which indicates a disk test, run it as such. - cmd := exec.Command("go", "run", "run.go", "-v", "--", test) +// TestCmds implements TestRunner.TestCmds. +func (goRunner) TestCmds(tests []string) []*exec.Cmd { + var toolTests, onDiskTests []string + for _, test := range tests { + if strings.HasSuffix(test, ".go") { + onDiskTests = append(onDiskTests, test) + } else { + toolTests = append(toolTests, "^"+test+"$") + } + } + + var cmds []*exec.Cmd + if len(toolTests) > 0 { + cmds = append(cmds, exec.Command("go", "tool", "dist", "test", "-v", "-no-rebuild", "-run", strings.Join(toolTests, "\\|"))) + } + if len(onDiskTests) > 0 { + cmd := exec.Command("go", append([]string{"run", "run.go", "-v", "--"}, onDiskTests...)...) cmd.Dir = goTestDir - return cmd + cmds = append(cmds, cmd) } - // No ".go" suffix, run as a tool test. - return exec.Command("go", "tool", "dist", "test", "-run", test) + return cmds } diff --git a/test/runtimes/images/proctor/java.go b/test/runtimes/proctor/java.go index 8b362029d..d456fa681 100644 --- a/test/runtimes/images/proctor/java.go +++ b/test/runtimes/proctor/java.go @@ -60,12 +60,17 @@ func (javaRunner) ListTests() ([]string, error) { return testSlice, nil } -// TestCmd implements TestRunner.TestCmd. -func (javaRunner) TestCmd(test string) *exec.Cmd { - args := []string{ - "-noreport", - "-dir:" + javaTestDir, - test, - } - return exec.Command("jtreg", args...) +// TestCmds implements TestRunner.TestCmds. +func (javaRunner) TestCmds(tests []string) []*exec.Cmd { + args := append( + []string{ + "-agentvm", // Execute each action using a pool of reusable JVMs. + "-dir:" + javaTestDir, // Base directory for test files and directories. + "-noreport", // Do not generate a final report. + "-timeoutFactor:20", // Extend the default timeout (2 min) of all tests by this factor. + "-verbose:nopass", // Verbose output but supress it for tests that passed. + }, + tests..., + ) + return []*exec.Cmd{exec.Command("jtreg", args...)} } diff --git a/test/runtimes/images/proctor/nodejs.go b/test/runtimes/proctor/nodejs.go index bd57db444..dead5af4f 100644 --- a/test/runtimes/images/proctor/nodejs.go +++ b/test/runtimes/proctor/nodejs.go @@ -39,8 +39,8 @@ func (nodejsRunner) ListTests() ([]string, error) { return testSlice, nil } -// TestCmd implements TestRunner.TestCmd. -func (nodejsRunner) TestCmd(test string) *exec.Cmd { - args := []string{filepath.Join("tools", "test.py"), test} - return exec.Command("/usr/bin/python", args...) +// TestCmds implements TestRunner.TestCmds. +func (nodejsRunner) TestCmds(tests []string) []*exec.Cmd { + args := append([]string{filepath.Join("tools", "test.py"), "--timeout=180"}, tests...) + return []*exec.Cmd{exec.Command("/usr/bin/python", args...)} } diff --git a/test/runtimes/images/proctor/php.go b/test/runtimes/proctor/php.go index 9115040e1..6a83d64e3 100644 --- a/test/runtimes/images/proctor/php.go +++ b/test/runtimes/proctor/php.go @@ -17,6 +17,7 @@ package main import ( "os/exec" "regexp" + "strings" ) var phpTestRegEx = regexp.MustCompile(`^.+\.phpt$`) @@ -35,8 +36,8 @@ func (phpRunner) ListTests() ([]string, error) { return testSlice, nil } -// TestCmd implements TestRunner.TestCmd. -func (phpRunner) TestCmd(test string) *exec.Cmd { - args := []string{"test", "TESTS=" + test} - return exec.Command("make", args...) +// TestCmds implements TestRunner.TestCmds. +func (phpRunner) TestCmds(tests []string) []*exec.Cmd { + args := []string{"test", "TESTS=" + strings.Join(tests, " ")} + return []*exec.Cmd{exec.Command("make", args...)} } diff --git a/test/runtimes/images/proctor/proctor.go b/test/runtimes/proctor/proctor.go index e6178e82b..9e0642424 100644 --- a/test/runtimes/images/proctor/proctor.go +++ b/test/runtimes/proctor/proctor.go @@ -25,6 +25,7 @@ import ( "os/signal" "path/filepath" "regexp" + "strings" "syscall" ) @@ -34,15 +35,18 @@ type TestRunner interface { // ListTests returns a string slice of tests available to run. ListTests() ([]string, error) - // TestCmd returns an *exec.Cmd that will run the given test. - TestCmd(test string) *exec.Cmd + // TestCmds returns a slice of *exec.Cmd that will run the given tests. + // There is no correlation between the number of exec.Cmds returned and the + // number of tests. It could return one command to run all tests or a few + // commands that collectively run all. + TestCmds(tests []string) []*exec.Cmd } var ( - runtime = flag.String("runtime", "", "name of runtime") - list = flag.Bool("list", false, "list all available tests") - test = flag.String("test", "", "run a single test from the list of available tests") - pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children") + runtime = flag.String("runtime", "", "name of runtime") + list = flag.Bool("list", false, "list all available tests") + testNames = flag.String("tests", "", "run a subset of the available tests") + pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children") ) func main() { @@ -74,14 +78,25 @@ func main() { return } - // Run a single test. - if *test == "" { - log.Fatalf("test flag must be provided") + var tests []string + if *testNames == "" { + // Run every test. + tests, err = tr.ListTests() + if err != nil { + log.Fatalf("failed to get all tests: %v", err) + } + } else { + // Run subset of test. + tests = strings.Split(*testNames, ",") } - cmd := tr.TestCmd(*test) - cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr - if err := cmd.Run(); err != nil { - log.Fatalf("FAIL: %v", err) + + // Run tests. + cmds := tr.TestCmds(tests) + for _, cmd := range cmds { + cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr + if err := cmd.Run(); err != nil { + log.Fatalf("FAIL: %v", err) + } } } diff --git a/test/runtimes/images/proctor/proctor_test.go b/test/runtimes/proctor/proctor_test.go index 6bb61d142..6ef2de085 100644 --- a/test/runtimes/images/proctor/proctor_test.go +++ b/test/runtimes/proctor/proctor_test.go @@ -23,24 +23,24 @@ import ( "strings" "testing" - "gvisor.dev/gvisor/runsc/testutil" + "gvisor.dev/gvisor/pkg/test/testutil" ) func touch(t *testing.T, name string) { t.Helper() f, err := os.Create(name) if err != nil { - t.Fatal(err) + t.Fatalf("error creating file %q: %v", name, err) } if err := f.Close(); err != nil { - t.Fatal(err) + t.Fatalf("error closing file %q: %v", name, err) } } func TestSearchEmptyDir(t *testing.T) { td, err := ioutil.TempDir(testutil.TmpDir(), "searchtest") if err != nil { - t.Fatal(err) + t.Fatalf("error creating searchtest: %v", err) } defer os.RemoveAll(td) @@ -60,7 +60,7 @@ func TestSearchEmptyDir(t *testing.T) { func TestSearch(t *testing.T) { td, err := ioutil.TempDir(testutil.TmpDir(), "searchtest") if err != nil { - t.Fatal(err) + t.Fatalf("error creating searchtest: %v", err) } defer os.RemoveAll(td) @@ -101,14 +101,14 @@ func TestSearch(t *testing.T) { if strings.HasSuffix(item, "/") { // This item is a directory, create it. if err := os.MkdirAll(filepath.Join(td, item), 0755); err != nil { - t.Fatal(err) + t.Fatalf("error making directory: %v", err) } } else { // This item is a file, create the directory and touch file. // Create directory in which file should be created fullDirPath := filepath.Join(td, filepath.Dir(item)) if err := os.MkdirAll(fullDirPath, 0755); err != nil { - t.Fatal(err) + t.Fatalf("error making directory: %v", err) } // Create file with full path to file. touch(t, filepath.Join(td, item)) diff --git a/test/runtimes/images/proctor/python.go b/test/runtimes/proctor/python.go index b9e0fbe6f..7c598801b 100644 --- a/test/runtimes/images/proctor/python.go +++ b/test/runtimes/proctor/python.go @@ -42,8 +42,8 @@ func (pythonRunner) ListTests() ([]string, error) { return toolSlice, nil } -// TestCmd implements TestRunner.TestCmd. -func (pythonRunner) TestCmd(test string) *exec.Cmd { - args := []string{"-m", "test", test} - return exec.Command("./python", args...) +// TestCmds implements TestRunner.TestCmds. +func (pythonRunner) TestCmds(tests []string) []*exec.Cmd { + args := append([]string{"-m", "test"}, tests...) + return []*exec.Cmd{exec.Command("./python", args...)} } diff --git a/test/runtimes/runner/BUILD b/test/runtimes/runner/BUILD new file mode 100644 index 000000000..dc0d5d5b4 --- /dev/null +++ b/test/runtimes/runner/BUILD @@ -0,0 +1,22 @@ +load("//tools:defs.bzl", "go_binary", "go_test") + +package(licenses = ["notice"]) + +go_binary( + name = "runner", + testonly = 1, + srcs = ["main.go"], + visibility = ["//test/runtimes:__pkg__"], + deps = [ + "//pkg/log", + "//pkg/test/dockerutil", + "//pkg/test/testutil", + ], +) + +go_test( + name = "exclude_test", + size = "small", + srcs = ["exclude_test.go"], + library = ":runner", +) diff --git a/test/runtimes/blacklist_test.go b/test/runtimes/runner/exclude_test.go index 52f49b984..67c2170c8 100644 --- a/test/runtimes/blacklist_test.go +++ b/test/runtimes/runner/exclude_test.go @@ -25,13 +25,13 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } -// Test that the blacklist parses without error. -func TestBlacklists(t *testing.T) { - bl, err := getBlacklist() +// Test that the exclude file parses without error. +func TestExcludelist(t *testing.T) { + ex, err := getExcludes() if err != nil { - t.Fatalf("error parsing blacklist: %v", err) + t.Fatalf("error parsing exclude file: %v", err) } - if *blacklistFile != "" && len(bl) == 0 { - t.Errorf("got empty blacklist for file %q", blacklistFile) + if *excludeFile != "" && len(ex) == 0 { + t.Errorf("got empty excludes for file %q", *excludeFile) } } diff --git a/test/runtimes/runner.go b/test/runtimes/runner/main.go index bec37c69d..948e7cf9c 100644 --- a/test/runtimes/runner.go +++ b/test/runtimes/runner/main.go @@ -16,29 +16,31 @@ package main import ( + "context" "encoding/csv" "flag" "fmt" "io" - "log" "os" "sort" "strings" "testing" "time" - "gvisor.dev/gvisor/runsc/dockerutil" - "gvisor.dev/gvisor/runsc/testutil" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/pkg/test/testutil" ) var ( - lang = flag.String("lang", "", "language runtime to test") - image = flag.String("image", "", "docker image with runtime tests") - blacklistFile = flag.String("blacklist_file", "", "file containing blacklist of tests to exclude, in CSV format with fields: test name, bug id, comment") + lang = flag.String("lang", "", "language runtime to test") + image = flag.String("image", "", "docker image with runtime tests") + excludeFile = flag.String("exclude_file", "", "file containing list of tests to exclude, in CSV format with fields: test name, bug id, comment") + batchSize = flag.Int("batch", 50, "number of test cases run in one command") ) // Wait time for each test to run. -const timeout = 5 * time.Minute +const timeout = 90 * time.Minute func main() { flag.Parse() @@ -46,7 +48,6 @@ func main() { fmt.Fprintf(os.Stderr, "lang and image flags must not be empty\n") os.Exit(1) } - os.Exit(runTests()) } @@ -54,21 +55,27 @@ func main() { // defered functions before exiting. It returns an exit code that should be // passed to os.Exit. func runTests() int { - // Get tests to blacklist. - blacklist, err := getBlacklist() + // Get tests to exclude.. + excludes, err := getExcludes() if err != nil { - fmt.Fprintf(os.Stderr, "Error getting blacklist: %s\n", err.Error()) + fmt.Fprintf(os.Stderr, "Error getting exclude list: %s\n", err.Error()) return 1 } - // Create a single docker container that will be used for all tests. - d := dockerutil.MakeDocker("gvisor-" + *lang) - defer d.CleanUp() + // Construct the shared docker instance. + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, testutil.DefaultLogger(*lang)) + defer d.CleanUp(ctx) + + if err := testutil.TouchShardStatusFile(); err != nil { + fmt.Fprintf(os.Stderr, "error touching status shard file: %v\n", err) + return 1 + } // Get a slice of tests to run. This will also start a single Docker // container that will be used to run each test. The final test will // stop the Docker container. - tests, err := getTests(d, blacklist) + tests, err := getTests(ctx, d, excludes) if err != nil { fmt.Fprintf(os.Stderr, "%s\n", err.Error()) return 1 @@ -78,21 +85,19 @@ func runTests() int { return m.Run() } -// getTests returns a slice of tests to run, subject to the shard size and -// index. -func getTests(d dockerutil.Docker, blacklist map[string]struct{}) ([]testing.InternalTest, error) { - // Pull the image. - if err := dockerutil.Pull(*image); err != nil { - return nil, fmt.Errorf("docker pull %q failed: %v", *image, err) +// getTests executes all tests as table tests. +func getTests(ctx context.Context, d *dockerutil.Container, excludes map[string]struct{}) ([]testing.InternalTest, error) { + // Start the container. + opts := dockerutil.RunOpts{ + Image: fmt.Sprintf("runtimes/%s", *image), } - - // Run proctor with --pause flag to keep container alive forever. - if err := d.Run(*image, "--pause"); err != nil { + d.CopyFiles(&opts, "/proctor", "test/runtimes/proctor/proctor") + if err := d.Spawn(ctx, opts, "/proctor/proctor", "--pause"); err != nil { return nil, fmt.Errorf("docker run failed: %v", err) } // Get a list of all tests in the image. - list, err := d.Exec("/proctor", "--runtime", *lang, "--list") + list, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/proctor/proctor", "--runtime", *lang, "--list") if err != nil { return nil, fmt.Errorf("docker exec failed: %v", err) } @@ -101,25 +106,29 @@ func getTests(d dockerutil.Docker, blacklist map[string]struct{}) ([]testing.Int // shard. tests := strings.Fields(list) sort.Strings(tests) - begin, end, err := testutil.TestBoundsForShard(len(tests)) + indices, err := testutil.TestIndicesForShard(len(tests)) if err != nil { return nil, fmt.Errorf("TestsForShard() failed: %v", err) } - log.Printf("Got bounds [%d:%d) for shard out of %d total tests", begin, end, len(tests)) - tests = tests[begin:end] var itests []testing.InternalTest - for _, tc := range tests { - // Capture tc in this scope. - tc := tc + for i := 0; i < len(indices); i += *batchSize { + var tcs []string + end := i + *batchSize + if end > len(indices) { + end = len(indices) + } + for _, tc := range indices[i:end] { + // Add test if not excluded. + if _, ok := excludes[tests[tc]]; ok { + log.Infof("Skipping test case %s\n", tests[tc]) + continue + } + tcs = append(tcs, tests[tc]) + } itests = append(itests, testing.InternalTest{ - Name: tc, + Name: strings.Join(tcs, ", "), F: func(t *testing.T) { - // Is the test blacklisted? - if _, ok := blacklist[tc]; ok { - t.Skip("SKIP: blacklisted test %q", tc) - } - var ( now = time.Now() done = make(chan struct{}) @@ -128,39 +137,36 @@ func getTests(d dockerutil.Docker, blacklist map[string]struct{}) ([]testing.Int ) go func() { - fmt.Printf("RUNNING %s...\n", tc) - output, err = d.Exec("/proctor", "--runtime", *lang, "--test", tc) + fmt.Printf("RUNNING the following in a batch\n%s\n", strings.Join(tcs, "\n")) + output, err = d.Exec(ctx, dockerutil.ExecOpts{}, "/proctor/proctor", "--runtime", *lang, "--tests", strings.Join(tcs, ",")) close(done) }() select { case <-done: if err == nil { - fmt.Printf("PASS: %s (%v)\n\n", tc, time.Since(now)) + fmt.Printf("PASS: (%v)\n\n", time.Since(now)) return } - t.Errorf("FAIL: %s (%v):\n%s\n", tc, time.Since(now), output) + t.Errorf("FAIL: (%v):\n%s\n", time.Since(now), output) case <-time.After(timeout): - t.Errorf("TIMEOUT: %s (%v):\n%s\n", tc, time.Since(now), output) + t.Errorf("TIMEOUT: (%v):\n%s\n", time.Since(now), output) } }, }) } + return itests, nil } -// getBlacklist reads the blacklist file and returns a set of test names to +// getBlacklist reads the exclude file and returns a set of test names to // exclude. -func getBlacklist() (map[string]struct{}, error) { - blacklist := make(map[string]struct{}) - if *blacklistFile == "" { - return blacklist, nil - } - file, err := testutil.FindFile(*blacklistFile) - if err != nil { - return nil, err +func getExcludes() (map[string]struct{}, error) { + excludes := make(map[string]struct{}) + if *excludeFile == "" { + return excludes, nil } - f, err := os.Open(file) + f, err := os.Open(*excludeFile) if err != nil { return nil, err } @@ -181,9 +187,9 @@ func getBlacklist() (map[string]struct{}, error) { if err != nil { return nil, err } - blacklist[record[0]] = struct{}{} + excludes[record[0]] = struct{}{} } - return blacklist, nil + return excludes, nil } // testDeps implements testing.testDeps (an unexported interface), and is diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index a53a23afd..0eadc6b08 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -1,15 +1,18 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") -load("//test/syscalls:build_defs.bzl", "syscall_test") +load("//test/runner:defs.bzl", "syscall_test") package(licenses = ["notice"]) -syscall_test(test = "//test/syscalls/linux:32bit_test") +syscall_test( + test = "//test/syscalls/linux:32bit_test", +) -syscall_test(test = "//test/syscalls/linux:accept_bind_stream_test") +syscall_test( + test = "//test/syscalls/linux:accept_bind_stream_test", +) syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:accept_bind_test", ) @@ -18,7 +21,9 @@ syscall_test( test = "//test/syscalls/linux:access_test", ) -syscall_test(test = "//test/syscalls/linux:affinity_test") +syscall_test( + test = "//test/syscalls/linux:affinity_test", +) syscall_test( add_overlay = True, @@ -31,9 +36,13 @@ syscall_test( test = "//test/syscalls/linux:alarm_test", ) -syscall_test(test = "//test/syscalls/linux:arch_prctl_test") +syscall_test( + test = "//test/syscalls/linux:arch_prctl_test", +) -syscall_test(test = "//test/syscalls/linux:bad_test") +syscall_test( + test = "//test/syscalls/linux:bad_test", +) syscall_test( size = "large", @@ -41,9 +50,27 @@ syscall_test( test = "//test/syscalls/linux:bind_test", ) -syscall_test(test = "//test/syscalls/linux:brk_test") +syscall_test( + test = "//test/syscalls/linux:brk_test", +) -syscall_test(test = "//test/syscalls/linux:socket_test") +syscall_test( + test = "//test/syscalls/linux:socket_test", +) + +syscall_test( + test = "//test/syscalls/linux:socket_capability_test", +) + +syscall_test( + size = "large", + shard_count = 50, + # Takes too long for TSAN. Since this is kind of a stress test that doesn't + # involve much concurrency, TSAN's usefulness here is limited anyway. + tags = ["nogotsan"], + test = "//test/syscalls/linux:socket_stress_test", + vfs2 = False, +) syscall_test( add_overlay = True, @@ -67,16 +94,22 @@ syscall_test( test = "//test/syscalls/linux:chroot_test", ) -syscall_test(test = "//test/syscalls/linux:clock_getres_test") +syscall_test( + test = "//test/syscalls/linux:clock_getres_test", +) syscall_test( size = "medium", test = "//test/syscalls/linux:clock_gettime_test", ) -syscall_test(test = "//test/syscalls/linux:clock_nanosleep_test") +syscall_test( + test = "//test/syscalls/linux:clock_nanosleep_test", +) -syscall_test(test = "//test/syscalls/linux:concurrency_test") +syscall_test( + test = "//test/syscalls/linux:concurrency_test", +) syscall_test( add_uds_tree = True, @@ -89,18 +122,27 @@ syscall_test( test = "//test/syscalls/linux:creat_test", ) -syscall_test(test = "//test/syscalls/linux:dev_test") +syscall_test( + fuse = "True", + test = "//test/syscalls/linux:dev_test", +) syscall_test( add_overlay = True, test = "//test/syscalls/linux:dup_test", ) -syscall_test(test = "//test/syscalls/linux:epoll_test") +syscall_test( + test = "//test/syscalls/linux:epoll_test", +) -syscall_test(test = "//test/syscalls/linux:eventfd_test") +syscall_test( + test = "//test/syscalls/linux:eventfd_test", +) -syscall_test(test = "//test/syscalls/linux:exceptions_test") +syscall_test( + test = "//test/syscalls/linux:exceptions_test", +) syscall_test( size = "medium", @@ -114,7 +156,9 @@ syscall_test( test = "//test/syscalls/linux:exec_binary_test", ) -syscall_test(test = "//test/syscalls/linux:exit_test") +syscall_test( + test = "//test/syscalls/linux:exit_test", +) syscall_test( add_overlay = True, @@ -126,7 +170,9 @@ syscall_test( test = "//test/syscalls/linux:fallocate_test", ) -syscall_test(test = "//test/syscalls/linux:fault_test") +syscall_test( + test = "//test/syscalls/linux:fault_test", +) syscall_test( add_overlay = True, @@ -144,11 +190,17 @@ syscall_test( test = "//test/syscalls/linux:flock_test", ) -syscall_test(test = "//test/syscalls/linux:fork_test") +syscall_test( + test = "//test/syscalls/linux:fork_test", +) -syscall_test(test = "//test/syscalls/linux:fpsig_fork_test") +syscall_test( + test = "//test/syscalls/linux:fpsig_fork_test", +) -syscall_test(test = "//test/syscalls/linux:fpsig_nested_test") +syscall_test( + test = "//test/syscalls/linux:fpsig_nested_test", +) syscall_test( add_overlay = True, @@ -161,18 +213,26 @@ syscall_test( test = "//test/syscalls/linux:futex_test", ) -syscall_test(test = "//test/syscalls/linux:getcpu_host_test") +syscall_test( + test = "//test/syscalls/linux:getcpu_host_test", +) -syscall_test(test = "//test/syscalls/linux:getcpu_test") +syscall_test( + test = "//test/syscalls/linux:getcpu_test", +) syscall_test( add_overlay = True, test = "//test/syscalls/linux:getdents_test", ) -syscall_test(test = "//test/syscalls/linux:getrandom_test") +syscall_test( + test = "//test/syscalls/linux:getrandom_test", +) -syscall_test(test = "//test/syscalls/linux:getrusage_test") +syscall_test( + test = "//test/syscalls/linux:getrusage_test", +) syscall_test( size = "medium", @@ -196,7 +256,9 @@ syscall_test( test = "//test/syscalls/linux:itimer_test", ) -syscall_test(test = "//test/syscalls/linux:kill_test") +syscall_test( + test = "//test/syscalls/linux:kill_test", +) syscall_test( add_overlay = True, @@ -209,13 +271,21 @@ syscall_test( test = "//test/syscalls/linux:lseek_test", ) -syscall_test(test = "//test/syscalls/linux:madvise_test") +syscall_test( + test = "//test/syscalls/linux:madvise_test", +) -syscall_test(test = "//test/syscalls/linux:memory_accounting_test") +syscall_test( + test = "//test/syscalls/linux:memory_accounting_test", +) -syscall_test(test = "//test/syscalls/linux:mempolicy_test") +syscall_test( + test = "//test/syscalls/linux:mempolicy_test", +) -syscall_test(test = "//test/syscalls/linux:mincore_test") +syscall_test( + test = "//test/syscalls/linux:mincore_test", +) syscall_test( add_overlay = True, @@ -225,7 +295,6 @@ syscall_test( syscall_test( add_overlay = True, test = "//test/syscalls/linux:mknod_test", - use_tmpfs = True, # mknod is not supported over gofer. ) syscall_test( @@ -249,7 +318,13 @@ syscall_test( test = "//test/syscalls/linux:msync_test", ) -syscall_test(test = "//test/syscalls/linux:munmap_test") +syscall_test( + test = "//test/syscalls/linux:munmap_test", +) + +syscall_test( + test = "//test/syscalls/linux:network_namespace_test", +) syscall_test( add_overlay = True, @@ -261,13 +336,28 @@ syscall_test( test = "//test/syscalls/linux:open_test", ) -syscall_test(test = "//test/syscalls/linux:packet_socket_raw_test") +syscall_test( + test = "//test/syscalls/linux:packet_socket_raw_test", +) -syscall_test(test = "//test/syscalls/linux:packet_socket_test") +syscall_test( + test = "//test/syscalls/linux:packet_socket_test", +) -syscall_test(test = "//test/syscalls/linux:partial_bad_buffer_test") +syscall_test( + test = "//test/syscalls/linux:partial_bad_buffer_test", +) -syscall_test(test = "//test/syscalls/linux:pause_test") +syscall_test( + test = "//test/syscalls/linux:pause_test", +) + +syscall_test( + size = "medium", + # Takes too long under gotsan to run. + tags = ["nogotsan"], + test = "//test/syscalls/linux:ping_socket_test", +) syscall_test( size = "large", @@ -276,16 +366,22 @@ syscall_test( test = "//test/syscalls/linux:pipe_test", ) -syscall_test(test = "//test/syscalls/linux:poll_test") +syscall_test( + test = "//test/syscalls/linux:poll_test", +) syscall_test( size = "medium", test = "//test/syscalls/linux:ppoll_test", ) -syscall_test(test = "//test/syscalls/linux:prctl_setuid_test") +syscall_test( + test = "//test/syscalls/linux:prctl_setuid_test", +) -syscall_test(test = "//test/syscalls/linux:prctl_test") +syscall_test( + test = "//test/syscalls/linux:prctl_test", +) syscall_test( add_overlay = True, @@ -302,23 +398,39 @@ syscall_test( test = "//test/syscalls/linux:preadv2_test", ) -syscall_test(test = "//test/syscalls/linux:priority_test") +syscall_test( + test = "//test/syscalls/linux:priority_test", +) syscall_test( size = "medium", test = "//test/syscalls/linux:proc_test", ) -syscall_test(test = "//test/syscalls/linux:proc_pid_uid_gid_map_test") +syscall_test( + test = "//test/syscalls/linux:proc_net_test", +) + +syscall_test( + test = "//test/syscalls/linux:proc_pid_oomscore_test", +) + +syscall_test( + test = "//test/syscalls/linux:proc_pid_smaps_test", +) -syscall_test(test = "//test/syscalls/linux:proc_net_test") +syscall_test( + test = "//test/syscalls/linux:proc_pid_uid_gid_map_test", +) syscall_test( size = "medium", test = "//test/syscalls/linux:pselect_test", ) -syscall_test(test = "//test/syscalls/linux:ptrace_test") +syscall_test( + test = "//test/syscalls/linux:ptrace_test", +) syscall_test( size = "medium", @@ -340,11 +452,17 @@ syscall_test( test = "//test/syscalls/linux:pwrite64_test", ) -syscall_test(test = "//test/syscalls/linux:raw_socket_hdrincl_test") +syscall_test( + test = "//test/syscalls/linux:raw_socket_hdrincl_test", +) -syscall_test(test = "//test/syscalls/linux:raw_socket_icmp_test") +syscall_test( + test = "//test/syscalls/linux:raw_socket_icmp_test", +) -syscall_test(test = "//test/syscalls/linux:raw_socket_ipv4_test") +syscall_test( + test = "//test/syscalls/linux:raw_socket_test", +) syscall_test( add_overlay = True, @@ -374,17 +492,37 @@ syscall_test( test = "//test/syscalls/linux:rename_test", ) -syscall_test(test = "//test/syscalls/linux:rlimits_test") +syscall_test( + test = "//test/syscalls/linux:rlimits_test", +) -syscall_test(test = "//test/syscalls/linux:rtsignal_test") +syscall_test( + test = "//test/syscalls/linux:rseq_test", +) -syscall_test(test = "//test/syscalls/linux:sched_test") +syscall_test( + test = "//test/syscalls/linux:rtsignal_test", +) -syscall_test(test = "//test/syscalls/linux:sched_yield_test") +syscall_test( + test = "//test/syscalls/linux:signalfd_test", +) -syscall_test(test = "//test/syscalls/linux:seccomp_test") +syscall_test( + test = "//test/syscalls/linux:sched_test", +) -syscall_test(test = "//test/syscalls/linux:select_test") +syscall_test( + test = "//test/syscalls/linux:sched_yield_test", +) + +syscall_test( + test = "//test/syscalls/linux:seccomp_test", +) + +syscall_test( + test = "//test/syscalls/linux:select_test", +) syscall_test( shard_count = 20, @@ -406,21 +544,29 @@ syscall_test( test = "//test/syscalls/linux:splice_test", ) -syscall_test(test = "//test/syscalls/linux:sigaction_test") +syscall_test( + test = "//test/syscalls/linux:sigaction_test", +) # TODO(b/119826902): Enable once the test passes in runsc. -# syscall_test(test = "//test/syscalls/linux:sigaltstack_test") +# syscall_test(vfs2="True",test = "//test/syscalls/linux:sigaltstack_test") -syscall_test(test = "//test/syscalls/linux:sigiret_test") +syscall_test( + test = "//test/syscalls/linux:sigiret_test", +) -syscall_test(test = "//test/syscalls/linux:sigprocmask_test") +syscall_test( + test = "//test/syscalls/linux:sigprocmask_test", +) syscall_test( size = "medium", test = "//test/syscalls/linux:sigstop_test", ) -syscall_test(test = "//test/syscalls/linux:sigtimedwait_test") +syscall_test( + test = "//test/syscalls/linux:sigtimedwait_test", +) syscall_test( size = "medium", @@ -434,7 +580,7 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_abstract_test", ) @@ -445,7 +591,7 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_domain_test", ) @@ -458,19 +604,27 @@ syscall_test( syscall_test( size = "large", add_overlay = True, - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_filesystem_test", ) syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_inet_loopback_test", ) syscall_test( size = "large", - shard_count = 10, + shard_count = 50, + # Takes too long for TSAN. Creates a lot of TCP sockets. + tags = ["nogotsan"], + test = "//test/syscalls/linux:socket_inet_loopback_nogotsan_test", +) + +syscall_test( + size = "large", + shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_generic_loopback_test", ) @@ -481,13 +635,13 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_loopback_test", ) syscall_test( size = "medium", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_udp_generic_loopback_test", ) @@ -498,7 +652,7 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_ip_udp_loopback_test", ) @@ -507,19 +661,41 @@ syscall_test( test = "//test/syscalls/linux:socket_ipv4_udp_unbound_loopback_test", ) -syscall_test(test = "//test/syscalls/linux:socket_ip_unbound_test") +syscall_test( + test = "//test/syscalls/linux:socket_ip_unbound_test", +) + +syscall_test( + test = "//test/syscalls/linux:socket_netdevice_test", +) -syscall_test(test = "//test/syscalls/linux:socket_netdevice_test") +syscall_test( + test = "//test/syscalls/linux:socket_netlink_test", +) -syscall_test(test = "//test/syscalls/linux:socket_netlink_route_test") +syscall_test( + test = "//test/syscalls/linux:socket_netlink_route_test", +) + +syscall_test( + test = "//test/syscalls/linux:socket_netlink_uevent_test", +) -syscall_test(test = "//test/syscalls/linux:socket_blocking_local_test") +syscall_test( + test = "//test/syscalls/linux:socket_blocking_local_test", +) -syscall_test(test = "//test/syscalls/linux:socket_blocking_ip_test") +syscall_test( + test = "//test/syscalls/linux:socket_blocking_ip_test", +) -syscall_test(test = "//test/syscalls/linux:socket_non_stream_blocking_local_test") +syscall_test( + test = "//test/syscalls/linux:socket_non_stream_blocking_local_test", +) -syscall_test(test = "//test/syscalls/linux:socket_non_stream_blocking_udp_test") +syscall_test( + test = "//test/syscalls/linux:socket_non_stream_blocking_udp_test", +) syscall_test( size = "large", @@ -556,7 +732,7 @@ syscall_test( syscall_test( size = "large", add_overlay = True, - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_unix_pair_test", ) @@ -595,7 +771,7 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_unix_unbound_stream_test", ) @@ -634,11 +810,17 @@ syscall_test( test = "//test/syscalls/linux:sync_file_range_test", ) -syscall_test(test = "//test/syscalls/linux:sysinfo_test") +syscall_test( + test = "//test/syscalls/linux:sysinfo_test", +) -syscall_test(test = "//test/syscalls/linux:syslog_test") +syscall_test( + test = "//test/syscalls/linux:syslog_test", +) -syscall_test(test = "//test/syscalls/linux:sysret_test") +syscall_test( + test = "//test/syscalls/linux:sysret_test", +) syscall_test( size = "medium", @@ -646,52 +828,88 @@ syscall_test( test = "//test/syscalls/linux:tcp_socket_test", ) -syscall_test(test = "//test/syscalls/linux:tgkill_test") +syscall_test( + test = "//test/syscalls/linux:tgkill_test", +) -syscall_test(test = "//test/syscalls/linux:timerfd_test") +syscall_test( + test = "//test/syscalls/linux:timerfd_test", +) -syscall_test(test = "//test/syscalls/linux:timers_test") +syscall_test( + test = "//test/syscalls/linux:timers_test", +) -syscall_test(test = "//test/syscalls/linux:time_test") +syscall_test( + test = "//test/syscalls/linux:time_test", +) -syscall_test(test = "//test/syscalls/linux:tkill_test") +syscall_test( + test = "//test/syscalls/linux:tkill_test", +) syscall_test( add_overlay = True, test = "//test/syscalls/linux:truncate_test", ) -syscall_test(test = "//test/syscalls/linux:udp_bind_test") +syscall_test( + test = "//test/syscalls/linux:tuntap_test", +) + +syscall_test( + add_hostinet = True, + test = "//test/syscalls/linux:tuntap_hostinet_test", +) + +syscall_test( + test = "//test/syscalls/linux:udp_bind_test", +) syscall_test( size = "medium", + add_hostinet = True, shard_count = 10, test = "//test/syscalls/linux:udp_socket_test", ) -syscall_test(test = "//test/syscalls/linux:uidgid_test") +syscall_test( + test = "//test/syscalls/linux:uidgid_test", +) -syscall_test(test = "//test/syscalls/linux:uname_test") +syscall_test( + test = "//test/syscalls/linux:uname_test", +) syscall_test( add_overlay = True, test = "//test/syscalls/linux:unlink_test", ) -syscall_test(test = "//test/syscalls/linux:unshare_test") +syscall_test( + test = "//test/syscalls/linux:unshare_test", +) -syscall_test(test = "//test/syscalls/linux:utimes_test") +syscall_test( + test = "//test/syscalls/linux:utimes_test", +) syscall_test( size = "medium", test = "//test/syscalls/linux:vdso_clock_gettime_test", ) -syscall_test(test = "//test/syscalls/linux:vdso_test") +syscall_test( + test = "//test/syscalls/linux:vdso_test", +) -syscall_test(test = "//test/syscalls/linux:vsyscall_test") +syscall_test( + test = "//test/syscalls/linux:vsyscall_test", +) -syscall_test(test = "//test/syscalls/linux:vfork_test") +syscall_test( + test = "//test/syscalls/linux:vfork_test", +) syscall_test( size = "medium", @@ -704,26 +922,14 @@ syscall_test( test = "//test/syscalls/linux:write_test", ) -syscall_test(test = "//test/syscalls/linux:proc_net_unix_test") - -syscall_test(test = "//test/syscalls/linux:proc_net_tcp_test") - -syscall_test(test = "//test/syscalls/linux:proc_net_udp_test") - -go_binary( - name = "syscall_test_runner", - testonly = 1, - srcs = ["syscall_test_runner.go"], - data = [ - "//runsc", - ], - deps = [ - "//pkg/log", - "//runsc/specutils", - "//runsc/testutil", - "//test/syscalls/gtest", - "//test/uds", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - ], +syscall_test( + test = "//test/syscalls/linux:proc_net_unix_test", +) + +syscall_test( + test = "//test/syscalls/linux:proc_net_tcp_test", +) + +syscall_test( + test = "//test/syscalls/linux:proc_net_udp_test", ) diff --git a/test/syscalls/build_defs.bzl b/test/syscalls/build_defs.bzl deleted file mode 100644 index dcf5b73ed..000000000 --- a/test/syscalls/build_defs.bzl +++ /dev/null @@ -1,136 +0,0 @@ -"""Defines a rule for syscall test targets.""" - -# syscall_test is a macro that will create targets to run the given test target -# on the host (native) and runsc. -def syscall_test( - test, - shard_count = 5, - size = "small", - use_tmpfs = False, - add_overlay = False, - add_uds_tree = False, - tags = None): - _syscall_test( - test = test, - shard_count = shard_count, - size = size, - platform = "native", - use_tmpfs = False, - add_uds_tree = add_uds_tree, - tags = tags, - ) - - _syscall_test( - test = test, - shard_count = shard_count, - size = size, - platform = "kvm", - use_tmpfs = use_tmpfs, - add_uds_tree = add_uds_tree, - tags = tags, - ) - - _syscall_test( - test = test, - shard_count = shard_count, - size = size, - platform = "ptrace", - use_tmpfs = use_tmpfs, - add_uds_tree = add_uds_tree, - tags = tags, - ) - - if add_overlay: - _syscall_test( - test = test, - shard_count = shard_count, - size = size, - platform = "ptrace", - use_tmpfs = False, # overlay is adding a writable tmpfs on top of root. - add_uds_tree = add_uds_tree, - tags = tags, - overlay = True, - ) - - if not use_tmpfs: - # Also test shared gofer access. - _syscall_test( - test = test, - shard_count = shard_count, - size = size, - platform = "ptrace", - use_tmpfs = use_tmpfs, - add_uds_tree = add_uds_tree, - tags = tags, - file_access = "shared", - ) - -def _syscall_test( - test, - shard_count, - size, - platform, - use_tmpfs, - tags, - file_access = "exclusive", - overlay = False, - add_uds_tree = False): - test_name = test.split(":")[1] - - # Prepend "runsc" to non-native platform names. - full_platform = platform if platform == "native" else "runsc_" + platform - - name = test_name + "_" + full_platform - if file_access == "shared": - name += "_shared" - if overlay: - name += "_overlay" - - if tags == None: - tags = [] - - # Add the full_platform and file access in a tag to make it easier to run - # all the tests on a specific flavor. Use --test_tag_filters=ptrace,file_shared. - tags += [full_platform, "file_" + file_access] - - # Add tag to prevent the tests from running in a Bazel sandbox. - # TODO(b/120560048): Make the tests run without this tag. - tags.append("no-sandbox") - - # TODO(b/112165693): KVM tests are tagged "manual" to until the platform is - # more stable. - if platform == "kvm": - tags += ["manual"] - tags += ["requires-kvm"] - - args = [ - # Arguments are passed directly to syscall_test_runner binary. - "--test-name=" + test_name, - "--platform=" + platform, - "--use-tmpfs=" + str(use_tmpfs), - "--file-access=" + file_access, - "--overlay=" + str(overlay), - "--add-uds-tree=" + str(add_uds_tree), - ] - - sh_test( - srcs = ["syscall_test_runner.sh"], - name = name, - data = [ - ":syscall_test_runner", - test, - ], - args = args, - size = size, - tags = tags, - shard_count = shard_count, - ) - -def sh_test(**kwargs): - """Wraps the standard sh_test.""" - native.sh_test( - **kwargs - ) - -def select_for_linux(for_linux, for_others = []): - return for_linux diff --git a/test/syscalls/gtest/BUILD b/test/syscalls/gtest/BUILD deleted file mode 100644 index 9293f25cb..000000000 --- a/test/syscalls/gtest/BUILD +++ /dev/null @@ -1,12 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "gtest", - srcs = ["gtest.go"], - importpath = "gvisor.dev/gvisor/test/syscalls/gtest", - visibility = [ - "//test:__subpackages__", - ], -) diff --git a/test/syscalls/gtest/gtest.go b/test/syscalls/gtest/gtest.go deleted file mode 100644 index bdec8eb07..000000000 --- a/test/syscalls/gtest/gtest.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package gtest contains helpers for running google-test tests from Go. -package gtest - -import ( - "fmt" - "os/exec" - "strings" -) - -var ( - // ListTestFlag is the flag that will list tests in gtest binaries. - ListTestFlag = "--gtest_list_tests" - - // FilterTestFlag is the flag that will filter tests in gtest binaries. - FilterTestFlag = "--gtest_filter" -) - -// TestCase is a single gtest test case. -type TestCase struct { - // Suite is the suite for this test. - Suite string - - // Name is the name of this individual test. - Name string -} - -// FullName returns the name of the test including the suite. It is suitable to -// pass to "-gtest_filter". -func (tc TestCase) FullName() string { - return fmt.Sprintf("%s.%s", tc.Suite, tc.Name) -} - -// ParseTestCases calls a gtest test binary to list its test and returns a -// slice with the name and suite of each test. -func ParseTestCases(testBin string, extraArgs ...string) ([]TestCase, error) { - args := append([]string{ListTestFlag}, extraArgs...) - cmd := exec.Command(testBin, args...) - out, err := cmd.Output() - if err != nil { - exitErr, ok := err.(*exec.ExitError) - if !ok { - return nil, fmt.Errorf("could not enumerate gtest tests: %v", err) - } - return nil, fmt.Errorf("could not enumerate gtest tests: %v\nstderr:\n%s", err, exitErr.Stderr) - } - - var t []TestCase - var suite string - for _, line := range strings.Split(string(out), "\n") { - // Strip comments. - line = strings.Split(line, "#")[0] - - // New suite? - if !strings.HasPrefix(line, " ") { - suite = strings.TrimSuffix(strings.TrimSpace(line), ".") - continue - } - - // Individual test. - name := strings.TrimSpace(line) - - // Do we have a suite yet? - if suite == "" { - return nil, fmt.Errorf("test without a suite: %v", name) - } - - // Add this individual test. - t = append(t, TestCase{ - Suite: suite, - Name: name, - }) - - } - - if len(t) == 0 { - return nil, fmt.Errorf("no tests parsed from %v", testBin) - } - return t, nil -} diff --git a/test/syscalls/linux/32bit.cc b/test/syscalls/linux/32bit.cc index a7cbee06b..3c825477c 100644 --- a/test/syscalls/linux/32bit.cc +++ b/test/syscalls/linux/32bit.cc @@ -15,10 +15,12 @@ #include <string.h> #include <sys/mman.h> +#include "gtest/gtest.h" +#include "absl/base/macros.h" #include "test/util/memory_util.h" +#include "test/util/platform_util.h" #include "test/util/posix_error.h" #include "test/util/test_util.h" -#include "gtest/gtest.h" #ifndef __x86_64__ #error "This test is x86-64 specific." @@ -30,7 +32,6 @@ namespace testing { namespace { constexpr char kInt3 = '\xcc'; - constexpr char kInt80[2] = {'\xcd', '\x80'}; constexpr char kSyscall[2] = {'\x0f', '\x05'}; constexpr char kSysenter[2] = {'\x0f', '\x34'}; @@ -43,6 +44,7 @@ void ExitGroup32(const char instruction[2], int code) { // Fill with INT 3 in case we execute too far. memset(m.ptr(), kInt3, m.len()); + // Copy in the actual instruction. memcpy(m.ptr(), instruction, 2); // We're playing *extremely* fast-and-loose with the various syscall ABIs @@ -71,77 +73,96 @@ void ExitGroup32(const char instruction[2], int code) { "iretl\n" "int $3\n" : - : [code] "m"(code), [ip] "d"(m.ptr()) - : "rax", "rbx", "rsp"); + : [ code ] "m"(code), [ ip ] "d"(m.ptr()) + : "rax", "rbx"); } constexpr int kExitCode = 42; TEST(Syscall32Bit, Int80) { - switch (GvisorPlatform()) { - case Platform::kKVM: - // TODO(b/111805002): 32-bit segments are broken (but not explictly - // disabled). - return; - case Platform::kPtrace: - // TODO(gvisor.dev/issue/167): The ptrace platform does not have a - // consistent story here. - return; - case Platform::kNative: + switch (PlatformSupport32Bit()) { + case PlatformSupport::NotSupported: + break; + case PlatformSupport::Segfault: + EXPECT_EXIT(ExitGroup32(kInt80, kExitCode), + ::testing::KilledBySignal(SIGSEGV), ""); break; - } - // Upstream Linux. 32-bit syscalls allowed. - EXPECT_EXIT(ExitGroup32(kInt80, kExitCode), ::testing::ExitedWithCode(42), - ""); -} + case PlatformSupport::Ignored: + // Since the call is ignored, we'll hit the int3 trap. + EXPECT_EXIT(ExitGroup32(kInt80, kExitCode), + ::testing::KilledBySignal(SIGTRAP), ""); + break; -TEST(Syscall32Bit, Sysenter) { - switch (GvisorPlatform()) { - case Platform::kKVM: - // TODO(b/111805002): See above. - return; - case Platform::kPtrace: - // TODO(gvisor.dev/issue/167): See above. - return; - case Platform::kNative: + case PlatformSupport::Allowed: + EXPECT_EXIT(ExitGroup32(kInt80, kExitCode), ::testing::ExitedWithCode(42), + ""); break; } +} - if (GetCPUVendor() == CPUVendor::kAMD) { +TEST(Syscall32Bit, Sysenter) { + if ((PlatformSupport32Bit() == PlatformSupport::Allowed || + PlatformSupport32Bit() == PlatformSupport::Ignored) && + GetCPUVendor() == CPUVendor::kAMD) { // SYSENTER is an illegal instruction in compatibility mode on AMD. EXPECT_EXIT(ExitGroup32(kSysenter, kExitCode), ::testing::KilledBySignal(SIGILL), ""); return; } - // Upstream Linux on !AMD, 32-bit syscalls allowed. - EXPECT_EXIT(ExitGroup32(kSysenter, kExitCode), ::testing::ExitedWithCode(42), - ""); -} + switch (PlatformSupport32Bit()) { + case PlatformSupport::NotSupported: + break; -TEST(Syscall32Bit, Syscall) { - switch (GvisorPlatform()) { - case Platform::kKVM: - // TODO(b/111805002): See above. - return; - case Platform::kPtrace: - // TODO(gvisor.dev/issue/167): See above. - return; - case Platform::kNative: + case PlatformSupport::Segfault: + EXPECT_EXIT(ExitGroup32(kSysenter, kExitCode), + ::testing::KilledBySignal(SIGSEGV), ""); + break; + + case PlatformSupport::Ignored: + // See above, except expected code is SIGSEGV. + EXPECT_EXIT(ExitGroup32(kSysenter, kExitCode), + ::testing::KilledBySignal(SIGSEGV), ""); + break; + + case PlatformSupport::Allowed: + EXPECT_EXIT(ExitGroup32(kSysenter, kExitCode), + ::testing::ExitedWithCode(42), ""); break; } +} - if (GetCPUVendor() == CPUVendor::kIntel) { +TEST(Syscall32Bit, Syscall) { + if ((PlatformSupport32Bit() == PlatformSupport::Allowed || + PlatformSupport32Bit() == PlatformSupport::Ignored) && + GetCPUVendor() == CPUVendor::kIntel) { // SYSCALL is an illegal instruction in compatibility mode on Intel. EXPECT_EXIT(ExitGroup32(kSyscall, kExitCode), ::testing::KilledBySignal(SIGILL), ""); return; } - // Upstream Linux on !Intel, 32-bit syscalls allowed. - EXPECT_EXIT(ExitGroup32(kSyscall, kExitCode), ::testing::ExitedWithCode(42), - ""); + switch (PlatformSupport32Bit()) { + case PlatformSupport::NotSupported: + break; + + case PlatformSupport::Segfault: + EXPECT_EXIT(ExitGroup32(kSyscall, kExitCode), + ::testing::KilledBySignal(SIGSEGV), ""); + break; + + case PlatformSupport::Ignored: + // See above. + EXPECT_EXIT(ExitGroup32(kSyscall, kExitCode), + ::testing::KilledBySignal(SIGSEGV), ""); + break; + + case PlatformSupport::Allowed: + EXPECT_EXIT(ExitGroup32(kSyscall, kExitCode), + ::testing::ExitedWithCode(42), ""); + break; + } } // Far call code called below. @@ -205,19 +226,20 @@ void FarCall32() { } TEST(Call32Bit, Disallowed) { - switch (GvisorPlatform()) { - case Platform::kKVM: - // TODO(b/111805002): See above. - return; - case Platform::kPtrace: - // The ptrace platform cannot prevent switching to compatibility mode. - ABSL_FALLTHROUGH_INTENDED; - case Platform::kNative: + switch (PlatformSupport32Bit()) { + case PlatformSupport::NotSupported: break; - } - // Shouldn't crash. - FarCall32(); + case PlatformSupport::Segfault: + EXPECT_EXIT(FarCall32(), ::testing::KilledBySignal(SIGSEGV), ""); + break; + + case PlatformSupport::Ignored: + ABSL_FALLTHROUGH_INTENDED; + case PlatformSupport::Allowed: + // Shouldn't crash. + FarCall32(); + } } } // namespace diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 833fbaa09..66a31cd28 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1,11 +1,34 @@ -load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library") -load("//test/syscalls:build_defs.bzl", "select_for_linux") +load("//tools:defs.bzl", "cc_binary", "cc_library", "default_net_util", "gtest", "select_arch", "select_system") package( default_visibility = ["//:sandbox"], licenses = ["notice"], ) +exports_files( + [ + "socket.cc", + "socket_inet_loopback.cc", + "socket_ip_loopback_blocking.cc", + "socket_ip_tcp_generic_loopback.cc", + "socket_ip_tcp_loopback.cc", + "socket_ip_tcp_loopback_blocking.cc", + "socket_ip_tcp_loopback_nonblock.cc", + "socket_ip_tcp_udp_generic.cc", + "socket_ip_udp_loopback.cc", + "socket_ip_udp_loopback_blocking.cc", + "socket_ip_udp_loopback_nonblock.cc", + "socket_ip_unbound.cc", + "socket_ipv4_tcp_unbound_external_networking_test.cc", + "socket_ipv4_udp_unbound_external_networking_test.cc", + "socket_ipv4_udp_unbound_loopback.cc", + "tcp_socket.cc", + "udp_bind.cc", + "udp_socket.cc", + ], + visibility = ["//:sandbox"], +) + cc_binary( name = "sigaltstack_check", testonly = 1, @@ -70,14 +93,14 @@ cc_library( srcs = ["base_poll_test.cc"], hdrs = ["base_poll_test.h"], deps = [ + "@com_google_absl//absl/memory", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + gtest, "//test/util:logging", "//test/util:signal_util", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -87,11 +110,11 @@ cc_library( hdrs = ["file_base.h"], deps = [ "//test/util:file_descriptor", + "@com_google_absl//absl/strings", + gtest, "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_util", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", ], ) @@ -109,34 +132,37 @@ cc_library( ) cc_library( + name = "socket_netlink_route_util", + testonly = 1, + srcs = ["socket_netlink_route_util.cc"], + hdrs = ["socket_netlink_route_util.h"], + deps = [ + ":socket_netlink_util", + ], +) + +cc_library( name = "socket_test_util", testonly = 1, srcs = [ "socket_test_util.cc", - ] + select_for_linux( - [ - "socket_test_util_impl.cc", - ], - ), + "socket_test_util_impl.cc", + ], hdrs = ["socket_test_util.h"], - deps = [ - "@com_google_googletest//:gtest", + defines = select_system(), + deps = default_net_util() + [ + gtest, "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", + "@com_google_absl//absl/types:optional", "//test/util:file_descriptor", "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_util", "//test/util:thread_util", - ] + select_for_linux([ - ]), -) - -cc_library( - name = "temp_umask", - hdrs = ["temp_umask.h"], + ], ) cc_library( @@ -146,9 +172,9 @@ cc_library( hdrs = ["unix_domain_socket_test_util.h"], deps = [ ":socket_test_util", - "//test/util:test_util", "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", + gtest, + "//test/util:test_util", ], ) @@ -170,28 +196,33 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:cleanup", + "@com_google_absl//absl/time", + gtest, "//test/util:posix_error", "//test/util:signal_util", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", "//test/util:timer_util", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) cc_binary( name = "32bit_test", testonly = 1, - srcs = ["32bit.cc"], + srcs = select_arch( + amd64 = ["32bit.cc"], + arm64 = [], + ), linkstatic = 1, deps = [ + "@com_google_absl//absl/base:core_headers", + gtest, "//test/util:memory_util", + "//test/util:platform_util", "//test/util:posix_error", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -204,9 +235,9 @@ cc_binary( ":socket_test_util", ":unix_domain_socket_test_util", "//test/util:file_descriptor", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -219,9 +250,9 @@ cc_binary( ":socket_test_util", ":unix_domain_socket_test_util", "//test/util:file_descriptor", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -233,10 +264,10 @@ cc_binary( deps = [ "//test/util:capability_util", "//test/util:fs_util", + gtest, "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -248,12 +279,12 @@ cc_binary( deps = [ "//test/util:cleanup", "//test/util:fs_util", + "@com_google_absl//absl/strings", + gtest, "//test/util:posix_error", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", ], ) @@ -266,12 +297,11 @@ cc_binary( ], linkstatic = 1, deps = [ - # The heapchecker doesn't recognize that io_destroy munmaps. - "@com_google_googletest//:gtest", - "@com_google_absl//absl/strings", "//test/util:cleanup", "//test/util:file_descriptor", "//test/util:fs_util", + "@com_google_absl//absl/strings", + gtest, "//test/util:memory_util", "//test/util:posix_error", "//test/util:proc_util", @@ -288,12 +318,12 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:file_descriptor", + "@com_google_absl//absl/time", + gtest, "//test/util:logging", "//test/util:signal_util", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -306,9 +336,9 @@ cc_binary( "//:sandbox", ], deps = [ + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -320,9 +350,9 @@ cc_binary( deps = [ ":socket_test_util", ":unix_domain_socket_test_util", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -333,10 +363,26 @@ cc_binary( linkstatic = 1, deps = [ ":socket_test_util", + gtest, + "//test/util:file_descriptor", + "//test/util:temp_umask", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "socket_capability_test", + testonly = 1, + srcs = ["socket_capability.cc"], + linkstatic = 1, + deps = [ + ":socket_test_util", + "//test/util:capability_util", "//test/util:file_descriptor", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -358,10 +404,10 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:capability_util", + gtest, "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -374,10 +420,10 @@ cc_binary( "//test/util:capability_util", "//test/util:file_descriptor", "//test/util:fs_util", + gtest, "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -390,14 +436,14 @@ cc_binary( "//test/util:capability_util", "//test/util:file_descriptor", "//test/util:fs_util", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/synchronization", + gtest, "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/synchronization", - "@com_google_googletest//:gtest", ], ) @@ -410,12 +456,12 @@ cc_binary( "//test/util:capability_util", "//test/util:file_descriptor", "//test/util:fs_util", + "@com_google_absl//absl/flags:flag", + gtest, "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/flags:flag", - "@com_google_googletest//:gtest", ], ) @@ -429,12 +475,12 @@ cc_binary( "//test/util:cleanup", "//test/util:file_descriptor", "//test/util:fs_util", + "@com_google_absl//absl/strings", + gtest, "//test/util:mount_util", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", ], ) @@ -444,9 +490,9 @@ cc_binary( srcs = ["clock_getres.cc"], linkstatic = 1, deps = [ + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -456,11 +502,11 @@ cc_binary( srcs = ["clock_gettime.cc"], linkstatic = 1, deps = [ + "@com_google_absl//absl/time", + gtest, "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -470,12 +516,13 @@ cc_binary( srcs = ["concurrency.cc"], linkstatic = 1, deps = [ + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + gtest, + "//test/util:platform_util", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -488,9 +535,9 @@ cc_binary( ":socket_test_util", "//test/util:file_descriptor", "//test/util:fs_util", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -501,10 +548,10 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:fs_util", + gtest, "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -515,9 +562,9 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:file_descriptor", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -529,11 +576,11 @@ cc_binary( deps = [ "//test/util:eventfd_util", "//test/util:file_descriptor", + gtest, "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -546,10 +593,10 @@ cc_binary( "//test/util:epoll_util", "//test/util:eventfd_util", "//test/util:file_descriptor", + gtest, "//test/util:posix_error", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -561,24 +608,28 @@ cc_binary( deps = [ "//test/util:epoll_util", "//test/util:eventfd_util", + gtest, "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_googletest//:gtest", ], ) cc_binary( name = "exceptions_test", testonly = 1, - srcs = ["exceptions.cc"], + srcs = select_arch( + amd64 = ["exceptions.cc"], + arm64 = [], + ), linkstatic = 1, deps = [ + gtest, "//test/util:logging", + "//test/util:platform_util", "//test/util:signal_util", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -588,10 +639,10 @@ cc_binary( srcs = ["getcpu.cc"], linkstatic = 1, deps = [ + "@com_google_absl//absl/time", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -601,10 +652,10 @@ cc_binary( srcs = ["getcpu.cc"], linkstatic = 1, deps = [ + "@com_google_absl//absl/time", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -614,13 +665,13 @@ cc_binary( srcs = ["getrusage.cc"], linkstatic = 1, deps = [ + "@com_google_absl//absl/time", + gtest, "//test/util:logging", "//test/util:memory_util", "//test/util:signal_util", "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -633,14 +684,14 @@ cc_binary( "//test/util:cleanup", "//test/util:file_descriptor", "//test/util:fs_util", + "@com_google_absl//absl/strings", + gtest, "//test/util:multiprocess_util", "//test/util:posix_error", "//test/util:proc_util", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", ], ) @@ -663,15 +714,15 @@ cc_binary( deps = [ "//test/util:file_descriptor", "//test/util:fs_util", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:optional", + gtest, "//test/util:multiprocess_util", "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:optional", - "@com_google_googletest//:gtest", ], ) @@ -682,11 +733,11 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:file_descriptor", + "@com_google_absl//absl/time", + gtest, "//test/util:test_main", "//test/util:test_util", "//test/util:time_util", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -697,12 +748,17 @@ cc_binary( linkstatic = 1, deps = [ ":file_base", + ":socket_test_util", "//test/util:cleanup", + "//test/util:eventfd_util", "//test/util:file_descriptor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + gtest, + "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -712,9 +768,9 @@ cc_binary( srcs = ["fault.cc"], linkstatic = 1, deps = [ + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -725,10 +781,10 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:capability_util", + gtest, "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -740,18 +796,22 @@ cc_binary( deps = [ ":socket_test_util", "//test/util:cleanup", + "//test/util:epoll_util", "//test/util:eventfd_util", - "//test/util:multiprocess_util", - "//test/util:posix_error", - "//test/util:temp_path", - "//test/util:test_util", - "//test/util:timer_util", + "//test/util:fs_util", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", + gtest, + "//test/util:multiprocess_util", + "//test/util:posix_error", + "//test/util:save_util", + "//test/util:temp_path", + "//test/util:test_util", + "//test/util:thread_util", + "//test/util:timer_util", ], ) @@ -764,16 +824,19 @@ cc_binary( ], linkstatic = 1, deps = [ + ":socket_test_util", "//test/util:file_descriptor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + gtest, + "//test/util:epoll_util", + "//test/util:eventfd_util", "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", "//test/util:timer_util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -784,13 +847,13 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:capability_util", + "@com_google_absl//absl/time", + gtest, "//test/util:logging", "//test/util:memory_util", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -800,11 +863,11 @@ cc_binary( srcs = ["fpsig_fork.cc"], linkstatic = 1, deps = [ + gtest, "//test/util:logging", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_googletest//:gtest", ], ) @@ -814,10 +877,10 @@ cc_binary( srcs = ["fpsig_nested.cc"], linkstatic = 1, deps = [ + gtest, "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_googletest//:gtest", ], ) @@ -828,10 +891,10 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:file_descriptor", + gtest, "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -842,10 +905,10 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:file_descriptor", + gtest, "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -857,6 +920,9 @@ cc_binary( deps = [ "//test/util:cleanup", "//test/util:file_descriptor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/time", + gtest, "//test/util:memory_util", "//test/util:save_util", "//test/util:temp_path", @@ -865,9 +931,6 @@ cc_binary( "//test/util:thread_util", "//test/util:time_util", "//test/util:timer_util", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -880,12 +943,13 @@ cc_binary( "//test/util:eventfd_util", "//test/util:file_descriptor", "//test/util:fs_util", + "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/strings", + gtest, "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", ], ) @@ -895,9 +959,9 @@ cc_binary( srcs = ["getrandom.cc"], linkstatic = 1, deps = [ + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -910,12 +974,14 @@ cc_binary( "//test/util:epoll_util", "//test/util:file_descriptor", "//test/util:fs_util", + "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", ], ) @@ -930,10 +996,10 @@ cc_binary( ":socket_test_util", ":unix_domain_socket_test_util", "//test/util:file_descriptor", + gtest, "//test/util:signal_util", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -957,9 +1023,9 @@ cc_binary( ":socket_test_util", "//test/util:capability_util", "//test/util:file_descriptor", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -970,6 +1036,9 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:file_descriptor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + gtest, "//test/util:logging", "//test/util:multiprocess_util", "//test/util:posix_error", @@ -977,9 +1046,6 @@ cc_binary( "//test/util:test_util", "//test/util:thread_util", "//test/util:timer_util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -991,15 +1057,15 @@ cc_binary( deps = [ "//test/util:capability_util", "//test/util:file_descriptor", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + gtest, "//test/util:logging", "//test/util:signal_util", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -1012,14 +1078,14 @@ cc_binary( "//test/util:capability_util", "//test/util:file_descriptor", "//test/util:fs_util", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/strings", + gtest, "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", ], ) @@ -1030,10 +1096,10 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:file_descriptor", + gtest, "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -1044,6 +1110,7 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:file_descriptor", + gtest, "//test/util:logging", "//test/util:memory_util", "//test/util:multiprocess_util", @@ -1051,7 +1118,6 @@ cc_binary( "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -1062,12 +1128,12 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:cleanup", + "@com_google_absl//absl/memory", + gtest, "//test/util:memory_util", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/memory", - "@com_google_googletest//:gtest", ], ) @@ -1077,11 +1143,11 @@ cc_binary( srcs = ["mincore.cc"], linkstatic = 1, deps = [ + gtest, "//test/util:memory_util", "//test/util:posix_error", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -1091,13 +1157,13 @@ cc_binary( srcs = ["mkdir.cc"], linkstatic = 1, deps = [ - ":temp_umask", "//test/util:capability_util", "//test/util:fs_util", + gtest, "//test/util:temp_path", + "//test/util:temp_umask", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -1108,11 +1174,11 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:file_descriptor", + gtest, "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_googletest//:gtest", ], ) @@ -1124,12 +1190,12 @@ cc_binary( deps = [ "//test/util:capability_util", "//test/util:cleanup", + gtest, "//test/util:memory_util", "//test/util:multiprocess_util", "//test/util:rlimit_util", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -1142,13 +1208,13 @@ cc_binary( "//test/util:cleanup", "//test/util:file_descriptor", "//test/util:fs_util", + "@com_google_absl//absl/strings", + gtest, "//test/util:memory_util", "//test/util:multiprocess_util", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", ], ) @@ -1161,6 +1227,9 @@ cc_binary( "//test/util:capability_util", "//test/util:file_descriptor", "//test/util:fs_util", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + gtest, "//test/util:mount_util", "//test/util:multiprocess_util", "//test/util:posix_error", @@ -1168,9 +1237,6 @@ cc_binary( "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -1180,10 +1246,9 @@ cc_binary( srcs = ["mremap.cc"], linkstatic = 1, deps = [ - # The heap check fails due to MremapDeathTest - "@com_google_googletest//:gtest", - "@com_google_absl//absl/strings", "//test/util:file_descriptor", + "@com_google_absl//absl/strings", + gtest, "//test/util:logging", "//test/util:memory_util", "//test/util:multiprocess_util", @@ -1215,9 +1280,9 @@ cc_binary( srcs = ["munmap.cc"], linkstatic = 1, deps = [ + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -1234,14 +1299,14 @@ cc_binary( "//test/util:cleanup", "//test/util:file_descriptor", "//test/util:fs_util", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + gtest, "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", ], ) @@ -1251,14 +1316,14 @@ cc_binary( srcs = ["open_create.cc"], linkstatic = 1, deps = [ - ":temp_umask", "//test/util:capability_util", "//test/util:file_descriptor", "//test/util:fs_util", + gtest, "//test/util:temp_path", + "//test/util:temp_umask", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -1266,17 +1331,18 @@ cc_binary( name = "packet_socket_raw_test", testonly = 1, srcs = ["packet_socket_raw.cc"], + defines = select_system(), linkstatic = 1, deps = [ ":socket_test_util", ":unix_domain_socket_test_util", "//test/util:capability_util", "//test/util:file_descriptor", - "//test/util:test_main", - "//test/util:test_util", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:endian", - "@com_google_googletest//:gtest", + gtest, + "//test/util:test_main", + "//test/util:test_util", ], ) @@ -1290,11 +1356,11 @@ cc_binary( ":unix_domain_socket_test_util", "//test/util:capability_util", "//test/util:file_descriptor", - "//test/util:test_main", - "//test/util:test_util", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:endian", - "@com_google_googletest//:gtest", + gtest, + "//test/util:test_main", + "//test/util:test_util", ], ) @@ -1306,16 +1372,16 @@ cc_binary( deps = [ "//test/util:capability_util", "//test/util:file_descriptor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + gtest, "//test/util:posix_error", "//test/util:pty_util", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -1327,12 +1393,12 @@ cc_binary( deps = [ "//test/util:capability_util", "//test/util:file_descriptor", + "@com_google_absl//absl/base:core_headers", + gtest, "//test/util:posix_error", "//test/util:pty_util", "//test/util:test_main", "//test/util:thread_util", - "@com_google_absl//absl/base:core_headers", - "@com_google_googletest//:gtest", ], ) @@ -1342,15 +1408,15 @@ cc_binary( srcs = ["partial_bad_buffer.cc"], linkstatic = 1, deps = [ - "//test/syscalls/linux:socket_test_util", + ":socket_test_util", "//test/util:file_descriptor", "//test/util:fs_util", + "@com_google_absl//absl/time", + gtest, "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -1360,13 +1426,28 @@ cc_binary( srcs = ["pause.cc"], linkstatic = 1, deps = [ + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + gtest, "//test/util:signal_util", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", + ], +) + +cc_binary( + name = "ping_socket_test", + testonly = 1, + srcs = ["ping_socket.cc"], + linkstatic = 1, + deps = [ + ":socket_test_util", + "//test/util:file_descriptor", + gtest, + "//test/util:save_util", + "//test/util:test_main", + "//test/util:test_util", ], ) @@ -1377,15 +1458,16 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:file_descriptor", + "//test/util:fs_util", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + gtest, "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -1398,13 +1480,13 @@ cc_binary( ":base_poll_test", "//test/util:eventfd_util", "//test/util:file_descriptor", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + gtest, "//test/util:logging", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -1415,23 +1497,27 @@ cc_binary( linkstatic = 1, deps = [ ":base_poll_test", + "@com_google_absl//absl/time", + gtest, "//test/util:signal_util", "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) cc_binary( name = "arch_prctl_test", testonly = 1, - srcs = ["arch_prctl.cc"], + srcs = select_arch( + amd64 = ["arch_prctl.cc"], + arm64 = [], + ), linkstatic = 1, deps = [ + "//test/util:file_descriptor", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -1443,12 +1529,12 @@ cc_binary( deps = [ "//test/util:capability_util", "//test/util:cleanup", + "@com_google_absl//absl/flags:flag", + gtest, "//test/util:multiprocess_util", "//test/util:posix_error", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/flags:flag", - "@com_google_googletest//:gtest", ], ) @@ -1459,13 +1545,13 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:capability_util", + "@com_google_absl//absl/flags:flag", + gtest, "//test/util:logging", "//test/util:multiprocess_util", "//test/util:posix_error", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/flags:flag", - "@com_google_googletest//:gtest", ], ) @@ -1476,10 +1562,10 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:file_descriptor", + gtest, "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -1490,6 +1576,8 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:file_descriptor", + "@com_google_absl//absl/time", + gtest, "//test/util:logging", "//test/util:memory_util", "//test/util:temp_path", @@ -1497,8 +1585,6 @@ cc_binary( "//test/util:test_util", "//test/util:thread_util", "//test/util:timer_util", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -1512,13 +1598,13 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:file_descriptor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + gtest, "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", ], ) @@ -1530,11 +1616,11 @@ cc_binary( deps = [ "//test/util:capability_util", "//test/util:fs_util", + "@com_google_absl//absl/strings", + gtest, "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", ], ) @@ -1548,6 +1634,10 @@ cc_binary( "//test/util:cleanup", "//test/util:file_descriptor", "//test/util:fs_util", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + gtest, "//test/util:memory_util", "//test/util:posix_error", "//test/util:temp_path", @@ -1555,10 +1645,6 @@ cc_binary( "//test/util:thread_util", "//test/util:time_util", "//test/util:timer_util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -1572,11 +1658,24 @@ cc_binary( "//test/util:capability_util", "//test/util:file_descriptor", "//test/util:fs_util", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + gtest, + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "proc_pid_oomscore_test", + testonly = 1, + srcs = ["proc_pid_oomscore.cc"], + linkstatic = 1, + deps = [ + "//test/util:fs_util", "//test/util:test_main", "//test/util:test_util", "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -1588,17 +1687,17 @@ cc_binary( deps = [ "//test/util:file_descriptor", "//test/util:fs_util", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", + gtest, "//test/util:memory_util", "//test/util:posix_error", "//test/util:proc_util", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:optional", - "@com_google_googletest//:gtest", ], ) @@ -1612,6 +1711,8 @@ cc_binary( "//test/util:cleanup", "//test/util:file_descriptor", "//test/util:fs_util", + "@com_google_absl//absl/strings", + gtest, "//test/util:logging", "//test/util:multiprocess_util", "//test/util:posix_error", @@ -1619,8 +1720,6 @@ cc_binary( "//test/util:test_main", "//test/util:test_util", "//test/util:time_util", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", ], ) @@ -1631,11 +1730,11 @@ cc_binary( linkstatic = 1, deps = [ ":base_poll_test", + "@com_google_absl//absl/time", + gtest, "//test/util:signal_util", "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -1645,15 +1744,16 @@ cc_binary( srcs = ["ptrace.cc"], linkstatic = 1, deps = [ + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/time", + gtest, "//test/util:logging", "//test/util:multiprocess_util", + "//test/util:platform_util", "//test/util:signal_util", "//test/util:test_util", "//test/util:thread_util", "//test/util:time_util", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -1663,10 +1763,10 @@ cc_binary( srcs = ["pwrite64.cc"], linkstatic = 1, deps = [ + gtest, "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -1680,12 +1780,12 @@ cc_binary( deps = [ ":file_base", "//test/util:file_descriptor", + "@com_google_absl//absl/strings", + gtest, "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", ], ) @@ -1699,28 +1799,29 @@ cc_binary( ":unix_domain_socket_test_util", "//test/util:capability_util", "//test/util:file_descriptor", - "//test/util:test_main", - "//test/util:test_util", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:endian", - "@com_google_googletest//:gtest", + gtest, + "//test/util:test_main", + "//test/util:test_util", ], ) cc_binary( - name = "raw_socket_ipv4_test", + name = "raw_socket_test", testonly = 1, - srcs = ["raw_socket_ipv4.cc"], + srcs = ["raw_socket.cc"], + defines = select_system(), linkstatic = 1, deps = [ ":socket_test_util", ":unix_domain_socket_test_util", "//test/util:capability_util", "//test/util:file_descriptor", + "@com_google_absl//absl/base:core_headers", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/base:core_headers", - "@com_google_googletest//:gtest", ], ) @@ -1734,10 +1835,10 @@ cc_binary( ":unix_domain_socket_test_util", "//test/util:capability_util", "//test/util:file_descriptor", + "@com_google_absl//absl/base:core_headers", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/base:core_headers", - "@com_google_googletest//:gtest", ], ) @@ -1748,10 +1849,10 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:file_descriptor", + gtest, "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -1762,10 +1863,10 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:file_descriptor", + gtest, "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -1781,13 +1882,13 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:file_descriptor", + "@com_google_absl//absl/strings", + gtest, "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", "//test/util:timer_util", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", ], ) @@ -1795,7 +1896,6 @@ cc_binary( name = "readv_socket_test", testonly = 1, srcs = [ - "file_base.h", "readv_common.cc", "readv_common.h", "readv_socket.cc", @@ -1803,12 +1903,12 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:file_descriptor", + "@com_google_absl//absl/strings", + gtest, "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", ], ) @@ -1822,11 +1922,11 @@ cc_binary( "//test/util:cleanup", "//test/util:file_descriptor", "//test/util:fs_util", + "@com_google_absl//absl/strings", + gtest, "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", ], ) @@ -1843,17 +1943,33 @@ cc_binary( ) cc_binary( + name = "rseq_test", + testonly = 1, + srcs = ["rseq.cc"], + data = ["//test/syscalls/linux/rseq"], + linkstatic = 1, + deps = [ + "//test/syscalls/linux/rseq:lib", + gtest, + "//test/util:logging", + "//test/util:multiprocess_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( name = "rtsignal_test", testonly = 1, srcs = ["rtsignal.cc"], linkstatic = 1, deps = [ "//test/util:cleanup", + gtest, "//test/util:logging", "//test/util:posix_error", "//test/util:signal_util", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -1863,9 +1979,9 @@ cc_binary( srcs = ["sched.cc"], linkstatic = 1, deps = [ + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -1875,9 +1991,9 @@ cc_binary( srcs = ["sched_yield.cc"], linkstatic = 1, deps = [ + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -1887,6 +2003,8 @@ cc_binary( srcs = ["seccomp.cc"], linkstatic = 1, deps = [ + "@com_google_absl//absl/base:core_headers", + gtest, "//test/util:logging", "//test/util:memory_util", "//test/util:multiprocess_util", @@ -1894,8 +2012,6 @@ cc_binary( "//test/util:proc_util", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/base:core_headers", - "@com_google_googletest//:gtest", ], ) @@ -1907,14 +2023,14 @@ cc_binary( deps = [ ":base_poll_test", "//test/util:file_descriptor", + "@com_google_absl//absl/time", + gtest, "//test/util:multiprocess_util", "//test/util:posix_error", "//test/util:rlimit_util", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -1926,13 +2042,13 @@ cc_binary( deps = [ "//test/util:eventfd_util", "//test/util:file_descriptor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + gtest, "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -1944,12 +2060,14 @@ cc_binary( deps = [ ":socket_test_util", "//test/util:file_descriptor", + "@com_google_absl//absl/strings", + gtest, + ":ip_socket_test_util", + ":unix_domain_socket_test_util", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", ], ) @@ -1960,13 +2078,13 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:file_descriptor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + gtest, "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -1976,9 +2094,9 @@ cc_binary( srcs = ["sigaction.cc"], linkstatic = 1, deps = [ + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -1993,28 +2111,34 @@ cc_binary( deps = [ "//test/util:cleanup", "//test/util:fs_util", + gtest, "//test/util:multiprocess_util", "//test/util:posix_error", "//test/util:signal_util", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_googletest//:gtest", ], ) cc_binary( name = "sigiret_test", testonly = 1, - srcs = ["sigiret.cc"], + srcs = select_arch( + amd64 = ["sigiret.cc"], + arm64 = [], + ), linkstatic = 1, deps = [ + gtest, "//test/util:logging", "//test/util:signal_util", "//test/util:test_util", "//test/util:timer_util", - "@com_google_googletest//:gtest", - ], + ] + select_arch( + amd64 = [], + arm64 = ["//test/util:test_main"], + ), ) cc_binary( @@ -2024,14 +2148,14 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:file_descriptor", + "@com_google_absl//absl/synchronization", + gtest, "//test/util:logging", "//test/util:posix_error", "//test/util:signal_util", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/synchronization", - "@com_google_googletest//:gtest", ], ) @@ -2041,10 +2165,10 @@ cc_binary( srcs = ["sigprocmask.cc"], linkstatic = 1, deps = [ + gtest, "//test/util:signal_util", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -2054,13 +2178,13 @@ cc_binary( srcs = ["sigstop.cc"], linkstatic = 1, deps = [ + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/time", + gtest, "//test/util:multiprocess_util", "//test/util:posix_error", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -2071,13 +2195,13 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:file_descriptor", + "@com_google_absl//absl/time", + gtest, "//test/util:logging", "//test/util:signal_util", "//test/util:test_util", "//test/util:thread_util", "//test/util:timer_util", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -2093,14 +2217,30 @@ cc_library( deps = [ ":socket_test_util", ":unix_domain_socket_test_util", - "//test/util:test_util", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_googletest//:gtest", + gtest, + "//test/util:test_util", ], alwayslink = 1, ) +cc_binary( + name = "socket_stress_test", + testonly = 1, + srcs = [ + "socket_generic_stress.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_test_util", + gtest, + "//test/util:test_main", + "//test/util:test_util", + ], +) + cc_library( name = "socket_unix_dgram_test_cases", testonly = 1, @@ -2109,8 +2249,8 @@ cc_library( deps = [ ":socket_test_util", ":unix_domain_socket_test_util", + gtest, "//test/util:test_util", - "@com_google_googletest//:gtest", ], alwayslink = 1, ) @@ -2123,8 +2263,8 @@ cc_library( deps = [ ":socket_test_util", ":unix_domain_socket_test_util", + gtest, "//test/util:test_util", - "@com_google_googletest//:gtest", ], alwayslink = 1, ) @@ -2140,8 +2280,11 @@ cc_library( ], deps = [ ":socket_test_util", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/time", + gtest, "//test/util:test_util", - "@com_google_googletest//:gtest", + "//test/util:thread_util", ], alwayslink = 1, ) @@ -2158,8 +2301,8 @@ cc_library( deps = [ ":socket_test_util", ":unix_domain_socket_test_util", + gtest, "//test/util:test_util", - "@com_google_googletest//:gtest", ], alwayslink = 1, ) @@ -2176,9 +2319,9 @@ cc_library( deps = [ ":socket_test_util", ":unix_domain_socket_test_util", + gtest, "//test/util:memory_util", "//test/util:test_util", - "@com_google_googletest//:gtest", ], alwayslink = 1, ) @@ -2196,8 +2339,8 @@ cc_library( ":ip_socket_test_util", ":socket_test_util", ":unix_domain_socket_test_util", + gtest, "//test/util:test_util", - "@com_google_googletest//:gtest", ], alwayslink = 1, ) @@ -2214,8 +2357,8 @@ cc_library( deps = [ ":ip_socket_test_util", ":socket_test_util", + gtest, "//test/util:test_util", - "@com_google_googletest//:gtest", ], alwayslink = 1, ) @@ -2232,8 +2375,9 @@ cc_library( deps = [ ":ip_socket_test_util", ":socket_test_util", + "@com_google_absl//absl/memory", + gtest, "//test/util:test_util", - "@com_google_googletest//:gtest", ], alwayslink = 1, ) @@ -2250,8 +2394,8 @@ cc_library( deps = [ ":ip_socket_test_util", ":socket_test_util", + gtest, "//test/util:test_util", - "@com_google_googletest//:gtest", ], alwayslink = 1, ) @@ -2268,8 +2412,8 @@ cc_library( deps = [ ":ip_socket_test_util", ":socket_test_util", + gtest, "//test/util:test_util", - "@com_google_googletest//:gtest", ], alwayslink = 1, ) @@ -2332,9 +2476,9 @@ cc_binary( deps = [ ":socket_test_util", ":unix_domain_socket_test_util", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -2364,9 +2508,9 @@ cc_binary( deps = [ ":socket_test_util", ":unix_domain_socket_test_util", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -2396,9 +2540,9 @@ cc_binary( deps = [ ":ip_socket_test_util", ":socket_test_util", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -2496,10 +2640,10 @@ cc_binary( ":socket_bind_to_device_util", ":socket_test_util", "//test/util:capability_util", + gtest, "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_googletest//:gtest", ], ) @@ -2515,10 +2659,11 @@ cc_binary( ":socket_bind_to_device_util", ":socket_test_util", "//test/util:capability_util", + "@com_google_absl//absl/container:node_hash_map", + gtest, "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_googletest//:gtest", ], ) @@ -2534,10 +2679,10 @@ cc_binary( ":socket_bind_to_device_util", ":socket_test_util", "//test/util:capability_util", + gtest, "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_googletest//:gtest", ], ) @@ -2583,9 +2728,9 @@ cc_binary( deps = [ ":ip_socket_test_util", ":socket_test_util", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -2661,17 +2806,52 @@ cc_binary( srcs = ["socket_inet_loopback.cc"], linkstatic = 1, deps = [ + ":ip_socket_test_util", ":socket_test_util", "//test/util:file_descriptor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + gtest, "//test/util:posix_error", "//test/util:save_util", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", + ], +) + +cc_binary( + name = "socket_inet_loopback_nogotsan_test", + testonly = 1, + srcs = ["socket_inet_loopback_nogotsan.cc"], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_test_util", + "//test/util:file_descriptor", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", + gtest, + "//test/util:posix_error", + "//test/util:save_util", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + ], +) + +cc_binary( + name = "socket_netlink_test", + testonly = 1, + srcs = ["socket_netlink.cc"], + linkstatic = 1, + deps = [ + ":socket_test_util", + "//test/util:file_descriptor", + gtest, + "//test/util:test_main", + "//test/util:test_util", ], ) @@ -2681,14 +2861,31 @@ cc_binary( srcs = ["socket_netlink_route.cc"], linkstatic = 1, deps = [ + ":socket_netlink_route_util", ":socket_netlink_util", ":socket_test_util", + "//test/util:capability_util", "//test/util:cleanup", "//test/util:file_descriptor", + "@com_google_absl//absl/strings:str_format", + gtest, + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "socket_netlink_uevent_test", + testonly = 1, + srcs = ["socket_netlink_uevent.cc"], + linkstatic = 1, + deps = [ + ":socket_netlink_util", + ":socket_test_util", + "//test/util:file_descriptor", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/strings:str_format", - "@com_google_googletest//:gtest", ], ) @@ -2706,9 +2903,9 @@ cc_library( deps = [ ":socket_test_util", ":unix_domain_socket_test_util", - "//test/util:test_util", "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", + gtest, + "//test/util:test_util", ], alwayslink = 1, ) @@ -2725,11 +2922,11 @@ cc_library( deps = [ ":socket_test_util", ":unix_domain_socket_test_util", + "@com_google_absl//absl/time", + gtest, "//test/util:test_util", "//test/util:thread_util", "//test/util:timer_util", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], alwayslink = 1, ) @@ -2746,10 +2943,10 @@ cc_library( deps = [ ":socket_test_util", ":unix_domain_socket_test_util", + "@com_google_absl//absl/strings", + gtest, "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", ], alwayslink = 1, ) @@ -2766,10 +2963,10 @@ cc_library( deps = [ ":socket_test_util", ":unix_domain_socket_test_util", + "@com_google_absl//absl/strings", + gtest, "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", ], alwayslink = 1, ) @@ -2786,11 +2983,11 @@ cc_library( deps = [ ":socket_test_util", ":unix_domain_socket_test_util", + "@com_google_absl//absl/time", + gtest, "//test/util:test_util", "//test/util:thread_util", "//test/util:timer_util", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], alwayslink = 1, ) @@ -2807,8 +3004,8 @@ cc_library( deps = [ ":socket_test_util", ":unix_domain_socket_test_util", + gtest, "//test/util:test_util", - "@com_google_googletest//:gtest", ], alwayslink = 1, ) @@ -2825,11 +3022,10 @@ cc_library( deps = [ ":socket_test_util", ":unix_domain_socket_test_util", + "@com_google_absl//absl/time", + gtest, "//test/util:test_util", "//test/util:thread_util", - "//test/util:timer_util", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], alwayslink = 1, ) @@ -2923,9 +3119,9 @@ cc_binary( deps = [ ":socket_test_util", ":unix_domain_socket_test_util", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -2937,9 +3133,9 @@ cc_binary( deps = [ ":socket_test_util", ":unix_domain_socket_test_util", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -2951,9 +3147,9 @@ cc_binary( deps = [ ":socket_test_util", ":unix_domain_socket_test_util", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -2968,9 +3164,9 @@ cc_binary( ":socket_blocking_test_cases", ":socket_test_util", ":unix_domain_socket_test_util", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -2985,9 +3181,9 @@ cc_binary( ":ip_socket_test_util", ":socket_blocking_test_cases", ":socket_test_util", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -3002,9 +3198,9 @@ cc_binary( ":socket_non_stream_blocking_test_cases", ":socket_test_util", ":unix_domain_socket_test_util", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -3019,9 +3215,9 @@ cc_binary( ":ip_socket_test_util", ":socket_non_stream_blocking_test_cases", ":socket_test_util", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -3037,9 +3233,9 @@ cc_binary( ":socket_unix_cmsg_test_cases", ":socket_unix_test_cases", ":unix_domain_socket_test_util", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -3051,9 +3247,9 @@ cc_binary( deps = [ ":socket_test_util", ":unix_domain_socket_test_util", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -3065,9 +3261,9 @@ cc_binary( deps = [ ":socket_test_util", ":unix_domain_socket_test_util", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -3080,10 +3276,10 @@ cc_binary( ":socket_netlink_util", ":socket_test_util", "//test/util:file_descriptor", + "@com_google_absl//absl/base:endian", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/base:endian", - "@com_google_googletest//:gtest", ], ) @@ -3099,12 +3295,12 @@ cc_binary( "//test/util:cleanup", "//test/util:file_descriptor", "//test/util:fs_util", + "@com_google_absl//absl/strings", + gtest, "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", ], ) @@ -3115,11 +3311,11 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:file_descriptor", + "@com_google_absl//absl/time", + gtest, "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -3133,12 +3329,12 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:file_descriptor", + "@com_google_absl//absl/strings", + gtest, "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", ], ) @@ -3151,10 +3347,11 @@ cc_binary( "//test/util:capability_util", "//test/util:file_descriptor", "//test/util:fs_util", + "@com_google_absl//absl/time", + gtest, "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -3164,10 +3361,10 @@ cc_binary( srcs = ["sync.cc"], linkstatic = 1, deps = [ + gtest, "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -3177,10 +3374,10 @@ cc_binary( srcs = ["sysinfo.cc"], linkstatic = 1, deps = [ + "@com_google_absl//absl/time", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -3190,9 +3387,9 @@ cc_binary( srcs = ["syslog.cc"], linkstatic = 1, deps = [ + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -3202,10 +3399,10 @@ cc_binary( srcs = ["sysret.cc"], linkstatic = 1, deps = [ + gtest, "//test/util:logging", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -3213,18 +3410,17 @@ cc_binary( name = "tcp_socket_test", testonly = 1, srcs = ["tcp_socket.cc"], + defines = select_system(), linkstatic = 1, - # FIXME(b/135470853) - tags = ["flaky"], deps = [ ":socket_test_util", "//test/util:file_descriptor", + "@com_google_absl//absl/time", + gtest, "//test/util:posix_error", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -3234,11 +3430,11 @@ cc_binary( srcs = ["tgkill.cc"], linkstatic = 1, deps = [ + gtest, "//test/util:signal_util", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_googletest//:gtest", ], ) @@ -3248,10 +3444,10 @@ cc_binary( srcs = ["time.cc"], linkstatic = 1, deps = [ + gtest, "//test/util:proc_util", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -3276,15 +3472,15 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:cleanup", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/time", + gtest, "//test/util:logging", "//test/util:multiprocess_util", "//test/util:posix_error", "//test/util:signal_util", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -3294,11 +3490,11 @@ cc_binary( srcs = ["tkill.cc"], linkstatic = 1, deps = [ + gtest, "//test/util:logging", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_googletest//:gtest", ], ) @@ -3312,28 +3508,78 @@ cc_binary( "//test/util:capability_util", "//test/util:cleanup", "//test/util:file_descriptor", + "@com_google_absl//absl/strings", + gtest, "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", + ], +) + +cc_binary( + name = "tuntap_test", + testonly = 1, + srcs = ["tuntap.cc"], + linkstatic = 1, + deps = [ + ":socket_test_util", + gtest, + ":socket_netlink_route_util", + "//test/util:capability_util", + "//test/util:file_descriptor", + "//test/util:fs_util", + "//test/util:posix_error", + "//test/util:test_main", + "//test/util:test_util", "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", ], ) cc_binary( - name = "udp_socket_test", + name = "tuntap_hostinet_test", testonly = 1, - srcs = ["udp_socket.cc"], + srcs = ["tuntap_hostinet.cc"], linkstatic = 1, deps = [ + gtest, + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_library( + name = "udp_socket_test_cases", + testonly = 1, + srcs = [ + "udp_socket_errqueue_test_case.cc", + "udp_socket_test_cases.cc", + ], + hdrs = ["udp_socket_test_cases.h"], + defines = select_system(), + deps = [ + ":ip_socket_test_util", ":socket_test_util", ":unix_domain_socket_test_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", + gtest, + "//test/util:file_descriptor", + "//test/util:posix_error", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", + ], + alwayslink = 1, +) + +cc_binary( + name = "udp_socket_test", + testonly = 1, + srcs = ["udp_socket.cc"], + linkstatic = 1, + deps = [ + ":udp_socket_test_cases", ], ) @@ -3345,9 +3591,9 @@ cc_binary( deps = [ ":socket_test_util", "//test/util:file_descriptor", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -3358,14 +3604,14 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:capability_util", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/strings", + gtest, "//test/util:posix_error", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", "//test/util:uid_util", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", ], ) @@ -3376,11 +3622,11 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:capability_util", + "@com_google_absl//absl/strings", + gtest, "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", ], ) @@ -3393,11 +3639,11 @@ cc_binary( "//test/util:capability_util", "//test/util:file_descriptor", "//test/util:fs_util", + "@com_google_absl//absl/strings", + gtest, "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", ], ) @@ -3407,11 +3653,11 @@ cc_binary( srcs = ["unshare.cc"], linkstatic = 1, deps = [ + "@com_google_absl//absl/synchronization", + gtest, "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/synchronization", - "@com_google_googletest//:gtest", ], ) @@ -3437,11 +3683,11 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:fs_util", + gtest, "//test/util:posix_error", "//test/util:proc_util", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -3451,13 +3697,13 @@ cc_binary( srcs = ["vfork.cc"], linkstatic = 1, deps = [ + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/time", + gtest, "//test/util:logging", "//test/util:multiprocess_util", "//test/util:test_util", "//test/util:time_util", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -3469,6 +3715,10 @@ cc_binary( deps = [ "//test/util:cleanup", "//test/util:file_descriptor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + gtest, "//test/util:logging", "//test/util:multiprocess_util", "//test/util:posix_error", @@ -3477,10 +3727,6 @@ cc_binary( "//test/util:test_util", "//test/util:thread_util", "//test/util:time_util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", ], ) @@ -3491,10 +3737,10 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:cleanup", + gtest, "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -3505,30 +3751,46 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:fs_util", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + gtest, "//test/util:posix_error", "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_googletest//:gtest", ], ) cc_binary( - name = "semaphore_test", + name = "network_namespace_test", testonly = 1, - srcs = ["semaphore.cc"], + srcs = ["network_namespace.cc"], linkstatic = 1, deps = [ + ":socket_test_util", + gtest, "//test/util:capability_util", + "//test/util:posix_error", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", + ], +) + +cc_binary( + name = "semaphore_test", + testonly = 1, + srcs = ["semaphore.cc"], + linkstatic = 1, + deps = [ + "//test/util:capability_util", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", + gtest, + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", ], ) @@ -3554,10 +3816,10 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:file_descriptor", + gtest, "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -3567,11 +3829,11 @@ cc_binary( srcs = ["vdso_clock_gettime.cc"], linkstatic = 1, deps = [ - "//test/util:test_main", - "//test/util:test_util", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", + gtest, + "//test/util:test_main", + "//test/util:test_util", ], ) @@ -3581,10 +3843,10 @@ cc_binary( srcs = ["vsyscall.cc"], linkstatic = 1, deps = [ + gtest, "//test/util:proc_util", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -3597,11 +3859,11 @@ cc_binary( ":unix_domain_socket_test_util", "//test/util:file_descriptor", "//test/util:fs_util", - "//test/util:test_main", - "//test/util:test_util", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_googletest//:gtest", + gtest, + "//test/util:test_main", + "//test/util:test_util", ], ) @@ -3613,12 +3875,12 @@ cc_binary( deps = [ "//test/util:file_descriptor", "//test/util:fs_util", + gtest, "//test/util:memory_util", "//test/util:multiprocess_util", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", - "@com_google_googletest//:gtest", ], ) @@ -3630,10 +3892,10 @@ cc_binary( deps = [ ":ip_socket_test_util", "//test/util:file_descriptor", + "@com_google_absl//absl/strings", + gtest, "//test/util:test_main", "//test/util:test_util", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", ], ) @@ -3645,9 +3907,31 @@ cc_binary( deps = [ ":ip_socket_test_util", "//test/util:file_descriptor", + "@com_google_absl//absl/strings", + gtest, "//test/util:test_main", "//test/util:test_util", + ], +) + +cc_binary( + name = "xattr_test", + testonly = 1, + srcs = [ + "file_base.h", + "xattr.cc", + ], + linkstatic = 1, + deps = [ + "//test/util:capability_util", + "//test/util:file_descriptor", + "//test/util:fs_util", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", + gtest, + "//test/util:posix_error", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", ], ) diff --git a/test/syscalls/linux/accept_bind.cc b/test/syscalls/linux/accept_bind.cc index 427c42ede..f65a14fb8 100644 --- a/test/syscalls/linux/accept_bind.cc +++ b/test/syscalls/linux/accept_bind.cc @@ -13,9 +13,12 @@ // limitations under the License. #include <stdio.h> +#include <sys/socket.h> #include <sys/un.h> + #include <algorithm> #include <vector> + #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" @@ -139,6 +142,47 @@ TEST_P(AllSocketPairTest, Connect) { SyscallSucceeds()); } +TEST_P(AllSocketPairTest, ConnectWithWrongType) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int type; + socklen_t typelen = sizeof(type); + EXPECT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_TYPE, &type, &typelen), + SyscallSucceeds()); + switch (type) { + case SOCK_STREAM: + type = SOCK_SEQPACKET; + break; + case SOCK_SEQPACKET: + type = SOCK_STREAM; + break; + } + + const FileDescriptor another_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, type, 0)); + + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); + + if (sockets->first_addr()->sa_data[0] != 0) { + ASSERT_THAT(connect(another_socket.get(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallFailsWithErrno(EPROTOTYPE)); + } else { + ASSERT_THAT(connect(another_socket.get(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallFailsWithErrno(ECONNREFUSED)); + } + + ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); +} + TEST_P(AllSocketPairTest, ConnectNonListening) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); diff --git a/test/syscalls/linux/accept_bind_stream.cc b/test/syscalls/linux/accept_bind_stream.cc index 7bcd91e9e..4857f160b 100644 --- a/test/syscalls/linux/accept_bind_stream.cc +++ b/test/syscalls/linux/accept_bind_stream.cc @@ -14,8 +14,10 @@ #include <stdio.h> #include <sys/un.h> + #include <algorithm> #include <vector> + #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/aio.cc b/test/syscalls/linux/aio.cc index b27d4e10a..806d5729e 100644 --- a/test/syscalls/linux/aio.cc +++ b/test/syscalls/linux/aio.cc @@ -89,6 +89,7 @@ class AIOTest : public FileTest { FileTest::TearDown(); if (ctx_ != 0) { ASSERT_THAT(DestroyContext(), SyscallSucceeds()); + ctx_ = 0; } } @@ -129,7 +130,7 @@ TEST_F(AIOTest, BasicWrite) { // aio implementation uses aio_ring. gVisor doesn't and returns all zeroes. // Linux implements aio_ring, so skip the zeroes check. // - // TODO(b/65486370): Remove when gVisor implements aio_ring. + // TODO(gvisor.dev/issue/204): Remove when gVisor implements aio_ring. auto ring = reinterpret_cast<struct aio_ring*>(ctx_); auto magic = IsRunningOnGvisor() ? 0 : AIO_RING_MAGIC; EXPECT_EQ(ring->magic, magic); @@ -188,14 +189,19 @@ TEST_F(AIOTest, BadWrite) { } TEST_F(AIOTest, ExitWithPendingIo) { - // Setup a context that is 5 entries deep. - ASSERT_THAT(SetupContext(5), SyscallSucceeds()); + // Setup a context that is 100 entries deep. + ASSERT_THAT(SetupContext(100), SyscallSucceeds()); struct iocb cb = CreateCallback(); struct iocb* cbs[] = {&cb}; // Submit a request but don't complete it to make it pending. - EXPECT_THAT(Submit(1, cbs), SyscallSucceeds()); + for (int i = 0; i < 100; ++i) { + EXPECT_THAT(Submit(1, cbs), SyscallSucceeds()); + } + + ASSERT_THAT(DestroyContext(), SyscallSucceeds()); + ctx_ = 0; } int Submitter(void* arg) { diff --git a/test/syscalls/linux/alarm.cc b/test/syscalls/linux/alarm.cc index d89269985..940c97285 100644 --- a/test/syscalls/linux/alarm.cc +++ b/test/syscalls/linux/alarm.cc @@ -188,6 +188,5 @@ int main(int argc, char** argv) { TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); gvisor::testing::TestInit(&argc, &argv); - - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/bad.cc b/test/syscalls/linux/bad.cc index f246a799e..a26fc6af3 100644 --- a/test/syscalls/linux/bad.cc +++ b/test/syscalls/linux/bad.cc @@ -22,11 +22,17 @@ namespace gvisor { namespace testing { namespace { +#ifdef __x86_64__ +// get_kernel_syms is not supported in Linux > 2.6, and not implemented in +// gVisor. +constexpr uint32_t kNotImplementedSyscall = SYS_get_kernel_syms; +#elif __aarch64__ +// Use the last of arch_specific_syscalls which are not implemented on arm64. +constexpr uint32_t kNotImplementedSyscall = __NR_arch_specific_syscall + 15; +#endif TEST(BadSyscallTest, NotImplemented) { - // get_kernel_syms is not supported in Linux > 2.6, and not implemented in - // gVisor. - EXPECT_THAT(syscall(SYS_get_kernel_syms), SyscallFailsWithErrno(ENOSYS)); + EXPECT_THAT(syscall(kNotImplementedSyscall), SyscallFailsWithErrno(ENOSYS)); } TEST(BadSyscallTest, NegativeOne) { diff --git a/test/syscalls/linux/chmod.cc b/test/syscalls/linux/chmod.cc index 7e918b9b2..a06b5cfd6 100644 --- a/test/syscalls/linux/chmod.cc +++ b/test/syscalls/linux/chmod.cc @@ -16,6 +16,7 @@ #include <sys/stat.h> #include <sys/types.h> #include <unistd.h> + #include <string> #include "gtest/gtest.h" diff --git a/test/syscalls/linux/chroot.cc b/test/syscalls/linux/chroot.cc index de1611c21..85ec013d5 100644 --- a/test/syscalls/linux/chroot.cc +++ b/test/syscalls/linux/chroot.cc @@ -19,6 +19,7 @@ #include <sys/stat.h> #include <syscall.h> #include <unistd.h> + #include <string> #include <vector> @@ -161,12 +162,12 @@ TEST(ChrootTest, DotDotFromOpenFD) { // getdents on fd should not error. char buf[1024]; - ASSERT_THAT(syscall(SYS_getdents, fd.get(), buf, sizeof(buf)), + ASSERT_THAT(syscall(SYS_getdents64, fd.get(), buf, sizeof(buf)), SyscallSucceeds()); } // Test that link resolution in a chroot can escape the root by following an -// open proc fd. +// open proc fd. Regression test for b/32316719. TEST(ChrootTest, ProcFdLinkResolutionInChroot) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT))); diff --git a/test/syscalls/linux/clock_gettime.cc b/test/syscalls/linux/clock_gettime.cc index c9e3ed6b2..7f6015049 100644 --- a/test/syscalls/linux/clock_gettime.cc +++ b/test/syscalls/linux/clock_gettime.cc @@ -14,6 +14,7 @@ #include <pthread.h> #include <sys/time.h> + #include <cerrno> #include <cstdint> #include <ctime> @@ -55,11 +56,6 @@ void spin_ns(int64_t ns) { // Test that CLOCK_PROCESS_CPUTIME_ID is a superset of CLOCK_THREAD_CPUTIME_ID. TEST(ClockGettime, CputimeId) { - // TODO(b/128871825,golang.org/issue/10958): Test times out when there is a - // small number of core because one goroutine starves the others. - printf("CPUS: %d\n", std::thread::hardware_concurrency()); - SKIP_IF(std::thread::hardware_concurrency() <= 2); - constexpr int kNumThreads = 13; // arbitrary absl::Duration spin_time = absl::Seconds(1); diff --git a/test/syscalls/linux/concurrency.cc b/test/syscalls/linux/concurrency.cc index 4e0a13f8b..7cd6a75bd 100644 --- a/test/syscalls/linux/concurrency.cc +++ b/test/syscalls/linux/concurrency.cc @@ -13,12 +13,14 @@ // limitations under the License. #include <signal.h> + #include <atomic> #include "gtest/gtest.h" #include "absl/strings/string_view.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "test/util/platform_util.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -44,7 +46,8 @@ TEST(ConcurrencyTest, SingleProcessMultithreaded) { } // Test that multiple threads in this process continue to execute in parallel, -// even if an unrelated second process is spawned. +// even if an unrelated second process is spawned. Regression test for +// b/32119508. TEST(ConcurrencyTest, MultiProcessMultithreaded) { // In PID 1, start TIDs 1 and 2, and put both to sleep. // @@ -98,6 +101,7 @@ TEST(ConcurrencyTest, MultiProcessMultithreaded) { // Test that multiple processes can execute concurrently, even if one process // never yields. TEST(ConcurrencyTest, MultiProcessConcurrency) { + SKIP_IF(PlatformSupportMultiProcess() == PlatformSupport::NotSupported); pid_t child_pid = fork(); if (child_pid == 0) { diff --git a/test/syscalls/linux/connect_external.cc b/test/syscalls/linux/connect_external.cc index bfe1da82e..1edb50e47 100644 --- a/test/syscalls/linux/connect_external.cc +++ b/test/syscalls/linux/connect_external.cc @@ -56,7 +56,7 @@ TEST_P(GoferStreamSeqpacketTest, Echo) { ProtocolSocket proto; std::tie(env, proto) = GetParam(); - char *val = getenv(env.c_str()); + char* val = getenv(env.c_str()); ASSERT_NE(val, nullptr); std::string root(val); @@ -69,7 +69,7 @@ TEST_P(GoferStreamSeqpacketTest, Echo) { addr.sun_family = AF_UNIX; memcpy(addr.sun_path, socket_path.c_str(), socket_path.length()); - ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr *>(&addr), + ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), SyscallSucceeds()); @@ -92,7 +92,7 @@ TEST_P(GoferStreamSeqpacketTest, NonListening) { ProtocolSocket proto; std::tie(env, proto) = GetParam(); - char *val = getenv(env.c_str()); + char* val = getenv(env.c_str()); ASSERT_NE(val, nullptr); std::string root(val); @@ -105,7 +105,7 @@ TEST_P(GoferStreamSeqpacketTest, NonListening) { addr.sun_family = AF_UNIX; memcpy(addr.sun_path, socket_path.c_str(), socket_path.length()); - ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr *>(&addr), + ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), SyscallFailsWithErrno(ECONNREFUSED)); } @@ -127,7 +127,7 @@ using GoferDgramTest = ::testing::TestWithParam<std::string>; // unnamed. The server thus has no way to reply to us. TEST_P(GoferDgramTest, Null) { std::string env = GetParam(); - char *val = getenv(env.c_str()); + char* val = getenv(env.c_str()); ASSERT_NE(val, nullptr); std::string root(val); @@ -140,7 +140,7 @@ TEST_P(GoferDgramTest, Null) { addr.sun_family = AF_UNIX; memcpy(addr.sun_path, socket_path.c_str(), socket_path.length()); - ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr *>(&addr), + ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), SyscallSucceeds()); diff --git a/test/syscalls/linux/dev.cc b/test/syscalls/linux/dev.cc index 4dd302eed..1d0d584cd 100644 --- a/test/syscalls/linux/dev.cc +++ b/test/syscalls/linux/dev.cc @@ -153,6 +153,27 @@ TEST(DevTest, TTYExists) { EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666); } +TEST(DevTest, OpenDevFuse) { + // Note(gvisor.dev/issue/3076) This won't work in the sentry until the new + // device registration is complete. + SKIP_IF(IsRunningWithVFS1() || IsRunningOnGvisor() || !IsFUSEEnabled()); + + ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/fuse", O_RDONLY)); +} + +TEST(DevTest, ReadDevFuseWithoutMount) { + // Note(gvisor.dev/issue/3076) This won't work in the sentry until the new + // device registration is complete. + SKIP_IF(IsRunningWithVFS1() || IsRunningOnGvisor()); + + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/fuse", O_RDONLY)); + + std::vector<char> buf(1); + EXPECT_THAT(ReadFd(fd.get(), buf.data(), sizeof(buf)), + SyscallFailsWithErrno(EPERM)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/epoll.cc b/test/syscalls/linux/epoll.cc index a4f8f3cec..2101e5c9f 100644 --- a/test/syscalls/linux/epoll.cc +++ b/test/syscalls/linux/epoll.cc @@ -56,10 +56,6 @@ TEST(EpollTest, AllWritable) { struct epoll_event result[kFDsPerEpoll]; ASSERT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, -1), SyscallSucceedsWithValue(kFDsPerEpoll)); - // TODO(edahlgren): Why do some tests check epoll_event::data, and others - // don't? Does Linux actually guarantee that, in any of these test cases, - // epoll_wait will necessarily write out the epoll_events in the order that - // they were registered? for (int i = 0; i < kFDsPerEpoll; i++) { ASSERT_EQ(result[i].events, EPOLLOUT); } @@ -426,6 +422,28 @@ TEST(EpollTest, CloseFile) { SyscallSucceedsWithValue(0)); } +TEST(EpollTest, PipeReaderHupAfterWriterClosed) { + auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); + int pipefds[2]; + ASSERT_THAT(pipe(pipefds), SyscallSucceeds()); + FileDescriptor rfd(pipefds[0]); + FileDescriptor wfd(pipefds[1]); + + ASSERT_NO_ERRNO(RegisterEpollFD(epollfd.get(), rfd.get(), 0, kMagicConstant)); + struct epoll_event result[kFDsPerEpoll]; + // Initially, rfd should not generate any events of interest. + ASSERT_THAT(epoll_wait(epollfd.get(), result, kFDsPerEpoll, 0), + SyscallSucceedsWithValue(0)); + // Close the write end of the pipe. + wfd.reset(); + // rfd should now generate EPOLLHUP, which EPOLL_CTL_ADD unconditionally adds + // to the set of events of interest. + ASSERT_THAT(epoll_wait(epollfd.get(), result, kFDsPerEpoll, 0), + SyscallSucceedsWithValue(1)); + EXPECT_EQ(result[0].events, EPOLLHUP); + EXPECT_EQ(result[0].data.u64, kMagicConstant); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/eventfd.cc b/test/syscalls/linux/eventfd.cc index 367682c3d..dc794415e 100644 --- a/test/syscalls/linux/eventfd.cc +++ b/test/syscalls/linux/eventfd.cc @@ -100,6 +100,23 @@ TEST(EventfdTest, SmallRead) { ASSERT_THAT(read(efd.get(), &l, 4), SyscallFailsWithErrno(EINVAL)); } +TEST(EventfdTest, IllegalSeek) { + FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0)); + EXPECT_THAT(lseek(efd.get(), 0, SEEK_SET), SyscallFailsWithErrno(ESPIPE)); +} + +TEST(EventfdTest, IllegalPread) { + FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0)); + int l; + EXPECT_THAT(pread(efd.get(), &l, sizeof(l), 0), + SyscallFailsWithErrno(ESPIPE)); +} + +TEST(EventfdTest, IllegalPwrite) { + FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0)); + EXPECT_THAT(pwrite(efd.get(), "x", 1, 0), SyscallFailsWithErrno(ESPIPE)); +} + TEST(EventfdTest, BigWrite) { FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_NONBLOCK | EFD_SEMAPHORE)); @@ -132,6 +149,31 @@ TEST(EventfdTest, BigWriteBigRead) { EXPECT_EQ(l[0], 1); } +TEST(EventfdTest, SpliceFromPipePartialSucceeds) { + int pipes[2]; + ASSERT_THAT(pipe2(pipes, O_NONBLOCK), SyscallSucceeds()); + const FileDescriptor pipe_rfd(pipes[0]); + const FileDescriptor pipe_wfd(pipes[1]); + constexpr uint64_t kVal{1}; + + FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_NONBLOCK)); + + uint64_t event_array[2]; + event_array[0] = kVal; + event_array[1] = kVal; + ASSERT_THAT(write(pipe_wfd.get(), event_array, sizeof(event_array)), + SyscallSucceedsWithValue(sizeof(event_array))); + EXPECT_THAT(splice(pipe_rfd.get(), /*__offin=*/nullptr, efd.get(), + /*__offout=*/nullptr, sizeof(event_array[0]) + 1, + SPLICE_F_NONBLOCK), + SyscallSucceedsWithValue(sizeof(event_array[0]))); + + uint64_t val; + ASSERT_THAT(read(efd.get(), &val, sizeof(val)), + SyscallSucceedsWithValue(sizeof(val))); + EXPECT_EQ(val, kVal); +} + // NotifyNonZero is inherently racy, so random save is disabled. TEST(EventfdTest, NotifyNonZero_NoRandomSave) { // Waits will time out at 10 seconds. diff --git a/test/syscalls/linux/exceptions.cc b/test/syscalls/linux/exceptions.cc index 370e85166..420b9543f 100644 --- a/test/syscalls/linux/exceptions.cc +++ b/test/syscalls/linux/exceptions.cc @@ -16,12 +16,30 @@ #include "gtest/gtest.h" #include "test/util/logging.h" +#include "test/util/platform_util.h" #include "test/util/signal_util.h" #include "test/util/test_util.h" namespace gvisor { namespace testing { +// Default value for the x87 FPU control word. See Intel SDM Vol 1, Ch 8.1.5 +// "x87 FPU Control Word". +constexpr uint16_t kX87ControlWordDefault = 0x37f; + +// Mask for the divide-by-zero exception. +constexpr uint16_t kX87ControlWordDiv0Mask = 1 << 2; + +// Default value for the SSE control register (MXCSR). See Intel SDM Vol 1, Ch +// 11.6.4 "Initialization of SSE/SSE3 Extensions". +constexpr uint32_t kMXCSRDefault = 0x1f80; + +// Mask for the divide-by-zero exception. +constexpr uint32_t kMXCSRDiv0Mask = 1 << 9; + +// Flag for a pending divide-by-zero exception. +constexpr uint32_t kMXCSRDiv0Flag = 1 << 2; + void inline Halt() { asm("hlt\r\n"); } void inline SetAlignmentCheck() { @@ -107,6 +125,170 @@ TEST(ExceptionTest, DivideByZero) { ::testing::KilledBySignal(SIGFPE), ""); } +// By default, x87 exceptions are masked and simply return a default value. +TEST(ExceptionTest, X87DivideByZeroMasked) { + int32_t quotient; + int32_t value = 1; + int32_t divisor = 0; + asm("fildl %[value]\r\n" + "fidivl %[divisor]\r\n" + "fistpl %[quotient]\r\n" + : [ quotient ] "=m"(quotient) + : [ value ] "m"(value), [ divisor ] "m"(divisor)); + + EXPECT_EQ(quotient, INT32_MIN); +} + +// When unmasked, division by zero raises SIGFPE. +TEST(ExceptionTest, X87DivideByZeroUnmasked) { + // See above. + struct sigaction sa = {}; + sa.sa_handler = SIG_DFL; + auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGFPE, sa)); + + EXPECT_EXIT( + { + // Clear the divide by zero exception mask. + constexpr uint16_t kControlWord = + kX87ControlWordDefault & ~kX87ControlWordDiv0Mask; + + int32_t quotient; + int32_t value = 1; + int32_t divisor = 0; + asm volatile( + "fldcw %[cw]\r\n" + "fildl %[value]\r\n" + "fidivl %[divisor]\r\n" + "fistpl %[quotient]\r\n" + : [ quotient ] "=m"(quotient) + : [ cw ] "m"(kControlWord), [ value ] "m"(value), + [ divisor ] "m"(divisor)); + }, + ::testing::KilledBySignal(SIGFPE), ""); +} + +// Pending exceptions in the x87 status register are not clobbered by syscalls. +TEST(ExceptionTest, X87StatusClobber) { + // See above. + struct sigaction sa = {}; + sa.sa_handler = SIG_DFL; + auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGFPE, sa)); + + EXPECT_EXIT( + { + // Clear the divide by zero exception mask. + constexpr uint16_t kControlWord = + kX87ControlWordDefault & ~kX87ControlWordDiv0Mask; + + int32_t quotient; + int32_t value = 1; + int32_t divisor = 0; + asm volatile( + "fildl %[value]\r\n" + "fidivl %[divisor]\r\n" + // Exception is masked, so it does not occur here. + "fistpl %[quotient]\r\n" + + // SYS_getpid placed in rax by constraint. + "syscall\r\n" + + // Unmask exception. The syscall didn't clobber the pending + // exception, so now it can be raised. + // + // N.B. "a floating-point exception will be generated upon execution + // of the *next* floating-point instruction". + "fldcw %[cw]\r\n" + "fwait\r\n" + : [ quotient ] "=m"(quotient) + : [ value ] "m"(value), [ divisor ] "m"(divisor), "a"(SYS_getpid), + [ cw ] "m"(kControlWord) + : "rcx", "r11"); + }, + ::testing::KilledBySignal(SIGFPE), ""); +} + +// By default, SSE exceptions are masked and simply return a default value. +TEST(ExceptionTest, SSEDivideByZeroMasked) { + uint32_t status; + int32_t quotient; + int32_t value = 1; + int32_t divisor = 0; + asm("cvtsi2ssl %[value], %%xmm0\r\n" + "cvtsi2ssl %[divisor], %%xmm1\r\n" + "divss %%xmm1, %%xmm0\r\n" + "cvtss2sil %%xmm0, %[quotient]\r\n" + : [ quotient ] "=r"(quotient), [ status ] "=r"(status) + : [ value ] "r"(value), [ divisor ] "r"(divisor) + : "xmm0", "xmm1"); + + EXPECT_EQ(quotient, INT32_MIN); +} + +// When unmasked, division by zero raises SIGFPE. +TEST(ExceptionTest, SSEDivideByZeroUnmasked) { + // See above. + struct sigaction sa = {}; + sa.sa_handler = SIG_DFL; + auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGFPE, sa)); + + EXPECT_EXIT( + { + // Clear the divide by zero exception mask. + constexpr uint32_t kMXCSR = kMXCSRDefault & ~kMXCSRDiv0Mask; + + int32_t quotient; + int32_t value = 1; + int32_t divisor = 0; + asm volatile( + "ldmxcsr %[mxcsr]\r\n" + "cvtsi2ssl %[value], %%xmm0\r\n" + "cvtsi2ssl %[divisor], %%xmm1\r\n" + "divss %%xmm1, %%xmm0\r\n" + "cvtss2sil %%xmm0, %[quotient]\r\n" + : [ quotient ] "=r"(quotient) + : [ mxcsr ] "m"(kMXCSR), [ value ] "r"(value), + [ divisor ] "r"(divisor) + : "xmm0", "xmm1"); + }, + ::testing::KilledBySignal(SIGFPE), ""); +} + +// Pending exceptions in the SSE status register are not clobbered by syscalls. +TEST(ExceptionTest, SSEStatusClobber) { + uint32_t mxcsr; + int32_t quotient; + int32_t value = 1; + int32_t divisor = 0; + asm("cvtsi2ssl %[value], %%xmm0\r\n" + "cvtsi2ssl %[divisor], %%xmm1\r\n" + "divss %%xmm1, %%xmm0\r\n" + // Exception is masked, so it does not occur here. + "cvtss2sil %%xmm0, %[quotient]\r\n" + + // SYS_getpid placed in rax by constraint. + "syscall\r\n" + + // Intel SDM Vol 1, Ch 10.2.3.1 "SIMD Floating-Point Mask and Flag Bits": + // "If LDMXCSR or FXRSTOR clears a mask bit and sets the corresponding + // exception flag bit, a SIMD floating-point exception will not be + // generated as a result of this change. The unmasked exception will be + // generated only upon the execution of the next SSE/SSE2/SSE3 instruction + // that detects the unmasked exception condition." + // + // Though ambiguous, empirical evidence indicates that this means that + // exception flags set in the status register will never cause an + // exception to be raised; only a new exception condition will do so. + // + // Thus here we just check for the flag itself rather than trying to raise + // the exception. + "stmxcsr %[mxcsr]\r\n" + : [ quotient ] "=r"(quotient), [ mxcsr ] "+m"(mxcsr) + : [ value ] "r"(value), [ divisor ] "r"(divisor), "a"(SYS_getpid) + : "xmm0", "xmm1", "rcx", "r11"); + + EXPECT_TRUE(mxcsr & kMXCSRDiv0Flag); +} + TEST(ExceptionTest, IOAccessFault) { // See above. struct sigaction sa = {}; @@ -143,6 +325,7 @@ TEST(ExceptionTest, AlignmentHalt) { } TEST(ExceptionTest, AlignmentCheck) { + SKIP_IF(PlatformSupportAlignmentCheck() != PlatformSupport::Allowed); // See above. struct sigaction sa = {}; diff --git a/test/syscalls/linux/exec.cc b/test/syscalls/linux/exec.cc index 21a5ffd40..c5acfc794 100644 --- a/test/syscalls/linux/exec.cc +++ b/test/syscalls/linux/exec.cc @@ -47,23 +47,14 @@ namespace testing { namespace { -constexpr char kBasicWorkload[] = "exec_basic_workload"; -constexpr char kExitScript[] = "exit_script"; -constexpr char kStateWorkload[] = "exec_state_workload"; -constexpr char kProcExeWorkload[] = "exec_proc_exe_workload"; -constexpr char kAssertClosedWorkload[] = "exec_assert_closed_workload"; -constexpr char kPriorityWorkload[] = "priority_execve"; - -std::string WorkloadPath(absl::string_view binary) { - std::string full_path; - char* test_src = getenv("TEST_SRCDIR"); - if (test_src) { - full_path = JoinPath(test_src, "__main__/test/syscalls/linux", binary); - } - - TEST_CHECK(full_path.empty() == false); - return full_path; -} +constexpr char kBasicWorkload[] = "test/syscalls/linux/exec_basic_workload"; +constexpr char kExitScript[] = "test/syscalls/linux/exit_script"; +constexpr char kStateWorkload[] = "test/syscalls/linux/exec_state_workload"; +constexpr char kProcExeWorkload[] = + "test/syscalls/linux/exec_proc_exe_workload"; +constexpr char kAssertClosedWorkload[] = + "test/syscalls/linux/exec_assert_closed_workload"; +constexpr char kPriorityWorkload[] = "test/syscalls/linux/priority_execve"; constexpr char kExit42[] = "--exec_exit_42"; constexpr char kExecWithThread[] = "--exec_exec_with_thread"; @@ -171,44 +162,44 @@ TEST(ExecTest, EmptyPath) { } TEST(ExecTest, Basic) { - CheckExec(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)}, {}, + CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload)}, {}, ArgEnvExitStatus(0, 0), - absl::StrCat(WorkloadPath(kBasicWorkload), "\n")); + absl::StrCat(RunfilePath(kBasicWorkload), "\n")); } TEST(ExecTest, OneArg) { - CheckExec(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload), "1"}, - {}, ArgEnvExitStatus(1, 0), - absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n")); + CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload), "1"}, {}, + ArgEnvExitStatus(1, 0), + absl::StrCat(RunfilePath(kBasicWorkload), "\n1\n")); } TEST(ExecTest, FiveArg) { - CheckExec(WorkloadPath(kBasicWorkload), - {WorkloadPath(kBasicWorkload), "1", "2", "3", "4", "5"}, {}, + CheckExec(RunfilePath(kBasicWorkload), + {RunfilePath(kBasicWorkload), "1", "2", "3", "4", "5"}, {}, ArgEnvExitStatus(5, 0), - absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n2\n3\n4\n5\n")); + absl::StrCat(RunfilePath(kBasicWorkload), "\n1\n2\n3\n4\n5\n")); } TEST(ExecTest, OneEnv) { - CheckExec(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)}, {"1"}, + CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload)}, {"1"}, ArgEnvExitStatus(0, 1), - absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n")); + absl::StrCat(RunfilePath(kBasicWorkload), "\n1\n")); } TEST(ExecTest, FiveEnv) { - CheckExec(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)}, + CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload)}, {"1", "2", "3", "4", "5"}, ArgEnvExitStatus(0, 5), - absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n2\n3\n4\n5\n")); + absl::StrCat(RunfilePath(kBasicWorkload), "\n1\n2\n3\n4\n5\n")); } TEST(ExecTest, OneArgOneEnv) { - CheckExec(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload), "arg"}, + CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload), "arg"}, {"env"}, ArgEnvExitStatus(1, 1), - absl::StrCat(WorkloadPath(kBasicWorkload), "\narg\nenv\n")); + absl::StrCat(RunfilePath(kBasicWorkload), "\narg\nenv\n")); } TEST(ExecTest, InterpreterScript) { - CheckExec(WorkloadPath(kExitScript), {WorkloadPath(kExitScript), "25"}, {}, + CheckExec(RunfilePath(kExitScript), {RunfilePath(kExitScript), "25"}, {}, ArgEnvExitStatus(25, 0), ""); } @@ -216,7 +207,7 @@ TEST(ExecTest, InterpreterScript) { TEST(ExecTest, InterpreterScriptArgSplit) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " foo bar"), @@ -230,7 +221,7 @@ TEST(ExecTest, InterpreterScriptArgSplit) { TEST(ExecTest, InterpreterScriptArgvZero) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path()), 0755)); @@ -244,7 +235,7 @@ TEST(ExecTest, InterpreterScriptArgvZero) { TEST(ExecTest, InterpreterScriptArgvZeroRelative) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path()), 0755)); @@ -261,7 +252,7 @@ TEST(ExecTest, InterpreterScriptArgvZeroRelative) { TEST(ExecTest, InterpreterScriptArgvZeroAdded) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path()), 0755)); @@ -274,7 +265,7 @@ TEST(ExecTest, InterpreterScriptArgvZeroAdded) { TEST(ExecTest, InterpreterScriptArgNUL) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), @@ -289,7 +280,7 @@ TEST(ExecTest, InterpreterScriptArgNUL) { TEST(ExecTest, InterpreterScriptTrailingWhitespace) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " "), 0755)); @@ -302,7 +293,7 @@ TEST(ExecTest, InterpreterScriptTrailingWhitespace) { TEST(ExecTest, InterpreterScriptArgWhitespace) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " foo"), 0755)); @@ -325,7 +316,7 @@ TEST(ExecTest, InterpreterScriptNoPath) { TEST(ExecTest, ExecFn) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kStateWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kStateWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " PrintExecFn"), @@ -342,7 +333,7 @@ TEST(ExecTest, ExecFn) { } TEST(ExecTest, ExecName) { - std::string path = WorkloadPath(kStateWorkload); + std::string path = RunfilePath(kStateWorkload); CheckExec(path, {path, "PrintExecName"}, {}, ArgEnvExitStatus(0, 0), absl::StrCat(Basename(path).substr(0, 15), "\n")); @@ -351,7 +342,7 @@ TEST(ExecTest, ExecName) { TEST(ExecTest, ExecNameScript) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kStateWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kStateWorkload))); TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), @@ -405,13 +396,13 @@ TEST(ExecStateTest, HandlerReset) { ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds()); ExecveArray args = { - WorkloadPath(kStateWorkload), + RunfilePath(kStateWorkload), "CheckSigHandler", absl::StrCat(SIGUSR1), absl::StrCat(absl::Hex(reinterpret_cast<uintptr_t>(SIG_DFL))), }; - CheckExec(WorkloadPath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); + CheckExec(RunfilePath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); } // Ignored signal dispositions are not reset. @@ -421,13 +412,13 @@ TEST(ExecStateTest, IgnorePreserved) { ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds()); ExecveArray args = { - WorkloadPath(kStateWorkload), + RunfilePath(kStateWorkload), "CheckSigHandler", absl::StrCat(SIGUSR1), absl::StrCat(absl::Hex(reinterpret_cast<uintptr_t>(SIG_IGN))), }; - CheckExec(WorkloadPath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); + CheckExec(RunfilePath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); } // Signal masks are not reset on exec @@ -438,12 +429,12 @@ TEST(ExecStateTest, SignalMask) { ASSERT_THAT(sigprocmask(SIG_BLOCK, &s, nullptr), SyscallSucceeds()); ExecveArray args = { - WorkloadPath(kStateWorkload), + RunfilePath(kStateWorkload), "CheckSigBlocked", absl::StrCat(SIGUSR1), }; - CheckExec(WorkloadPath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); + CheckExec(RunfilePath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); } // itimers persist across execve. @@ -471,7 +462,7 @@ TEST(ExecStateTest, ItimerPreserved) { } }; - std::string filename = WorkloadPath(kStateWorkload); + std::string filename = RunfilePath(kStateWorkload); ExecveArray argv = { filename, "CheckItimerEnabled", @@ -495,8 +486,8 @@ TEST(ExecStateTest, ItimerPreserved) { TEST(ProcSelfExe, ChangesAcrossExecve) { // See exec_proc_exe_workload for more details. We simply // assert that the /proc/self/exe link changes across execve. - CheckExec(WorkloadPath(kProcExeWorkload), - {WorkloadPath(kProcExeWorkload), + CheckExec(RunfilePath(kProcExeWorkload), + {RunfilePath(kProcExeWorkload), ASSERT_NO_ERRNO_AND_VALUE(ProcessExePath(getpid()))}, {}, W_EXITCODE(0, 0), ""); } @@ -507,8 +498,8 @@ TEST(ExecTest, CloexecNormalFile) { const FileDescriptor fd_closed_on_exec = ASSERT_NO_ERRNO_AND_VALUE(Open(tempFile.path(), O_RDONLY | O_CLOEXEC)); - CheckExec(WorkloadPath(kAssertClosedWorkload), - {WorkloadPath(kAssertClosedWorkload), + CheckExec(RunfilePath(kAssertClosedWorkload), + {RunfilePath(kAssertClosedWorkload), absl::StrCat(fd_closed_on_exec.get())}, {}, W_EXITCODE(0, 0), ""); @@ -517,10 +508,10 @@ TEST(ExecTest, CloexecNormalFile) { const FileDescriptor fd_open_on_exec = ASSERT_NO_ERRNO_AND_VALUE(Open(tempFile.path(), O_RDONLY)); - CheckExec(WorkloadPath(kAssertClosedWorkload), - {WorkloadPath(kAssertClosedWorkload), - absl::StrCat(fd_open_on_exec.get())}, - {}, W_EXITCODE(2, 0), ""); + CheckExec( + RunfilePath(kAssertClosedWorkload), + {RunfilePath(kAssertClosedWorkload), absl::StrCat(fd_open_on_exec.get())}, + {}, W_EXITCODE(2, 0), ""); } TEST(ExecTest, CloexecEventfd) { @@ -528,19 +519,65 @@ TEST(ExecTest, CloexecEventfd) { ASSERT_THAT(efd = eventfd(0, EFD_CLOEXEC), SyscallSucceeds()); FileDescriptor fd(efd); - CheckExec(WorkloadPath(kAssertClosedWorkload), - {WorkloadPath(kAssertClosedWorkload), absl::StrCat(fd.get())}, {}, + CheckExec(RunfilePath(kAssertClosedWorkload), + {RunfilePath(kAssertClosedWorkload), absl::StrCat(fd.get())}, {}, W_EXITCODE(0, 0), ""); } +constexpr int kLinuxMaxSymlinks = 40; + +TEST(ExecTest, SymlinkLimitExceeded) { + std::string path = RunfilePath(kBasicWorkload); + + // Hold onto TempPath objects so they are not destructed prematurely. + std::vector<TempPath> symlinks; + for (int i = 0; i < kLinuxMaxSymlinks + 1; i++) { + symlinks.push_back( + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateSymlinkTo("/tmp", path))); + path = symlinks[i].path(); + } + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE( + ForkAndExec(path, {path}, {}, /*child=*/nullptr, &execve_errno)); + EXPECT_EQ(execve_errno, ELOOP); +} + +TEST(ExecTest, SymlinkLimitRefreshedForInterpreter) { + std::string tmp_dir = "/tmp"; + std::string interpreter_path = "/bin/echo"; + TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + tmp_dir, absl::StrCat("#!", interpreter_path), 0755)); + std::string script_path = script.path(); + + // Hold onto TempPath objects so they are not destructed prematurely. + std::vector<TempPath> interpreter_symlinks; + std::vector<TempPath> script_symlinks; + // Replace both the interpreter and script paths with symlink chains of just + // over half the symlink limit each; this is the minimum required to test that + // the symlink limit applies separately to each traversal, while tolerating + // some symlinks in the resolution of (the original) interpreter_path and + // script_path. + for (int i = 0; i < (kLinuxMaxSymlinks / 2) + 1; i++) { + interpreter_symlinks.push_back(ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(tmp_dir, interpreter_path))); + interpreter_path = interpreter_symlinks[i].path(); + script_symlinks.push_back(ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(tmp_dir, script_path))); + script_path = script_symlinks[i].path(); + } + + CheckExec(script_path, {script_path}, {}, ArgEnvExitStatus(0, 0), ""); +} + TEST(ExecveatTest, BasicWithFDCWD) { - std::string path = WorkloadPath(kBasicWorkload); + std::string path = RunfilePath(kBasicWorkload); CheckExecveat(AT_FDCWD, path, {path}, {}, /*flags=*/0, ArgEnvExitStatus(0, 0), absl::StrCat(path, "\n")); } TEST(ExecveatTest, Basic) { - std::string absolute_path = WorkloadPath(kBasicWorkload); + std::string absolute_path = RunfilePath(kBasicWorkload); std::string parent_dir = std::string(Dirname(absolute_path)); std::string base = std::string(Basename(absolute_path)); const FileDescriptor dirfd = @@ -551,7 +588,7 @@ TEST(ExecveatTest, Basic) { } TEST(ExecveatTest, FDNotADirectory) { - std::string absolute_path = WorkloadPath(kBasicWorkload); + std::string absolute_path = RunfilePath(kBasicWorkload); std::string base = std::string(Basename(absolute_path)); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(absolute_path, 0)); @@ -563,13 +600,13 @@ TEST(ExecveatTest, FDNotADirectory) { } TEST(ExecveatTest, AbsolutePathWithFDCWD) { - std::string path = WorkloadPath(kBasicWorkload); + std::string path = RunfilePath(kBasicWorkload); CheckExecveat(AT_FDCWD, path, {path}, {}, ArgEnvExitStatus(0, 0), 0, absl::StrCat(path, "\n")); } TEST(ExecveatTest, AbsolutePath) { - std::string path = WorkloadPath(kBasicWorkload); + std::string path = RunfilePath(kBasicWorkload); // File descriptor should be ignored when an absolute path is given. const int32_t badFD = -1; CheckExecveat(badFD, path, {path}, {}, ArgEnvExitStatus(0, 0), 0, @@ -577,7 +614,7 @@ TEST(ExecveatTest, AbsolutePath) { } TEST(ExecveatTest, EmptyPathBasic) { - std::string path = WorkloadPath(kBasicWorkload); + std::string path = RunfilePath(kBasicWorkload); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_PATH)); CheckExecveat(fd.get(), "", {path}, {}, AT_EMPTY_PATH, ArgEnvExitStatus(0, 0), @@ -585,7 +622,7 @@ TEST(ExecveatTest, EmptyPathBasic) { } TEST(ExecveatTest, EmptyPathWithDirFD) { - std::string path = WorkloadPath(kBasicWorkload); + std::string path = RunfilePath(kBasicWorkload); std::string parent_dir = std::string(Dirname(path)); const FileDescriptor dirfd = ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY)); @@ -598,7 +635,7 @@ TEST(ExecveatTest, EmptyPathWithDirFD) { } TEST(ExecveatTest, EmptyPathWithoutEmptyPathFlag) { - std::string path = WorkloadPath(kBasicWorkload); + std::string path = RunfilePath(kBasicWorkload); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_PATH)); int execve_errno; @@ -608,7 +645,7 @@ TEST(ExecveatTest, EmptyPathWithoutEmptyPathFlag) { } TEST(ExecveatTest, AbsolutePathWithEmptyPathFlag) { - std::string path = WorkloadPath(kBasicWorkload); + std::string path = RunfilePath(kBasicWorkload); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_PATH)); CheckExecveat(fd.get(), path, {path}, {}, AT_EMPTY_PATH, @@ -616,7 +653,7 @@ TEST(ExecveatTest, AbsolutePathWithEmptyPathFlag) { } TEST(ExecveatTest, RelativePathWithEmptyPathFlag) { - std::string absolute_path = WorkloadPath(kBasicWorkload); + std::string absolute_path = RunfilePath(kBasicWorkload); std::string parent_dir = std::string(Dirname(absolute_path)); std::string base = std::string(Basename(absolute_path)); const FileDescriptor dirfd = @@ -629,7 +666,7 @@ TEST(ExecveatTest, RelativePathWithEmptyPathFlag) { TEST(ExecveatTest, SymlinkNoFollowWithRelativePath) { std::string parent_dir = "/tmp"; TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo(parent_dir, WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo(parent_dir, RunfilePath(kBasicWorkload))); const FileDescriptor dirfd = ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY)); std::string base = std::string(Basename(link.path())); @@ -641,10 +678,35 @@ TEST(ExecveatTest, SymlinkNoFollowWithRelativePath) { EXPECT_EQ(execve_errno, ELOOP); } +TEST(ExecveatTest, UnshareFiles) { + TempPath tempFile = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateFileWith(GetAbsoluteTestTmpdir(), "bar", 0755)); + const FileDescriptor fd_closed_on_exec = + ASSERT_NO_ERRNO_AND_VALUE(Open(tempFile.path(), O_RDONLY | O_CLOEXEC)); + + ExecveArray argv = {"test"}; + ExecveArray envp; + std::string child_path = RunfilePath(kBasicWorkload); + pid_t child = + syscall(__NR_clone, SIGCHLD | CLONE_VFORK | CLONE_FILES, 0, 0, 0, 0); + if (child == 0) { + execve(child_path.c_str(), argv.get(), envp.get()); + _exit(1); + } + ASSERT_THAT(child, SyscallSucceeds()); + + int status; + ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds()); + EXPECT_EQ(status, 0); + + struct stat st; + EXPECT_THAT(fstat(fd_closed_on_exec.get(), &st), SyscallSucceeds()); +} + TEST(ExecveatTest, SymlinkNoFollowWithAbsolutePath) { std::string parent_dir = "/tmp"; TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo(parent_dir, WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo(parent_dir, RunfilePath(kBasicWorkload))); std::string path = link.path(); int execve_errno; @@ -656,7 +718,7 @@ TEST(ExecveatTest, SymlinkNoFollowWithAbsolutePath) { TEST(ExecveatTest, SymlinkNoFollowAndEmptyPath) { TempPath link = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); + TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload))); std::string path = link.path(); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, 0)); @@ -681,6 +743,39 @@ TEST(ExecveatTest, SymlinkNoFollowWithNormalFile) { ArgEnvExitStatus(0, 0), ""); } +TEST(ExecveatTest, BasicWithCloexecFD) { + std::string path = RunfilePath(kBasicWorkload); + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_CLOEXEC)); + + CheckExecveat(fd.get(), "", {path}, {}, AT_SYMLINK_NOFOLLOW | AT_EMPTY_PATH, + ArgEnvExitStatus(0, 0), absl::StrCat(path, "\n")); +} + +TEST(ExecveatTest, InterpreterScriptWithCloexecFD) { + std::string path = RunfilePath(kExitScript); + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_CLOEXEC)); + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(fd.get(), "", {path}, {}, + AT_EMPTY_PATH, /*child=*/nullptr, + &execve_errno)); + EXPECT_EQ(execve_errno, ENOENT); +} + +TEST(ExecveatTest, InterpreterScriptWithCloexecDirFD) { + std::string absolute_path = RunfilePath(kExitScript); + std::string parent_dir = std::string(Dirname(absolute_path)); + std::string base = std::string(Basename(absolute_path)); + const FileDescriptor dirfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_CLOEXEC | O_DIRECTORY)); + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(dirfd.get(), base, {base}, {}, + /*flags=*/0, /*child=*/nullptr, + &execve_errno)); + EXPECT_EQ(execve_errno, ENOENT); +} + TEST(ExecveatTest, InvalidFlags) { int execve_errno; ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat( @@ -701,7 +796,7 @@ TEST(GetpriorityTest, ExecveMaintainsPriority) { // Program run (priority_execve) will exit(X) where // X=getpriority(PRIO_PROCESS,0). Check that this exit value is prio. - CheckExec(WorkloadPath(kPriorityWorkload), {WorkloadPath(kPriorityWorkload)}, + CheckExec(RunfilePath(kPriorityWorkload), {RunfilePath(kPriorityWorkload)}, {}, W_EXITCODE(expected_exit_code, 0), ""); } @@ -747,26 +842,28 @@ void ExecFromThread() { bool ValidateProcCmdlineVsArgv(const int argc, const char* const* argv) { auto contents_or = GetContents("/proc/self/cmdline"); if (!contents_or.ok()) { - std::cerr << "Unable to get /proc/self/cmdline: " << contents_or.error(); + std::cerr << "Unable to get /proc/self/cmdline: " << contents_or.error() + << std::endl; return false; } auto contents = contents_or.ValueOrDie(); if (contents.back() != '\0') { - std::cerr << "Non-null terminated /proc/self/cmdline!"; + std::cerr << "Non-null terminated /proc/self/cmdline!" << std::endl; return false; } contents.pop_back(); std::vector<std::string> procfs_cmdline = absl::StrSplit(contents, '\0'); if (static_cast<int>(procfs_cmdline.size()) != argc) { - std::cerr << "argc = " << argc << " != " << procfs_cmdline.size(); + std::cerr << "argc = " << argc << " != " << procfs_cmdline.size() + << std::endl; return false; } for (int i = 0; i < argc; ++i) { if (procfs_cmdline[i] != argv[i]) { std::cerr << "Procfs command line argument " << i << " mismatch " - << procfs_cmdline[i] << " != " << argv[i]; + << procfs_cmdline[i] << " != " << argv[i] << std::endl; return false; } } @@ -803,6 +900,5 @@ int main(int argc, char** argv) { } gvisor::testing::TestInit(&argc, &argv); - - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/exec_binary.cc b/test/syscalls/linux/exec_binary.cc index 0a3931e5a..18d2f22c1 100644 --- a/test/syscalls/linux/exec_binary.cc +++ b/test/syscalls/linux/exec_binary.cc @@ -20,6 +20,7 @@ #include <sys/types.h> #include <sys/user.h> #include <unistd.h> + #include <algorithm> #include <functional> #include <iterator> @@ -47,10 +48,17 @@ namespace { using ::testing::AnyOf; using ::testing::Eq; -#ifndef __x86_64__ +#if !defined(__x86_64__) && !defined(__aarch64__) // The assembly stub and ELF internal details must be ported to other arches. -#error "Test only supported on x86-64" -#endif // __x86_64__ +#error "Test only supported on x86-64/arm64" +#endif // __x86_64__ || __aarch64__ + +#if defined(__x86_64__) +#define EM_TYPE EM_X86_64 +#define IP_REG(p) ((p).rip) +#define RAX_REG(p) ((p).rax) +#define RDI_REG(p) ((p).rdi) +#define RETURN_REG(p) ((p).rax) // amd64 stub that calls PTRACE_TRACEME and sends itself SIGSTOP. const char kPtraceCode[] = { @@ -138,6 +146,76 @@ const char kPtraceCode[] = { // Size of a syscall instruction. constexpr int kSyscallSize = 2; +#elif defined(__aarch64__) +#define EM_TYPE EM_AARCH64 +#define IP_REG(p) ((p).pc) +#define RAX_REG(p) ((p).regs[8]) +#define RDI_REG(p) ((p).regs[0]) +#define RETURN_REG(p) ((p).regs[0]) + +const char kPtraceCode[] = { + // MOVD $117, R8 /* ptrace */ + '\xa8', + '\x0e', + '\x80', + '\xd2', + // MOVD $0, R0 /* PTRACE_TRACEME */ + '\x00', + '\x00', + '\x80', + '\xd2', + // MOVD $0, R1 /* pid */ + '\x01', + '\x00', + '\x80', + '\xd2', + // MOVD $0, R2 /* addr */ + '\x02', + '\x00', + '\x80', + '\xd2', + // MOVD $0, R3 /* data */ + '\x03', + '\x00', + '\x80', + '\xd2', + // SVC + '\x01', + '\x00', + '\x00', + '\xd4', + // MOVD $172, R8 /* getpid */ + '\x88', + '\x15', + '\x80', + '\xd2', + // SVC + '\x01', + '\x00', + '\x00', + '\xd4', + // MOVD $129, R8 /* kill, R0=pid */ + '\x28', + '\x10', + '\x80', + '\xd2', + // MOVD $19, R1 /* SIGSTOP */ + '\x61', + '\x02', + '\x80', + '\xd2', + // SVC + '\x01', + '\x00', + '\x00', + '\xd4', +}; +// Size of a syscall instruction. +constexpr int kSyscallSize = 4; +#else +#error "Unknown architecture" +#endif + // This test suite tests executable loading in the kernel (ELF and interpreter // scripts). @@ -280,7 +358,7 @@ ElfBinary<64> StandardElf() { elf.header.e_ident[EI_DATA] = ELFDATA2LSB; elf.header.e_ident[EI_VERSION] = EV_CURRENT; elf.header.e_type = ET_EXEC; - elf.header.e_machine = EM_X86_64; + elf.header.e_machine = EM_TYPE; elf.header.e_version = EV_CURRENT; elf.header.e_phoff = sizeof(elf.header); elf.header.e_phentsize = sizeof(decltype(elf)::ElfPhdr); @@ -326,9 +404,15 @@ TEST(ElfTest, Execute) { ASSERT_NO_ERRNO(WaitStopped(child)); struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); - // RIP is just beyond the final syscall instruction. - EXPECT_EQ(regs.rip, elf.header.e_entry + sizeof(kPtraceCode)); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); + // RIP/PC is just beyond the final syscall instruction. + EXPECT_EQ(IP_REG(regs), elf.header.e_entry + sizeof(kPtraceCode)); EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({ {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0, @@ -354,7 +438,12 @@ TEST(ElfTest, MissingText) { ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceedsWithValue(child)); // It runs off the end of the zeroes filling the end of the page. +#if defined(__x86_64__) EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGSEGV) << status; +#elif defined(__aarch64__) + // 0 is an invalid instruction opcode on arm64. + EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGILL) << status; +#endif } // Typical ELF with a data + bss segment @@ -717,9 +806,16 @@ TEST(ElfTest, PIE) { // RIP tells us which page the first segment was loaded into. struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); - const uint64_t load_addr = regs.rip & ~(kPageSize - 1); + const uint64_t load_addr = IP_REG(regs) & ~(kPageSize - 1); EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({ // text page. @@ -786,9 +882,15 @@ TEST(ElfTest, PIENonZeroStart) { // RIP tells us which page the first segment was loaded into. struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); - const uint64_t load_addr = regs.rip & ~(kPageSize - 1); + const uint64_t load_addr = IP_REG(regs) & ~(kPageSize - 1); // The ELF is loaded at an arbitrary address, not the first PT_LOAD vaddr. // @@ -909,9 +1011,15 @@ TEST(ElfTest, ELFInterpreter) { // RIP tells us which page the first segment of the interpreter was loaded // into. struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); - const uint64_t interp_load_addr = regs.rip & ~(kPageSize - 1); + const uint64_t interp_load_addr = IP_REG(regs) & ~(kPageSize - 1); EXPECT_THAT( child, ContainsMappings(std::vector<ProcMapsEntry>({ @@ -1083,9 +1191,15 @@ TEST(ElfTest, ELFInterpreterRelative) { // RIP tells us which page the first segment of the interpreter was loaded // into. struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); - const uint64_t interp_load_addr = regs.rip & ~(kPageSize - 1); + const uint64_t interp_load_addr = IP_REG(regs) & ~(kPageSize - 1); EXPECT_THAT( child, ContainsMappings(std::vector<ProcMapsEntry>({ @@ -1479,14 +1593,21 @@ TEST(ExecveTest, BrkAfterBinary) { ASSERT_NO_ERRNO(WaitStopped(child)); struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); // RIP is just beyond the final syscall instruction. Rewind to execute a brk // syscall. - regs.rip -= kSyscallSize; - regs.rax = __NR_brk; - regs.rdi = 0; - ASSERT_THAT(ptrace(PTRACE_SETREGS, child, 0, ®s), SyscallSucceeds()); + IP_REG(regs) -= kSyscallSize; + RAX_REG(regs) = __NR_brk; + RDI_REG(regs) = 0; + ASSERT_THAT(ptrace(PTRACE_SETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); // Resume the child, waiting for syscall entry. ASSERT_THAT(ptrace(PTRACE_SYSCALL, child, 0, 0), SyscallSucceeds()); @@ -1503,7 +1624,12 @@ TEST(ExecveTest, BrkAfterBinary) { ASSERT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGTRAP) << "status = " << status; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); // brk is after the text page. // @@ -1511,7 +1637,7 @@ TEST(ExecveTest, BrkAfterBinary) { // address will be, but it is always beyond the final page in the binary. // i.e., it does not start immediately after memsz in the middle of a page. // Userspace may expect to use that space. - EXPECT_GE(regs.rax, 0x41000); + EXPECT_GE(RETURN_REG(regs), 0x41000); } } // namespace diff --git a/test/syscalls/linux/exec_proc_exe_workload.cc b/test/syscalls/linux/exec_proc_exe_workload.cc index b790fe5be..2989379b7 100644 --- a/test/syscalls/linux/exec_proc_exe_workload.cc +++ b/test/syscalls/linux/exec_proc_exe_workload.cc @@ -21,6 +21,12 @@ #include "test/util/posix_error.h" int main(int argc, char** argv, char** envp) { + // This is annoying. Because remote build systems may put these binaries + // in a content-addressable-store, you may wind up with /proc/self/exe + // pointing to some random path (but with a sensible argv[0]). + // + // Therefore, this test simply checks that the /proc/self/exe + // is absolute and *doesn't* match argv[1]. std::string exe = gvisor::testing::ProcessExePath(getpid()).ValueOrDie(); if (exe[0] != '/') { diff --git a/test/syscalls/linux/fallocate.cc b/test/syscalls/linux/fallocate.cc index 1c3d00287..cabc2b751 100644 --- a/test/syscalls/linux/fallocate.cc +++ b/test/syscalls/linux/fallocate.cc @@ -15,16 +15,27 @@ #include <errno.h> #include <fcntl.h> #include <signal.h> +#include <sys/eventfd.h> #include <sys/resource.h> +#include <sys/signalfd.h> +#include <sys/socket.h> #include <sys/stat.h> +#include <sys/timerfd.h> #include <syscall.h> #include <time.h> #include <unistd.h> +#include <ctime> + #include "gtest/gtest.h" +#include "absl/strings/str_cat.h" +#include "absl/time/time.h" #include "test/syscalls/linux/file_base.h" +#include "test/syscalls/linux/socket_test_util.h" #include "test/util/cleanup.h" +#include "test/util/eventfd_util.h" #include "test/util/file_descriptor.h" +#include "test/util/posix_error.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -33,7 +44,7 @@ namespace testing { namespace { int fallocate(int fd, int mode, off_t offset, off_t len) { - return syscall(__NR_fallocate, fd, mode, offset, len); + return RetryEINTR(syscall)(__NR_fallocate, fd, mode, offset, len); } class AllocateTest : public FileTest { @@ -47,27 +58,33 @@ TEST_F(AllocateTest, Fallocate) { EXPECT_EQ(buf.st_size, 0); // Grow to ten bytes. - EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 0, 10), SyscallSucceeds()); + ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 0, 10), SyscallSucceeds()); ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); EXPECT_EQ(buf.st_size, 10); // Allocate to a smaller size should be noop. - EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 0, 5), SyscallSucceeds()); + ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 0, 5), SyscallSucceeds()); ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); EXPECT_EQ(buf.st_size, 10); // Grow again. - EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 0, 20), SyscallSucceeds()); + ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 0, 20), SyscallSucceeds()); ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); EXPECT_EQ(buf.st_size, 20); // Grow with offset. - EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 10, 20), SyscallSucceeds()); + ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 10, 20), SyscallSucceeds()); ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); EXPECT_EQ(buf.st_size, 30); // Grow with offset beyond EOF. - EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 39, 1), SyscallSucceeds()); + ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 39, 1), SyscallSucceeds()); + ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); + EXPECT_EQ(buf.st_size, 40); + + // Given length 0 should fail with EINVAL. + ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 50, 0), + SyscallFailsWithErrno(EINVAL)); ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); EXPECT_EQ(buf.st_size, 40); } @@ -136,6 +153,34 @@ TEST_F(AllocateTest, FallocateRlimit) { ASSERT_THAT(sigprocmask(SIG_UNBLOCK, &new_mask, nullptr), SyscallSucceeds()); } +TEST_F(AllocateTest, FallocateOtherFDs) { + int fd; + ASSERT_THAT(fd = timerfd_create(CLOCK_MONOTONIC, 0), SyscallSucceeds()); + auto timer_fd = FileDescriptor(fd); + EXPECT_THAT(fallocate(timer_fd.get(), 0, 0, 10), + SyscallFailsWithErrno(ENODEV)); + + sigset_t mask; + sigemptyset(&mask); + ASSERT_THAT(fd = signalfd(-1, &mask, 0), SyscallSucceeds()); + auto sfd = FileDescriptor(fd); + EXPECT_THAT(fallocate(sfd.get(), 0, 0, 10), SyscallFailsWithErrno(ENODEV)); + + auto efd = + ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_NONBLOCK | EFD_SEMAPHORE)); + EXPECT_THAT(fallocate(efd.get(), 0, 0, 10), SyscallFailsWithErrno(ENODEV)); + + auto sockfd = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + EXPECT_THAT(fallocate(sockfd.get(), 0, 0, 10), SyscallFailsWithErrno(ENODEV)); + + int socks[2]; + ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, PF_UNIX, socks), + SyscallSucceeds()); + auto sock0 = FileDescriptor(socks[0]); + auto sock1 = FileDescriptor(socks[1]); + EXPECT_THAT(fallocate(sock0.get(), 0, 0, 10), SyscallFailsWithErrno(ENODEV)); +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/fault.cc b/test/syscalls/linux/fault.cc index f6e19026f..a85750382 100644 --- a/test/syscalls/linux/fault.cc +++ b/test/syscalls/linux/fault.cc @@ -37,6 +37,9 @@ int GetPcFromUcontext(ucontext_t* uc, uintptr_t* pc) { #elif defined(__i386__) *pc = uc->uc_mcontext.gregs[REG_EIP]; return 1; +#elif defined(__aarch64__) + *pc = uc->uc_mcontext.pc; + return 1; #else return 0; #endif diff --git a/test/syscalls/linux/fcntl.cc b/test/syscalls/linux/fcntl.cc index 8a45be12a..34016d4bd 100644 --- a/test/syscalls/linux/fcntl.cc +++ b/test/syscalls/linux/fcntl.cc @@ -14,10 +14,14 @@ #include <fcntl.h> #include <signal.h> +#include <sys/types.h> #include <syscall.h> #include <unistd.h> +#include <iostream> +#include <list> #include <string> +#include <vector> #include "gtest/gtest.h" #include "absl/base/macros.h" @@ -30,10 +34,13 @@ #include "test/syscalls/linux/socket_test_util.h" #include "test/util/cleanup.h" #include "test/util/eventfd_util.h" +#include "test/util/fs_util.h" #include "test/util/multiprocess_util.h" #include "test/util/posix_error.h" +#include "test/util/save_util.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" +#include "test/util/thread_util.h" #include "test/util/timer_util.h" ABSL_FLAG(std::string, child_setlock_on, "", @@ -53,10 +60,6 @@ ABSL_FLAG(int32_t, socket_fd, -1, namespace gvisor { namespace testing { -// O_LARGEFILE as defined by Linux. glibc tries to be clever by setting it to 0 -// because "it isn't needed", even though Linux can return it via F_GETFL. -constexpr int kOLargeFile = 00100000; - class FcntlLockTest : public ::testing::Test { public: void SetUp() override { @@ -116,6 +119,15 @@ PosixErrorOr<Cleanup> SubprocessLock(std::string const& path, bool for_write, return std::move(cleanup); } +TEST(FcntlTest, SetCloExecBadFD) { + // Open an eventfd file descriptor with FD_CLOEXEC descriptor flag not set. + FileDescriptor f = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0)); + auto fd = f.get(); + f.reset(); + ASSERT_THAT(fcntl(fd, F_GETFD), SyscallFailsWithErrno(EBADF)); + ASSERT_THAT(fcntl(fd, F_SETFD, FD_CLOEXEC), SyscallFailsWithErrno(EBADF)); +} + TEST(FcntlTest, SetCloExec) { // Open an eventfd file descriptor with FD_CLOEXEC descriptor flag not set. FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0)); @@ -183,45 +195,85 @@ TEST(FcntlTest, SetFlags) { EXPECT_EQ(rflags, expected); } -TEST_F(FcntlLockTest, SetLockBadFd) { +void TestLock(int fd, short lock_type = F_RDLCK) { // NOLINT, type in flock struct flock fl; - fl.l_type = F_WRLCK; + fl.l_type = lock_type; fl.l_whence = SEEK_SET; fl.l_start = 0; - // len 0 has a special meaning: lock all bytes despite how - // large the file grows. + // len 0 locks all bytes despite how large the file grows. fl.l_len = 0; - EXPECT_THAT(fcntl(-1, F_SETLK, &fl), SyscallFailsWithErrno(EBADF)); + EXPECT_THAT(fcntl(fd, F_SETLK, &fl), SyscallSucceeds()); } -TEST_F(FcntlLockTest, SetLockPipe) { - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - +void TestLockBadFD(int fd, + short lock_type = F_RDLCK) { // NOLINT, type in flock struct flock fl; - fl.l_type = F_WRLCK; + fl.l_type = lock_type; fl.l_whence = SEEK_SET; fl.l_start = 0; - // Same as SetLockBadFd, but doesn't matter, we expect this to fail. + // len 0 locks all bytes despite how large the file grows. fl.l_len = 0; - EXPECT_THAT(fcntl(fds[0], F_SETLK, &fl), SyscallFailsWithErrno(EBADF)); - EXPECT_THAT(close(fds[0]), SyscallSucceeds()); - EXPECT_THAT(close(fds[1]), SyscallSucceeds()); + EXPECT_THAT(fcntl(fd, F_SETLK, &fl), SyscallFailsWithErrno(EBADF)); } +TEST_F(FcntlLockTest, SetLockBadFd) { TestLockBadFD(-1); } + TEST_F(FcntlLockTest, SetLockDir) { auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY, 0666)); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY, 0000)); + TestLock(fd.get()); +} - struct flock fl; - fl.l_type = F_RDLCK; - fl.l_whence = SEEK_SET; - fl.l_start = 0; - // Same as SetLockBadFd. - fl.l_len = 0; +TEST_F(FcntlLockTest, SetLockSymlink) { + // TODO(gvisor.dev/issue/2782): Replace with IsRunningWithVFS1() when O_PATH + // is supported. + SKIP_IF(IsRunningOnGvisor()); - EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds()); + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + auto symlink = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(GetAbsoluteTestTmpdir(), file.path())); + + auto fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(symlink.path(), O_RDONLY | O_PATH, 0000)); + TestLockBadFD(fd.get()); +} + +TEST_F(FcntlLockTest, SetLockProc) { + auto fd = + ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/status", O_RDONLY, 0000)); + TestLock(fd.get()); +} + +TEST_F(FcntlLockTest, SetLockPipe) { + SKIP_IF(IsRunningWithVFS1()); + + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + + TestLock(fds[0]); + TestLockBadFD(fds[0], F_WRLCK); + + TestLock(fds[1], F_WRLCK); + TestLockBadFD(fds[1]); + + EXPECT_THAT(close(fds[0]), SyscallSucceeds()); + EXPECT_THAT(close(fds[1]), SyscallSucceeds()); +} + +TEST_F(FcntlLockTest, SetLockSocket) { + SKIP_IF(IsRunningWithVFS1()); + + int sock = socket(AF_UNIX, SOCK_STREAM, 0); + ASSERT_THAT(sock, SyscallSucceeds()); + + struct sockaddr_un addr = + ASSERT_NO_ERRNO_AND_VALUE(UniqueUnixAddr(true /* abstract */, AF_UNIX)); + ASSERT_THAT( + bind(sock, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), + SyscallSucceeds()); + + TestLock(sock); + EXPECT_THAT(close(sock), SyscallSucceeds()); } TEST_F(FcntlLockTest, SetLockBadOpenFlagsWrite) { @@ -233,8 +285,7 @@ TEST_F(FcntlLockTest, SetLockBadOpenFlagsWrite) { fl0.l_type = F_WRLCK; fl0.l_whence = SEEK_SET; fl0.l_start = 0; - // Same as SetLockBadFd. - fl0.l_len = 0; + fl0.l_len = 0; // Lock all file // Expect that setting a write lock using a read only file descriptor // won't work. @@ -696,7 +747,7 @@ TEST_F(FcntlLockTest, SetWriteLockThenBlockingWriteLock) { << "Exited with code: " << status; } -// This test will veirfy that blocking works as expected when another process +// This test will verify that blocking works as expected when another process // holds a read lock when obtaining a write lock. This test will hold the lock // for some amount of time and then wait for the second process to send over the // socket_fd the amount of time it was blocked for before the lock succeeded. @@ -906,14 +957,346 @@ TEST(FcntlTest, DupAfterO_ASYNC) { EXPECT_EQ(after & O_ASYNC, O_ASYNC); } -TEST(FcntlTest, GetOwn) { +TEST(FcntlTest, GetOwnNone) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + // Use the raw syscall because the glibc wrapper may convert F_{GET,SET}OWN + // into F_{GET,SET}OWN_EX. + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN), + SyscallSucceedsWithValue(0)); + MaybeSave(); +} + +TEST(FcntlTest, GetOwnExNone) { FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); - ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN), + f_owner_ex owner = {}; + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &owner), SyscallSucceedsWithValue(0)); } +TEST(FcntlTest, SetOwnInvalidPid) { + SKIP_IF(IsRunningWithVFS1()); + + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, 12345678), + SyscallFailsWithErrno(ESRCH)); +} + +TEST(FcntlTest, SetOwnInvalidPgrp) { + SKIP_IF(IsRunningWithVFS1()); + + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, -12345678), + SyscallFailsWithErrno(ESRCH)); +} + +TEST(FcntlTest, SetOwnPid) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + pid_t pid; + EXPECT_THAT(pid = getpid(), SyscallSucceeds()); + + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, pid), + SyscallSucceedsWithValue(0)); + + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN), + SyscallSucceedsWithValue(pid)); + MaybeSave(); +} + +TEST(FcntlTest, SetOwnPgrp) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + pid_t pgid; + EXPECT_THAT(pgid = getpgrp(), SyscallSucceeds()); + + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, -pgid), + SyscallSucceedsWithValue(0)); + + // Verify with F_GETOWN_EX; using F_GETOWN on Linux may incorrectly treat the + // negative return value as an error, converting the return value to -1 and + // setting errno accordingly. + f_owner_ex got_owner = {}; + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(got_owner.type, F_OWNER_PGRP); + EXPECT_EQ(got_owner.pid, pgid); + MaybeSave(); +} + +TEST(FcntlTest, SetOwnUnset) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + // Set and unset pid. + pid_t pid; + EXPECT_THAT(pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, pid), + SyscallSucceedsWithValue(0)); + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, 0), + SyscallSucceedsWithValue(0)); + + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN), + SyscallSucceedsWithValue(0)); + + // Set and unset pgid. + pid_t pgid; + EXPECT_THAT(pgid = getpgrp(), SyscallSucceeds()); + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, -pgid), + SyscallSucceedsWithValue(0)); + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, 0), + SyscallSucceedsWithValue(0)); + + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN), + SyscallSucceedsWithValue(0)); + MaybeSave(); +} + +// F_SETOWN flips the sign of negative values, an operation that is guarded +// against overflow. +TEST(FcntlTest, SetOwnOverflow) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, INT_MIN), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(FcntlTest, SetOwnExInvalidType) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex owner = {}; + owner.type = __pid_type(-1); + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(FcntlTest, SetOwnExInvalidTid) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex owner = {}; + owner.type = F_OWNER_TID; + owner.pid = -1; + + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallFailsWithErrno(ESRCH)); +} + +TEST(FcntlTest, SetOwnExInvalidPid) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex owner = {}; + owner.type = F_OWNER_PID; + owner.pid = -1; + + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallFailsWithErrno(ESRCH)); +} + +TEST(FcntlTest, SetOwnExInvalidPgrp) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex owner = {}; + owner.type = F_OWNER_PGRP; + owner.pid = -1; + + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallFailsWithErrno(ESRCH)); +} + +TEST(FcntlTest, SetOwnExTid) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex owner = {}; + owner.type = F_OWNER_TID; + EXPECT_THAT(owner.pid = syscall(__NR_gettid), SyscallSucceeds()); + + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallSucceedsWithValue(0)); + + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN), + SyscallSucceedsWithValue(owner.pid)); + MaybeSave(); +} + +TEST(FcntlTest, SetOwnExPid) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex owner = {}; + owner.type = F_OWNER_PID; + EXPECT_THAT(owner.pid = getpid(), SyscallSucceeds()); + + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallSucceedsWithValue(0)); + + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN), + SyscallSucceedsWithValue(owner.pid)); + MaybeSave(); +} + +TEST(FcntlTest, SetOwnExPgrp) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex set_owner = {}; + set_owner.type = F_OWNER_PGRP; + EXPECT_THAT(set_owner.pid = getpgrp(), SyscallSucceeds()); + + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner), + SyscallSucceedsWithValue(0)); + + // Verify with F_GETOWN_EX; using F_GETOWN on Linux may incorrectly treat the + // negative return value as an error, converting the return value to -1 and + // setting errno accordingly. + f_owner_ex got_owner = {}; + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(got_owner.type, set_owner.type); + EXPECT_EQ(got_owner.pid, set_owner.pid); + MaybeSave(); +} + +TEST(FcntlTest, SetOwnExUnset) { + SKIP_IF(IsRunningWithVFS1()); + + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + // Set and unset pid. + f_owner_ex owner = {}; + owner.type = F_OWNER_PID; + EXPECT_THAT(owner.pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallSucceedsWithValue(0)); + owner.pid = 0; + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallSucceedsWithValue(0)); + + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN), + SyscallSucceedsWithValue(0)); + + // Set and unset pgid. + owner.type = F_OWNER_PGRP; + EXPECT_THAT(owner.pid = getpgrp(), SyscallSucceeds()); + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallSucceedsWithValue(0)); + owner.pid = 0; + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallSucceedsWithValue(0)); + + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN), + SyscallSucceedsWithValue(0)); + MaybeSave(); +} + +TEST(FcntlTest, GetOwnExTid) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex set_owner = {}; + set_owner.type = F_OWNER_TID; + EXPECT_THAT(set_owner.pid = syscall(__NR_gettid), SyscallSucceeds()); + + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner), + SyscallSucceedsWithValue(0)); + + f_owner_ex got_owner = {}; + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(got_owner.type, set_owner.type); + EXPECT_EQ(got_owner.pid, set_owner.pid); +} + +TEST(FcntlTest, GetOwnExPid) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex set_owner = {}; + set_owner.type = F_OWNER_PID; + EXPECT_THAT(set_owner.pid = getpid(), SyscallSucceeds()); + + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner), + SyscallSucceedsWithValue(0)); + + f_owner_ex got_owner = {}; + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(got_owner.type, set_owner.type); + EXPECT_EQ(got_owner.pid, set_owner.pid); +} + +TEST(FcntlTest, GetOwnExPgrp) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + f_owner_ex set_owner = {}; + set_owner.type = F_OWNER_PGRP; + EXPECT_THAT(set_owner.pid = getpgrp(), SyscallSucceeds()); + + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner), + SyscallSucceedsWithValue(0)); + + f_owner_ex got_owner = {}; + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(got_owner.type, set_owner.type); + EXPECT_EQ(got_owner.pid, set_owner.pid); +} + +// Make sure that making multiple concurrent changes to async signal generation +// does not cause any race issues. +TEST(FcntlTest, SetFlSetOwnDoNotRace) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + pid_t pid; + EXPECT_THAT(pid = getpid(), SyscallSucceeds()); + + constexpr absl::Duration runtime = absl::Milliseconds(300); + auto setAsync = [&s, &runtime] { + for (auto start = absl::Now(); absl::Now() - start < runtime;) { + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETFL, O_ASYNC), + SyscallSucceeds()); + sched_yield(); + } + }; + auto resetAsync = [&s, &runtime] { + for (auto start = absl::Now(); absl::Now() - start < runtime;) { + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETFL, 0), SyscallSucceeds()); + sched_yield(); + } + }; + auto setOwn = [&s, &pid, &runtime] { + for (auto start = absl::Now(); absl::Now() - start < runtime;) { + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, pid), + SyscallSucceeds()); + sched_yield(); + } + }; + + std::list<ScopedThread> threads; + for (int i = 0; i < 10; i++) { + threads.emplace_back(setAsync); + threads.emplace_back(resetAsync); + threads.emplace_back(setOwn); + } +} + } // namespace } // namespace testing @@ -943,8 +1326,7 @@ int main(int argc, char** argv) { fl.l_start = absl::GetFlag(FLAGS_child_setlock_start); fl.l_len = absl::GetFlag(FLAGS_child_setlock_len); - // Test the fcntl, no need to log, the error is unambiguously - // from fcntl at this point. + // Test the fcntl. int err = 0; int ret = 0; @@ -957,6 +1339,8 @@ int main(int argc, char** argv) { if (ret == -1 && errno != 0) { err = errno; + std::cerr << "CHILD lock " << setlock_on << " failed " << err + << std::endl; } // If there is a socket fd let's send back the time in microseconds it took @@ -971,5 +1355,5 @@ int main(int argc, char** argv) { exit(err); } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/file_base.h b/test/syscalls/linux/file_base.h index 4d155b618..fb418e052 100644 --- a/test/syscalls/linux/file_base.h +++ b/test/syscalls/linux/file_base.h @@ -27,6 +27,7 @@ #include <sys/types.h> #include <sys/uio.h> #include <unistd.h> + #include <cstring> #include <string> @@ -51,17 +52,6 @@ class FileTest : public ::testing::Test { test_file_fd_ = ASSERT_NO_ERRNO_AND_VALUE( Open(test_file_name_, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR)); - // FIXME(edahlgren): enable when mknod syscall is supported. - // test_fifo_name_ = NewTempAbsPath(); - // ASSERT_THAT(mknod(test_fifo_name_.c_str()), S_IFIFO|0644, 0, - // SyscallSucceeds()); - // ASSERT_THAT(test_fifo_[1] = open(test_fifo_name_.c_str(), - // O_WRONLY), - // SyscallSucceeds()); - // ASSERT_THAT(test_fifo_[0] = open(test_fifo_name_.c_str(), - // O_RDONLY), - // SyscallSucceeds()); - ASSERT_THAT(pipe(test_pipe_), SyscallSucceeds()); ASSERT_THAT(fcntl(test_pipe_[0], F_SETFL, O_NONBLOCK), SyscallSucceeds()); } @@ -95,110 +85,15 @@ class FileTest : public ::testing::Test { CloseFile(); UnlinkFile(); ClosePipes(); - - // FIXME(edahlgren): enable when mknod syscall is supported. - // close(test_fifo_[0]); - // close(test_fifo_[1]); - // unlink(test_fifo_name_.c_str()); } + protected: std::string test_file_name_; - std::string test_fifo_name_; FileDescriptor test_file_fd_; - int test_fifo_[2]; int test_pipe_[2]; }; -class SocketTest : public ::testing::Test { - public: - void SetUp() override { - test_unix_stream_socket_[0] = -1; - test_unix_stream_socket_[1] = -1; - test_unix_dgram_socket_[0] = -1; - test_unix_dgram_socket_[1] = -1; - test_unix_seqpacket_socket_[0] = -1; - test_unix_seqpacket_socket_[1] = -1; - test_tcp_socket_[0] = -1; - test_tcp_socket_[1] = -1; - - ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, test_unix_stream_socket_), - SyscallSucceeds()); - ASSERT_THAT(fcntl(test_unix_stream_socket_[0], F_SETFL, O_NONBLOCK), - SyscallSucceeds()); - ASSERT_THAT(socketpair(AF_UNIX, SOCK_DGRAM, 0, test_unix_dgram_socket_), - SyscallSucceeds()); - ASSERT_THAT(fcntl(test_unix_dgram_socket_[0], F_SETFL, O_NONBLOCK), - SyscallSucceeds()); - ASSERT_THAT( - socketpair(AF_UNIX, SOCK_SEQPACKET, 0, test_unix_seqpacket_socket_), - SyscallSucceeds()); - ASSERT_THAT(fcntl(test_unix_seqpacket_socket_[0], F_SETFL, O_NONBLOCK), - SyscallSucceeds()); - } - - void TearDown() override { - close(test_unix_stream_socket_[0]); - close(test_unix_stream_socket_[1]); - - close(test_unix_dgram_socket_[0]); - close(test_unix_dgram_socket_[1]); - - close(test_unix_seqpacket_socket_[0]); - close(test_unix_seqpacket_socket_[1]); - - close(test_tcp_socket_[0]); - close(test_tcp_socket_[1]); - } - - int test_unix_stream_socket_[2]; - int test_unix_dgram_socket_[2]; - int test_unix_seqpacket_socket_[2]; - int test_tcp_socket_[2]; -}; - -// MatchesStringLength checks that a tuple argument of (struct iovec *, int) -// corresponding to an iovec array and its length, contains data that matches -// the string length strlen. -MATCHER_P(MatchesStringLength, strlen, "") { - struct iovec* iovs = arg.first; - int niov = arg.second; - int offset = 0; - for (int i = 0; i < niov; i++) { - offset += iovs[i].iov_len; - } - if (offset != static_cast<int>(strlen)) { - *result_listener << offset; - return false; - } - return true; -} - -// MatchesStringValue checks that a tuple argument of (struct iovec *, int) -// corresponding to an iovec array and its length, contains data that matches -// the string value str. -MATCHER_P(MatchesStringValue, str, "") { - struct iovec* iovs = arg.first; - int len = strlen(str); - int niov = arg.second; - int offset = 0; - for (int i = 0; i < niov; i++) { - struct iovec iov = iovs[i]; - if (len < offset) { - *result_listener << "strlen " << len << " < offset " << offset; - return false; - } - if (strncmp(static_cast<char*>(iov.iov_base), &str[offset], iov.iov_len)) { - absl::string_view iovec_string(static_cast<char*>(iov.iov_base), - iov.iov_len); - *result_listener << iovec_string << " @offset " << offset; - return false; - } - offset += iov.iov_len; - } - return true; -} - } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/flock.cc b/test/syscalls/linux/flock.cc index b4a91455d..638a93979 100644 --- a/test/syscalls/linux/flock.cc +++ b/test/syscalls/linux/flock.cc @@ -14,12 +14,14 @@ #include <errno.h> #include <sys/file.h> + #include <string> #include "gtest/gtest.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/syscalls/linux/file_base.h" +#include "test/syscalls/linux/socket_test_util.h" #include "test/util/file_descriptor.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -33,11 +35,6 @@ namespace { class FlockTest : public FileTest {}; -TEST_F(FlockTest, BadFD) { - // EBADF: fd is not an open file descriptor. - ASSERT_THAT(flock(-1, 0), SyscallFailsWithErrno(EBADF)); -} - TEST_F(FlockTest, InvalidOpCombinations) { // The operation cannot be both exclusive and shared. EXPECT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_SH | LOCK_NB), @@ -56,15 +53,6 @@ TEST_F(FlockTest, NoOperationSpecified) { SyscallFailsWithErrno(EINVAL)); } -TEST(FlockTestNoFixture, FlockSupportsPipes) { - int fds[2]; - ASSERT_THAT(pipe(fds), SyscallSucceeds()); - - EXPECT_THAT(flock(fds[0], LOCK_EX | LOCK_NB), SyscallSucceeds()); - EXPECT_THAT(close(fds[0]), SyscallSucceeds()); - EXPECT_THAT(close(fds[1]), SyscallSucceeds()); -} - TEST_F(FlockTest, TestSimpleExLock) { // Test that we can obtain an exclusive lock (no other holders) // and that we can unlock it. @@ -582,6 +570,66 @@ TEST_F(FlockTest, BlockingLockFirstExclusiveSecondExclusive_NoRandomSave) { EXPECT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceeds()); } +TEST(FlockTestNoFixture, BadFD) { + // EBADF: fd is not an open file descriptor. + ASSERT_THAT(flock(-1, 0), SyscallFailsWithErrno(EBADF)); +} + +TEST(FlockTestNoFixture, FlockDir) { + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY, 0000)); + EXPECT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), SyscallSucceeds()); +} + +TEST(FlockTestNoFixture, FlockSymlink) { + // TODO(gvisor.dev/issue/2782): Replace with IsRunningWithVFS1() when O_PATH + // is supported. + SKIP_IF(IsRunningOnGvisor()); + + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + auto symlink = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(GetAbsoluteTestTmpdir(), file.path())); + + auto fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(symlink.path(), O_RDONLY | O_PATH, 0000)); + EXPECT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), SyscallFailsWithErrno(EBADF)); +} + +TEST(FlockTestNoFixture, FlockProc) { + auto fd = + ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/status", O_RDONLY, 0000)); + EXPECT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), SyscallSucceeds()); +} + +TEST(FlockTestNoFixture, FlockPipe) { + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + + EXPECT_THAT(flock(fds[0], LOCK_EX | LOCK_NB), SyscallSucceeds()); + // Check that the pipe was locked above. + EXPECT_THAT(flock(fds[1], LOCK_EX | LOCK_NB), SyscallFailsWithErrno(EAGAIN)); + + EXPECT_THAT(flock(fds[0], LOCK_UN), SyscallSucceeds()); + EXPECT_THAT(flock(fds[1], LOCK_EX | LOCK_NB), SyscallSucceeds()); + + EXPECT_THAT(close(fds[0]), SyscallSucceeds()); + EXPECT_THAT(close(fds[1]), SyscallSucceeds()); +} + +TEST(FlockTestNoFixture, FlockSocket) { + int sock = socket(AF_UNIX, SOCK_STREAM, 0); + ASSERT_THAT(sock, SyscallSucceeds()); + + struct sockaddr_un addr = + ASSERT_NO_ERRNO_AND_VALUE(UniqueUnixAddr(true /* abstract */, AF_UNIX)); + ASSERT_THAT( + bind(sock, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), + SyscallSucceeds()); + + EXPECT_THAT(flock(sock, LOCK_EX | LOCK_NB), SyscallSucceeds()); + EXPECT_THAT(close(sock), SyscallSucceeds()); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/fork.cc b/test/syscalls/linux/fork.cc index dd6e1a422..853f6231a 100644 --- a/test/syscalls/linux/fork.cc +++ b/test/syscalls/linux/fork.cc @@ -20,6 +20,7 @@ #include <sys/stat.h> #include <sys/types.h> #include <unistd.h> + #include <atomic> #include <cstdlib> @@ -214,6 +215,8 @@ TEST_F(ForkTest, PrivateMapping) { EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(0)); } +// CPUID is x86 specific. +#ifdef __x86_64__ // Test that cpuid works after a fork. TEST_F(ForkTest, Cpuid) { pid_t child = Fork(); @@ -226,6 +229,7 @@ TEST_F(ForkTest, Cpuid) { } EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(0)); } +#endif TEST_F(ForkTest, Mmap) { pid_t child = Fork(); @@ -267,7 +271,7 @@ TEST_F(ForkTest, Alarm) { EXPECT_EQ(0, alarmed); } -// Child cannot affect parent private memory. +// Child cannot affect parent private memory. Regression test for b/24137240. TEST_F(ForkTest, PrivateMemory) { std::atomic<uint32_t> local(0); @@ -294,6 +298,9 @@ TEST_F(ForkTest, PrivateMemory) { } // Kernel-accessed buffers should remain coherent across COW. +// +// The buffer must be >= usermem.ZeroCopyMinBytes, as UnsafeAccess operates +// differently. Regression test for b/33811887. TEST_F(ForkTest, COWSegment) { constexpr int kBufSize = 1024; char* read_buf = private_; @@ -424,7 +431,6 @@ TEST(CloneTest, NewUserNamespacePermitsAllOtherNamespaces) { << "status = " << status; } -#ifdef __x86_64__ // Clone with CLONE_SETTLS and a non-canonical TLS address is rejected. TEST(CloneTest, NonCanonicalTLS) { constexpr uintptr_t kNonCanonical = 1ull << 48; @@ -433,11 +439,25 @@ TEST(CloneTest, NonCanonicalTLS) { // on this. char stack; + // The raw system call interface on x86-64 is: + // long clone(unsigned long flags, void *stack, + // int *parent_tid, int *child_tid, + // unsigned long tls); + // + // While on arm64, the order of the last two arguments is reversed: + // long clone(unsigned long flags, void *stack, + // int *parent_tid, unsigned long tls, + // int *child_tid); +#if defined(__x86_64__) EXPECT_THAT(syscall(__NR_clone, SIGCHLD | CLONE_SETTLS, &stack, nullptr, nullptr, kNonCanonical), SyscallFailsWithErrno(EPERM)); -} +#elif defined(__aarch64__) + EXPECT_THAT(syscall(__NR_clone, SIGCHLD | CLONE_SETTLS, &stack, nullptr, + kNonCanonical, nullptr), + SyscallFailsWithErrno(EPERM)); #endif +} } // namespace } // namespace testing diff --git a/test/syscalls/linux/fpsig_fork.cc b/test/syscalls/linux/fpsig_fork.cc index e7e9f06a1..c47567b4e 100644 --- a/test/syscalls/linux/fpsig_fork.cc +++ b/test/syscalls/linux/fpsig_fork.cc @@ -27,9 +27,22 @@ namespace testing { namespace { +#ifdef __x86_64__ #define GET_XMM(__var, __xmm) \ asm volatile("movq %%" #__xmm ", %0" : "=r"(__var)) #define SET_XMM(__var, __xmm) asm volatile("movq %0, %%" #__xmm : : "r"(__var)) +#define GET_FP0(__var) GET_XMM(__var, xmm0) +#define SET_FP0(__var) SET_XMM(__var, xmm0) +#elif __aarch64__ +#define __stringify_1(x...) #x +#define __stringify(x...) __stringify_1(x) +#define GET_FPREG(var, regname) \ + asm volatile("str " __stringify(regname) ", %0" : "=m"(var)) +#define SET_FPREG(var, regname) \ + asm volatile("ldr " __stringify(regname) ", %0" : "=m"(var)) +#define GET_FP0(var) GET_FPREG(var, d0) +#define SET_FP0(var) SET_FPREG(var, d0) +#endif int parent, child; @@ -40,7 +53,10 @@ void sigusr1(int s, siginfo_t* siginfo, void* _uc) { TEST_CHECK_MSG(child >= 0, "fork failed"); uint64_t val = SIGUSR1; - SET_XMM(val, xmm0); + SET_FP0(val); + uint64_t got; + GET_FP0(got); + TEST_CHECK_MSG(val == got, "Basic FP check failed in sigusr1()"); } TEST(FPSigTest, Fork) { @@ -67,8 +83,9 @@ TEST(FPSigTest, Fork) { // be the one clobbered. uint64_t expected = 0xdeadbeeffacefeed; - SET_XMM(expected, xmm0); + SET_FP0(expected); +#ifdef __x86_64__ asm volatile( "movl %[killnr], %%eax;" "movl %[parent], %%edi;" @@ -76,14 +93,23 @@ TEST(FPSigTest, Fork) { "movl %[sig], %%edx;" "syscall;" : - : [killnr] "i"(__NR_tgkill), [parent] "rm"(parent), - [tid] "rm"(parent_tid), [sig] "i"(SIGUSR1) + : [ killnr ] "i"(__NR_tgkill), [ parent ] "rm"(parent), + [ tid ] "rm"(parent_tid), [ sig ] "i"(SIGUSR1) : "rax", "rdi", "rsi", "rdx", // Clobbered by syscall. "rcx", "r11"); +#elif __aarch64__ + asm volatile( + "mov x8, %0\n" + "mov x0, %1\n" + "mov x1, %2\n" + "mov x2, %3\n" + "svc #0\n" ::"r"(__NR_tgkill), + "r"(parent), "r"(parent_tid), "r"(SIGUSR1)); +#endif uint64_t got; - GET_XMM(got, xmm0); + GET_FP0(got); if (getpid() == parent) { // Parent. int status; diff --git a/test/syscalls/linux/fpsig_nested.cc b/test/syscalls/linux/fpsig_nested.cc index 395463aed..302d928d1 100644 --- a/test/syscalls/linux/fpsig_nested.cc +++ b/test/syscalls/linux/fpsig_nested.cc @@ -26,9 +26,22 @@ namespace testing { namespace { +#ifdef __x86_64__ #define GET_XMM(__var, __xmm) \ asm volatile("movq %%" #__xmm ", %0" : "=r"(__var)) #define SET_XMM(__var, __xmm) asm volatile("movq %0, %%" #__xmm : : "r"(__var)) +#define GET_FP0(__var) GET_XMM(__var, xmm0) +#define SET_FP0(__var) SET_XMM(__var, xmm0) +#elif __aarch64__ +#define __stringify_1(x...) #x +#define __stringify(x...) __stringify_1(x) +#define GET_FPREG(var, regname) \ + asm volatile("str " __stringify(regname) ", %0" : "=m"(var)) +#define SET_FPREG(var, regname) \ + asm volatile("ldr " __stringify(regname) ", %0" : "=m"(var)) +#define GET_FP0(var) GET_FPREG(var, d0) +#define SET_FP0(var) SET_FPREG(var, d0) +#endif int pid; int tid; @@ -40,20 +53,21 @@ void sigusr2(int s, siginfo_t* siginfo, void* _uc) { uint64_t val = SIGUSR2; // Record the value of %xmm0 on entry and then clobber it. - GET_XMM(entryxmm[1], xmm0); - SET_XMM(val, xmm0); - GET_XMM(exitxmm[1], xmm0); + GET_FP0(entryxmm[1]); + SET_FP0(val); + GET_FP0(exitxmm[1]); } void sigusr1(int s, siginfo_t* siginfo, void* _uc) { uint64_t val = SIGUSR1; // Record the value of %xmm0 on entry and then clobber it. - GET_XMM(entryxmm[0], xmm0); - SET_XMM(val, xmm0); + GET_FP0(entryxmm[0]); + SET_FP0(val); // Send a SIGUSR2 to ourself. The signal mask is configured such that // the SIGUSR2 handler will run before this handler returns. +#ifdef __x86_64__ asm volatile( "movl %[killnr], %%eax;" "movl %[pid], %%edi;" @@ -61,15 +75,24 @@ void sigusr1(int s, siginfo_t* siginfo, void* _uc) { "movl %[sig], %%edx;" "syscall;" : - : [killnr] "i"(__NR_tgkill), [pid] "rm"(pid), [tid] "rm"(tid), - [sig] "i"(SIGUSR2) + : [ killnr ] "i"(__NR_tgkill), [ pid ] "rm"(pid), [ tid ] "rm"(tid), + [ sig ] "i"(SIGUSR2) : "rax", "rdi", "rsi", "rdx", // Clobbered by syscall. "rcx", "r11"); +#elif __aarch64__ + asm volatile( + "mov x8, %0\n" + "mov x0, %1\n" + "mov x1, %2\n" + "mov x2, %3\n" + "svc #0\n" ::"r"(__NR_tgkill), + "r"(pid), "r"(tid), "r"(SIGUSR2)); +#endif // Record value of %xmm0 again to verify that the nested signal handler // does not clobber it. - GET_XMM(exitxmm[0], xmm0); + GET_FP0(exitxmm[0]); } TEST(FPSigTest, NestedSignals) { @@ -98,8 +121,9 @@ TEST(FPSigTest, NestedSignals) { // to signal the current thread ensures that this is the clobbered thread. uint64_t expected = 0xdeadbeeffacefeed; - SET_XMM(expected, xmm0); + SET_FP0(expected); +#ifdef __x86_64__ asm volatile( "movl %[killnr], %%eax;" "movl %[pid], %%edi;" @@ -107,14 +131,23 @@ TEST(FPSigTest, NestedSignals) { "movl %[sig], %%edx;" "syscall;" : - : [killnr] "i"(__NR_tgkill), [pid] "rm"(pid), [tid] "rm"(tid), - [sig] "i"(SIGUSR1) + : [ killnr ] "i"(__NR_tgkill), [ pid ] "rm"(pid), [ tid ] "rm"(tid), + [ sig ] "i"(SIGUSR1) : "rax", "rdi", "rsi", "rdx", // Clobbered by syscall. "rcx", "r11"); +#elif __aarch64__ + asm volatile( + "mov x8, %0\n" + "mov x0, %1\n" + "mov x1, %2\n" + "mov x2, %3\n" + "svc #0\n" ::"r"(__NR_tgkill), + "r"(pid), "r"(tid), "r"(SIGUSR1)); +#endif uint64_t got; - GET_XMM(got, xmm0); + GET_FP0(got); // // The checks below verifies the following: diff --git a/test/syscalls/linux/futex.cc b/test/syscalls/linux/futex.cc index d3e3f998c..90b1f0508 100644 --- a/test/syscalls/linux/futex.cc +++ b/test/syscalls/linux/futex.cc @@ -18,6 +18,7 @@ #include <sys/syscall.h> #include <sys/time.h> #include <sys/types.h> +#include <syscall.h> #include <unistd.h> #include <algorithm> @@ -239,6 +240,27 @@ TEST_P(PrivateAndSharedFutexTest, Wake1_NoRandomSave) { EXPECT_THAT(futex_wake(IsPrivate(), &a, 1), SyscallSucceedsWithValue(1)); } +TEST_P(PrivateAndSharedFutexTest, Wake0_NoRandomSave) { + constexpr int kInitialValue = 1; + std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue); + + // Prevent save/restore from interrupting futex_wait, which will cause it to + // return EAGAIN instead of the expected result if futex_wait is restarted + // after we change the value of a below. + DisableSave ds; + ScopedThread thread([&] { + EXPECT_THAT(futex_wait(IsPrivate(), &a, kInitialValue), + SyscallSucceedsWithValue(0)); + }); + absl::SleepFor(kWaiterStartupDelay); + + // Change a so that if futex_wake happens before futex_wait, the latter + // returns EAGAIN instead of hanging the test. + a.fetch_add(1); + // The Linux kernel wakes one waiter even if val is 0 or negative. + EXPECT_THAT(futex_wake(IsPrivate(), &a, 0), SyscallSucceedsWithValue(1)); +} + TEST_P(PrivateAndSharedFutexTest, WakeAll_NoRandomSave) { constexpr int kInitialValue = 1; std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue); @@ -716,6 +738,97 @@ TEST_P(PrivateAndSharedFutexTest, PITryLockConcurrency_NoRandomSave) { } } +int get_robust_list(int pid, struct robust_list_head** head_ptr, + size_t* len_ptr) { + return syscall(__NR_get_robust_list, pid, head_ptr, len_ptr); +} + +int set_robust_list(struct robust_list_head* head, size_t len) { + return syscall(__NR_set_robust_list, head, len); +} + +TEST(RobustFutexTest, BasicSetGet) { + struct robust_list_head hd = {}; + struct robust_list_head* hd_ptr = &hd; + + // Set! + EXPECT_THAT(set_robust_list(hd_ptr, sizeof(hd)), SyscallSucceedsWithValue(0)); + + // Get! + struct robust_list_head* new_hd_ptr = hd_ptr; + size_t len; + EXPECT_THAT(get_robust_list(0, &new_hd_ptr, &len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(new_hd_ptr, hd_ptr); + EXPECT_EQ(len, sizeof(hd)); +} + +TEST(RobustFutexTest, GetFromOtherTid) { + // Get the current tid and list head. + pid_t tid = gettid(); + struct robust_list_head* hd_ptr = {}; + size_t len; + EXPECT_THAT(get_robust_list(0, &hd_ptr, &len), SyscallSucceedsWithValue(0)); + + // Create a new thread. + ScopedThread t([&] { + // Current tid list head should be different from parent tid. + struct robust_list_head* got_hd_ptr = {}; + EXPECT_THAT(get_robust_list(0, &got_hd_ptr, &len), + SyscallSucceedsWithValue(0)); + EXPECT_NE(hd_ptr, got_hd_ptr); + + // Get the parent list head by passing its tid. + EXPECT_THAT(get_robust_list(tid, &got_hd_ptr, &len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(hd_ptr, got_hd_ptr); + }); + + // Wait for thread. + t.Join(); +} + +TEST(RobustFutexTest, InvalidSize) { + struct robust_list_head* hd = {}; + EXPECT_THAT(set_robust_list(hd, sizeof(*hd) + 1), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(RobustFutexTest, PthreadMutexAttr) { + constexpr int kNumMutexes = 3; + + // Create a bunch of robust mutexes. + pthread_mutexattr_t attrs[kNumMutexes]; + pthread_mutex_t mtxs[kNumMutexes]; + for (int i = 0; i < kNumMutexes; i++) { + TEST_PCHECK(pthread_mutexattr_init(&attrs[i]) == 0); + TEST_PCHECK(pthread_mutexattr_setrobust(&attrs[i], PTHREAD_MUTEX_ROBUST) == + 0); + TEST_PCHECK(pthread_mutex_init(&mtxs[i], &attrs[i]) == 0); + } + + // Start thread to lock the mutexes and then exit. + ScopedThread t([&] { + for (int i = 0; i < kNumMutexes; i++) { + TEST_PCHECK(pthread_mutex_lock(&mtxs[i]) == 0); + } + pthread_exit(NULL); + }); + + // Wait for thread. + t.Join(); + + // Now try to take the mutexes. + for (int i = 0; i < kNumMutexes; i++) { + // Should get EOWNERDEAD. + EXPECT_EQ(pthread_mutex_lock(&mtxs[i]), EOWNERDEAD); + // Make the mutex consistent. + EXPECT_EQ(pthread_mutex_consistent(&mtxs[i]), 0); + // Unlock. + EXPECT_EQ(pthread_mutex_unlock(&mtxs[i]), 0); + } +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/getdents.cc b/test/syscalls/linux/getdents.cc index fe9cfafe8..b040cdcf7 100644 --- a/test/syscalls/linux/getdents.cc +++ b/test/syscalls/linux/getdents.cc @@ -23,6 +23,7 @@ #include <sys/types.h> #include <syscall.h> #include <unistd.h> + #include <map> #include <string> #include <unordered_map> @@ -31,6 +32,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/container/node_hash_set.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "test/util/eventfd_util.h" @@ -227,19 +229,28 @@ class GetdentsTest : public ::testing::Test { // Multiple template parameters are not allowed, so we must use explicit // template specialization to set the syscall number. + +// SYS_getdents isn't defined on arm64. +#ifdef __x86_64__ template <> int GetdentsTest<struct linux_dirent>::SyscallNum() { return SYS_getdents; } +#endif template <> int GetdentsTest<struct linux_dirent64>::SyscallNum() { return SYS_getdents64; } -// Test both legacy getdents and getdents64. +#ifdef __x86_64__ +// Test both legacy getdents and getdents64 on x86_64. typedef ::testing::Types<struct linux_dirent, struct linux_dirent64> GetdentsTypes; +#elif __aarch64__ +// Test only getdents64 on arm64. +typedef ::testing::Types<struct linux_dirent64> GetdentsTypes; +#endif TYPED_TEST_SUITE(GetdentsTest, GetdentsTypes); // N.B. TYPED_TESTs require explicitly using this-> to access members of @@ -383,7 +394,7 @@ TYPED_TEST(GetdentsTest, ProcSelfFd) { // Make the buffer very small since we want to iterate. typename TestFixture::DirentBufferType dirents( 2 * sizeof(typename TestFixture::LinuxDirentType)); - std::unordered_set<int> prev_fds; + absl::node_hash_set<int> prev_fds; while (true) { dirents.Reset(); int rv; diff --git a/test/syscalls/linux/getrandom.cc b/test/syscalls/linux/getrandom.cc index f97f60029..f87cdd7a1 100644 --- a/test/syscalls/linux/getrandom.cc +++ b/test/syscalls/linux/getrandom.cc @@ -29,6 +29,8 @@ namespace { #define SYS_getrandom 318 #elif defined(__i386__) #define SYS_getrandom 355 +#elif defined(__aarch64__) +#define SYS_getrandom 278 #else #error "Unknown architecture" #endif diff --git a/test/syscalls/linux/getrusage.cc b/test/syscalls/linux/getrusage.cc index 9bdb1e4cd..0e51d42a8 100644 --- a/test/syscalls/linux/getrusage.cc +++ b/test/syscalls/linux/getrusage.cc @@ -67,7 +67,7 @@ TEST(GetrusageTest, Grandchild) { pid = fork(); if (pid == 0) { int flags = MAP_ANONYMOUS | MAP_POPULATE | MAP_PRIVATE; - void *addr = + void* addr = mmap(nullptr, kGrandchildSizeKb * 1024, PROT_WRITE, flags, -1, 0); TEST_PCHECK(addr != MAP_FAILED); } else { diff --git a/test/syscalls/linux/inotify.cc b/test/syscalls/linux/inotify.cc index 7384c27dc..5cb325a9e 100644 --- a/test/syscalls/linux/inotify.cc +++ b/test/syscalls/linux/inotify.cc @@ -18,7 +18,9 @@ #include <sys/epoll.h> #include <sys/inotify.h> #include <sys/ioctl.h> +#include <sys/sendfile.h> #include <sys/time.h> +#include <sys/xattr.h> #include <atomic> #include <list> @@ -28,11 +30,13 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/synchronization/mutex.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/util/epoll_util.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" +#include "test/util/posix_error.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -330,9 +334,32 @@ PosixErrorOr<int> InotifyAddWatch(int fd, const std::string& path, return wd; } -TEST(Inotify, InotifyFdNotWritable) { +TEST(Inotify, IllegalSeek) { const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(0)); - EXPECT_THAT(write(fd.get(), "x", 1), SyscallFailsWithErrno(EBADF)); + EXPECT_THAT(lseek(fd.get(), 0, SEEK_SET), SyscallFailsWithErrno(ESPIPE)); +} + +TEST(Inotify, IllegalPread) { + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(0)); + int val; + EXPECT_THAT(pread(fd.get(), &val, sizeof(val), 0), + SyscallFailsWithErrno(ESPIPE)); +} + +TEST(Inotify, IllegalPwrite) { + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(0)); + EXPECT_THAT(pwrite(fd.get(), "x", 1, 0), SyscallFailsWithErrno(ESPIPE)); +} + +TEST(Inotify, IllegalWrite) { + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(0)); + int val = 0; + EXPECT_THAT(write(fd.get(), &val, sizeof(val)), SyscallFailsWithErrno(EBADF)); +} + +TEST(Inotify, InitFlags) { + EXPECT_THAT(inotify_init1(IN_NONBLOCK | IN_CLOEXEC), SyscallSucceeds()); + EXPECT_THAT(inotify_init1(12345), SyscallFailsWithErrno(EINVAL)); } TEST(Inotify, NonBlockingReadReturnsEagain) { @@ -395,7 +422,7 @@ TEST(Inotify, CanDeleteFileAfterRemovingWatch) { file1.reset(); } -TEST(Inotify, CanRemoveWatchAfterDeletingFile) { +TEST(Inotify, RemoveWatchAfterDeletingFileFails) { const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); @@ -491,17 +518,23 @@ TEST(Inotify, DeletingChildGeneratesEvents) { Event(IN_DELETE, root_wd, Basename(file1_path))})); } +// Creating a file in "parent/child" should generate events for child, but not +// parent. TEST(Inotify, CreatingFileGeneratesEvents) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const TempPath child = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(parent.path())); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(fd.get(), parent.path(), IN_ALL_EVENTS)); const int wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); + InotifyAddWatch(fd.get(), child.path(), IN_ALL_EVENTS)); // Create a new file in the directory. const TempPath file1 = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(child.path())); const std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); @@ -554,6 +587,47 @@ TEST(Inotify, WritingFileGeneratesModifyEvent) { ASSERT_THAT(events, Are({Event(IN_MODIFY, wd, Basename(file1.path()))})); } +TEST(Inotify, SizeZeroReadWriteGeneratesNothing) { + const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + const TempPath file1 = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); + + const FileDescriptor file1_fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDWR)); + ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); + + // Read from the empty file. + int val; + ASSERT_THAT(read(file1_fd.get(), &val, sizeof(val)), + SyscallSucceedsWithValue(0)); + + // Write zero bytes. + ASSERT_THAT(write(file1_fd.get(), "", 0), SyscallSucceedsWithValue(0)); + + const std::vector<Event> events = + ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); + ASSERT_THAT(events, Are({})); +} + +TEST(Inotify, FailedFileCreationGeneratesNoEvents) { + const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const std::string dir_path = dir.path(); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch(fd.get(), dir_path, IN_ALL_EVENTS)); + + const char* p = dir_path.c_str(); + ASSERT_THAT(mkdir(p, 0777), SyscallFails()); + ASSERT_THAT(mknod(p, S_IFIFO, 0777), SyscallFails()); + ASSERT_THAT(symlink(p, p), SyscallFails()); + ASSERT_THAT(link(p, p), SyscallFails()); + std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); + ASSERT_THAT(events, Are({})); +} + TEST(Inotify, WatchSetAfterOpenReportsCloseFdEvent) { const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const FileDescriptor fd = @@ -602,7 +676,7 @@ TEST(Inotify, ChildrenDeletionInWatchedDirGeneratesEvent) { Event(IN_DELETE | IN_ISDIR, wd, Basename(dir1_path))})); } -TEST(Inotify, WatchTargetDeletionGeneratesEvent) { +TEST(Inotify, RmdirOnWatchedTargetGeneratesEvent) { const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); @@ -977,7 +1051,7 @@ TEST(Inotify, WatchOnRelativePath) { ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY)); // Change working directory to root. - const char* old_working_dir = get_current_dir_name(); + const FileDescriptor cwd = ASSERT_NO_ERRNO_AND_VALUE(Open(".", O_PATH)); EXPECT_THAT(chdir(root.path().c_str()), SyscallSucceeds()); // Add a watch on file1 with a relative path. @@ -997,7 +1071,7 @@ TEST(Inotify, WatchOnRelativePath) { // continue to hold a reference, random save/restore tests can fail if a save // is triggered after "root" is unlinked; we can't save deleted fs objects // with active references. - EXPECT_THAT(chdir(old_working_dir), SyscallSucceeds()); + EXPECT_THAT(fchdir(cwd.get()), SyscallSucceeds()); } TEST(Inotify, ZeroLengthReadWriteDoesNotGenerateEvent) { @@ -1055,9 +1129,9 @@ TEST(Inotify, ChmodGeneratesAttribEvent_NoRandomSave) { const TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); - const FileDescriptor root_fd = + FileDescriptor root_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(root.path(), O_RDONLY)); - const FileDescriptor file1_fd = + FileDescriptor file1_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDWR)); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); @@ -1091,6 +1165,11 @@ TEST(Inotify, ChmodGeneratesAttribEvent_NoRandomSave) { ASSERT_THAT(fchmodat(root_fd.get(), file1_basename.c_str(), S_IWGRP, 0), SyscallSucceeds()); verify_chmod_events(); + + // Make sure the chmod'ed file descriptors are destroyed before DisableSave + // is destructed. + root_fd.reset(); + file1_fd.reset(); } TEST(Inotify, TruncateGeneratesModifyEvent) { @@ -1223,7 +1302,7 @@ TEST(Inotify, LinkGeneratesAttribAndCreateEvents) { InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS)); const int rc = link(file1.path().c_str(), link1.path().c_str()); - // link(2) is only supported on tmpfs in the sandbox. + // NOTE(b/34861058): link(2) is only supported on tmpfs in the sandbox. SKIP_IF(IsRunningOnGvisor() && rc != 0 && (errno == EPERM || errno == ENOENT)); ASSERT_THAT(rc, SyscallSucceeds()); @@ -1246,7 +1325,7 @@ TEST(Inotify, UtimesGeneratesAttribEvent) { const int wd = ASSERT_NO_ERRNO_AND_VALUE( InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - struct timeval times[2] = {{1, 0}, {2, 0}}; + const struct timeval times[2] = {{1, 0}, {2, 0}}; EXPECT_THAT(futimes(file1_fd.get(), times), SyscallSucceeds()); const std::vector<Event> events = @@ -1317,21 +1396,27 @@ TEST(Inotify, HardlinksReuseSameWatch) { Event(IN_DELETE, root_wd, Basename(file1_path))})); } +// Calling mkdir within "parent/child" should generate an event for child, but +// not parent. TEST(Inotify, MkdirGeneratesCreateEventWithDirFlag) { - const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const TempPath child = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(parent.path())); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - const int root_wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); + ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(fd.get(), parent.path(), IN_ALL_EVENTS)); + const int child_wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(fd.get(), child.path(), IN_ALL_EVENTS)); - const TempPath dir1(NewTempAbsPathInDir(root.path())); + const TempPath dir1(NewTempAbsPathInDir(child.path())); ASSERT_THAT(mkdir(dir1.path().c_str(), 0777), SyscallSucceeds()); const std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); ASSERT_THAT( events, - Are({Event(IN_CREATE | IN_ISDIR, root_wd, Basename(dir1.path()))})); + Are({Event(IN_CREATE | IN_ISDIR, child_wd, Basename(dir1.path()))})); } TEST(Inotify, MultipleInotifyInstancesAndWatchesAllGetEvents) { @@ -1419,20 +1504,26 @@ TEST(Inotify, DuplicateWatchReturnsSameWatchDescriptor) { TEST(Inotify, UnmatchedEventsAreDiscarded) { const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const TempPath file1 = + TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch(fd.get(), file1.path(), IN_ACCESS)); + const int wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(fd.get(), file1.path(), IN_ACCESS)); - const FileDescriptor file1_fd = + FileDescriptor file1_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_WRONLY)); - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); + std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); // We only asked for access events, the open event should be discarded. ASSERT_THAT(events, Are({})); + + // IN_IGNORED events are always generated, regardless of the mask. + file1_fd.reset(); + file1.reset(); + events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); + ASSERT_THAT(events, Are({Event(IN_IGNORED, wd)})); } TEST(Inotify, AddWatchWithInvalidEventMaskFails) { @@ -1591,6 +1682,754 @@ TEST(Inotify, EpollNoDeadlock) { } } +TEST(Inotify, Fallocate) { + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); + + const FileDescriptor inotify_fd = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + const int wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(inotify_fd.get(), file.path(), IN_ALL_EVENTS)); + + // Do an arbitrary modification with fallocate. + ASSERT_THAT(RetryEINTR(fallocate)(fd.get(), 0, 0, 123), SyscallSucceeds()); + std::vector<Event> events = + ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + EXPECT_THAT(events, Are({Event(IN_MODIFY, wd)})); +} + +TEST(Inotify, Sendfile) { + SKIP_IF(IsRunningWithVFS1()); + + const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateFileWith(root.path(), "x", 0644)); + const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor in = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); + const FileDescriptor out = + ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY)); + + // Create separate inotify instances for the in and out fds. If both watches + // were on the same instance, we would have discrepancies between Linux and + // gVisor (order of events, duplicate events), which is not that important + // since inotify is asynchronous anyway. + const FileDescriptor in_inotify = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + const FileDescriptor out_inotify = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + const int in_wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(in_inotify.get(), in_file.path(), IN_ALL_EVENTS)); + const int out_wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(out_inotify.get(), out_file.path(), IN_ALL_EVENTS)); + + ASSERT_THAT(sendfile(out.get(), in.get(), /*offset=*/nullptr, 1), + SyscallSucceeds()); + + // Expect a single access event and a single modify event. + std::vector<Event> in_events = + ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(in_inotify.get())); + std::vector<Event> out_events = + ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(out_inotify.get())); + EXPECT_THAT(in_events, Are({Event(IN_ACCESS, in_wd)})); + EXPECT_THAT(out_events, Are({Event(IN_MODIFY, out_wd)})); +} + +// On Linux, inotify behavior is not very consistent with splice(2). We try our +// best to emulate Linux for very basic calls to splice. +TEST(Inotify, SpliceOnWatchTarget) { + int pipes[2]; + ASSERT_THAT(pipe2(pipes, O_NONBLOCK), SyscallSucceeds()); + + const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const FileDescriptor inotify_fd = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + dir.path(), "some content", TempPath::kDefaultFileMode)); + + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); + const int dir_wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(inotify_fd.get(), dir.path(), IN_ALL_EVENTS)); + const int file_wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(inotify_fd.get(), file.path(), IN_ALL_EVENTS)); + + EXPECT_THAT(splice(fd.get(), nullptr, pipes[1], nullptr, 1, /*flags=*/0), + SyscallSucceedsWithValue(1)); + + // Surprisingly, events are not generated in Linux if we read from a file. + std::vector<Event> events = + ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + ASSERT_THAT(events, Are({})); + + EXPECT_THAT(splice(pipes[0], nullptr, fd.get(), nullptr, 1, /*flags=*/0), + SyscallSucceedsWithValue(1)); + + events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + ASSERT_THAT(events, Are({ + Event(IN_MODIFY, dir_wd, Basename(file.path())), + Event(IN_MODIFY, file_wd), + })); +} + +TEST(Inotify, SpliceOnInotifyFD) { + int pipes[2]; + ASSERT_THAT(pipe2(pipes, O_NONBLOCK), SyscallSucceeds()); + + const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + const TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + root.path(), "some content", TempPath::kDefaultFileMode)); + + const FileDescriptor file1_fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY)); + const int watcher = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS)); + + char buf; + EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds()); + + EXPECT_THAT(splice(fd.get(), nullptr, pipes[1], nullptr, + sizeof(struct inotify_event) + 1, SPLICE_F_NONBLOCK), + SyscallSucceedsWithValue(sizeof(struct inotify_event))); + + const FileDescriptor read_fd(pipes[0]); + const std::vector<Event> events = + ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(read_fd.get())); + ASSERT_THAT(events, Are({Event(IN_ACCESS, watcher)})); +} + +// Watches on a parent should not be triggered by actions on a hard link to one +// of its children that has a different parent. +TEST(Inotify, LinkOnOtherParent) { + const TempPath dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const TempPath dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const TempPath file = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path())); + std::string link_path = NewTempAbsPathInDir(dir2.path()); + + const int rc = link(file.path().c_str(), link_path.c_str()); + // NOTE(b/34861058): link(2) is only supported on tmpfs in the sandbox. + SKIP_IF(IsRunningOnGvisor() && rc != 0 && + (errno == EPERM || errno == ENOENT)); + ASSERT_THAT(rc, SyscallSucceeds()); + + const FileDescriptor inotify_fd = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(inotify_fd.get(), dir1.path(), IN_ALL_EVENTS)); + + // Perform various actions on the link outside of dir1, which should trigger + // no inotify events. + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(link_path.c_str(), O_RDWR)); + int val = 0; + ASSERT_THAT(write(fd.get(), &val, sizeof(val)), SyscallSucceeds()); + ASSERT_THAT(read(fd.get(), &val, sizeof(val)), SyscallSucceeds()); + ASSERT_THAT(ftruncate(fd.get(), 12345), SyscallSucceeds()); + ASSERT_THAT(unlink(link_path.c_str()), SyscallSucceeds()); + const std::vector<Event> events = + ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + EXPECT_THAT(events, Are({})); +} + +TEST(Inotify, Xattr) { + // TODO(gvisor.dev/issue/1636): Support extended attributes in runsc gofer. + SKIP_IF(IsRunningOnGvisor()); + + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const std::string path = file.path(); + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_RDWR)); + const FileDescriptor inotify_fd = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + const int wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(inotify_fd.get(), path, IN_ALL_EVENTS)); + + const char* cpath = path.c_str(); + const char* name = "user.test"; + int val = 123; + ASSERT_THAT(setxattr(cpath, name, &val, sizeof(val), /*flags=*/0), + SyscallSucceeds()); + std::vector<Event> events = + ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + EXPECT_THAT(events, Are({Event(IN_ATTRIB, wd)})); + + ASSERT_THAT(getxattr(cpath, name, &val, sizeof(val)), SyscallSucceeds()); + events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + EXPECT_THAT(events, Are({})); + + char list[100]; + ASSERT_THAT(listxattr(cpath, list, sizeof(list)), SyscallSucceeds()); + events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + EXPECT_THAT(events, Are({})); + + ASSERT_THAT(removexattr(cpath, name), SyscallSucceeds()); + events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + EXPECT_THAT(events, Are({Event(IN_ATTRIB, wd)})); + + ASSERT_THAT(fsetxattr(fd.get(), name, &val, sizeof(val), /*flags=*/0), + SyscallSucceeds()); + events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + EXPECT_THAT(events, Are({Event(IN_ATTRIB, wd)})); + + ASSERT_THAT(fgetxattr(fd.get(), name, &val, sizeof(val)), SyscallSucceeds()); + events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + EXPECT_THAT(events, Are({})); + + ASSERT_THAT(flistxattr(fd.get(), list, sizeof(list)), SyscallSucceeds()); + events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + EXPECT_THAT(events, Are({})); + + ASSERT_THAT(fremovexattr(fd.get(), name), SyscallSucceeds()); + events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + EXPECT_THAT(events, Are({Event(IN_ATTRIB, wd)})); +} + +TEST(Inotify, Exec) { + const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const TempPath bin = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(dir.path(), "/bin/true")); + + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + const int wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(fd.get(), bin.path(), IN_ALL_EVENTS)); + + // Perform exec. + ScopedThread t([&bin]() { + ASSERT_THAT(execl(bin.path().c_str(), bin.path().c_str(), (char*)nullptr), + SyscallSucceeds()); + }); + t.Join(); + + std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); + EXPECT_THAT(events, Are({Event(IN_OPEN, wd), Event(IN_ACCESS, wd)})); +} + +// Watches without IN_EXCL_UNLINK, should continue to emit events for file +// descriptors after their corresponding files have been unlinked. +// +// We need to disable S/R because there are filesystems where we cannot re-open +// fds to an unlinked file across S/R, e.g. gofer-backed filesytems. +TEST(Inotify, IncludeUnlinkedFile_NoRandomSave) { + const DisableSave ds; + + const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateFileWith(dir.path(), "123", TempPath::kDefaultFileMode)); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); + + const FileDescriptor inotify_fd = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + const int dir_wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(inotify_fd.get(), dir.path(), IN_ALL_EVENTS)); + const int file_wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(inotify_fd.get(), file.path(), IN_ALL_EVENTS)); + + ASSERT_THAT(unlink(file.path().c_str()), SyscallSucceeds()); + int val = 0; + ASSERT_THAT(read(fd.get(), &val, sizeof(val)), SyscallSucceeds()); + ASSERT_THAT(write(fd.get(), &val, sizeof(val)), SyscallSucceeds()); + std::vector<Event> events = + ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + EXPECT_THAT(events, Are({ + Event(IN_ATTRIB, file_wd), + Event(IN_DELETE, dir_wd, Basename(file.path())), + Event(IN_ACCESS, dir_wd, Basename(file.path())), + Event(IN_ACCESS, file_wd), + Event(IN_MODIFY, dir_wd, Basename(file.path())), + Event(IN_MODIFY, file_wd), + })); + + fd.reset(); + events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + EXPECT_THAT(events, Are({ + Event(IN_CLOSE_WRITE, dir_wd, Basename(file.path())), + Event(IN_CLOSE_WRITE, file_wd), + Event(IN_DELETE_SELF, file_wd), + Event(IN_IGNORED, file_wd), + })); +} + +// Watches created with IN_EXCL_UNLINK will stop emitting events on fds for +// children that have already been unlinked. +// +// We need to disable S/R because there are filesystems where we cannot re-open +// fds to an unlinked file across S/R, e.g. gofer-backed filesytems. +TEST(Inotify, ExcludeUnlink_NoRandomSave) { + const DisableSave ds; + // TODO(gvisor.dev/issue/1624): This test fails on VFS1. + SKIP_IF(IsRunningWithVFS1()); + + const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const TempPath file = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); + + const FileDescriptor inotify_fd = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + const int dir_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch( + inotify_fd.get(), dir.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK)); + const int file_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch( + inotify_fd.get(), file.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK)); + + // Unlink the child, which should cause further operations on the open file + // descriptor to be ignored. + ASSERT_THAT(unlink(file.path().c_str()), SyscallSucceeds()); + int val = 0; + ASSERT_THAT(write(fd.get(), &val, sizeof(val)), SyscallSucceeds()); + ASSERT_THAT(read(fd.get(), &val, sizeof(val)), SyscallSucceeds()); + std::vector<Event> events = + ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + EXPECT_THAT(events, Are({ + Event(IN_ATTRIB, file_wd), + Event(IN_DELETE, dir_wd, Basename(file.path())), + })); + + fd.reset(); + events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + ASSERT_THAT(events, Are({ + Event(IN_DELETE_SELF, file_wd), + Event(IN_IGNORED, file_wd), + })); +} + +// We need to disable S/R because there are filesystems where we cannot re-open +// fds to an unlinked file across S/R, e.g. gofer-backed filesytems. +TEST(Inotify, ExcludeUnlinkDirectory_NoRandomSave) { + // TODO(gvisor.dev/issue/1624): This test fails on VFS1. Remove once VFS1 is + // deleted. + SKIP_IF(IsRunningWithVFS1()); + + const DisableSave ds; + + const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + TempPath dir = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(parent.path())); + std::string dirPath = dir.path(); + const FileDescriptor inotify_fd = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dirPath.c_str(), O_RDONLY | O_DIRECTORY)); + const int parent_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch( + inotify_fd.get(), parent.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK)); + const int self_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch( + inotify_fd.get(), dir.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK)); + + // Unlink the dir, and then close the open fd. + ASSERT_THAT(rmdir(dirPath.c_str()), SyscallSucceeds()); + dir.reset(); + + std::vector<Event> events = + ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + // No close event should appear. + ASSERT_THAT(events, + Are({Event(IN_DELETE | IN_ISDIR, parent_wd, Basename(dirPath))})); + + fd.reset(); + events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + ASSERT_THAT(events, Are({ + Event(IN_DELETE_SELF, self_wd), + Event(IN_IGNORED, self_wd), + })); +} + +// If "dir/child" and "dir/child2" are links to the same file, and "dir/child" +// is unlinked, a watch on "dir" with IN_EXCL_UNLINK will exclude future events +// for fds on "dir/child" but not "dir/child2". +// +// We need to disable S/R because there are filesystems where we cannot re-open +// fds to an unlinked file across S/R, e.g. gofer-backed filesytems. +TEST(Inotify, ExcludeUnlinkMultipleChildren_NoRandomSave) { + const DisableSave ds; + // TODO(gvisor.dev/issue/1624): This test fails on VFS1. + SKIP_IF(IsRunningWithVFS1()); + + const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const TempPath file = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); + std::string path1 = file.path(); + std::string path2 = NewTempAbsPathInDir(dir.path()); + + const int rc = link(path1.c_str(), path2.c_str()); + // NOTE(b/34861058): link(2) is only supported on tmpfs in the sandbox. + SKIP_IF(IsRunningOnGvisor() && rc != 0 && + (errno == EPERM || errno == ENOENT)); + ASSERT_THAT(rc, SyscallSucceeds()); + const FileDescriptor fd1 = + ASSERT_NO_ERRNO_AND_VALUE(Open(path1.c_str(), O_RDWR)); + const FileDescriptor fd2 = + ASSERT_NO_ERRNO_AND_VALUE(Open(path2.c_str(), O_RDWR)); + + const FileDescriptor inotify_fd = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + const int wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch( + inotify_fd.get(), dir.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK)); + + // After unlinking path1, only events on the fd for path2 should be generated. + ASSERT_THAT(unlink(path1.c_str()), SyscallSucceeds()); + ASSERT_THAT(write(fd1.get(), "x", 1), SyscallSucceeds()); + ASSERT_THAT(write(fd2.get(), "x", 1), SyscallSucceeds()); + + const std::vector<Event> events = + ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + EXPECT_THAT(events, Are({ + Event(IN_DELETE, wd, Basename(path1)), + Event(IN_MODIFY, wd, Basename(path2)), + })); +} + +// On native Linux, actions of data type FSNOTIFY_EVENT_INODE are not affected +// by IN_EXCL_UNLINK (see +// fs/notify/inotify/inotify_fsnotify.c:inotify_handle_event). Inode-level +// events include changes to metadata and extended attributes. +// +// We need to disable S/R because there are filesystems where we cannot re-open +// fds to an unlinked file across S/R, e.g. gofer-backed filesytems. +TEST(Inotify, ExcludeUnlinkInodeEvents_NoRandomSave) { + const DisableSave ds; + + const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const TempPath file = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); + + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path().c_str(), O_RDWR)); + + // NOTE(b/157163751): Create another link before unlinking. This is needed for + // the gofer filesystem in gVisor, where open fds will not work once the link + // count hits zero. In VFS2, we end up skipping the gofer test anyway, because + // hard links are not supported for gofer fs. + if (IsRunningOnGvisor()) { + std::string link_path = NewTempAbsPath(); + const int rc = link(file.path().c_str(), link_path.c_str()); + // NOTE(b/34861058): link(2) is only supported on tmpfs in the sandbox. + SKIP_IF(rc != 0 && (errno == EPERM || errno == ENOENT)); + ASSERT_THAT(rc, SyscallSucceeds()); + } + + const FileDescriptor inotify_fd = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + const int dir_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch( + inotify_fd.get(), dir.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK)); + const int file_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch( + inotify_fd.get(), file.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK)); + + // Even after unlinking, inode-level operations will trigger events regardless + // of IN_EXCL_UNLINK. + ASSERT_THAT(unlink(file.path().c_str()), SyscallSucceeds()); + + // Perform various actions on fd. + ASSERT_THAT(ftruncate(fd.get(), 12345), SyscallSucceeds()); + std::vector<Event> events = + ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + EXPECT_THAT(events, Are({ + Event(IN_ATTRIB, file_wd), + Event(IN_DELETE, dir_wd, Basename(file.path())), + Event(IN_MODIFY, dir_wd, Basename(file.path())), + Event(IN_MODIFY, file_wd), + })); + + const struct timeval times[2] = {{1, 0}, {2, 0}}; + ASSERT_THAT(futimes(fd.get(), times), SyscallSucceeds()); + events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + EXPECT_THAT(events, Are({ + Event(IN_ATTRIB, dir_wd, Basename(file.path())), + Event(IN_ATTRIB, file_wd), + })); + + // S/R is disabled on this entire test due to behavior with unlink; it must + // also be disabled after this point because of fchmod. + ASSERT_THAT(fchmod(fd.get(), 0777), SyscallSucceeds()); + events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + EXPECT_THAT(events, Are({ + Event(IN_ATTRIB, dir_wd, Basename(file.path())), + Event(IN_ATTRIB, file_wd), + })); +} + +TEST(Inotify, OneShot) { + // TODO(gvisor.dev/issue/1624): IN_ONESHOT not supported in VFS1. + SKIP_IF(IsRunningWithVFS1()); + + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor inotify_fd = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + + const int wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(inotify_fd.get(), file.path(), IN_MODIFY | IN_ONESHOT)); + + // Open an fd, write to it, and then close it. + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY)); + ASSERT_THAT(write(fd.get(), "x", 1), SyscallSucceedsWithValue(1)); + fd.reset(); + + // We should get a single event followed by IN_IGNORED indicating removal + // of the one-shot watch. Prior activity (i.e. open) that is not in the mask + // should not trigger removal, and activity after removal (i.e. close) should + // not generate events. + std::vector<Event> events = + ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + EXPECT_THAT(events, Are({ + Event(IN_MODIFY, wd), + Event(IN_IGNORED, wd), + })); + + // The watch should already have been removed. + EXPECT_THAT(inotify_rm_watch(inotify_fd.get(), wd), + SyscallFailsWithErrno(EINVAL)); +} + +// This test helps verify that the lock order of filesystem and inotify locks +// is respected when inotify instances and watch targets are concurrently being +// destroyed. +TEST(InotifyTest, InotifyAndTargetDestructionDoNotDeadlock_NoRandomSave) { + const DisableSave ds; // Too many syscalls. + + // A file descriptor protected by a mutex. This ensures that while a + // descriptor is in use, it cannot be closed and reused for a different file + // description. + struct atomic_fd { + int fd; + absl::Mutex mu; + }; + + // Set up initial inotify instances. + constexpr int num_fds = 3; + std::vector<atomic_fd> fds(num_fds); + for (int i = 0; i < num_fds; i++) { + int fd; + ASSERT_THAT(fd = inotify_init1(IN_NONBLOCK), SyscallSucceeds()); + fds[i].fd = fd; + } + + // Set up initial watch targets. + std::vector<std::string> paths; + for (int i = 0; i < 3; i++) { + paths.push_back(NewTempAbsPath()); + ASSERT_THAT(mknod(paths[i].c_str(), S_IFREG | 0600, 0), SyscallSucceeds()); + } + + constexpr absl::Duration runtime = absl::Seconds(4); + + // Constantly replace each inotify instance with a new one. + auto replace_fds = [&] { + for (auto start = absl::Now(); absl::Now() - start < runtime;) { + for (auto& afd : fds) { + int new_fd; + ASSERT_THAT(new_fd = inotify_init1(IN_NONBLOCK), SyscallSucceeds()); + absl::MutexLock l(&afd.mu); + ASSERT_THAT(close(afd.fd), SyscallSucceeds()); + afd.fd = new_fd; + for (auto& p : paths) { + // inotify_add_watch may fail if the file at p was deleted. + ASSERT_THAT(inotify_add_watch(afd.fd, p.c_str(), IN_ALL_EVENTS), + AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(ENOENT))); + } + } + sched_yield(); + } + }; + + std::list<ScopedThread> ts; + for (int i = 0; i < 3; i++) { + ts.emplace_back(replace_fds); + } + + // Constantly replace each watch target with a new one. + for (auto start = absl::Now(); absl::Now() - start < runtime;) { + for (auto& p : paths) { + ASSERT_THAT(unlink(p.c_str()), SyscallSucceeds()); + ASSERT_THAT(mknod(p.c_str(), S_IFREG | 0600, 0), SyscallSucceeds()); + } + sched_yield(); + } +} + +// This test helps verify that the lock order of filesystem and inotify locks +// is respected when adding/removing watches occurs concurrently with the +// removal of their targets. +TEST(InotifyTest, AddRemoveUnlinkDoNotDeadlock_NoRandomSave) { + const DisableSave ds; // Too many syscalls. + + // Set up inotify instances. + constexpr int num_fds = 3; + std::vector<int> fds(num_fds); + for (int i = 0; i < num_fds; i++) { + ASSERT_THAT(fds[i] = inotify_init1(IN_NONBLOCK), SyscallSucceeds()); + } + + // Set up initial watch targets. + std::vector<std::string> paths; + for (int i = 0; i < 3; i++) { + paths.push_back(NewTempAbsPath()); + ASSERT_THAT(mknod(paths[i].c_str(), S_IFREG | 0600, 0), SyscallSucceeds()); + } + + constexpr absl::Duration runtime = absl::Seconds(1); + + // Constantly add/remove watches for each inotify instance/watch target pair. + auto add_remove_watches = [&] { + for (auto start = absl::Now(); absl::Now() - start < runtime;) { + for (int fd : fds) { + for (auto& p : paths) { + // Do not assert on inotify_add_watch and inotify_rm_watch. They may + // fail if the file at p was deleted. inotify_add_watch may also fail + // if another thread beat us to adding a watch. + const int wd = inotify_add_watch(fd, p.c_str(), IN_ALL_EVENTS); + if (wd > 0) { + inotify_rm_watch(fd, wd); + } + } + } + sched_yield(); + } + }; + + std::list<ScopedThread> ts; + for (int i = 0; i < 15; i++) { + ts.emplace_back(add_remove_watches); + } + + // Constantly replace each watch target with a new one. + for (auto start = absl::Now(); absl::Now() - start < runtime;) { + for (auto& p : paths) { + ASSERT_THAT(unlink(p.c_str()), SyscallSucceeds()); + ASSERT_THAT(mknod(p.c_str(), S_IFREG | 0600, 0), SyscallSucceeds()); + } + sched_yield(); + } +} + +// This test helps verify that the lock order of filesystem and inotify locks +// is respected when many inotify events and filesystem operations occur +// simultaneously. +TEST(InotifyTest, NotifyNoDeadlock_NoRandomSave) { + const DisableSave ds; // Too many syscalls. + + const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const std::string dir = parent.path(); + + // mu protects file, which will change on rename. + absl::Mutex mu; + std::string file = NewTempAbsPathInDir(dir); + ASSERT_THAT(mknod(file.c_str(), 0644 | S_IFREG, 0), SyscallSucceeds()); + + const absl::Duration runtime = absl::Milliseconds(300); + + // Add/remove watches on dir and file. + ScopedThread add_remove_watches([&] { + const FileDescriptor ifd = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + int dir_wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(ifd.get(), dir, IN_ALL_EVENTS)); + int file_wd; + { + absl::ReaderMutexLock l(&mu); + file_wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(ifd.get(), file, IN_ALL_EVENTS)); + } + for (auto start = absl::Now(); absl::Now() - start < runtime;) { + ASSERT_THAT(inotify_rm_watch(ifd.get(), file_wd), SyscallSucceeds()); + ASSERT_THAT(inotify_rm_watch(ifd.get(), dir_wd), SyscallSucceeds()); + dir_wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(ifd.get(), dir, IN_ALL_EVENTS)); + { + absl::ReaderMutexLock l(&mu); + file_wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(ifd.get(), file, IN_ALL_EVENTS)); + } + sched_yield(); + } + }); + + // Modify attributes on dir and file. + ScopedThread stats([&] { + int fd, dir_fd; + { + absl::ReaderMutexLock l(&mu); + ASSERT_THAT(fd = open(file.c_str(), O_RDONLY), SyscallSucceeds()); + } + ASSERT_THAT(dir_fd = open(dir.c_str(), O_RDONLY | O_DIRECTORY), + SyscallSucceeds()); + const struct timeval times[2] = {{1, 0}, {2, 0}}; + + for (auto start = absl::Now(); absl::Now() - start < runtime;) { + { + absl::ReaderMutexLock l(&mu); + EXPECT_THAT(utimes(file.c_str(), times), SyscallSucceeds()); + } + EXPECT_THAT(futimes(fd, times), SyscallSucceeds()); + EXPECT_THAT(utimes(dir.c_str(), times), SyscallSucceeds()); + EXPECT_THAT(futimes(dir_fd, times), SyscallSucceeds()); + sched_yield(); + } + }); + + // Modify extended attributes on dir and file. + ScopedThread xattrs([&] { + // TODO(gvisor.dev/issue/1636): Support extended attributes in runsc gofer. + if (!IsRunningOnGvisor()) { + int fd; + { + absl::ReaderMutexLock l(&mu); + ASSERT_THAT(fd = open(file.c_str(), O_RDONLY), SyscallSucceeds()); + } + + const char* name = "user.test"; + int val = 123; + for (auto start = absl::Now(); absl::Now() - start < runtime;) { + { + absl::ReaderMutexLock l(&mu); + ASSERT_THAT( + setxattr(file.c_str(), name, &val, sizeof(val), /*flags=*/0), + SyscallSucceeds()); + ASSERT_THAT(removexattr(file.c_str(), name), SyscallSucceeds()); + } + + ASSERT_THAT(fsetxattr(fd, name, &val, sizeof(val), /*flags=*/0), + SyscallSucceeds()); + ASSERT_THAT(fremovexattr(fd, name), SyscallSucceeds()); + sched_yield(); + } + } + }); + + // Read and write file's contents. Read and write dir's entries. + ScopedThread read_write([&] { + int fd; + { + absl::ReaderMutexLock l(&mu); + ASSERT_THAT(fd = open(file.c_str(), O_RDWR), SyscallSucceeds()); + } + for (auto start = absl::Now(); absl::Now() - start < runtime;) { + int val = 123; + ASSERT_THAT(write(fd, &val, sizeof(val)), SyscallSucceeds()); + ASSERT_THAT(read(fd, &val, sizeof(val)), SyscallSucceeds()); + TempPath new_file = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir)); + ASSERT_NO_ERRNO(ListDir(dir, false)); + new_file.reset(); + sched_yield(); + } + }); + + // Rename file. + for (auto start = absl::Now(); absl::Now() - start < runtime;) { + const std::string new_path = NewTempAbsPathInDir(dir); + { + absl::WriterMutexLock l(&mu); + ASSERT_THAT(rename(file.c_str(), new_path.c_str()), SyscallSucceeds()); + file = new_path; + } + sched_yield(); + } +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/ioctl.cc b/test/syscalls/linux/ioctl.cc index c4f8bff08..b0a07a064 100644 --- a/test/syscalls/linux/ioctl.cc +++ b/test/syscalls/linux/ioctl.cc @@ -215,7 +215,8 @@ TEST_F(IoctlTest, FIOASYNCSelfTarget2) { auto mask_cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGIO)); - pid_t pid = getpid(); + pid_t pid = -1; + EXPECT_THAT(pid = getpid(), SyscallSucceeds()); EXPECT_THAT(ioctl(pair->second_fd(), FIOSETOWN, &pid), SyscallSucceeds()); int set = 1; diff --git a/test/syscalls/linux/ip_socket_test_util.cc b/test/syscalls/linux/ip_socket_test_util.cc index 57e99596f..98d07ae85 100644 --- a/test/syscalls/linux/ip_socket_test_util.cc +++ b/test/syscalls/linux/ip_socket_test_util.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "test/syscalls/linux/ip_socket_test_util.h" + #include <net/if.h> #include <netinet/in.h> -#include <sys/ioctl.h> #include <sys/socket.h> -#include <cstring> -#include "test/syscalls/linux/ip_socket_test_util.h" +#include <cstring> namespace gvisor { namespace testing { @@ -34,12 +34,11 @@ uint16_t PortFromInetSockaddr(const struct sockaddr* addr) { } PosixErrorOr<int> InterfaceIndex(std::string name) { - // TODO(igudger): Consider using netlink. - ifreq req = {}; - memcpy(req.ifr_name, name.c_str(), name.size()); - ASSIGN_OR_RETURN_ERRNO(auto sock, Socket(AF_INET, SOCK_DGRAM, 0)); - RETURN_ERROR_IF_SYSCALL_FAIL(ioctl(sock.get(), SIOCGIFINDEX, &req)); - return req.ifr_ifindex; + int index = if_nametoindex(name.c_str()); + if (index) { + return index; + } + return PosixError(errno); } namespace { @@ -78,6 +77,33 @@ SocketPairKind DualStackTCPAcceptBindSocketPair(int type) { /* dual_stack = */ true)}; } +SocketPairKind IPv6TCPAcceptBindPersistentListenerSocketPair(int type) { + std::string description = + absl::StrCat(DescribeSocketType(type), "connected IPv6 TCP socket"); + return SocketPairKind{description, AF_INET6, type | SOCK_STREAM, IPPROTO_TCP, + TCPAcceptBindPersistentListenerSocketPairCreator( + AF_INET6, type | SOCK_STREAM, 0, + /* dual_stack = */ false)}; +} + +SocketPairKind IPv4TCPAcceptBindPersistentListenerSocketPair(int type) { + std::string description = + absl::StrCat(DescribeSocketType(type), "connected IPv4 TCP socket"); + return SocketPairKind{description, AF_INET, type | SOCK_STREAM, IPPROTO_TCP, + TCPAcceptBindPersistentListenerSocketPairCreator( + AF_INET, type | SOCK_STREAM, 0, + /* dual_stack = */ false)}; +} + +SocketPairKind DualStackTCPAcceptBindPersistentListenerSocketPair(int type) { + std::string description = + absl::StrCat(DescribeSocketType(type), "connected dual stack TCP socket"); + return SocketPairKind{description, AF_INET6, type | SOCK_STREAM, IPPROTO_TCP, + TCPAcceptBindPersistentListenerSocketPairCreator( + AF_INET6, type | SOCK_STREAM, 0, + /* dual_stack = */ true)}; +} + SocketPairKind IPv6UDPBidirectionalBindSocketPair(int type) { std::string description = absl::StrCat(DescribeSocketType(type), "connected IPv6 UDP socket"); @@ -149,17 +175,17 @@ SocketKind IPv6TCPUnboundSocket(int type) { PosixError IfAddrHelper::Load() { Release(); RETURN_ERROR_IF_SYSCALL_FAIL(getifaddrs(&ifaddr_)); - return PosixError(0); + return NoError(); } void IfAddrHelper::Release() { if (ifaddr_) { freeifaddrs(ifaddr_); + ifaddr_ = nullptr; } - ifaddr_ = nullptr; } -std::vector<std::string> IfAddrHelper::InterfaceList(int family) { +std::vector<std::string> IfAddrHelper::InterfaceList(int family) const { std::vector<std::string> names; for (auto ifa = ifaddr_; ifa != NULL; ifa = ifa->ifa_next) { if (ifa->ifa_addr == NULL || ifa->ifa_addr->sa_family != family) { @@ -170,7 +196,7 @@ std::vector<std::string> IfAddrHelper::InterfaceList(int family) { return names; } -sockaddr* IfAddrHelper::GetAddr(int family, std::string name) { +const sockaddr* IfAddrHelper::GetAddr(int family, std::string name) const { for (auto ifa = ifaddr_; ifa != NULL; ifa = ifa->ifa_next) { if (ifa->ifa_addr == NULL || ifa->ifa_addr->sa_family != family) { continue; @@ -182,28 +208,28 @@ sockaddr* IfAddrHelper::GetAddr(int family, std::string name) { return nullptr; } -PosixErrorOr<int> IfAddrHelper::GetIndex(std::string name) { +PosixErrorOr<int> IfAddrHelper::GetIndex(std::string name) const { return InterfaceIndex(name); } -std::string GetAddr4Str(in_addr* a) { +std::string GetAddr4Str(const in_addr* a) { char str[INET_ADDRSTRLEN]; inet_ntop(AF_INET, a, str, sizeof(str)); return std::string(str); } -std::string GetAddr6Str(in6_addr* a) { +std::string GetAddr6Str(const in6_addr* a) { char str[INET6_ADDRSTRLEN]; inet_ntop(AF_INET6, a, str, sizeof(str)); return std::string(str); } -std::string GetAddrStr(sockaddr* a) { +std::string GetAddrStr(const sockaddr* a) { if (a->sa_family == AF_INET) { - auto src = &(reinterpret_cast<sockaddr_in*>(a)->sin_addr); + auto src = &(reinterpret_cast<const sockaddr_in*>(a)->sin_addr); return GetAddr4Str(src); } else if (a->sa_family == AF_INET6) { - auto src = &(reinterpret_cast<sockaddr_in6*>(a)->sin6_addr); + auto src = &(reinterpret_cast<const sockaddr_in6*>(a)->sin6_addr); return GetAddr6Str(src); } return std::string("<invalid>"); diff --git a/test/syscalls/linux/ip_socket_test_util.h b/test/syscalls/linux/ip_socket_test_util.h index 072230d85..9c3859fcd 100644 --- a/test/syscalls/linux/ip_socket_test_util.h +++ b/test/syscalls/linux/ip_socket_test_util.h @@ -26,25 +26,6 @@ namespace gvisor { namespace testing { -// Possible values of the "st" field in a /proc/net/{tcp,udp} entry. Source: -// Linux kernel, include/net/tcp_states.h. -enum { - TCP_ESTABLISHED = 1, - TCP_SYN_SENT, - TCP_SYN_RECV, - TCP_FIN_WAIT1, - TCP_FIN_WAIT2, - TCP_TIME_WAIT, - TCP_CLOSE, - TCP_CLOSE_WAIT, - TCP_LAST_ACK, - TCP_LISTEN, - TCP_CLOSING, - TCP_NEW_SYN_RECV, - - TCP_MAX_STATES -}; - // Extracts the IP address from an inet sockaddr in network byte order. uint32_t IPFromInetSockaddr(const struct sockaddr* addr); @@ -69,6 +50,21 @@ SocketPairKind IPv4TCPAcceptBindSocketPair(int type); // given type bound to the IPv4 loopback. SocketPairKind DualStackTCPAcceptBindSocketPair(int type); +// IPv6TCPAcceptBindPersistentListenerSocketPair is like +// IPv6TCPAcceptBindSocketPair except it uses a persistent listening socket to +// create all socket pairs. +SocketPairKind IPv6TCPAcceptBindPersistentListenerSocketPair(int type); + +// IPv4TCPAcceptBindPersistentListenerSocketPair is like +// IPv4TCPAcceptBindSocketPair except it uses a persistent listening socket to +// create all socket pairs. +SocketPairKind IPv4TCPAcceptBindPersistentListenerSocketPair(int type); + +// DualStackTCPAcceptBindPersistentListenerSocketPair is like +// DualStackTCPAcceptBindSocketPair except it uses a persistent listening socket +// to create all socket pairs. +SocketPairKind DualStackTCPAcceptBindPersistentListenerSocketPair(int type); + // IPv6UDPBidirectionalBindSocketPair returns a SocketPairKind that represents // SocketPairs created with bind() and connect() syscalls with AF_INET6 and the // given type bound to the IPv6 loopback. @@ -88,20 +84,20 @@ SocketPairKind DualStackUDPBidirectionalBindSocketPair(int type); // SocketPairs created with AF_INET and the given type. SocketPairKind IPv4UDPUnboundSocketPair(int type); -// IPv4UDPUnboundSocketPair returns a SocketKind that represents -// a SimpleSocket created with AF_INET, SOCK_DGRAM, and the given type. +// IPv4UDPUnboundSocket returns a SocketKind that represents a SimpleSocket +// created with AF_INET, SOCK_DGRAM, and the given type. SocketKind IPv4UDPUnboundSocket(int type); -// IPv6UDPUnboundSocketPair returns a SocketKind that represents -// a SimpleSocket created with AF_INET6, SOCK_DGRAM, and the given type. +// IPv6UDPUnboundSocket returns a SocketKind that represents a SimpleSocket +// created with AF_INET6, SOCK_DGRAM, and the given type. SocketKind IPv6UDPUnboundSocket(int type); -// IPv4TCPUnboundSocketPair returns a SocketKind that represents -// a SimpleSocket created with AF_INET, SOCK_STREAM and the given type. +// IPv4TCPUnboundSocket returns a SocketKind that represents a SimpleSocket +// created with AF_INET, SOCK_STREAM and the given type. SocketKind IPv4TCPUnboundSocket(int type); -// IPv6TCPUnboundSocketPair returns a SocketKind that represents -// a SimpleSocket created with AF_INET6, SOCK_STREAM and the given type. +// IPv6TCPUnboundSocket returns a SocketKind that represents a SimpleSocket +// created with AF_INET6, SOCK_STREAM and the given type. SocketKind IPv6TCPUnboundSocket(int type); // IfAddrHelper is a helper class that determines the local interfaces present @@ -114,24 +110,24 @@ class IfAddrHelper { PosixError Load(); void Release(); - std::vector<std::string> InterfaceList(int family); + std::vector<std::string> InterfaceList(int family) const; - struct sockaddr* GetAddr(int family, std::string name); - PosixErrorOr<int> GetIndex(std::string name); + const sockaddr* GetAddr(int family, std::string name) const; + PosixErrorOr<int> GetIndex(std::string name) const; private: struct ifaddrs* ifaddr_; }; // GetAddr4Str returns the given IPv4 network address structure as a string. -std::string GetAddr4Str(in_addr* a); +std::string GetAddr4Str(const in_addr* a); // GetAddr6Str returns the given IPv6 network address structure as a string. -std::string GetAddr6Str(in6_addr* a); +std::string GetAddr6Str(const in6_addr* a); // GetAddrStr returns the given IPv4 or IPv6 network address structure as a // string. -std::string GetAddrStr(sockaddr* a); +std::string GetAddrStr(const sockaddr* a); } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/iptables.h b/test/syscalls/linux/iptables.h index 616bea550..0719c60a4 100644 --- a/test/syscalls/linux/iptables.h +++ b/test/syscalls/linux/iptables.h @@ -188,7 +188,7 @@ struct ipt_replace { unsigned int num_counters; // The unchanged values from each ipt_entry's counters. - struct xt_counters *counters; + struct xt_counters* counters; // The entries to write to the table. This will run past the size defined by // sizeof(srtuct ipt_replace); diff --git a/test/syscalls/linux/itimer.cc b/test/syscalls/linux/itimer.cc index 930d2b940..e397d5f57 100644 --- a/test/syscalls/linux/itimer.cc +++ b/test/syscalls/linux/itimer.cc @@ -246,7 +246,7 @@ int TestSIGPROFFairness(absl::Duration sleep) { // The number of samples on the main thread should be very low as it did // nothing. - TEST_CHECK(result.main_thread_samples < 60); + TEST_CHECK(result.main_thread_samples < 80); // Both workers should get roughly equal number of samples. TEST_CHECK(result.worker_samples.size() == 2); @@ -267,6 +267,20 @@ int TestSIGPROFFairness(absl::Duration sleep) { // Random save/restore is disabled as it introduces additional latency and // unpredictable distribution patterns. TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyActive_NoRandomSave) { + // On the KVM and ptrace platforms, switches between sentry and application + // context are sometimes extremely slow, causing the itimer to send SIGPROF to + // a thread that either already has one pending or has had SIGPROF delivered, + // but hasn't handled it yet (and thus therefore still has SIGPROF masked). In + // either case, since itimer signals are group-directed, signal sending falls + // back to notifying the thread group leader. ItimerSignalTest() fails if "too + // many" signals are delivered to the thread group leader, so these tests are + // flaky on these platforms. + // + // TODO(b/143247272): Clarify why context switches are so slow on KVM. + const auto gvisor_platform = GvisorPlatform(); + SKIP_IF(gvisor_platform == Platform::kKVM || + gvisor_platform == Platform::kPtrace); + pid_t child; int execve_errno; auto kill = ASSERT_NO_ERRNO_AND_VALUE( @@ -288,6 +302,11 @@ TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyActive_NoRandomSave) { // Random save/restore is disabled as it introduces additional latency and // unpredictable distribution patterns. TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyIdle_NoRandomSave) { + // See comment in DeliversSIGPROFToThreadsRoughlyFairlyActive. + const auto gvisor_platform = GvisorPlatform(); + SKIP_IF(gvisor_platform == Platform::kKVM || + gvisor_platform == Platform::kPtrace); + pid_t child; int execve_errno; auto kill = ASSERT_NO_ERRNO_AND_VALUE( @@ -343,6 +362,5 @@ int main(int argc, char** argv) { } gvisor::testing::TestInit(&argc, &argv); - - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/link.cc b/test/syscalls/linux/link.cc index dd5352954..544681168 100644 --- a/test/syscalls/linux/link.cc +++ b/test/syscalls/linux/link.cc @@ -55,7 +55,8 @@ TEST(LinkTest, CanCreateLinkFile) { const std::string newname = NewTempAbsPath(); // Get the initial link count. - uint64_t initial_link_count = ASSERT_NO_ERRNO_AND_VALUE(Links(oldfile.path())); + uint64_t initial_link_count = + ASSERT_NO_ERRNO_AND_VALUE(Links(oldfile.path())); EXPECT_THAT(link(oldfile.path().c_str(), newname.c_str()), SyscallSucceeds()); @@ -78,8 +79,13 @@ TEST(LinkTest, PermissionDenied) { // Make the file "unsafe" to link by making it only readable, but not // writable. - const auto oldfile = + const auto unwriteable_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0400)); + const std::string special_path = NewTempAbsPath(); + ASSERT_THAT(mkfifo(special_path.c_str(), 0666), SyscallSucceeds()); + const auto setuid_file = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0666 | S_ISUID)); + const std::string newname = NewTempAbsPath(); // Do setuid in a separate thread so that after finishing this test, the @@ -96,8 +102,14 @@ TEST(LinkTest, PermissionDenied) { EXPECT_THAT(syscall(SYS_setuid, absl::GetFlag(FLAGS_scratch_uid)), SyscallSucceeds()); - EXPECT_THAT(link(oldfile.path().c_str(), newname.c_str()), + EXPECT_THAT(link(unwriteable_file.path().c_str(), newname.c_str()), + SyscallFailsWithErrno(EPERM)); + EXPECT_THAT(link(special_path.c_str(), newname.c_str()), SyscallFailsWithErrno(EPERM)); + if (!IsRunningWithVFS1()) { + EXPECT_THAT(link(setuid_file.path().c_str(), newname.c_str()), + SyscallFailsWithErrno(EPERM)); + } }); } diff --git a/test/syscalls/linux/lseek.cc b/test/syscalls/linux/lseek.cc index a8af8e545..6ce1e6cc3 100644 --- a/test/syscalls/linux/lseek.cc +++ b/test/syscalls/linux/lseek.cc @@ -53,7 +53,7 @@ TEST(LseekTest, NegativeOffset) { // A 32-bit off_t is not large enough to represent an offset larger than // maximum file size on standard file systems, so it isn't possible to cause // overflow. -#ifdef __x86_64__ +#if defined(__x86_64__) || defined(__aarch64__) TEST(LseekTest, Overflow) { // HA! Classic Linux. We really should have an EOVERFLOW // here, since we're seeking to something that cannot be diff --git a/test/syscalls/linux/madvise.cc b/test/syscalls/linux/madvise.cc index 7fd0ea20c..5a1973f60 100644 --- a/test/syscalls/linux/madvise.cc +++ b/test/syscalls/linux/madvise.cc @@ -139,7 +139,7 @@ TEST(MadviseDontneedTest, IgnoresPermissions) { TEST(MadviseDontforkTest, AddressLength) { auto m = ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, MAP_PRIVATE)); - char *addr = static_cast<char *>(m.ptr()); + char* addr = static_cast<char*>(m.ptr()); // Address must be page aligned. EXPECT_THAT(madvise(addr + 1, kPageSize, MADV_DONTFORK), @@ -168,9 +168,9 @@ TEST(MadviseDontforkTest, DontforkShared) { Mapping m = ASSERT_NO_ERRNO_AND_VALUE(Mmap( nullptr, kPageSize * 2, PROT_READ | PROT_WRITE, MAP_SHARED, fd.get(), 0)); - const Mapping ms1 = Mapping(reinterpret_cast<void *>(m.addr()), kPageSize); + const Mapping ms1 = Mapping(reinterpret_cast<void*>(m.addr()), kPageSize); const Mapping ms2 = - Mapping(reinterpret_cast<void *>(m.addr() + kPageSize), kPageSize); + Mapping(reinterpret_cast<void*>(m.addr() + kPageSize), kPageSize); m.release(); ASSERT_THAT(madvise(ms2.ptr(), kPageSize, MADV_DONTFORK), SyscallSucceeds()); @@ -197,11 +197,11 @@ TEST(MadviseDontforkTest, DontforkAnonPrivate) { // Mmap three anonymous pages and MADV_DONTFORK the middle page. Mapping m = ASSERT_NO_ERRNO_AND_VALUE( MmapAnon(kPageSize * 3, PROT_READ | PROT_WRITE, MAP_PRIVATE)); - const Mapping mp1 = Mapping(reinterpret_cast<void *>(m.addr()), kPageSize); + const Mapping mp1 = Mapping(reinterpret_cast<void*>(m.addr()), kPageSize); const Mapping mp2 = - Mapping(reinterpret_cast<void *>(m.addr() + kPageSize), kPageSize); + Mapping(reinterpret_cast<void*>(m.addr() + kPageSize), kPageSize); const Mapping mp3 = - Mapping(reinterpret_cast<void *>(m.addr() + 2 * kPageSize), kPageSize); + Mapping(reinterpret_cast<void*>(m.addr() + 2 * kPageSize), kPageSize); m.release(); ASSERT_THAT(madvise(mp2.ptr(), kPageSize, MADV_DONTFORK), SyscallSucceeds()); diff --git a/test/syscalls/linux/memfd.cc b/test/syscalls/linux/memfd.cc index e57b49a4a..f8b7f7938 100644 --- a/test/syscalls/linux/memfd.cc +++ b/test/syscalls/linux/memfd.cc @@ -16,6 +16,7 @@ #include <fcntl.h> #include <linux/magic.h> #include <linux/memfd.h> +#include <linux/unistd.h> #include <string.h> #include <sys/mman.h> #include <sys/statfs.h> diff --git a/test/syscalls/linux/memory_accounting.cc b/test/syscalls/linux/memory_accounting.cc index ff2f49863..94aea4077 100644 --- a/test/syscalls/linux/memory_accounting.cc +++ b/test/syscalls/linux/memory_accounting.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <sys/mman.h> + #include <map> #include "gtest/gtest.h" diff --git a/test/syscalls/linux/mempolicy.cc b/test/syscalls/linux/mempolicy.cc index 9d5f47651..059fad598 100644 --- a/test/syscalls/linux/mempolicy.cc +++ b/test/syscalls/linux/mempolicy.cc @@ -43,17 +43,17 @@ namespace { #define MPOL_MF_MOVE (1 << 1) #define MPOL_MF_MOVE_ALL (1 << 2) -int get_mempolicy(int *policy, uint64_t *nmask, uint64_t maxnode, void *addr, +int get_mempolicy(int* policy, uint64_t* nmask, uint64_t maxnode, void* addr, int flags) { return syscall(SYS_get_mempolicy, policy, nmask, maxnode, addr, flags); } -int set_mempolicy(int mode, uint64_t *nmask, uint64_t maxnode) { +int set_mempolicy(int mode, uint64_t* nmask, uint64_t maxnode) { return syscall(SYS_set_mempolicy, mode, nmask, maxnode); } -int mbind(void *addr, unsigned long len, int mode, - const unsigned long *nodemask, unsigned long maxnode, +int mbind(void* addr, unsigned long len, int mode, + const unsigned long* nodemask, unsigned long maxnode, unsigned flags) { return syscall(SYS_mbind, addr, len, mode, nodemask, maxnode, flags); } @@ -68,7 +68,7 @@ Cleanup ScopedMempolicy() { // Temporarily change the memory policy for the calling thread within the // caller's scope. -PosixErrorOr<Cleanup> ScopedSetMempolicy(int mode, uint64_t *nmask, +PosixErrorOr<Cleanup> ScopedSetMempolicy(int mode, uint64_t* nmask, uint64_t maxnode) { if (set_mempolicy(mode, nmask, maxnode)) { return PosixError(errno, "set_mempolicy"); diff --git a/test/syscalls/linux/mkdir.cc b/test/syscalls/linux/mkdir.cc index cf138d328..4036a9275 100644 --- a/test/syscalls/linux/mkdir.cc +++ b/test/syscalls/linux/mkdir.cc @@ -18,10 +18,10 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "test/syscalls/linux/temp_umask.h" #include "test/util/capability_util.h" #include "test/util/fs_util.h" #include "test/util/temp_path.h" +#include "test/util/temp_umask.h" #include "test/util/test_util.h" namespace gvisor { @@ -36,21 +36,12 @@ class MkdirTest : public ::testing::Test { // TearDown unlinks created files. void TearDown() override { - // FIXME(edahlgren): We don't currently implement rmdir. - // We do this unconditionally because there's no harm in trying. - rmdir(dirname_.c_str()); + EXPECT_THAT(rmdir(dirname_.c_str()), SyscallSucceeds()); } std::string dirname_; }; -TEST_F(MkdirTest, DISABLED_CanCreateReadbleDir) { - ASSERT_THAT(mkdir(dirname_.c_str(), 0444), SyscallSucceeds()); - ASSERT_THAT( - open(JoinPath(dirname_, "anything").c_str(), O_RDWR | O_CREAT, 0666), - SyscallFailsWithErrno(EACCES)); -} - TEST_F(MkdirTest, CanCreateWritableDir) { ASSERT_THAT(mkdir(dirname_.c_str(), 0777), SyscallSucceeds()); std::string filename = JoinPath(dirname_, "anything"); @@ -84,10 +75,11 @@ TEST_F(MkdirTest, FailsOnDirWithoutWritePerms) { ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - auto parent = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0555)); - auto dir = JoinPath(parent.path(), "foo"); - ASSERT_THAT(mkdir(dir.c_str(), 0777), SyscallFailsWithErrno(EACCES)); + ASSERT_THAT(mkdir(dirname_.c_str(), 0555), SyscallSucceeds()); + auto dir = JoinPath(dirname_.c_str(), "foo"); + EXPECT_THAT(mkdir(dir.c_str(), 0777), SyscallFailsWithErrno(EACCES)); + EXPECT_THAT(open(JoinPath(dirname_, "file").c_str(), O_RDWR | O_CREAT, 0666), + SyscallFailsWithErrno(EACCES)); } } // namespace diff --git a/test/syscalls/linux/mknod.cc b/test/syscalls/linux/mknod.cc index 4c45766c7..05dfb375a 100644 --- a/test/syscalls/linux/mknod.cc +++ b/test/syscalls/linux/mknod.cc @@ -15,6 +15,7 @@ #include <errno.h> #include <fcntl.h> #include <sys/stat.h> +#include <sys/types.h> #include <sys/un.h> #include <unistd.h> @@ -39,7 +40,28 @@ TEST(MknodTest, RegularFile) { EXPECT_THAT(mknod(node1.c_str(), 0, 0), SyscallSucceeds()); } -TEST(MknodTest, MknodAtRegularFile) { +TEST(MknodTest, RegularFilePermissions) { + const std::string node = NewTempAbsPath(); + mode_t newUmask = 0077; + umask(newUmask); + + // Attempt to open file with mode 0777. Not specifying file type should create + // a regualar file. + mode_t perms = S_IRWXU | S_IRWXG | S_IRWXO; + EXPECT_THAT(mknod(node.c_str(), perms, 0), SyscallSucceeds()); + + // In the absence of a default ACL, the permissions of the created node are + // (mode & ~umask). -- mknod(2) + mode_t wantPerms = perms & ~newUmask; + struct stat st; + ASSERT_THAT(stat(node.c_str(), &st), SyscallSucceeds()); + ASSERT_EQ(st.st_mode & 0777, wantPerms); + + // "Zero file type is equivalent to type S_IFREG." - mknod(2) + ASSERT_EQ(st.st_mode & S_IFMT, S_IFREG); +} + +TEST(MknodTest, MknodAtFIFO) { const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const std::string fifo_relpath = NewTempRelPath(); const std::string fifo = JoinPath(dir.path(), fifo_relpath); @@ -72,7 +94,7 @@ TEST(MknodTest, MknodOnExistingPathFails) { TEST(MknodTest, UnimplementedTypesReturnError) { const std::string path = NewTempAbsPath(); - if (IsRunningOnGvisor()) { + if (IsRunningWithVFS1()) { ASSERT_THAT(mknod(path.c_str(), S_IFSOCK, 0), SyscallFailsWithErrno(EOPNOTSUPP)); } diff --git a/test/syscalls/linux/mlock.cc b/test/syscalls/linux/mlock.cc index 283c21ed3..78ac96bed 100644 --- a/test/syscalls/linux/mlock.cc +++ b/test/syscalls/linux/mlock.cc @@ -16,6 +16,7 @@ #include <sys/resource.h> #include <sys/syscall.h> #include <unistd.h> + #include <cerrno> #include <cstring> @@ -59,7 +60,6 @@ bool IsPageMlocked(uintptr_t addr) { return true; } - TEST(MlockTest, Basic) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock())); auto const mapping = ASSERT_NO_ERRNO_AND_VALUE( @@ -199,8 +199,10 @@ TEST(MunlockallTest, Basic) { } #ifndef SYS_mlock2 -#ifdef __x86_64__ +#if defined(__x86_64__) #define SYS_mlock2 325 +#elif defined(__aarch64__) +#define SYS_mlock2 284 #endif #endif diff --git a/test/syscalls/linux/mmap.cc b/test/syscalls/linux/mmap.cc index a112316e9..6d3227ab6 100644 --- a/test/syscalls/linux/mmap.cc +++ b/test/syscalls/linux/mmap.cc @@ -28,6 +28,7 @@ #include <sys/types.h> #include <sys/wait.h> #include <unistd.h> + #include <vector> #include "gmock/gmock.h" @@ -360,7 +361,7 @@ TEST_F(MMapTest, MapFixed) { } // 64-bit addresses work too -#ifdef __x86_64__ +#if defined(__x86_64__) || defined(__aarch64__) TEST_F(MMapTest, MapFixed64) { EXPECT_THAT(Map(0x300000000000, kPageSize, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS | MAP_FIXED, -1, 0), @@ -570,6 +571,12 @@ const uint8_t machine_code[] = { 0xb8, 0x2a, 0x00, 0x00, 0x00, // movl $42, %eax 0xc3, // retq }; +#elif defined(__aarch64__) +const uint8_t machine_code[] = { + 0x40, 0x05, 0x80, 0x52, // mov w0, #42 + 0xc0, 0x03, 0x5f, 0xd6, // ret +}; +#endif // PROT_EXEC allows code execution TEST_F(MMapTest, ProtExec) { @@ -604,7 +611,6 @@ TEST_F(MMapTest, NoProtExecDeath) { EXPECT_EXIT(func(), ::testing::KilledBySignal(SIGSEGV), ""); } -#endif TEST_F(MMapTest, NoExceedLimitData) { void* prevbrk; @@ -813,23 +819,27 @@ class MMapFileTest : public MMapTest { } }; +class MMapFileParamTest + : public MMapFileTest, + public ::testing::WithParamInterface<std::tuple<int, int>> { + protected: + int prot() const { return std::get<0>(GetParam()); } + + int flags() const { return std::get<1>(GetParam()); } +}; + // MAP_POPULATE allowed. // There isn't a good way to verify it actually did anything. -// -// FIXME(b/37222275): Parameterize. -TEST_F(MMapFileTest, MapPopulate) { - ASSERT_THAT( - Map(0, kPageSize, PROT_READ, MAP_PRIVATE | MAP_POPULATE, fd_.get(), 0), - SyscallSucceeds()); +TEST_P(MMapFileParamTest, MapPopulate) { + ASSERT_THAT(Map(0, kPageSize, prot(), flags() | MAP_POPULATE, fd_.get(), 0), + SyscallSucceeds()); } // MAP_POPULATE on a short file. -// -// FIXME(b/37222275): Parameterize. -TEST_F(MMapFileTest, MapPopulateShort) { - ASSERT_THAT(Map(0, 2 * kPageSize, PROT_READ, MAP_PRIVATE | MAP_POPULATE, - fd_.get(), 0), - SyscallSucceeds()); +TEST_P(MMapFileParamTest, MapPopulateShort) { + ASSERT_THAT( + Map(0, 2 * kPageSize, prot(), flags() | MAP_POPULATE, fd_.get(), 0), + SyscallSucceeds()); } // Read contents from mapped file. @@ -900,16 +910,6 @@ TEST_F(MMapFileTest, WritePrivateOnReadOnlyFd) { reinterpret_cast<volatile char*>(addr)); } -// MAP_PRIVATE PROT_READ is not allowed on write-only FDs. -TEST_F(MMapFileTest, ReadPrivateOnWriteOnlyFd) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(filename_, O_WRONLY)); - - uintptr_t addr; - EXPECT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_PRIVATE, fd.get(), 0), - SyscallFailsWithErrno(EACCES)); -} - // MAP_SHARED PROT_WRITE not allowed on read-only FDs. TEST_F(MMapFileTest, WriteSharedOnReadOnlyFd) { const FileDescriptor fd = @@ -921,28 +921,13 @@ TEST_F(MMapFileTest, WriteSharedOnReadOnlyFd) { SyscallFailsWithErrno(EACCES)); } -// MAP_SHARED PROT_READ not allowed on write-only FDs. -// -// FIXME(b/37222275): Parameterize. -TEST_F(MMapFileTest, ReadSharedOnWriteOnlyFd) { - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(filename_, O_WRONLY)); - - uintptr_t addr; - EXPECT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_SHARED, fd.get(), 0), - SyscallFailsWithErrno(EACCES)); -} - -// MAP_SHARED PROT_WRITE not allowed on write-only FDs. -// The FD must always be readable. -// -// FIXME(b/37222275): Parameterize. -TEST_F(MMapFileTest, WriteSharedOnWriteOnlyFd) { +// The FD must be readable. +TEST_P(MMapFileParamTest, WriteOnlyFd) { const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(filename_, O_WRONLY)); uintptr_t addr; - EXPECT_THAT(addr = Map(0, kPageSize, PROT_WRITE, MAP_SHARED, fd.get(), 0), + EXPECT_THAT(addr = Map(0, kPageSize, prot(), flags(), fd.get(), 0), SyscallFailsWithErrno(EACCES)); } @@ -1181,7 +1166,7 @@ TEST_F(MMapFileTest, ReadSharedTruncateDownThenUp) { ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0), SyscallSucceeds()); - // Check that the memory contains he file data. + // Check that the memory contains the file data. EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr), buf.c_str(), kPageSize)); // Truncate down, then up. @@ -1370,132 +1355,75 @@ TEST_F(MMapFileTest, WritePrivate) { EqualsMemory(std::string(len, '\0'))); } -// SIGBUS raised when writing past end of file to a private mapping. -// -// FIXME(b/37222275): Parameterize. -TEST_F(MMapFileTest, SigBusDeathWritePrivate) { - SetupGvisorDeathTest(); - - uintptr_t addr; - ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE, - fd_.get(), 0), - SyscallSucceeds()); - - // MMapFileTest makes a file kPageSize/2 long. The entire first page will be - // accessible. Write just beyond that. - size_t len = strlen(kFileContents); - EXPECT_EXIT(std::copy(kFileContents, kFileContents + len, - reinterpret_cast<volatile char*>(addr + kPageSize)), - ::testing::KilledBySignal(SIGBUS), ""); -} - -// SIGBUS raised when reading past end of file on a shared mapping. -// -// FIXME(b/37222275): Parameterize. -TEST_F(MMapFileTest, SigBusDeathReadShared) { +// SIGBUS raised when reading or writing past end of a mapped file. +TEST_P(MMapFileParamTest, SigBusDeath) { SetupGvisorDeathTest(); uintptr_t addr; - ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0), + ASSERT_THAT(addr = Map(0, 2 * kPageSize, prot(), flags(), fd_.get(), 0), SyscallSucceeds()); - // MMapFileTest makes a file kPageSize/2 long. The entire first page will be - // accessible. Read just beyond that. - std::vector<char> in(kPageSize); - EXPECT_EXIT( - std::copy(reinterpret_cast<volatile char*>(addr + kPageSize), - reinterpret_cast<volatile char*>(addr + kPageSize) + kPageSize, - in.data()), - ::testing::KilledBySignal(SIGBUS), ""); + auto* start = reinterpret_cast<volatile char*>(addr + kPageSize); + + // MMapFileTest makes a file kPageSize/2 long. The entire first page should be + // accessible, but anything beyond it should not. + if (prot() & PROT_WRITE) { + // Write beyond first page. + size_t len = strlen(kFileContents); + EXPECT_EXIT(std::copy(kFileContents, kFileContents + len, start), + ::testing::KilledBySignal(SIGBUS), ""); + } else { + // Read beyond first page. + std::vector<char> in(kPageSize); + EXPECT_EXIT(std::copy(start, start + kPageSize, in.data()), + ::testing::KilledBySignal(SIGBUS), ""); + } } -// SIGBUS raised when reading past end of file on a shared mapping. +// Tests that SIGBUS is not raised when reading or writing to a file-mapped +// page before EOF, even if part of the mapping extends beyond EOF. // -// FIXME(b/37222275): Parameterize. -TEST_F(MMapFileTest, SigBusDeathWriteShared) { - SetupGvisorDeathTest(); - +// See b/27877699. +TEST_P(MMapFileParamTest, NoSigBusOnPagesBeforeEOF) { uintptr_t addr; - ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, - fd_.get(), 0), - SyscallSucceeds()); - - // MMapFileTest makes a file kPageSize/2 long. The entire first page will be - // accessible. Write just beyond that. - size_t len = strlen(kFileContents); - EXPECT_EXIT(std::copy(kFileContents, kFileContents + len, - reinterpret_cast<volatile char*>(addr + kPageSize)), - ::testing::KilledBySignal(SIGBUS), ""); -} - -// Tests that SIGBUS is not raised when writing to a file-mapped page before -// EOF, even if part of the mapping extends beyond EOF. -TEST_F(MMapFileTest, NoSigBusOnPagesBeforeEOF) { - uintptr_t addr; - ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE, - fd_.get(), 0), + ASSERT_THAT(addr = Map(0, 2 * kPageSize, prot(), flags(), fd_.get(), 0), SyscallSucceeds()); // The test passes if this survives. - size_t len = strlen(kFileContents); - std::copy(kFileContents, kFileContents + len, - reinterpret_cast<volatile char*>(addr)); -} - -// Tests that SIGBUS is not raised when writing to a file-mapped page containing -// EOF, *after* the EOF for a private mapping. -TEST_F(MMapFileTest, NoSigBusOnPageContainingEOFWritePrivate) { - uintptr_t addr; - ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE, - fd_.get(), 0), - SyscallSucceeds()); - - // The test passes if this survives. (Technically addr+kPageSize/2 is already - // beyond EOF, but +1 to check for fencepost errors.) - size_t len = strlen(kFileContents); - std::copy(kFileContents, kFileContents + len, - reinterpret_cast<volatile char*>(addr + (kPageSize / 2) + 1)); -} - -// Tests that SIGBUS is not raised when reading from a file-mapped page -// containing EOF, *after* the EOF for a shared mapping. -// -// FIXME(b/37222275): Parameterize. -TEST_F(MMapFileTest, NoSigBusOnPageContainingEOFReadShared) { - uintptr_t addr; - ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0), - SyscallSucceeds()); - - // The test passes if this survives. (Technically addr+kPageSize/2 is already - // beyond EOF, but +1 to check for fencepost errors.) auto* start = reinterpret_cast<volatile char*>(addr + (kPageSize / 2) + 1); size_t len = strlen(kFileContents); - std::vector<char> in(len); - std::copy(start, start + len, in.data()); + if (prot() & PROT_WRITE) { + std::copy(kFileContents, kFileContents + len, start); + } else { + std::vector<char> in(len); + std::copy(start, start + len, in.data()); + } } -// Tests that SIGBUS is not raised when writing to a file-mapped page containing -// EOF, *after* the EOF for a shared mapping. -// -// FIXME(b/37222275): Parameterize. -TEST_F(MMapFileTest, NoSigBusOnPageContainingEOFWriteShared) { +// Tests that SIGBUS is not raised when reading or writing from a file-mapped +// page containing EOF, *after* the EOF. +TEST_P(MMapFileParamTest, NoSigBusOnPageContainingEOF) { uintptr_t addr; - ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, - fd_.get(), 0), + ASSERT_THAT(addr = Map(0, 2 * kPageSize, prot(), flags(), fd_.get(), 0), SyscallSucceeds()); // The test passes if this survives. (Technically addr+kPageSize/2 is already // beyond EOF, but +1 to check for fencepost errors.) + auto* start = reinterpret_cast<volatile char*>(addr + (kPageSize / 2) + 1); size_t len = strlen(kFileContents); - std::copy(kFileContents, kFileContents + len, - reinterpret_cast<volatile char*>(addr + (kPageSize / 2) + 1)); + if (prot() & PROT_WRITE) { + std::copy(kFileContents, kFileContents + len, start); + } else { + std::vector<char> in(len); + std::copy(start, start + len, in.data()); + } } // Tests that reading from writable shared file-mapped pages succeeds. // // On most platforms this is trivial, but when the file is mapped via the sentry // page cache (which does not yet support writing to shared mappings), a bug -// caused reads to fail unnecessarily on such mappings. +// caused reads to fail unnecessarily on such mappings. See b/28913513. TEST_F(MMapFileTest, ReadingWritableSharedFilePageSucceeds) { uintptr_t addr; size_t len = strlen(kFileContents); @@ -1512,7 +1440,7 @@ TEST_F(MMapFileTest, ReadingWritableSharedFilePageSucceeds) { // Tests that EFAULT is returned when invoking a syscall that requires the OS to // read past end of file (resulting in a fault in sentry context in the gVisor -// case). +// case). See b/28913513. TEST_F(MMapFileTest, InternalSigBus) { uintptr_t addr; ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE, @@ -1655,7 +1583,7 @@ TEST_F(MMapFileTest, Bug38498194) { } // Tests that reading from a file to a memory mapping of the same file does not -// deadlock. +// deadlock. See b/34813270. TEST_F(MMapFileTest, SelfRead) { uintptr_t addr; ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, @@ -1667,7 +1595,7 @@ TEST_F(MMapFileTest, SelfRead) { } // Tests that writing to a file from a memory mapping of the same file does not -// deadlock. +// deadlock. Regression test for b/34813270. TEST_F(MMapFileTest, SelfWrite) { uintptr_t addr; ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0), @@ -1721,6 +1649,7 @@ TEST(MMapNoFixtureTest, MapReadOnlyAfterCreateWriteOnly) { } // Conditional on MAP_32BIT. +// This flag is supported only on x86-64, for 64-bit programs. #ifdef __x86_64__ TEST(MMapNoFixtureTest, Map32Bit) { @@ -1732,6 +1661,15 @@ TEST(MMapNoFixtureTest, Map32Bit) { #endif // defined(__x86_64__) +INSTANTIATE_TEST_SUITE_P( + ReadWriteSharedPrivate, MMapFileParamTest, + ::testing::Combine(::testing::ValuesIn({ + PROT_READ, + PROT_WRITE, + PROT_READ | PROT_WRITE, + }), + ::testing::ValuesIn({MAP_SHARED, MAP_PRIVATE}))); + } // namespace } // namespace testing diff --git a/test/syscalls/linux/mount.cc b/test/syscalls/linux/mount.cc index e35be3cab..46b6f38db 100644 --- a/test/syscalls/linux/mount.cc +++ b/test/syscalls/linux/mount.cc @@ -18,6 +18,7 @@ #include <sys/mount.h> #include <sys/stat.h> #include <unistd.h> + #include <functional> #include <memory> #include <string> @@ -320,6 +321,42 @@ TEST(MountTest, RenameRemoveMountPoint) { ASSERT_THAT(rmdir(dir.path().c_str()), SyscallFailsWithErrno(EBUSY)); } +TEST(MountTest, MountFuseFilesystemNoDevice) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + SKIP_IF(IsRunningOnGvisor() && !IsFUSEEnabled()); + + auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + + // Before kernel version 4.16-rc6, FUSE mount is protected by + // capable(CAP_SYS_ADMIN). After this version, it uses + // ns_capable(CAP_SYS_ADMIN) to protect. Before the 4.16 kernel, it was not + // allowed to mount fuse file systems without the global CAP_SYS_ADMIN. + int res = mount("", dir.path().c_str(), "fuse", 0, ""); + SKIP_IF(!IsRunningOnGvisor() && res == -1 && errno == EPERM); + + EXPECT_THAT(mount("", dir.path().c_str(), "fuse", 0, ""), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(MountTest, MountFuseFilesystem) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + SKIP_IF(IsRunningOnGvisor() && !IsFUSEEnabled()); + + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/fuse", O_WRONLY)); + std::string mopts = "fd=" + std::to_string(fd.get()); + + auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + + // See comments in MountFuseFilesystemNoDevice for the reason why we skip + // EPERM when running on Linux. + int res = mount("", dir.path().c_str(), "fuse", 0, ""); + SKIP_IF(!IsRunningOnGvisor() && res == -1 && errno == EPERM); + + auto const mount = + ASSERT_NO_ERRNO_AND_VALUE(Mount("", dir.path(), "fuse", 0, mopts, 0)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/msync.cc b/test/syscalls/linux/msync.cc index ac7146017..2b2b6aef9 100644 --- a/test/syscalls/linux/msync.cc +++ b/test/syscalls/linux/msync.cc @@ -60,9 +60,7 @@ std::vector<std::function<PosixErrorOr<Mapping>()>> SyncableMappings() { for (int const mflags : {MAP_PRIVATE, MAP_SHARED}) { int const prot = PROT_READ | (writable ? PROT_WRITE : 0); int const oflags = O_CREAT | (writable ? O_RDWR : O_RDONLY); - funcs.push_back([=] { - return MmapAnon(kPageSize, prot, mflags); - }); + funcs.push_back([=] { return MmapAnon(kPageSize, prot, mflags); }); funcs.push_back([=]() -> PosixErrorOr<Mapping> { std::string const path = NewTempAbsPath(); ASSIGN_OR_RETURN_ERRNO(auto fd, Open(path, oflags, 0644)); diff --git a/test/syscalls/linux/network_namespace.cc b/test/syscalls/linux/network_namespace.cc new file mode 100644 index 000000000..133fdecf0 --- /dev/null +++ b/test/syscalls/linux/network_namespace.cc @@ -0,0 +1,52 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <net/if.h> +#include <sched.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/posix_error.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { +namespace { + +TEST(NetworkNamespaceTest, LoopbackExists) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + ScopedThread t([&] { + ASSERT_THAT(unshare(CLONE_NEWNET), SyscallSucceedsWithValue(0)); + + // TODO(gvisor.dev/issue/1833): Update this to test that only "lo" exists. + // Check loopback device exists. + int sock = socket(AF_INET, SOCK_DGRAM, 0); + ASSERT_THAT(sock, SyscallSucceeds()); + struct ifreq ifr; + strncpy(ifr.ifr_name, "lo", IFNAMSIZ); + EXPECT_THAT(ioctl(sock, SIOCGIFINDEX, &ifr), SyscallSucceeds()) + << "lo cannot be found"; + }); +} + +} // namespace +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/open.cc b/test/syscalls/linux/open.cc index 2b1df52ce..77f390f3c 100644 --- a/test/syscalls/linux/open.cc +++ b/test/syscalls/linux/open.cc @@ -27,6 +27,7 @@ #include "test/util/cleanup.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" +#include "test/util/posix_error.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -73,6 +74,60 @@ class OpenTest : public FileTest { const std::string test_data_ = "hello world\n"; }; +TEST_F(OpenTest, OTrunc) { + auto dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); + ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); + ASSERT_THAT(open(dirpath.c_str(), O_TRUNC, 0666), + SyscallFailsWithErrno(EISDIR)); +} + +TEST_F(OpenTest, OTruncAndReadOnlyDir) { + auto dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); + ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); + ASSERT_THAT(open(dirpath.c_str(), O_TRUNC | O_RDONLY, 0666), + SyscallFailsWithErrno(EISDIR)); +} + +TEST_F(OpenTest, OTruncAndReadOnlyFile) { + auto dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncfile"); + const FileDescriptor existing = + ASSERT_NO_ERRNO_AND_VALUE(Open(dirpath.c_str(), O_RDWR | O_CREAT, 0666)); + const FileDescriptor otrunc = ASSERT_NO_ERRNO_AND_VALUE( + Open(dirpath.c_str(), O_TRUNC | O_RDONLY, 0666)); +} + +TEST_F(OpenTest, OCreateDirectory) { + SKIP_IF(IsRunningWithVFS1()); + auto dirpath = GetAbsoluteTestTmpdir(); + + // Normal case: existing directory. + ASSERT_THAT(open(dirpath.c_str(), O_RDWR | O_CREAT, 0666), + SyscallFailsWithErrno(EISDIR)); + // Trailing separator on existing directory. + ASSERT_THAT(open(dirpath.append("/").c_str(), O_RDWR | O_CREAT, 0666), + SyscallFailsWithErrno(EISDIR)); + // Trailing separator on non-existing directory. + ASSERT_THAT(open(JoinPath(dirpath, "non-existent").append("/").c_str(), + O_RDWR | O_CREAT, 0666), + SyscallFailsWithErrno(EISDIR)); + // "." special case. + ASSERT_THAT(open(JoinPath(dirpath, ".").c_str(), O_RDWR | O_CREAT, 0666), + SyscallFailsWithErrno(EISDIR)); +} + +TEST_F(OpenTest, MustCreateExisting) { + auto dirPath = GetAbsoluteTestTmpdir(); + + // Existing directory. + ASSERT_THAT(open(dirPath.c_str(), O_RDWR | O_CREAT | O_EXCL, 0666), + SyscallFailsWithErrno(EEXIST)); + + // Existing file. + auto newFile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dirPath)); + ASSERT_THAT(open(newFile.path().c_str(), O_RDWR | O_CREAT | O_EXCL, 0666), + SyscallFailsWithErrno(EEXIST)); +} + TEST_F(OpenTest, ReadOnly) { char buf; const FileDescriptor ro_file = @@ -93,6 +148,26 @@ TEST_F(OpenTest, WriteOnly) { EXPECT_THAT(write(wo_file.get(), &buf, 1), SyscallSucceedsWithValue(1)); } +TEST_F(OpenTest, CreateWithAppend) { + std::string data = "text"; + std::string new_file = NewTempAbsPath(); + const FileDescriptor file = ASSERT_NO_ERRNO_AND_VALUE( + Open(new_file, O_WRONLY | O_APPEND | O_CREAT, 0666)); + EXPECT_THAT(write(file.get(), data.c_str(), data.size()), + SyscallSucceedsWithValue(data.size())); + EXPECT_THAT(lseek(file.get(), 0, SEEK_SET), SyscallSucceeds()); + EXPECT_THAT(write(file.get(), data.c_str(), data.size()), + SyscallSucceedsWithValue(data.size())); + + // Check that the size of the file is correct and that the offset has been + // incremented to that size. + struct stat s0; + EXPECT_THAT(fstat(file.get(), &s0), SyscallSucceeds()); + EXPECT_EQ(s0.st_size, 2 * data.size()); + EXPECT_THAT(lseek(file.get(), 0, SEEK_CUR), + SyscallSucceedsWithValue(2 * data.size())); +} + TEST_F(OpenTest, ReadWrite) { char buf; const FileDescriptor rw_file = @@ -164,6 +239,28 @@ TEST_F(OpenTest, OpenNoFollowStillFollowsLinksInPath) { ASSERT_NO_ERRNO_AND_VALUE(Open(path_via_symlink, O_RDONLY | O_NOFOLLOW)); } +// Test that open(2) can follow symlinks that point back to the same tree. +// Test sets up files as follows: +// root/child/symlink => redirects to ../.. +// root/child/target => regular file +// +// open("root/child/symlink/root/child/file") +TEST_F(OpenTest, SymlinkRecurse) { + auto root = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(GetAbsoluteTestTmpdir())); + auto child = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root.path())); + auto symlink = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(child.path(), "../..")); + auto target = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateFileWith(child.path(), "abc", 0644)); + auto path_via_symlink = + JoinPath(symlink.path(), Basename(root.path()), Basename(child.path()), + Basename(target.path())); + const auto contents = + ASSERT_NO_ERRNO_AND_VALUE(GetContents(path_via_symlink)); + ASSERT_EQ(contents, "abc"); +} + TEST_F(OpenTest, Fault) { char* totally_not_null = nullptr; ASSERT_THAT(open(totally_not_null, O_RDONLY), SyscallFailsWithErrno(EFAULT)); @@ -191,7 +288,7 @@ TEST_F(OpenTest, AppendOnly) { ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR | O_APPEND)); EXPECT_THAT(lseek(fd2.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0)); - // Then try to write to the first file and make sure the bytes are appended. + // Then try to write to the first fd and make sure the bytes are appended. EXPECT_THAT(WriteFd(fd1.get(), buf.data(), buf.size()), SyscallSucceedsWithValue(buf.size())); @@ -203,7 +300,7 @@ TEST_F(OpenTest, AppendOnly) { EXPECT_THAT(lseek(fd1.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(kBufSize * 2)); - // Then try to write to the second file and make sure the bytes are appended. + // Then try to write to the second fd and make sure the bytes are appended. EXPECT_THAT(WriteFd(fd2.get(), buf.data(), buf.size()), SyscallSucceedsWithValue(buf.size())); @@ -312,6 +409,13 @@ TEST_F(OpenTest, FileNotDirectory) { SyscallFailsWithErrno(ENOTDIR)); } +TEST_F(OpenTest, SymlinkDirectory) { + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + std::string link = NewTempAbsPath(); + ASSERT_THAT(symlink(dir.path().c_str(), link.c_str()), SyscallSucceeds()); + ASSERT_NO_ERRNO(Open(link, O_RDONLY | O_DIRECTORY)); +} + TEST_F(OpenTest, Null) { char c = '\0'; ASSERT_THAT(open(&c, O_RDONLY), SyscallFailsWithErrno(ENOENT)); @@ -372,6 +476,35 @@ TEST_F(OpenTest, CanTruncateWriteOnlyNoReadPermission_NoRandomSave) { EXPECT_EQ(stat.st_size, 0); } +TEST_F(OpenTest, CanTruncateWithStrangePermissions) { + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + const DisableSave ds; // Permissions are dropped. + std::string path = NewTempAbsPath(); + int fd; + // Create a file without user permissions. + EXPECT_THAT( // SAVE_BELOW + fd = open(path.c_str(), O_CREAT | O_TRUNC | O_WRONLY, 055), + SyscallSucceeds()); + EXPECT_THAT(close(fd), SyscallSucceeds()); + + // Cannot open file because we are owner and have no permissions set. + EXPECT_THAT(open(path.c_str(), O_RDONLY), SyscallFailsWithErrno(EACCES)); + + // We *can* chmod the file, because we are the owner. + EXPECT_THAT(chmod(path.c_str(), 0755), SyscallSucceeds()); + + // Now we can open the file again. + EXPECT_THAT(fd = open(path.c_str(), O_RDWR), SyscallSucceeds()); + EXPECT_THAT(close(fd), SyscallSucceeds()); +} + +TEST_F(OpenTest, OpenNonDirectoryWithTrailingSlash) { + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const std::string bad_path = file.path() + "/"; + EXPECT_THAT(open(bad_path.c_str(), O_RDONLY), SyscallFailsWithErrno(ENOTDIR)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc index e5a85ef9d..51eacf3f2 100644 --- a/test/syscalls/linux/open_create.cc +++ b/test/syscalls/linux/open_create.cc @@ -19,11 +19,11 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "test/syscalls/linux/temp_umask.h" #include "test/util/capability_util.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" #include "test/util/temp_path.h" +#include "test/util/temp_umask.h" #include "test/util/test_util.h" namespace gvisor { @@ -88,6 +88,30 @@ TEST(CreateTest, CreateExclusively) { SyscallFailsWithErrno(EEXIST)); } +TEST(CreateTeast, CreatWithOTrunc) { + std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); + ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); + ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC, 0666), + SyscallFailsWithErrno(EISDIR)); +} + +TEST(CreateTeast, CreatDirWithOTruncAndReadOnly) { + std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); + ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); + ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC | O_RDONLY, 0666), + SyscallFailsWithErrno(EISDIR)); +} + +TEST(CreateTeast, CreatFileWithOTruncAndReadOnly) { + std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncfile"); + int dirfd; + ASSERT_THAT(dirfd = open(dirpath.c_str(), O_RDWR | O_CREAT, 0666), + SyscallSucceeds()); + ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC | O_RDONLY, 0666), + SyscallSucceeds()); + ASSERT_THAT(close(dirfd), SyscallSucceeds()); +} + TEST(CreateTest, CreateFailsOnUnpermittedDir) { // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to // always override directory permissions. @@ -108,6 +132,7 @@ TEST(CreateTest, CreateFailsOnDirWithoutWritePerms) { } // A file originally created RW, but opened RO can later be opened RW. +// Regression test for b/65385065. TEST(CreateTest, OpenCreateROThenRW) { TempPath file(NewTempAbsPath()); diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc index 92ae55eec..861617ff7 100644 --- a/test/syscalls/linux/packet_socket.cc +++ b/test/syscalls/linux/packet_socket.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <arpa/inet.h> +#include <ifaddrs.h> #include <linux/capability.h> #include <linux/if_arp.h> #include <linux/if_packet.h> @@ -163,16 +164,11 @@ int CookedPacketTest::GetLoopbackIndex() { return ifr.ifr_ifindex; } -// Receive via a packet socket. -TEST_P(CookedPacketTest, Receive) { - // Let's use a simple IP payload: a UDP datagram. - FileDescriptor udp_sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); - SendUDPMessage(udp_sock.get()); - +// Receive and verify the message via packet socket on interface. +void ReceiveMessage(int sock, int ifindex) { // Wait for the socket to become readable. struct pollfd pfd = {}; - pfd.fd = socket_; + pfd.fd = sock; pfd.events = POLLIN; EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 2000), SyscallSucceedsWithValue(1)); @@ -182,20 +178,22 @@ TEST_P(CookedPacketTest, Receive) { char buf[64]; struct sockaddr_ll src = {}; socklen_t src_len = sizeof(src); - ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0, + ASSERT_THAT(recvfrom(sock, buf, sizeof(buf), 0, reinterpret_cast<struct sockaddr*>(&src), &src_len), SyscallSucceedsWithValue(packet_size)); + // sockaddr_ll ends with an 8 byte physical address field, but ethernet // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2 // here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns // sizeof(sockaddr_ll). ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2))); - // TODO(b/129292371): Verify protocol once we return it. + // TODO(gvisor.dev/issue/173): Verify protocol once we return it. // Verify the source address. EXPECT_EQ(src.sll_family, AF_PACKET); - EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex()); + EXPECT_EQ(src.sll_ifindex, ifindex); EXPECT_EQ(src.sll_halen, ETH_ALEN); + EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP); // This came from the loopback device, so the address is all 0s. for (int i = 0; i < src.sll_halen; i++) { EXPECT_EQ(src.sll_addr[i], 0); @@ -222,9 +220,21 @@ TEST_P(CookedPacketTest, Receive) { EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0); } +// Receive via a packet socket. +TEST_P(CookedPacketTest, Receive) { + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + SendUDPMessage(udp_sock.get()); + + // Receive and verify the data. + int loopback_index = GetLoopbackIndex(); + ReceiveMessage(socket_, loopback_index); +} + // Send via a packet socket. TEST_P(CookedPacketTest, Send) { - // TODO(b/129292371): Remove once we support packet socket writing. + // TODO(gvisor.dev/issue/173): Remove once we support packet socket writing. SKIP_IF(IsRunningOnGvisor()); // Let's send a UDP packet and receive it using a regular UDP socket. @@ -313,6 +323,230 @@ TEST_P(CookedPacketTest, Send) { EXPECT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK)); } +// Bind and receive via packet socket. +TEST_P(CookedPacketTest, BindReceive) { + struct sockaddr_ll bind_addr = {}; + bind_addr.sll_family = AF_PACKET; + bind_addr.sll_protocol = htons(GetParam()); + bind_addr.sll_ifindex = GetLoopbackIndex(); + + ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + SyscallSucceeds()); + + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + SendUDPMessage(udp_sock.get()); + + // Receive and verify the data. + ReceiveMessage(socket_, bind_addr.sll_ifindex); +} + +// Double Bind socket. +TEST_P(CookedPacketTest, DoubleBindSucceeds) { + struct sockaddr_ll bind_addr = {}; + bind_addr.sll_family = AF_PACKET; + bind_addr.sll_protocol = htons(GetParam()); + bind_addr.sll_ifindex = GetLoopbackIndex(); + + ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + SyscallSucceeds()); + + // Binding socket again should fail. + ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + // Linux 4.09 returns EINVAL here, but some time before 4.19 it + // switched to EADDRINUSE. + SyscallSucceeds()); +} + +// Bind and verify we do not receive data on interface which is not bound +TEST_P(CookedPacketTest, BindDrop) { + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + + struct ifaddrs* if_addr_list = nullptr; + auto cleanup = Cleanup([&if_addr_list]() { freeifaddrs(if_addr_list); }); + + ASSERT_THAT(getifaddrs(&if_addr_list), SyscallSucceeds()); + + // Get interface other than loopback. + struct ifreq ifr = {}; + for (struct ifaddrs* i = if_addr_list; i; i = i->ifa_next) { + if (strcmp(i->ifa_name, "lo") != 0) { + strncpy(ifr.ifr_name, i->ifa_name, sizeof(ifr.ifr_name)); + break; + } + } + + // Skip if no interface is available other than loopback. + if (strlen(ifr.ifr_name) == 0) { + GTEST_SKIP(); + } + + // Get interface index. + EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds()); + EXPECT_NE(ifr.ifr_ifindex, 0); + + // Bind to packet socket requires only family, protocol and ifindex. + struct sockaddr_ll bind_addr = {}; + bind_addr.sll_family = AF_PACKET; + bind_addr.sll_protocol = htons(GetParam()); + bind_addr.sll_ifindex = ifr.ifr_ifindex; + + ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + SyscallSucceeds()); + + // Send to loopback interface. + struct sockaddr_in dest = {}; + dest.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + dest.sin_family = AF_INET; + dest.sin_port = kPort; + EXPECT_THAT(sendto(udp_sock.get(), kMessage, sizeof(kMessage), 0, + reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), + SyscallSucceedsWithValue(sizeof(kMessage))); + + // Wait and make sure the socket never receives any data. + struct pollfd pfd = {}; + pfd.fd = socket_; + pfd.events = POLLIN; + EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0)); +} + +// Verify that we receive outbound packets. This test requires at least one +// non loopback interface so that we can actually capture an outgoing packet. +TEST_P(CookedPacketTest, ReceiveOutbound) { + // Only ETH_P_ALL sockets can receive outbound packets on linux. + SKIP_IF(GetParam() != ETH_P_ALL); + + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + + struct ifaddrs* if_addr_list = nullptr; + auto cleanup = Cleanup([&if_addr_list]() { freeifaddrs(if_addr_list); }); + + ASSERT_THAT(getifaddrs(&if_addr_list), SyscallSucceeds()); + + // Get interface other than loopback. + struct ifreq ifr = {}; + for (struct ifaddrs* i = if_addr_list; i; i = i->ifa_next) { + if (strcmp(i->ifa_name, "lo") != 0) { + strncpy(ifr.ifr_name, i->ifa_name, sizeof(ifr.ifr_name)); + break; + } + } + + // Skip if no interface is available other than loopback. + if (strlen(ifr.ifr_name) == 0) { + GTEST_SKIP(); + } + + // Get interface index and name. + EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds()); + EXPECT_NE(ifr.ifr_ifindex, 0); + int ifindex = ifr.ifr_ifindex; + + constexpr int kMACSize = 6; + char hwaddr[kMACSize]; + // Get interface address. + ASSERT_THAT(ioctl(socket_, SIOCGIFHWADDR, &ifr), SyscallSucceeds()); + ASSERT_THAT(ifr.ifr_hwaddr.sa_family, + AnyOf(Eq(ARPHRD_NONE), Eq(ARPHRD_ETHER))); + memcpy(hwaddr, ifr.ifr_hwaddr.sa_data, kMACSize); + + // Just send it to the google dns server 8.8.8.8. It's UDP we don't care + // if it actually gets to the DNS Server we just want to see that we receive + // it on our AF_PACKET socket. + // + // NOTE: We just want to pick an IP that is non-local to avoid having to + // handle ARP as this should cause the UDP packet to be sent to the default + // gateway configured for the system under test. Otherwise the only packet we + // will see is the ARP query unless we picked an IP which will actually + // resolve. The test is a bit brittle but this was the best compromise for + // now. + struct sockaddr_in dest = {}; + ASSERT_EQ(inet_pton(AF_INET, "8.8.8.8", &dest.sin_addr.s_addr), 1); + dest.sin_family = AF_INET; + dest.sin_port = kPort; + EXPECT_THAT(sendto(udp_sock.get(), kMessage, sizeof(kMessage), 0, + reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), + SyscallSucceedsWithValue(sizeof(kMessage))); + + // Wait and make sure the socket receives the data. + struct pollfd pfd = {}; + pfd.fd = socket_; + pfd.events = POLLIN; + EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(1)); + + // Now read and check that the packet is the one we just sent. + // Read and verify the data. + constexpr size_t packet_size = + sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage); + char buf[64]; + struct sockaddr_ll src = {}; + socklen_t src_len = sizeof(src); + ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0, + reinterpret_cast<struct sockaddr*>(&src), &src_len), + SyscallSucceedsWithValue(packet_size)); + + // sockaddr_ll ends with an 8 byte physical address field, but ethernet + // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2 + // here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns + // sizeof(sockaddr_ll). + ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2))); + + // Verify the source address. + EXPECT_EQ(src.sll_family, AF_PACKET); + EXPECT_EQ(src.sll_ifindex, ifindex); + EXPECT_EQ(src.sll_halen, ETH_ALEN); + EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP); + EXPECT_EQ(src.sll_pkttype, PACKET_OUTGOING); + // Verify the link address of the interface matches that of the non + // non loopback interface address we stored above. + for (int i = 0; i < src.sll_halen; i++) { + EXPECT_EQ(src.sll_addr[i], hwaddr[i]); + } + + // Verify the IP header. + struct iphdr ip = {}; + memcpy(&ip, buf, sizeof(ip)); + EXPECT_EQ(ip.ihl, 5); + EXPECT_EQ(ip.version, 4); + EXPECT_EQ(ip.tot_len, htons(packet_size)); + EXPECT_EQ(ip.protocol, IPPROTO_UDP); + EXPECT_EQ(ip.daddr, dest.sin_addr.s_addr); + EXPECT_NE(ip.saddr, htonl(INADDR_LOOPBACK)); + + // Verify the UDP header. + struct udphdr udp = {}; + memcpy(&udp, buf + sizeof(iphdr), sizeof(udp)); + EXPECT_EQ(udp.dest, kPort); + EXPECT_EQ(udp.len, htons(sizeof(udphdr) + sizeof(kMessage))); + + // Verify the payload. + char* payload = reinterpret_cast<char*>(buf + sizeof(iphdr) + sizeof(udphdr)); + EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0); +} + +// Bind with invalid address. +TEST_P(CookedPacketTest, BindFail) { + // Null address. + ASSERT_THAT( + bind(socket_, nullptr, sizeof(struct sockaddr)), + AnyOf(SyscallFailsWithErrno(EFAULT), SyscallFailsWithErrno(EINVAL))); + + // Address of size 1. + uint8_t addr = 0; + ASSERT_THAT( + bind(socket_, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), + SyscallFailsWithErrno(EINVAL)); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, CookedPacketTest, ::testing::Values(ETH_P_IP, ETH_P_ALL)); diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc index d258d353c..a11a03415 100644 --- a/test/syscalls/linux/packet_socket_raw.cc +++ b/test/syscalls/linux/packet_socket_raw.cc @@ -14,6 +14,9 @@ #include <arpa/inet.h> #include <linux/capability.h> +#ifndef __fuchsia__ +#include <linux/filter.h> +#endif // __fuchsia__ #include <linux/if_arp.h> #include <linux/if_packet.h> #include <net/ethernet.h> @@ -97,7 +100,7 @@ class RawPacketTest : public ::testing::TestWithParam<int> { int GetLoopbackIndex(); // The socket used for both reading and writing. - int socket_; + int s_; }; void RawPacketTest::SetUp() { @@ -108,34 +111,58 @@ void RawPacketTest::SetUp() { } if (!IsRunningOnGvisor()) { + // Ensure that looped back packets aren't rejected by the kernel. FileDescriptor acceptLocal = ASSERT_NO_ERRNO_AND_VALUE( - Open("/proc/sys/net/ipv4/conf/lo/accept_local", O_RDONLY)); + Open("/proc/sys/net/ipv4/conf/lo/accept_local", O_RDWR)); FileDescriptor routeLocalnet = ASSERT_NO_ERRNO_AND_VALUE( - Open("/proc/sys/net/ipv4/conf/lo/route_localnet", O_RDONLY)); + Open("/proc/sys/net/ipv4/conf/lo/route_localnet", O_RDWR)); char enabled; ASSERT_THAT(read(acceptLocal.get(), &enabled, 1), SyscallSucceedsWithValue(1)); - ASSERT_EQ(enabled, '1'); + if (enabled != '1') { + enabled = '1'; + ASSERT_THAT(lseek(acceptLocal.get(), 0, SEEK_SET), + SyscallSucceedsWithValue(0)); + ASSERT_THAT(write(acceptLocal.get(), &enabled, 1), + SyscallSucceedsWithValue(1)); + ASSERT_THAT(lseek(acceptLocal.get(), 0, SEEK_SET), + SyscallSucceedsWithValue(0)); + ASSERT_THAT(read(acceptLocal.get(), &enabled, 1), + SyscallSucceedsWithValue(1)); + ASSERT_EQ(enabled, '1'); + } + ASSERT_THAT(read(routeLocalnet.get(), &enabled, 1), SyscallSucceedsWithValue(1)); - ASSERT_EQ(enabled, '1'); + if (enabled != '1') { + enabled = '1'; + ASSERT_THAT(lseek(routeLocalnet.get(), 0, SEEK_SET), + SyscallSucceedsWithValue(0)); + ASSERT_THAT(write(routeLocalnet.get(), &enabled, 1), + SyscallSucceedsWithValue(1)); + ASSERT_THAT(lseek(routeLocalnet.get(), 0, SEEK_SET), + SyscallSucceedsWithValue(0)); + ASSERT_THAT(read(routeLocalnet.get(), &enabled, 1), + SyscallSucceedsWithValue(1)); + ASSERT_EQ(enabled, '1'); + } } - ASSERT_THAT(socket_ = socket(AF_PACKET, SOCK_RAW, htons(GetParam())), + ASSERT_THAT(s_ = socket(AF_PACKET, SOCK_RAW, htons(GetParam())), SyscallSucceeds()); } void RawPacketTest::TearDown() { // TearDown will be run even if we skip the test. if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { - EXPECT_THAT(close(socket_), SyscallSucceeds()); + EXPECT_THAT(close(s_), SyscallSucceeds()); } } int RawPacketTest::GetLoopbackIndex() { struct ifreq ifr; snprintf(ifr.ifr_name, IFNAMSIZ, "lo"); - EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds()); + EXPECT_THAT(ioctl(s_, SIOCGIFINDEX, &ifr), SyscallSucceeds()); EXPECT_NE(ifr.ifr_ifindex, 0); return ifr.ifr_ifindex; } @@ -149,7 +176,7 @@ TEST_P(RawPacketTest, Receive) { // Wait for the socket to become readable. struct pollfd pfd = {}; - pfd.fd = socket_; + pfd.fd = s_; pfd.events = POLLIN; EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 2000), SyscallSucceedsWithValue(1)); @@ -159,7 +186,7 @@ TEST_P(RawPacketTest, Receive) { char buf[64]; struct sockaddr_ll src = {}; socklen_t src_len = sizeof(src); - ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0, + ASSERT_THAT(recvfrom(s_, buf, sizeof(buf), 0, reinterpret_cast<struct sockaddr*>(&src), &src_len), SyscallSucceedsWithValue(packet_size)); // sockaddr_ll ends with an 8 byte physical address field, but ethernet @@ -168,11 +195,12 @@ TEST_P(RawPacketTest, Receive) { // sizeof(sockaddr_ll). ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2))); - // TODO(b/129292371): Verify protocol once we return it. + // TODO(gvisor.dev/issue/173): Verify protocol once we return it. // Verify the source address. EXPECT_EQ(src.sll_family, AF_PACKET); EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex()); EXPECT_EQ(src.sll_halen, ETH_ALEN); + EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP); // This came from the loopback device, so the address is all 0s. for (int i = 0; i < src.sll_halen; i++) { EXPECT_EQ(src.sll_addr[i], 0); @@ -212,7 +240,7 @@ TEST_P(RawPacketTest, Receive) { // Send via a packet socket. TEST_P(RawPacketTest, Send) { - // TODO(b/129292371): Remove once we support packet socket writing. + // TODO(gvisor.dev/issue/173): Remove once we support packet socket writing. SKIP_IF(IsRunningOnGvisor()); // Let's send a UDP packet and receive it using a regular UDP socket. @@ -277,7 +305,7 @@ TEST_P(RawPacketTest, Send) { sizeof(kMessage)); // Send it. - ASSERT_THAT(sendto(socket_, send_buf, sizeof(send_buf), 0, + ASSERT_THAT(sendto(s_, send_buf, sizeof(send_buf), 0, reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), SyscallSucceedsWithValue(sizeof(send_buf))); @@ -286,13 +314,13 @@ TEST_P(RawPacketTest, Send) { pfd.fd = udp_sock.get(); pfd.events = POLLIN; ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1)); - pfd.fd = socket_; + pfd.fd = s_; pfd.events = POLLIN; ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1)); // Receive on the packet socket. char recv_buf[sizeof(send_buf)]; - ASSERT_THAT(recv(socket_, recv_buf, sizeof(recv_buf), 0), + ASSERT_THAT(recv(s_, recv_buf, sizeof(recv_buf), 0), SyscallSucceedsWithValue(sizeof(recv_buf))); ASSERT_EQ(memcmp(recv_buf, send_buf, sizeof(send_buf)), 0); @@ -309,6 +337,318 @@ TEST_P(RawPacketTest, Send) { EXPECT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK)); } +// Check that setting SO_RCVBUF below min is clamped to the minimum +// receive buffer size. +TEST_P(RawPacketTest, SetSocketRecvBufBelowMin) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // Discover minimum receive buf size by trying to set it to zero. + // See: + // https://github.com/torvalds/linux/blob/a5dc8300df75e8b8384b4c82225f1e4a0b4d9b55/net/core/sock.c#L820 + constexpr int kRcvBufSz = 0; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)), + SyscallSucceeds()); + + int min = 0; + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len), + SyscallSucceeds()); + + // Linux doubles the value so let's use a value that when doubled will still + // be smaller than min. + int below_min = min / 2 - 1; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &below_min, sizeof(below_min)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &val, &val_len), + SyscallSucceeds()); + + ASSERT_EQ(min, val); +} + +// Check that setting SO_RCVBUF above max is clamped to the maximum +// receive buffer size. +TEST_P(RawPacketTest, SetSocketRecvBufAboveMax) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // Discover max buf size by trying to set the largest possible buffer size. + constexpr int kRcvBufSz = 0xffffffff; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)), + SyscallSucceeds()); + + int max = 0; + socklen_t max_len = sizeof(max); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &max, &max_len), + SyscallSucceeds()); + + int above_max = max + 1; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &above_max, sizeof(above_max)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &val, &val_len), + SyscallSucceeds()); + ASSERT_EQ(max, val); +} + +// Check that setting SO_RCVBUF min <= kRcvBufSz <= max is honored. +TEST_P(RawPacketTest, SetSocketRecvBuf) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int max = 0; + int min = 0; + { + // Discover max buf size by trying to set a really large buffer size. + constexpr int kRcvBufSz = 0xffffffff; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)), + SyscallSucceeds()); + + max = 0; + socklen_t max_len = sizeof(max); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &max, &max_len), + SyscallSucceeds()); + } + + { + // Discover minimum buffer size by trying to set a zero size receive buffer + // size. + // See: + // https://github.com/torvalds/linux/blob/a5dc8300df75e8b8384b4c82225f1e4a0b4d9b55/net/core/sock.c#L820 + constexpr int kRcvBufSz = 0; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)), + SyscallSucceeds()); + + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len), + SyscallSucceeds()); + } + + int quarter_sz = min + (max - min) / 4; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &quarter_sz, sizeof(quarter_sz)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &val, &val_len), + SyscallSucceeds()); + + // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF. + // TODO(gvisor.dev/issue/2926): Remove when Netstack matches linux behavior. + if (!IsRunningOnGvisor()) { + quarter_sz *= 2; + } + ASSERT_EQ(quarter_sz, val); +} + +// Check that setting SO_SNDBUF below min is clamped to the minimum +// receive buffer size. +TEST_P(RawPacketTest, SetSocketSendBufBelowMin) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // Discover minimum buffer size by trying to set it to zero. + constexpr int kSndBufSz = 0; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)), + SyscallSucceeds()); + + int min = 0; + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &min, &min_len), + SyscallSucceeds()); + + // Linux doubles the value so let's use a value that when doubled will still + // be smaller than min. + int below_min = min / 2 - 1; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &below_min, sizeof(below_min)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len), + SyscallSucceeds()); + + ASSERT_EQ(min, val); +} + +// Check that setting SO_SNDBUF above max is clamped to the maximum +// send buffer size. +TEST_P(RawPacketTest, SetSocketSendBufAboveMax) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // Discover maximum buffer size by trying to set it to a large value. + constexpr int kSndBufSz = 0xffffffff; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)), + SyscallSucceeds()); + + int max = 0; + socklen_t max_len = sizeof(max); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &max, &max_len), + SyscallSucceeds()); + + int above_max = max + 1; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &above_max, sizeof(above_max)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len), + SyscallSucceeds()); + ASSERT_EQ(max, val); +} + +// Check that setting SO_SNDBUF min <= kSndBufSz <= max is honored. +TEST_P(RawPacketTest, SetSocketSendBuf) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int max = 0; + int min = 0; + { + // Discover maximum buffer size by trying to set it to a large value. + constexpr int kSndBufSz = 0xffffffff; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)), + SyscallSucceeds()); + + max = 0; + socklen_t max_len = sizeof(max); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &max, &max_len), + SyscallSucceeds()); + } + + { + // Discover minimum buffer size by trying to set it to zero. + constexpr int kSndBufSz = 0; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)), + SyscallSucceeds()); + + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &min, &min_len), + SyscallSucceeds()); + } + + int quarter_sz = min + (max - min) / 4; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &quarter_sz, sizeof(quarter_sz)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len), + SyscallSucceeds()); + + // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF. + // TODO(gvisor.dev/issue/2926): Remove the gvisor special casing when Netstack + // matches linux behavior. + if (!IsRunningOnGvisor()) { + quarter_sz *= 2; + } + + ASSERT_EQ(quarter_sz, val); +} + +TEST_P(RawPacketTest, GetSocketError) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_ERROR, &val, &val_len), + SyscallSucceeds()); + ASSERT_EQ(val, 0); +} + +TEST_P(RawPacketTest, GetSocketErrorBind) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + { + // Bind to the loopback device. + struct sockaddr_ll bind_addr = {}; + bind_addr.sll_family = AF_PACKET; + bind_addr.sll_protocol = htons(GetParam()); + bind_addr.sll_ifindex = GetLoopbackIndex(); + + ASSERT_THAT(bind(s_, reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + SyscallSucceeds()); + + // SO_ERROR should return no errors. + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_ERROR, &val, &val_len), + SyscallSucceeds()); + ASSERT_EQ(val, 0); + } + + { + // Now try binding to an invalid interface. + struct sockaddr_ll bind_addr = {}; + bind_addr.sll_family = AF_PACKET; + bind_addr.sll_protocol = htons(GetParam()); + bind_addr.sll_ifindex = 0xffff; // Just pick a really large number. + + // Binding should fail with EINVAL + ASSERT_THAT(bind(s_, reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + SyscallFailsWithErrno(ENODEV)); + + // SO_ERROR does not return error when the device is invalid. + // On Linux there is just one odd ball condition where this can return + // an error where the device was valid and then removed or disabled + // between the first check for index and the actual registration of + // the packet endpoint. On Netstack this is not possible as the stack + // global mutex is held during registration and check. + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_ERROR, &val, &val_len), + SyscallSucceeds()); + ASSERT_EQ(val, 0); + } +} + +#ifndef __fuchsia__ + +TEST_P(RawPacketTest, SetSocketDetachFilterNoInstalledFilter) { + // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. + // + // gVisor returns no error on SO_DETACH_FILTER even if there is no filter + // attached unlike linux which does return ENOENT in such cases. This is + // because gVisor doesn't support SO_ATTACH_FILTER and just silently returns + // success. + if (IsRunningOnGvisor()) { + constexpr int val = 0; + ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallSucceeds()); + return; + } + constexpr int val = 0; + ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallFailsWithErrno(ENOENT)); +} + +TEST_P(RawPacketTest, GetSocketDetachFilter) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len), + SyscallFailsWithErrno(ENOPROTOOPT)); +} + +#endif // __fuchsia__ + INSTANTIATE_TEST_SUITE_P(AllInetTests, RawPacketTest, ::testing::Values(ETH_P_IP, ETH_P_ALL)); diff --git a/test/syscalls/linux/partial_bad_buffer.cc b/test/syscalls/linux/partial_bad_buffer.cc index 33822ee57..df7129acc 100644 --- a/test/syscalls/linux/partial_bad_buffer.cc +++ b/test/syscalls/linux/partial_bad_buffer.cc @@ -18,7 +18,9 @@ #include <netinet/tcp.h> #include <sys/mman.h> #include <sys/socket.h> +#include <sys/stat.h> #include <sys/syscall.h> +#include <sys/types.h> #include <sys/uio.h> #include <unistd.h> @@ -62,9 +64,9 @@ class PartialBadBufferTest : public ::testing::Test { // Write some initial data. size_t size = sizeof(kMessage) - 1; EXPECT_THAT(WriteFd(fd_, &kMessage, size), SyscallSucceedsWithValue(size)); - ASSERT_THAT(lseek(fd_, 0, SEEK_SET), SyscallSucceeds()); + // Map a useable buffer. addr_ = mmap(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); ASSERT_NE(addr_, MAP_FAILED); @@ -79,6 +81,15 @@ class PartialBadBufferTest : public ::testing::Test { bad_buffer_ = buf + kPageSize - 1; } + off_t Size() { + struct stat st; + int rc = fstat(fd_, &st); + if (rc < 0) { + return static_cast<off_t>(rc); + } + return st.st_size; + } + void TearDown() override { EXPECT_THAT(munmap(addr_, 2 * kPageSize), SyscallSucceeds()) << addr_; EXPECT_THAT(close(fd_), SyscallSucceeds()); @@ -165,97 +176,99 @@ TEST_F(PartialBadBufferTest, PreadvSmall) { } TEST_F(PartialBadBufferTest, WriteBig) { - // FIXME(b/24788078): The sentry write syscalls will return immediately - // if Access returns an error, but Access may not return an error - // and the sentry will instead perform a partial write. - SKIP_IF(IsRunningOnGvisor()); + off_t orig_size = Size(); + int n; - EXPECT_THAT(RetryEINTR(write)(fd_, bad_buffer_, kPageSize), - SyscallFailsWithErrno(EFAULT)); + ASSERT_THAT(lseek(fd_, orig_size, SEEK_SET), SyscallSucceeds()); + EXPECT_THAT( + (n = RetryEINTR(write)(fd_, bad_buffer_, kPageSize)), + AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1))); + EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0)); } TEST_F(PartialBadBufferTest, WriteSmall) { - // FIXME(b/24788078): The sentry write syscalls will return immediately - // if Access returns an error, but Access may not return an error - // and the sentry will instead perform a partial write. - SKIP_IF(IsRunningOnGvisor()); + off_t orig_size = Size(); + int n; - EXPECT_THAT(RetryEINTR(write)(fd_, bad_buffer_, 10), - SyscallFailsWithErrno(EFAULT)); + ASSERT_THAT(lseek(fd_, orig_size, SEEK_SET), SyscallSucceeds()); + EXPECT_THAT( + (n = RetryEINTR(write)(fd_, bad_buffer_, 10)), + AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1))); + EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0)); } TEST_F(PartialBadBufferTest, PwriteBig) { - // FIXME(b/24788078): The sentry write syscalls will return immediately - // if Access returns an error, but Access may not return an error - // and the sentry will instead perform a partial write. - SKIP_IF(IsRunningOnGvisor()); + off_t orig_size = Size(); + int n; - EXPECT_THAT(RetryEINTR(pwrite)(fd_, bad_buffer_, kPageSize, 0), - SyscallFailsWithErrno(EFAULT)); + EXPECT_THAT( + (n = RetryEINTR(pwrite)(fd_, bad_buffer_, kPageSize, orig_size)), + AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1))); + EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0)); } TEST_F(PartialBadBufferTest, PwriteSmall) { - // FIXME(b/24788078): The sentry write syscalls will return immediately - // if Access returns an error, but Access may not return an error - // and the sentry will instead perform a partial write. - SKIP_IF(IsRunningOnGvisor()); + off_t orig_size = Size(); + int n; - EXPECT_THAT(RetryEINTR(pwrite)(fd_, bad_buffer_, 10, 0), - SyscallFailsWithErrno(EFAULT)); + EXPECT_THAT( + (n = RetryEINTR(pwrite)(fd_, bad_buffer_, 10, orig_size)), + AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1))); + EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0)); } TEST_F(PartialBadBufferTest, WritevBig) { - // FIXME(b/24788078): The sentry write syscalls will return immediately - // if Access returns an error, but Access may not return an error - // and the sentry will instead perform a partial write. - SKIP_IF(IsRunningOnGvisor()); - struct iovec vec; vec.iov_base = bad_buffer_; vec.iov_len = kPageSize; + off_t orig_size = Size(); + int n; - EXPECT_THAT(RetryEINTR(writev)(fd_, &vec, 1), SyscallFailsWithErrno(EFAULT)); + ASSERT_THAT(lseek(fd_, orig_size, SEEK_SET), SyscallSucceeds()); + EXPECT_THAT( + (n = RetryEINTR(writev)(fd_, &vec, 1)), + AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1))); + EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0)); } TEST_F(PartialBadBufferTest, WritevSmall) { - // FIXME(b/24788078): The sentry write syscalls will return immediately - // if Access returns an error, but Access may not return an error - // and the sentry will instead perform a partial write. - SKIP_IF(IsRunningOnGvisor()); - struct iovec vec; vec.iov_base = bad_buffer_; vec.iov_len = 10; + off_t orig_size = Size(); + int n; - EXPECT_THAT(RetryEINTR(writev)(fd_, &vec, 1), SyscallFailsWithErrno(EFAULT)); + ASSERT_THAT(lseek(fd_, orig_size, SEEK_SET), SyscallSucceeds()); + EXPECT_THAT( + (n = RetryEINTR(writev)(fd_, &vec, 1)), + AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1))); + EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0)); } TEST_F(PartialBadBufferTest, PwritevBig) { - // FIXME(b/24788078): The sentry write syscalls will return immediately - // if Access returns an error, but Access may not return an error - // and the sentry will instead perform a partial write. - SKIP_IF(IsRunningOnGvisor()); - struct iovec vec; vec.iov_base = bad_buffer_; vec.iov_len = kPageSize; + off_t orig_size = Size(); + int n; - EXPECT_THAT(RetryEINTR(pwritev)(fd_, &vec, 1, 0), - SyscallFailsWithErrno(EFAULT)); + EXPECT_THAT( + (n = RetryEINTR(pwritev)(fd_, &vec, 1, orig_size)), + AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1))); + EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0)); } TEST_F(PartialBadBufferTest, PwritevSmall) { - // FIXME(b/24788078): The sentry write syscalls will return immediately - // if Access returns an error, but Access may not return an error - // and the sentry will instead perform a partial write. - SKIP_IF(IsRunningOnGvisor()); - struct iovec vec; vec.iov_base = bad_buffer_; vec.iov_len = 10; + off_t orig_size = Size(); + int n; - EXPECT_THAT(RetryEINTR(pwritev)(fd_, &vec, 1, 0), - SyscallFailsWithErrno(EFAULT)); + EXPECT_THAT( + (n = RetryEINTR(pwritev)(fd_, &vec, 1, orig_size)), + AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1))); + EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0)); } // getdents returns EFAULT when the you claim the buffer is large enough, but @@ -283,29 +296,6 @@ TEST_F(PartialBadBufferTest, GetdentsOneEntry) { SyscallSucceedsWithValue(Gt(0))); } -// Verify that when write returns EFAULT the kernel hasn't silently written -// the initial valid bytes. -TEST_F(PartialBadBufferTest, WriteEfaultIsntPartial) { - // FIXME(b/24788078): The sentry write syscalls will return immediately - // if Access returns an error, but Access may not return an error - // and the sentry will instead perform a partial write. - SKIP_IF(IsRunningOnGvisor()); - - bad_buffer_[0] = 'A'; - EXPECT_THAT(RetryEINTR(write)(fd_, bad_buffer_, 10), - SyscallFailsWithErrno(EFAULT)); - - size_t size = 255; - char buf[255]; - memset(buf, 0, size); - - EXPECT_THAT(RetryEINTR(pread)(fd_, buf, size, 0), - SyscallSucceedsWithValue(sizeof(kMessage) - 1)); - - // 'A' has not been written. - EXPECT_STREQ(buf, kMessage); -} - PosixErrorOr<sockaddr_storage> InetLoopbackAddr(int family) { struct sockaddr_storage addr; memset(&addr, 0, sizeof(addr)); diff --git a/test/syscalls/linux/ping_socket.cc b/test/syscalls/linux/ping_socket.cc new file mode 100644 index 000000000..a9bfdb37b --- /dev/null +++ b/test/syscalls/linux/ping_socket.cc @@ -0,0 +1,91 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <netinet/in.h> +#include <netinet/ip.h> +#include <netinet/ip_icmp.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +#include <vector> + +#include "gtest/gtest.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/save_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { +namespace { + +class PingSocket : public ::testing::Test { + protected: + // Creates a socket to be used in tests. + void SetUp() override; + + // Closes the socket created by SetUp(). + void TearDown() override; + + // The loopback address. + struct sockaddr_in addr_; +}; + +void PingSocket::SetUp() { + // On some hosts ping sockets are restricted to specific groups using the + // sysctl "ping_group_range". + int s = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP); + if (s < 0 && errno == EPERM) { + GTEST_SKIP(); + } + close(s); + + addr_ = {}; + // Just a random port as the destination port number is irrelevant for ping + // sockets. + addr_.sin_port = 12345; + addr_.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + addr_.sin_family = AF_INET; +} + +void PingSocket::TearDown() {} + +// Test ICMP port exhaustion returns EAGAIN. +// +// We disable both random/cooperative S/R for this test as it makes way too many +// syscalls. +TEST_F(PingSocket, ICMPPortExhaustion_NoRandomSave) { + DisableSave ds; + std::vector<FileDescriptor> sockets; + constexpr int kSockets = 65536; + addr_.sin_port = 0; + for (int i = 0; i < kSockets; i++) { + auto s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP)); + int ret = connect(s.get(), reinterpret_cast<struct sockaddr*>(&addr_), + sizeof(addr_)); + if (ret == 0) { + sockets.push_back(std::move(s)); + continue; + } + ASSERT_THAT(ret, SyscallFailsWithErrno(EAGAIN)); + break; + } +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc index c0b354e65..34291850d 100644 --- a/test/syscalls/linux/pipe.cc +++ b/test/syscalls/linux/pipe.cc @@ -25,6 +25,7 @@ #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/util/file_descriptor.h" +#include "test/util/fs_util.h" #include "test/util/posix_error.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -144,11 +145,10 @@ TEST_P(PipeTest, Flags) { if (IsNamedPipe()) { // May be stubbed to zero; define locally. - constexpr int kLargefile = 0100000; EXPECT_THAT(fcntl(rfd_.get(), F_GETFL), - SyscallSucceedsWithValue(kLargefile | O_RDONLY)); + SyscallSucceedsWithValue(kOLargeFile | O_RDONLY)); EXPECT_THAT(fcntl(wfd_.get(), F_GETFL), - SyscallSucceedsWithValue(kLargefile | O_WRONLY)); + SyscallSucceedsWithValue(kOLargeFile | O_WRONLY)); } else { EXPECT_THAT(fcntl(rfd_.get(), F_GETFL), SyscallSucceedsWithValue(O_RDONLY)); EXPECT_THAT(fcntl(wfd_.get(), F_GETFL), SyscallSucceedsWithValue(O_WRONLY)); @@ -212,6 +212,20 @@ TEST(Pipe2Test, BadOptions) { EXPECT_THAT(pipe2(fds, 0xDEAD), SyscallFailsWithErrno(EINVAL)); } +// Tests that opening named pipes with O_TRUNC shouldn't cause an error, but +// calls to (f)truncate should. +TEST(NamedPipeTest, Truncate) { + const std::string tmp_path = NewTempAbsPath(); + SKIP_IF(mkfifo(tmp_path.c_str(), 0644) != 0); + + ASSERT_THAT(open(tmp_path.c_str(), O_NONBLOCK | O_RDONLY), SyscallSucceeds()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(tmp_path.c_str(), O_RDWR | O_NONBLOCK | O_TRUNC)); + + ASSERT_THAT(truncate(tmp_path.c_str(), 0), SyscallFailsWithErrno(EINVAL)); + ASSERT_THAT(ftruncate(fd.get(), 0), SyscallFailsWithErrno(EINVAL)); +} + TEST_P(PipeTest, Seek) { SKIP_IF(!CreateBlocking()); @@ -251,6 +265,8 @@ TEST_P(PipeTest, OffsetCalls) { SyscallFailsWithErrno(ESPIPE)); struct iovec iov; + iov.iov_base = &buf; + iov.iov_len = sizeof(buf); EXPECT_THAT(preadv(wfd_.get(), &iov, 1, 0), SyscallFailsWithErrno(ESPIPE)); EXPECT_THAT(pwritev(rfd_.get(), &iov, 1, 0), SyscallFailsWithErrno(ESPIPE)); } @@ -615,11 +631,14 @@ INSTANTIATE_TEST_SUITE_P( "namednonblocking", [](int fds[2], bool* is_blocking, bool* is_namedpipe) { // Create a new file-based pipe (non-blocking). - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - ASSERT_THAT(unlink(file.path().c_str()), SyscallSucceeds()); - SKIP_IF(mkfifo(file.path().c_str(), 0644) != 0); - fds[0] = open(file.path().c_str(), O_NONBLOCK | O_RDONLY); - fds[1] = open(file.path().c_str(), O_NONBLOCK | O_WRONLY); + std::string path; + { + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + path = file.path(); + } + SKIP_IF(mkfifo(path.c_str(), 0644) != 0); + fds[0] = open(path.c_str(), O_NONBLOCK | O_RDONLY); + fds[1] = open(path.c_str(), O_NONBLOCK | O_WRONLY); MaybeSave(); *is_blocking = false; *is_namedpipe = true; @@ -629,13 +648,15 @@ INSTANTIATE_TEST_SUITE_P( "namedblocking", [](int fds[2], bool* is_blocking, bool* is_namedpipe) { // Create a new file-based pipe (blocking). - auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - ASSERT_THAT(unlink(file.path().c_str()), SyscallSucceeds()); - SKIP_IF(mkfifo(file.path().c_str(), 0644) != 0); - ScopedThread t([&file, &fds]() { - fds[1] = open(file.path().c_str(), O_WRONLY); - }); - fds[0] = open(file.path().c_str(), O_RDONLY); + std::string path; + { + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + path = file.path(); + } + SKIP_IF(mkfifo(path.c_str(), 0644) != 0); + ScopedThread t( + [&path, &fds]() { fds[1] = open(path.c_str(), O_WRONLY); }); + fds[0] = open(path.c_str(), O_RDONLY); t.Join(); MaybeSave(); *is_blocking = true; diff --git a/test/syscalls/linux/poll.cc b/test/syscalls/linux/poll.cc index 9e5aa7fd0..7a316427d 100644 --- a/test/syscalls/linux/poll.cc +++ b/test/syscalls/linux/poll.cc @@ -259,14 +259,14 @@ TEST_F(PollTest, Nfds) { TEST_PCHECK(getrlimit(RLIMIT_NOFILE, &rlim) == 0); // gVisor caps the number of FDs that epoll can use beyond RLIMIT_NOFILE. - constexpr rlim_t gVisorMax = 1048576; - if (rlim.rlim_cur > gVisorMax) { - rlim.rlim_cur = gVisorMax; + constexpr rlim_t maxFD = 4096; + if (rlim.rlim_cur > maxFD) { + rlim.rlim_cur = maxFD; TEST_PCHECK(setrlimit(RLIMIT_NOFILE, &rlim) == 0); } rlim_t max_fds = rlim.rlim_cur; - std::cout << "Using limit: " << max_fds; + std::cout << "Using limit: " << max_fds << std::endl; // Create an eventfd. Since its value is initially zero, it is writable. FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()); @@ -275,7 +275,8 @@ TEST_F(PollTest, Nfds) { // Each entry in the 'fds' array refers to the eventfd and polls for // "writable" events (events=POLLOUT). This essentially guarantees that the // poll() is a no-op and allows negative testing of the 'nfds' parameter. - std::vector<struct pollfd> fds(max_fds, {.fd = efd.get(), .events = POLLOUT}); + std::vector<struct pollfd> fds(max_fds + 1, + {.fd = efd.get(), .events = POLLOUT}); // Verify that 'nfds' up to RLIMIT_NOFILE are allowed. EXPECT_THAT(RetryEINTR(poll)(fds.data(), 1, 1), SyscallSucceedsWithValue(1)); diff --git a/test/syscalls/linux/prctl.cc b/test/syscalls/linux/prctl.cc index d07571a5f..04c5161f5 100644 --- a/test/syscalls/linux/prctl.cc +++ b/test/syscalls/linux/prctl.cc @@ -226,5 +226,5 @@ int main(int argc, char** argv) { prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0)); } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/prctl_setuid.cc b/test/syscalls/linux/prctl_setuid.cc index 30f0d75b3..c4e9cf528 100644 --- a/test/syscalls/linux/prctl_setuid.cc +++ b/test/syscalls/linux/prctl_setuid.cc @@ -264,5 +264,5 @@ int main(int argc, char** argv) { prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0); } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/pread64.cc b/test/syscalls/linux/pread64.cc index 2cecf2e5f..bcdbbb044 100644 --- a/test/syscalls/linux/pread64.cc +++ b/test/syscalls/linux/pread64.cc @@ -14,6 +14,7 @@ #include <errno.h> #include <fcntl.h> +#include <linux/unistd.h> #include <sys/mman.h> #include <sys/socket.h> #include <sys/types.h> @@ -118,6 +119,21 @@ TEST_F(Pread64Test, EndOfFile) { EXPECT_THAT(pread64(fd.get(), buf, 1024, 0), SyscallSucceedsWithValue(0)); } +int memfd_create(const std::string& name, unsigned int flags) { + return syscall(__NR_memfd_create, name.c_str(), flags); +} + +TEST_F(Pread64Test, Overflow) { + int f = memfd_create("negative", 0); + const FileDescriptor fd(f); + + EXPECT_THAT(ftruncate(fd.get(), 0x7fffffffffffffffull), SyscallSucceeds()); + + char buf[10]; + EXPECT_THAT(pread64(fd.get(), buf, sizeof(buf), 0x7fffffffffffffffull), + SyscallFailsWithErrno(EINVAL)); +} + TEST(Pread64TestNoTempFile, CantReadSocketPair_NoRandomSave) { int sock_fds[2]; EXPECT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, sock_fds), SyscallSucceeds()); diff --git a/test/syscalls/linux/preadv.cc b/test/syscalls/linux/preadv.cc index f7ea44054..5b0743fe9 100644 --- a/test/syscalls/linux/preadv.cc +++ b/test/syscalls/linux/preadv.cc @@ -37,6 +37,7 @@ namespace testing { namespace { +// Stress copy-on-write. Attempts to reproduce b/38430174. TEST(PreadvTest, MMConcurrencyStress) { // Fill a one-page file with zeroes (the contents don't really matter). const auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( diff --git a/test/syscalls/linux/preadv2.cc b/test/syscalls/linux/preadv2.cc index c9246367d..4a9acd7ae 100644 --- a/test/syscalls/linux/preadv2.cc +++ b/test/syscalls/linux/preadv2.cc @@ -35,6 +35,8 @@ namespace { #ifndef SYS_preadv2 #if defined(__x86_64__) #define SYS_preadv2 327 +#elif defined(__aarch64__) +#define SYS_preadv2 286 #else #error "Unknown architecture" #endif @@ -202,7 +204,7 @@ TEST(Preadv2Test, TestInvalidOffset) { iov[0].iov_len = 0; EXPECT_THAT(preadv2(fd.get(), iov.get(), /*iovcnt=*/1, /*offset=*/-8, - /*flags=*/RWF_HIPRI), + /*flags=*/0), SyscallFailsWithErrno(EINVAL)); } diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index e4c030bbb..d6b875dbf 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -37,6 +37,7 @@ #include <map> #include <memory> #include <ostream> +#include <regex> #include <string> #include <unordered_set> #include <utility> @@ -51,6 +52,7 @@ #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/util/capability_util.h" @@ -98,9 +100,39 @@ namespace { #define SUID_DUMP_ROOT 2 #endif /* SUID_DUMP_ROOT */ -// O_LARGEFILE as defined by Linux. glibc tries to be clever by setting it to 0 -// because "it isn't needed", even though Linux can return it via F_GETFL. -constexpr int kOLargeFile = 00100000; +#if defined(__x86_64__) || defined(__i386__) +// This list of "required" fields is taken from reading the file +// arch/x86/kernel/cpu/proc.c and seeing which fields will be unconditionally +// printed by the kernel. +static const char* required_fields[] = { + "processor", + "vendor_id", + "cpu family", + "model\t\t:", + "model name", + "stepping", + "cpu MHz", + "fpu\t\t:", + "fpu_exception", + "cpuid level", + "wp", + "bogomips", + "clflush size", + "cache_alignment", + "address sizes", + "power management", +}; +#elif __aarch64__ +// This list of "required" fields is taken from reading the file +// arch/arm64/kernel/cpuinfo.c and seeing which fields will be unconditionally +// printed by the kernel. +static const char* required_fields[] = { + "processor", "BogoMIPS", "Features", "CPU implementer", + "CPU architecture", "CPU variant", "CPU part", "CPU revision", +}; +#else +#error "Unknown architecture" +#endif // Takes the subprocess command line and pid. // If it returns !OK, WithSubprocess returns immediately. @@ -183,7 +215,8 @@ PosixError WithSubprocess(SubprocessCallback const& running, siginfo_t info; // Wait until the child process has exited (WEXITED flag) but don't // reap the child (WNOWAIT flag). - waitid(P_PID, child_pid, &info, WNOWAIT | WEXITED); + EXPECT_THAT(waitid(P_PID, child_pid, &info, WNOWAIT | WEXITED), + SyscallSucceeds()); if (zombied) { // Arg of "Z" refers to a Zombied Process. @@ -714,28 +747,6 @@ TEST(ProcCpuinfo, RequiredFieldsArePresent) { ASSERT_FALSE(proc_cpuinfo.empty()); std::vector<std::string> cpuinfo_fields = absl::StrSplit(proc_cpuinfo, '\n'); - // This list of "required" fields is taken from reading the file - // arch/x86/kernel/cpu/proc.c and seeing which fields will be unconditionally - // printed by the kernel. - static const char* required_fields[] = { - "processor", - "vendor_id", - "cpu family", - "model\t\t:", - "model name", - "stepping", - "cpu MHz", - "fpu\t\t:", - "fpu_exception", - "cpuid level", - "wp", - "bogomips", - "clflush size", - "cache_alignment", - "address sizes", - "power management", - }; - // Check that the usual fields are there. We don't really care about the // contents. for (const std::string& field : required_fields) { @@ -743,8 +754,53 @@ TEST(ProcCpuinfo, RequiredFieldsArePresent) { } } -TEST(ProcCpuinfo, DeniesWrite) { - EXPECT_THAT(open("/proc/cpuinfo", O_WRONLY), SyscallFailsWithErrno(EACCES)); +TEST(ProcCpuinfo, DeniesWriteNonRoot) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_FOWNER))); + + // Do setuid in a separate thread so that after finishing this test, the + // process can still open files the test harness created before starting this + // test. Otherwise, the files are created by root (UID before the test), but + // cannot be opened by the `uid` set below after the test. After calling + // setuid(non-zero-UID), there is no way to get root privileges back. + ScopedThread([&] { + // Use syscall instead of glibc setuid wrapper because we want this setuid + // call to only apply to this task. POSIX threads, however, require that all + // threads have the same UIDs, so using the setuid wrapper sets all threads' + // real UID. + // Also drops capabilities. + constexpr int kNobody = 65534; + EXPECT_THAT(syscall(SYS_setuid, kNobody), SyscallSucceeds()); + EXPECT_THAT(open("/proc/cpuinfo", O_WRONLY), SyscallFailsWithErrno(EACCES)); + // TODO(gvisor.dev/issue/1193): Properly support setting size attributes in + // kernfs. + if (!IsRunningOnGvisor() || IsRunningWithVFS1()) { + EXPECT_THAT(truncate("/proc/cpuinfo", 123), + SyscallFailsWithErrno(EACCES)); + } + }); +} + +// With root privileges, it is possible to open /proc/cpuinfo with write mode, +// but all write operations will return EIO. +TEST(ProcCpuinfo, DeniesWriteRoot) { + // VFS1 does not behave differently for root/non-root. + SKIP_IF(IsRunningWithVFS1()); + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_FOWNER))); + + int fd; + EXPECT_THAT(fd = open("/proc/cpuinfo", O_WRONLY), SyscallSucceeds()); + if (fd > 0) { + EXPECT_THAT(write(fd, "x", 1), SyscallFailsWithErrno(EIO)); + EXPECT_THAT(pwrite(fd, "x", 1, 123), SyscallFailsWithErrno(EIO)); + } + // TODO(gvisor.dev/issue/1193): Properly support setting size attributes in + // kernfs. + if (!IsRunningOnGvisor() || IsRunningWithVFS1()) { + if (fd > 0) { + EXPECT_THAT(ftruncate(fd, 123), SyscallFailsWithErrno(EIO)); + } + EXPECT_THAT(truncate("/proc/cpuinfo", 123), SyscallFailsWithErrno(EIO)); + } } // Sanity checks that uptime is present. @@ -983,7 +1039,7 @@ constexpr uint64_t kMappingSize = 100 << 20; // Tolerance on RSS comparisons to account for background thread mappings, // reclaimed pages, newly faulted pages, etc. -constexpr uint64_t kRSSTolerance = 5 << 20; +constexpr uint64_t kRSSTolerance = 10 << 20; // Capture RSS before and after an anonymous mapping with passed prot. void MapPopulateRSS(int prot, uint64_t* before, uint64_t* after) { @@ -1315,8 +1371,6 @@ TEST(ProcPidSymlink, SubprocessRunning) { SyscallSucceedsWithValue(sizeof(buf))); } -// FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux -// on proc files. TEST(ProcPidSymlink, SubprocessZombied) { ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); @@ -1326,7 +1380,7 @@ TEST(ProcPidSymlink, SubprocessZombied) { int want = EACCES; if (!IsRunningOnGvisor()) { auto version = ASSERT_NO_ERRNO_AND_VALUE(GetKernelVersion()); - if (version.major == 4 && version.minor > 3) { + if (version.major > 4 || (version.major == 4 && version.minor > 3)) { want = ENOENT; } } @@ -1339,24 +1393,25 @@ TEST(ProcPidSymlink, SubprocessZombied) { SyscallFailsWithErrno(want)); } - // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux - // on proc files. - // 4.17 & gVisor: Syscall succeeds and returns 1 - // EXPECT_THAT(ReadlinkWhileZombied("ns/pid", buf, sizeof(buf)), - // SyscallFailsWithErrno(EACCES)); + // FIXME(gvisor.dev/issue/164): Inconsistent behavior between linux on proc + // files. + // + // ~4.3: Syscall fails with EACCES. + // 4.17: Syscall succeeds and returns 1. + // + if (!IsRunningOnGvisor()) { + return; + } - // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux - // on proc files. - // 4.17 & gVisor: Syscall succeeds and returns 1. - // EXPECT_THAT(ReadlinkWhileZombied("ns/user", buf, sizeof(buf)), - // SyscallFailsWithErrno(EACCES)); + EXPECT_THAT(ReadlinkWhileZombied("ns/pid", buf, sizeof(buf)), + SyscallFailsWithErrno(want)); + + EXPECT_THAT(ReadlinkWhileZombied("ns/user", buf, sizeof(buf)), + SyscallFailsWithErrno(want)); } // Test whether /proc/PID/ symlinks can be read for an exited process. TEST(ProcPidSymlink, SubprocessExited) { - // FIXME(gvisor.dev/issue/164): These all succeed on gVisor. - SKIP_IF(IsRunningOnGvisor()); - char buf[1]; EXPECT_THAT(ReadlinkWhileExited("exe", buf, sizeof(buf)), @@ -1414,14 +1469,24 @@ TEST(ProcPidFile, SubprocessRunning) { EXPECT_THAT(ReadWhileRunning("uid_map", buf, sizeof(buf)), SyscallSucceedsWithValue(sizeof(buf))); + + EXPECT_THAT(ReadWhileRunning("oom_score", buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + + EXPECT_THAT(ReadWhileRunning("oom_score_adj", buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); } // Test whether /proc/PID/ files can be read for a zombie process. TEST(ProcPidFile, SubprocessZombie) { char buf[1]; - // 4.17: Succeeds and returns 1 - // gVisor: Succeeds and returns 0 + // FIXME(gvisor.dev/issue/164): Loosen requirement due to inconsistent + // behavior on different kernels. + // + // ~4.3: Succeds and returns 0. + // 4.17: Succeeds and returns 1. + // gVisor: Succeeds and returns 0. EXPECT_THAT(ReadWhileZombied("auxv", buf, sizeof(buf)), SyscallSucceeds()); EXPECT_THAT(ReadWhileZombied("cmdline", buf, sizeof(buf)), @@ -1445,9 +1510,18 @@ TEST(ProcPidFile, SubprocessZombie) { EXPECT_THAT(ReadWhileZombied("uid_map", buf, sizeof(buf)), SyscallSucceedsWithValue(sizeof(buf))); + EXPECT_THAT(ReadWhileZombied("oom_score", buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + + EXPECT_THAT(ReadWhileZombied("oom_score_adj", buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux // on proc files. + // + // ~4.3: Fails and returns EACCES. // gVisor & 4.17: Succeeds and returns 1. + // // EXPECT_THAT(ReadWhileZombied("io", buf, sizeof(buf)), // SyscallFailsWithErrno(EACCES)); } @@ -1456,9 +1530,12 @@ TEST(ProcPidFile, SubprocessZombie) { TEST(ProcPidFile, SubprocessExited) { char buf[1]; - // FIXME(gvisor.dev/issue/164): Inconsistent behavior between kernels + // FIXME(gvisor.dev/issue/164): Inconsistent behavior between kernels. + // + // ~4.3: Fails and returns ESRCH. // gVisor: Fails with ESRCH. // 4.17: Succeeds and returns 1. + // // EXPECT_THAT(ReadWhileExited("auxv", buf, sizeof(buf)), // SyscallFailsWithErrno(ESRCH)); @@ -1500,6 +1577,15 @@ TEST(ProcPidFile, SubprocessExited) { EXPECT_THAT(ReadWhileExited("uid_map", buf, sizeof(buf)), SyscallSucceedsWithValue(sizeof(buf))); + + if (!IsRunningOnGvisor()) { + // FIXME(gvisor.dev/issue/164): Succeeds on gVisor. + EXPECT_THAT(ReadWhileExited("oom_score", buf, sizeof(buf)), + SyscallFailsWithErrno(ESRCH)); + } + + EXPECT_THAT(ReadWhileExited("oom_score_adj", buf, sizeof(buf)), + SyscallFailsWithErrno(ESRCH)); } PosixError DirContainsImpl(absl::string_view path, @@ -1630,7 +1716,7 @@ TEST(ProcTask, KilledThreadsDisappear) { EXPECT_NO_ERRNO(DirContainsExactly("/proc/self/task", TaskFiles(initial, {child1.Tid()}))); - // Stat child1's task file. + // Stat child1's task file. Regression test for b/32097707. struct stat statbuf; const std::string child1_task_file = absl::StrCat("/proc/self/task/", child1.Tid()); @@ -1658,7 +1744,7 @@ TEST(ProcTask, KilledThreadsDisappear) { EXPECT_NO_ERRNO(EventuallyDirContainsExactly( "/proc/self/task", TaskFiles(initial, {child3.Tid(), child5.Tid()}))); - // Stat child1's task file again. This time it should fail. + // Stat child1's task file again. This time it should fail. See b/32097707. EXPECT_THAT(stat(child1_task_file.c_str(), &statbuf), SyscallFailsWithErrno(ENOENT)); @@ -1813,7 +1899,7 @@ TEST(ProcSysVmOvercommitMemory, HasNumericValue) { } // Check that link for proc fd entries point the target node, not the -// symlink itself. +// symlink itself. Regression test for b/31155070. TEST(ProcTaskFd, FstatatFollowsSymlink) { const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); const FileDescriptor fd = @@ -1872,6 +1958,20 @@ TEST(ProcMounts, IsSymlink) { EXPECT_EQ(link, "self/mounts"); } +TEST(ProcSelfMountinfo, RequiredFieldsArePresent) { + auto mountinfo = + ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/mountinfo")); + EXPECT_THAT( + mountinfo, + AllOf( + // Root mount. + ContainsRegex( + R"([0-9]+ [0-9]+ [0-9]+:[0-9]+ /\S* / (rw|ro).*- \S+ \S+ (rw|ro)\S*)"), + // Proc mount - always rw. + ContainsRegex( + R"([0-9]+ [0-9]+ [0-9]+:[0-9]+ / /proc rw.*- \S+ \S+ rw\S*)"))); +} + // Check that /proc/self/mounts looks something like a real mounts file. TEST(ProcSelfMounts, RequiredFieldsArePresent) { auto mounts = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/mounts")); @@ -1884,43 +1984,77 @@ TEST(ProcSelfMounts, RequiredFieldsArePresent) { } void CheckDuplicatesRecursively(std::string path) { - errno = 0; - DIR* dir = opendir(path.c_str()); - if (dir == nullptr) { - // Ignore any directories we can't read or missing directories as the - // directory could have been deleted/mutated from the time the parent - // directory contents were read. - return; - } - auto dir_closer = Cleanup([&dir]() { closedir(dir); }); - std::unordered_set<std::string> children; - while (true) { - // Readdir(3): If the end of the directory stream is reached, NULL is - // returned and errno is not changed. If an error occurs, NULL is returned - // and errno is set appropriately. To distinguish end of stream and from an - // error, set errno to zero before calling readdir() and then check the - // value of errno if NULL is returned. + std::vector<std::string> child_dirs; + + // There is the known issue of the linux procfs, that two consequent calls of + // readdir can return the same entry twice if between these calls one or more + // entries have been removed from this directory. + int max_attempts = 5; + for (int i = 0; i < max_attempts; i++) { + child_dirs.clear(); errno = 0; - struct dirent* dp = readdir(dir); - if (dp == nullptr) { - ASSERT_EQ(errno, 0) << path; - break; // We're done. + bool success = true; + DIR* dir = opendir(path.c_str()); + if (dir == nullptr) { + // Ignore any directories we can't read or missing directories as the + // directory could have been deleted/mutated from the time the parent + // directory contents were read. + return; } + auto dir_closer = Cleanup([&dir]() { closedir(dir); }); + std::unordered_set<std::string> children; + while (true) { + // Readdir(3): If the end of the directory stream is reached, NULL is + // returned and errno is not changed. If an error occurs, NULL is + // returned and errno is set appropriately. To distinguish end of stream + // and from an error, set errno to zero before calling readdir() and then + // check the value of errno if NULL is returned. + errno = 0; + struct dirent* dp = readdir(dir); + if (dp == nullptr) { + // Linux will return EINVAL when calling getdents on a /proc/tid/net + // file corresponding to a zombie task. + // See fs/proc/proc_net.c:proc_tgid_net_readdir(). + // + // We just ignore the directory in this case. + if (errno == EINVAL && absl::StartsWith(path, "/proc/") && + absl::EndsWith(path, "/net")) { + break; + } - if (strcmp(dp->d_name, ".") == 0 || strcmp(dp->d_name, "..") == 0) { - continue; + // Otherwise, no errors are allowed. + ASSERT_EQ(errno, 0) << path; + break; // We're done. + } + + const std::string name = dp->d_name; + + if (name == "." || name == "..") { + continue; + } + + // Ignore a duplicate entry if it isn't the last attempt. + if (i == max_attempts - 1) { + ASSERT_EQ(children.find(name), children.end()) + << absl::StrCat(path, "/", name); + } else if (children.find(name) != children.end()) { + std::cerr << "Duplicate entry: " << i << ":" + << absl::StrCat(path, "/", name) << std::endl; + success = false; + break; + } + children.insert(name); + + if (dp->d_type == DT_DIR) { + child_dirs.push_back(name); + } } - - ASSERT_EQ(children.find(std::string(dp->d_name)), children.end()) - << dp->d_name; - children.insert(std::string(dp->d_name)); - - ASSERT_NE(dp->d_type, DT_UNKNOWN); - - if (dp->d_type != DT_DIR) { - continue; + if (success) { + break; } - CheckDuplicatesRecursively(absl::StrCat(path, "/", dp->d_name)); + } + for (auto dname = child_dirs.begin(); dname != child_dirs.end(); dname++) { + CheckDuplicatesRecursively(absl::StrCat(path, "/", *dname)); } } @@ -1983,10 +2117,48 @@ TEST(Proc, GetdentsEnoent) { }, nullptr, nullptr)); char buf[1024]; - ASSERT_THAT(syscall(SYS_getdents, fd.get(), buf, sizeof(buf)), + ASSERT_THAT(syscall(SYS_getdents64, fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(ENOENT)); } +void CheckSyscwFromIOFile(const std::string& path, const std::string& regex) { + std::string output; + ASSERT_NO_ERRNO(GetContents(path, &output)); + ASSERT_THAT(output, ContainsRegex(absl::StrCat("syscw:\\s+", regex, "\n"))); +} + +// Checks that there is variable accounting of IO between threads/tasks. +TEST(Proc, PidTidIOAccounting) { + absl::Notification notification; + + // Run a thread with a bunch of writes. Check that io account records exactly + // the number of write calls. File open/close is there to prevent buffering. + ScopedThread writer([¬ification] { + const int num_writes = 100; + for (int i = 0; i < num_writes; i++) { + auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + ASSERT_NO_ERRNO(SetContents(path.path(), "a")); + } + notification.Notify(); + const std::string& writer_dir = + absl::StrCat("/proc/", getpid(), "/task/", gettid(), "/io"); + + CheckSyscwFromIOFile(writer_dir, std::to_string(num_writes)); + }); + + // Run a thread and do no writes. Check that no writes are recorded. + ScopedThread noop([¬ification] { + notification.WaitForNotification(); + const std::string& noop_dir = + absl::StrCat("/proc/", getpid(), "/task/", gettid(), "/io"); + + CheckSyscwFromIOFile(noop_dir, "0"); + }); + + writer.Join(); + noop.Join(); +} + } // namespace } // namespace testing } // namespace gvisor @@ -1997,5 +2169,5 @@ int main(int argc, char** argv) { } gvisor::testing::TestInit(&argc, &argv); - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc index 897cf4950..b9a5a99bd 100644 --- a/test/syscalls/linux/proc_net.cc +++ b/test/syscalls/linux/proc_net.cc @@ -20,8 +20,13 @@ #include <sys/syscall.h> #include <sys/types.h> +#include <vector> + #include "gtest/gtest.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "absl/time/clock.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/capability_util.h" @@ -33,6 +38,31 @@ namespace gvisor { namespace testing { namespace { +constexpr const char kProcNet[] = "/proc/net"; + +TEST(ProcNetSymlinkTarget, FileMode) { + struct stat s; + ASSERT_THAT(stat(kProcNet, &s), SyscallSucceeds()); + EXPECT_EQ(s.st_mode & S_IFMT, S_IFDIR); + EXPECT_EQ(s.st_mode & 0777, 0555); +} + +TEST(ProcNetSymlink, FileMode) { + struct stat s; + ASSERT_THAT(lstat(kProcNet, &s), SyscallSucceeds()); + EXPECT_EQ(s.st_mode & S_IFMT, S_IFLNK); + EXPECT_EQ(s.st_mode & 0777, 0777); +} + +TEST(ProcNetSymlink, Contents) { + char buf[40] = {}; + int n = readlink(kProcNet, buf, sizeof(buf)); + ASSERT_THAT(n, SyscallSucceeds()); + + buf[n] = 0; + EXPECT_STREQ(buf, "self/net"); +} + TEST(ProcNetIfInet6, Format) { auto ifinet6 = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/if_inet6")); EXPECT_THAT(ifinet6, @@ -67,9 +97,62 @@ TEST(ProcSysNetIpv4Sack, CanReadAndWrite) { EXPECT_EQ(buf, to_write); } +// DeviceEntry is an entry in /proc/net/dev +struct DeviceEntry { + std::string name; + uint64_t stats[16]; +}; + +PosixErrorOr<std::vector<DeviceEntry>> GetDeviceMetricsFromProc( + const std::string dev) { + std::vector<std::string> lines = absl::StrSplit(dev, '\n'); + std::vector<DeviceEntry> entries; + + // /proc/net/dev prints 2 lines of headers followed by a line of metrics for + // each network interface. + for (unsigned i = 2; i < lines.size(); i++) { + // Ignore empty lines. + if (lines[i].empty()) { + continue; + } + + std::vector<std::string> values = + absl::StrSplit(lines[i], ' ', absl::SkipWhitespace()); + + // Interface name + 16 values. + if (values.size() != 17) { + return PosixError(EINVAL, "invalid line: " + lines[i]); + } + + DeviceEntry entry; + entry.name = values[0]; + // Skip the interface name and read only the values. + for (unsigned j = 1; j < 17; j++) { + uint64_t num; + if (!absl::SimpleAtoi(values[j], &num)) { + return PosixError(EINVAL, "invalid value: " + values[j]); + } + entry.stats[j - 1] = num; + } + + entries.push_back(entry); + } + + return entries; +} + +// TEST(ProcNetDev, Format) tests that /proc/net/dev is parsable and +// contains at least one entry. +TEST(ProcNetDev, Format) { + auto dev = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/dev")); + auto entries = ASSERT_NO_ERRNO_AND_VALUE(GetDeviceMetricsFromProc(dev)); + + EXPECT_GT(entries.size(), 0); +} + PosixErrorOr<uint64_t> GetSNMPMetricFromProc(const std::string snmp, - const std::string &type, - const std::string &item) { + const std::string& type, + const std::string& item) { std::vector<std::string> snmp_vec = absl::StrSplit(snmp, '\n'); // /proc/net/snmp prints a line of headers followed by a line of metrics. @@ -127,7 +210,7 @@ TEST(ProcNetSnmp, TcpReset_NoRandomSave) { }; ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1); - ASSERT_THAT(connect(s.get(), (struct sockaddr *)&sin, sizeof(sin)), + ASSERT_THAT(connect(s.get(), (struct sockaddr*)&sin, sizeof(sin)), SyscallFailsWithErrno(ECONNREFUSED)); uint64_t newAttemptFails; @@ -172,19 +255,19 @@ TEST(ProcNetSnmp, TcpEstab_NoRandomSave) { }; ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1); - ASSERT_THAT(bind(s_listen.get(), (struct sockaddr *)&sin, sizeof(sin)), + ASSERT_THAT(bind(s_listen.get(), (struct sockaddr*)&sin, sizeof(sin)), SyscallSucceeds()); ASSERT_THAT(listen(s_listen.get(), 1), SyscallSucceeds()); // Get the port bound by the listening socket. socklen_t addrlen = sizeof(sin); ASSERT_THAT( - getsockname(s_listen.get(), reinterpret_cast<sockaddr *>(&sin), &addrlen), + getsockname(s_listen.get(), reinterpret_cast<sockaddr*>(&sin), &addrlen), SyscallSucceeds()); FileDescriptor s_connect = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, 0)); - ASSERT_THAT(connect(s_connect.get(), (struct sockaddr *)&sin, sizeof(sin)), + ASSERT_THAT(connect(s_connect.get(), (struct sockaddr*)&sin, sizeof(sin)), SyscallSucceeds()); auto s_accept = @@ -260,7 +343,7 @@ TEST(ProcNetSnmp, UdpNoPorts_NoRandomSave) { .sin_port = htons(4444), }; ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1); - ASSERT_THAT(sendto(s.get(), "a", 1, 0, (struct sockaddr *)&sin, sizeof(sin)), + ASSERT_THAT(sendto(s.get(), "a", 1, 0, (struct sockaddr*)&sin, sizeof(sin)), SyscallSucceedsWithValue(1)); uint64_t newOutDatagrams; @@ -275,7 +358,7 @@ TEST(ProcNetSnmp, UdpNoPorts_NoRandomSave) { EXPECT_EQ(oldNoPorts, newNoPorts - 1); } -TEST(ProcNetSnmp, UdpIn) { +TEST(ProcNetSnmp, UdpIn_NoRandomSave) { // TODO(gvisor.dev/issue/866): epsocket metrics are not savable. const DisableSave ds; @@ -295,18 +378,18 @@ TEST(ProcNetSnmp, UdpIn) { .sin_port = htons(0), }; ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1); - ASSERT_THAT(bind(server.get(), (struct sockaddr *)&sin, sizeof(sin)), + ASSERT_THAT(bind(server.get(), (struct sockaddr*)&sin, sizeof(sin)), SyscallSucceeds()); // Get the port bound by the server socket. socklen_t addrlen = sizeof(sin); ASSERT_THAT( - getsockname(server.get(), reinterpret_cast<sockaddr *>(&sin), &addrlen), + getsockname(server.get(), reinterpret_cast<sockaddr*>(&sin), &addrlen), SyscallSucceeds()); FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); ASSERT_THAT( - sendto(client.get(), "a", 1, 0, (struct sockaddr *)&sin, sizeof(sin)), + sendto(client.get(), "a", 1, 0, (struct sockaddr*)&sin, sizeof(sin)), SyscallSucceedsWithValue(1)); char buf[128]; @@ -326,6 +409,113 @@ TEST(ProcNetSnmp, UdpIn) { EXPECT_EQ(oldInDatagrams, newInDatagrams - 1); } +TEST(ProcNetSnmp, CheckNetStat) { + // TODO(b/155123175): SNMP and netstat don't work on gVisor. + SKIP_IF(IsRunningOnGvisor()); + + std::string contents = + ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/netstat")); + + int name_count = 0; + int value_count = 0; + std::vector<absl::string_view> lines = absl::StrSplit(contents, '\n'); + for (int i = 0; i + 1 < lines.size(); i += 2) { + std::vector<absl::string_view> names = + absl::StrSplit(lines[i], absl::ByAnyChar("\t ")); + std::vector<absl::string_view> values = + absl::StrSplit(lines[i + 1], absl::ByAnyChar("\t ")); + EXPECT_EQ(names.size(), values.size()) << " mismatch in lines '" << lines[i] + << "' and '" << lines[i + 1] << "'"; + for (int j = 0; j < names.size() && j < values.size(); ++j) { + if (names[j] == "TCPOrigDataSent" || names[j] == "TCPSynRetrans" || + names[j] == "TCPDSACKRecv" || names[j] == "TCPDSACKOfoRecv") { + ++name_count; + int64_t val; + if (absl::SimpleAtoi(values[j], &val)) { + ++value_count; + } + } + } + } + EXPECT_EQ(name_count, 4); + EXPECT_EQ(value_count, 4); +} + +TEST(ProcNetSnmp, Stat) { + struct stat st = {}; + ASSERT_THAT(stat("/proc/net/snmp", &st), SyscallSucceeds()); +} + +TEST(ProcNetSnmp, CheckSnmp) { + // TODO(b/155123175): SNMP and netstat don't work on gVisor. + SKIP_IF(IsRunningOnGvisor()); + + std::string contents = + ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + + int name_count = 0; + int value_count = 0; + std::vector<absl::string_view> lines = absl::StrSplit(contents, '\n'); + for (int i = 0; i + 1 < lines.size(); i += 2) { + std::vector<absl::string_view> names = + absl::StrSplit(lines[i], absl::ByAnyChar("\t ")); + std::vector<absl::string_view> values = + absl::StrSplit(lines[i + 1], absl::ByAnyChar("\t ")); + EXPECT_EQ(names.size(), values.size()) << " mismatch in lines '" << lines[i] + << "' and '" << lines[i + 1] << "'"; + for (int j = 0; j < names.size() && j < values.size(); ++j) { + if (names[j] == "RetransSegs") { + ++name_count; + int64_t val; + if (absl::SimpleAtoi(values[j], &val)) { + ++value_count; + } + } + } + } + EXPECT_EQ(name_count, 1); + EXPECT_EQ(value_count, 1); +} + +TEST(ProcSysNetIpv4Recovery, Exists) { + EXPECT_THAT(open("/proc/sys/net/ipv4/tcp_recovery", O_RDONLY), + SyscallSucceeds()); +} + +TEST(ProcSysNetIpv4Recovery, CanReadAndWrite) { + // TODO(b/162988252): Enable save/restore for this test after the bug is + // fixed. + DisableSave ds; + + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE)))); + + auto const fd = ASSERT_NO_ERRNO_AND_VALUE( + Open("/proc/sys/net/ipv4/tcp_recovery", O_RDWR)); + + char buf[10] = {'\0'}; + char to_write = '2'; + + // Check initial value is set to 1. + EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(to_write) + 1)); + EXPECT_EQ(strcmp(buf, "1\n"), 0); + + // Set tcp_recovery to one of the allowed constants. + EXPECT_THAT(PwriteFd(fd.get(), &to_write, sizeof(to_write), 0), + SyscallSucceedsWithValue(sizeof(to_write))); + EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(to_write) + 1)); + EXPECT_EQ(strcmp(buf, "2\n"), 0); + + // Set tcp_recovery to any random value. + char kMessage[] = "100"; + EXPECT_THAT(PwriteFd(fd.get(), kMessage, strlen(kMessage), 0), + SyscallSucceedsWithValue(strlen(kMessage))); + EXPECT_THAT(PreadFd(fd.get(), buf, sizeof(kMessage), 0), + SyscallSucceedsWithValue(sizeof(kMessage))); + EXPECT_EQ(strcmp(buf, "100\n"), 0); +} + TEST(ProcSysNetIpv4IpForward, Exists) { auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/sys/net/ipv4/ip_forward", O_RDWR)); diff --git a/test/syscalls/linux/proc_net_tcp.cc b/test/syscalls/linux/proc_net_tcp.cc index 2659f6a98..5b6e3e3cd 100644 --- a/test/syscalls/linux/proc_net_tcp.cc +++ b/test/syscalls/linux/proc_net_tcp.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <netinet/tcp.h> #include <sys/socket.h> #include <sys/stat.h> #include <sys/types.h> diff --git a/test/syscalls/linux/proc_net_udp.cc b/test/syscalls/linux/proc_net_udp.cc index f06f1a24b..786b4b4af 100644 --- a/test/syscalls/linux/proc_net_udp.cc +++ b/test/syscalls/linux/proc_net_udp.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <netinet/tcp.h> #include <sys/socket.h> #include <sys/stat.h> #include <sys/types.h> diff --git a/test/syscalls/linux/proc_net_unix.cc b/test/syscalls/linux/proc_net_unix.cc index 66db0acaa..a63067586 100644 --- a/test/syscalls/linux/proc_net_unix.cc +++ b/test/syscalls/linux/proc_net_unix.cc @@ -106,7 +106,7 @@ PosixErrorOr<std::vector<UnixEntry>> ProcNetUnixEntries() { std::vector<UnixEntry> entries; std::vector<std::string> lines = absl::StrSplit(content, '\n'); std::cerr << "<contents of /proc/net/unix>" << std::endl; - for (std::string line : lines) { + for (const std::string& line : lines) { // Emit the proc entry to the test output to provide context for the test // results. std::cerr << line << std::endl; @@ -374,7 +374,7 @@ TEST(ProcNetUnix, DgramSocketStateDisconnectingOnBind) { // corresponding entries, as they don't have an address yet. if (IsRunningOnGvisor()) { ASSERT_EQ(entries.size(), 2); - for (auto e : entries) { + for (const auto& e : entries) { ASSERT_EQ(e.state, SS_DISCONNECTING); } } @@ -403,7 +403,7 @@ TEST(ProcNetUnix, DgramSocketStateConnectingOnConnect) { // corresponding entries, as they don't have an address yet. if (IsRunningOnGvisor()) { ASSERT_EQ(entries.size(), 2); - for (auto e : entries) { + for (const auto& e : entries) { ASSERT_EQ(e.state, SS_DISCONNECTING); } } diff --git a/test/syscalls/linux/proc_pid_oomscore.cc b/test/syscalls/linux/proc_pid_oomscore.cc new file mode 100644 index 000000000..707821a3f --- /dev/null +++ b/test/syscalls/linux/proc_pid_oomscore.cc @@ -0,0 +1,72 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> + +#include <exception> +#include <iostream> +#include <string> + +#include "test/util/fs_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +PosixErrorOr<int> ReadProcNumber(std::string path) { + ASSIGN_OR_RETURN_ERRNO(std::string contents, GetContents(path)); + EXPECT_EQ(contents[contents.length() - 1], '\n'); + + int num; + if (!absl::SimpleAtoi(contents, &num)) { + return PosixError(EINVAL, "invalid value: " + contents); + } + + return num; +} + +TEST(ProcPidOomscoreTest, BasicRead) { + auto const oom_score = + ASSERT_NO_ERRNO_AND_VALUE(ReadProcNumber("/proc/self/oom_score")); + EXPECT_LE(oom_score, 1000); + EXPECT_GE(oom_score, -1000); +} + +TEST(ProcPidOomscoreAdjTest, BasicRead) { + auto const oom_score = + ASSERT_NO_ERRNO_AND_VALUE(ReadProcNumber("/proc/self/oom_score_adj")); + + // oom_score_adj defaults to 0. + EXPECT_EQ(oom_score, 0); +} + +TEST(ProcPidOomscoreAdjTest, BasicWrite) { + constexpr int test_value = 7; + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/oom_score_adj", O_WRONLY)); + ASSERT_THAT( + RetryEINTR(write)(fd.get(), std::to_string(test_value).c_str(), 1), + SyscallSucceeds()); + + auto const oom_score = + ASSERT_NO_ERRNO_AND_VALUE(ReadProcNumber("/proc/self/oom_score_adj")); + EXPECT_EQ(oom_score, test_value); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/proc_pid_smaps.cc b/test/syscalls/linux/proc_pid_smaps.cc index 7f2e8f203..9fb1b3a2c 100644 --- a/test/syscalls/linux/proc_pid_smaps.cc +++ b/test/syscalls/linux/proc_pid_smaps.cc @@ -173,7 +173,7 @@ PosixErrorOr<std::vector<ProcPidSmapsEntry>> ParseProcPidSmaps( return; } unknown_fields.insert(std::string(key)); - std::cerr << "skipping unknown smaps field " << key; + std::cerr << "skipping unknown smaps field " << key << std::endl; }; auto lines = absl::StrSplit(contents, '\n', absl::SkipEmpty()); @@ -191,7 +191,7 @@ PosixErrorOr<std::vector<ProcPidSmapsEntry>> ParseProcPidSmaps( // amount of whitespace). if (!entry) { std::cerr << "smaps line not considered a maps line: " - << maybe_maps_entry.error_message(); + << maybe_maps_entry.error_message() << std::endl; return PosixError( EINVAL, absl::StrCat("smaps field line without preceding maps line: ", l)); diff --git a/test/syscalls/linux/ptrace.cc b/test/syscalls/linux/ptrace.cc index 8f3800380..926690eb8 100644 --- a/test/syscalls/linux/ptrace.cc +++ b/test/syscalls/linux/ptrace.cc @@ -32,6 +32,7 @@ #include "absl/time/time.h" #include "test/util/logging.h" #include "test/util/multiprocess_util.h" +#include "test/util/platform_util.h" #include "test/util/signal_util.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -178,7 +179,8 @@ TEST(PtraceTest, GetSigMask) { // Install a signal handler for kBlockSignal to avoid termination and block // it. - TEST_PCHECK(signal(kBlockSignal, +[](int signo) {}) != SIG_ERR); + TEST_PCHECK(signal( + kBlockSignal, +[](int signo) {}) != SIG_ERR); MaybeSave(); TEST_PCHECK(sigprocmask(SIG_SETMASK, &blocked, nullptr) == 0); MaybeSave(); @@ -398,9 +400,11 @@ TEST(PtraceTest, GetRegSet) { // Read exactly the full register set. EXPECT_EQ(iov.iov_len, sizeof(regs)); -#ifdef __x86_64__ +#if defined(__x86_64__) // Child called kill(2), with SIGSTOP as arg 2. EXPECT_EQ(regs.rsi, SIGSTOP); +#elif defined(__aarch64__) + EXPECT_EQ(regs.regs[1], SIGSTOP); #endif // Suppress SIGSTOP and resume the child. @@ -750,15 +754,23 @@ TEST(PtraceTest, SyscallSucceeds()); EXPECT_TRUE(siginfo.si_code == SIGTRAP || siginfo.si_code == (SIGTRAP | 0x80)) << "si_code = " << siginfo.si_code; -#ifdef __x86_64__ + { struct user_regs_struct regs = {}; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child_pid, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child_pid, NT_PRSTATUS, &iov), + SyscallSucceeds()); +#if defined(__x86_64__) EXPECT_TRUE(regs.orig_rax == SYS_vfork || regs.orig_rax == SYS_clone) << "orig_rax = " << regs.orig_rax; EXPECT_EQ(grandchild_pid, regs.rax); - } +#elif defined(__aarch64__) + EXPECT_TRUE(regs.regs[8] == SYS_clone) << "regs[8] = " << regs.regs[8]; + EXPECT_EQ(grandchild_pid, regs.regs[0]); #endif // defined(__x86_64__) + } // After this point, the child will be making wait4 syscalls that will be // interrupted by saving, so saving is not permitted. Note that this is @@ -803,14 +815,21 @@ TEST(PtraceTest, SyscallSucceedsWithValue(child_pid)); EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == (SIGTRAP | 0x80)) << " status " << status; -#ifdef __x86_64__ { struct user_regs_struct regs = {}; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child_pid, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child_pid, NT_PRSTATUS, &iov), + SyscallSucceeds()); +#if defined(__x86_64__) EXPECT_EQ(SYS_wait4, regs.orig_rax); EXPECT_EQ(grandchild_pid, regs.rax); - } +#elif defined(__aarch64__) + EXPECT_EQ(SYS_wait4, regs.regs[8]); + EXPECT_EQ(grandchild_pid, regs.regs[0]); #endif // defined(__x86_64__) + } // Detach from the child and wait for it to exit. ASSERT_THAT(ptrace(PTRACE_DETACH, child_pid, 0, 0), SyscallSucceeds()); @@ -823,13 +842,8 @@ TEST(PtraceTest, // These tests requires knowledge of architecture-specific syscall convention. #ifdef __x86_64__ TEST(PtraceTest, Int3) { - switch (GvisorPlatform()) { - case Platform::kKVM: - // TODO(b/124248694): int3 isn't handled properly. - return; - default: - break; - } + SKIP_IF(PlatformSupportInt3() == PlatformSupport::NotSupported); + pid_t const child_pid = fork(); if (child_pid == 0) { // In child process. @@ -1191,7 +1205,7 @@ TEST(PtraceTest, SeizeSetOptions) { // gVisor is not susceptible to this race because // kernel.Task.waitCollectTraceeStopLocked() checks specifically for an // active ptraceStop, which is not initiated if SIGKILL is pending. - std::cout << "Observed syscall-exit after SIGKILL"; + std::cout << "Observed syscall-exit after SIGKILL" << std::endl; ASSERT_THAT(waitpid(child_pid, &status, 0), SyscallSucceedsWithValue(child_pid)); } @@ -1211,5 +1225,5 @@ int main(int argc, char** argv) { gvisor::testing::RunExecveChild(); } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc index 99a0df235..f9392b9e0 100644 --- a/test/syscalls/linux/pty.cc +++ b/test/syscalls/linux/pty.cc @@ -70,6 +70,8 @@ constexpr absl::Duration kTimeout = absl::Seconds(20); // The maximum line size in bytes returned per read from a pty file. constexpr int kMaxLineSize = 4096; +constexpr char kMasterPath[] = "/dev/ptmx"; + // glibc defines its own, different, version of struct termios. We care about // what the kernel does, not glibc. #define KERNEL_NCCS 19 @@ -362,6 +364,12 @@ PosixErrorOr<size_t> PollAndReadFd(int fd, void* buf, size_t count, ssize_t n = ReadFd(fd, static_cast<char*>(buf) + completed, count - completed); if (n < 0) { + if (errno == EAGAIN) { + // Linux sometimes returns EAGAIN from this read, despite the fact that + // poll returned success. Let's just do what do as we are told and try + // again. + continue; + } return PosixError(errno, "read failed"); } completed += n; @@ -376,9 +384,25 @@ PosixErrorOr<size_t> PollAndReadFd(int fd, void* buf, size_t count, return PosixError(ETIMEDOUT, "Poll timed out"); } +TEST(PtyTrunc, Truncate) { + // Opening PTYs with O_TRUNC shouldn't cause an error, but calls to + // (f)truncate should. + FileDescriptor master = + ASSERT_NO_ERRNO_AND_VALUE(Open(kMasterPath, O_RDWR | O_TRUNC)); + int n = ASSERT_NO_ERRNO_AND_VALUE(SlaveID(master)); + std::string spath = absl::StrCat("/dev/pts/", n); + FileDescriptor slave = + ASSERT_NO_ERRNO_AND_VALUE(Open(spath, O_RDWR | O_NONBLOCK | O_TRUNC)); + + EXPECT_THAT(truncate(kMasterPath, 0), SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(truncate(spath.c_str(), 0), SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(ftruncate(master.get(), 0), SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(ftruncate(slave.get(), 0), SyscallFailsWithErrno(EINVAL)); +} + TEST(BasicPtyTest, StatUnopenedMaster) { struct stat s; - ASSERT_THAT(stat("/dev/ptmx", &s), SyscallSucceeds()); + ASSERT_THAT(stat(kMasterPath, &s), SyscallSucceeds()); EXPECT_EQ(s.st_rdev, makedev(TTYAUX_MAJOR, kPtmxMinor)); EXPECT_EQ(s.st_size, 0); @@ -610,6 +634,11 @@ TEST_F(PtyTest, TermiosAffectsSlave) { // Verify this by setting ICRNL (which rewrites input \r to \n) and verify that // it has no effect on the master. TEST_F(PtyTest, MasterTermiosUnchangable) { + struct kernel_termios master_termios = {}; + EXPECT_THAT(ioctl(master_.get(), TCGETS, &master_termios), SyscallSucceeds()); + master_termios.c_lflag |= ICRNL; + EXPECT_THAT(ioctl(master_.get(), TCSETS, &master_termios), SyscallSucceeds()); + char c = '\r'; ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); @@ -1108,7 +1137,7 @@ TEST_F(PtyTest, SwitchTwiceMultiline) { std::string kExpected = "GO\nBLUE\n!"; // Write each line. - for (std::string input : kInputs) { + for (const std::string& input : kInputs) { ASSERT_THAT(WriteFd(master_.get(), input.c_str(), input.size()), SyscallSucceedsWithValue(input.size())); } diff --git a/test/syscalls/linux/pty_root.cc b/test/syscalls/linux/pty_root.cc index 14a4af980..1d7dbefdb 100644 --- a/test/syscalls/linux/pty_root.cc +++ b/test/syscalls/linux/pty_root.cc @@ -25,16 +25,26 @@ namespace gvisor { namespace testing { -// These tests should be run as root. namespace { +// StealTTY tests whether privileged processes can steal controlling terminals. +// If the stealing process has CAP_SYS_ADMIN in the root user namespace, the +// test ensures that stealing works. If it has non-root CAP_SYS_ADMIN, it +// ensures stealing fails. TEST(JobControlRootTest, StealTTY) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - // Make this a session leader, which also drops the controlling terminal. - // In the gVisor test environment, this test will be run as the session - // leader already (as the sentry init process). + bool true_root = true; if (!IsRunningOnGvisor()) { + // If running in Linux, we may only have CAP_SYS_ADMIN in a non-root user + // namespace (i.e. we are not truly root). We use init_module as a proxy for + // whether we are true root, as it returns EPERM immediately. + ASSERT_THAT(syscall(SYS_init_module, nullptr, 0, nullptr), SyscallFails()); + true_root = errno != EPERM; + + // Make this a session leader, which also drops the controlling terminal. + // In the gVisor test environment, this test will be run as the session + // leader already (as the sentry init process). ASSERT_THAT(setsid(), SyscallSucceeds()); } @@ -53,8 +63,8 @@ TEST(JobControlRootTest, StealTTY) { ASSERT_THAT(setsid(), SyscallSucceeds()); // We shouldn't be able to steal the terminal with the wrong arg value. TEST_PCHECK(ioctl(slave.get(), TIOCSCTTY, 0)); - // We should be able to steal it here. - TEST_PCHECK(!ioctl(slave.get(), TIOCSCTTY, 1)); + // We should be able to steal it if we are true root. + TEST_PCHECK(true_root == !ioctl(slave.get(), TIOCSCTTY, 1)); _exit(0); } diff --git a/test/syscalls/linux/pwrite64.cc b/test/syscalls/linux/pwrite64.cc index b48fe540d..e69794910 100644 --- a/test/syscalls/linux/pwrite64.cc +++ b/test/syscalls/linux/pwrite64.cc @@ -14,6 +14,7 @@ #include <errno.h> #include <fcntl.h> +#include <linux/unistd.h> #include <sys/socket.h> #include <sys/types.h> #include <unistd.h> @@ -27,14 +28,7 @@ namespace testing { namespace { -// This test is currently very rudimentary. -// -// TODO(edahlgren): -// * bad buffer states (EFAULT). -// * bad fds (wrong permission, wrong type of file, EBADF). -// * check offset is not incremented. -// * check for EOF. -// * writing to pipes, symlinks, special files. +// TODO(gvisor.dev/issue/2370): This test is currently very rudimentary. class Pwrite64 : public ::testing::Test { void SetUp() override { name_ = NewTempAbsPath(); @@ -72,6 +66,17 @@ TEST_F(Pwrite64, InvalidArgs) { EXPECT_THAT(close(fd), SyscallSucceeds()); } +TEST_F(Pwrite64, Overflow) { + int fd; + ASSERT_THAT(fd = open(name_.c_str(), O_APPEND | O_RDWR), SyscallSucceeds()); + constexpr int64_t kBufSize = 1024; + std::vector<char> buf(kBufSize); + std::fill(buf.begin(), buf.end(), 'a'); + EXPECT_THAT(PwriteFd(fd, buf.data(), buf.size(), 0x7fffffffffffffffull), + SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(close(fd), SyscallSucceeds()); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/pwritev2.cc b/test/syscalls/linux/pwritev2.cc index 1dbc0d6df..63b686c62 100644 --- a/test/syscalls/linux/pwritev2.cc +++ b/test/syscalls/linux/pwritev2.cc @@ -34,6 +34,8 @@ namespace { #ifndef SYS_pwritev2 #if defined(__x86_64__) #define SYS_pwritev2 328 +#elif defined(__aarch64__) +#define SYS_pwritev2 287 #else #error "Unknown architecture" #endif @@ -67,7 +69,7 @@ ssize_t pwritev2(unsigned long fd, const struct iovec* iov, } // This test is the base case where we call pwritev (no offset, no flags). -TEST(Writev2Test, TestBaseCall) { +TEST(Writev2Test, BaseCall) { SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( @@ -95,7 +97,7 @@ TEST(Writev2Test, TestBaseCall) { } // This test is where we call pwritev2 with a positive offset and no flags. -TEST(Pwritev2Test, TestValidPositiveOffset) { +TEST(Pwritev2Test, ValidPositiveOffset) { SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); std::string prefix(kBufSize, '0'); @@ -127,7 +129,7 @@ TEST(Pwritev2Test, TestValidPositiveOffset) { // This test is the base case where we call writev by using -1 as the offset. // The write should use the file offset, so the test increments the file offset // prior to call pwritev2. -TEST(Pwritev2Test, TestNegativeOneOffset) { +TEST(Pwritev2Test, NegativeOneOffset) { SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); const std::string prefix = "00"; @@ -162,7 +164,7 @@ TEST(Pwritev2Test, TestNegativeOneOffset) { // pwritev2 requires if the RWF_HIPRI flag is passed, the fd must be opened with // O_DIRECT. This test implements a correct call with the RWF_HIPRI flag. -TEST(Pwritev2Test, TestCallWithRWF_HIPRI) { +TEST(Pwritev2Test, CallWithRWF_HIPRI) { SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( @@ -187,47 +189,8 @@ TEST(Pwritev2Test, TestCallWithRWF_HIPRI) { EXPECT_EQ(buf, content); } -// This test checks that pwritev2 can be called with valid flags -TEST(Pwritev2Test, TestCallWithValidFlags) { - SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); - - const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); - - std::vector<char> content(kBufSize, '0'); - struct iovec iov; - iov.iov_base = content.data(); - iov.iov_len = content.size(); - - EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1, - /*offset=*/0, /*flags=*/RWF_DSYNC), - SyscallSucceedsWithValue(kBufSize)); - - std::vector<char> buf(content.size()); - EXPECT_THAT(read(fd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(buf.size())); - - EXPECT_EQ(buf, content); - - SetContent(content); - - EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1, - /*offset=*/0, /*flags=*/0x4), - SyscallSucceedsWithValue(kBufSize)); - - ASSERT_THAT(lseek(fd.get(), 0, SEEK_CUR), - SyscallSucceedsWithValue(content.size())); - - EXPECT_THAT(pread(fd.get(), buf.data(), buf.size(), /*offset=*/0), - SyscallSucceedsWithValue(buf.size())); - - EXPECT_EQ(buf, content); -} - // This test calls pwritev2 with a bad file descriptor. -TEST(Writev2Test, TestBadFile) { +TEST(Writev2Test, BadFile) { SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); ASSERT_THAT(pwritev2(/*fd=*/-1, /*iov=*/nullptr, /*iovcnt=*/0, /*offset=*/0, /*flags=*/0), @@ -235,7 +198,7 @@ TEST(Writev2Test, TestBadFile) { } // This test calls pwrite2 with an invalid offset. -TEST(Pwritev2Test, TestInvalidOffset) { +TEST(Pwritev2Test, InvalidOffset) { SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( @@ -253,7 +216,7 @@ TEST(Pwritev2Test, TestInvalidOffset) { SyscallFailsWithErrno(EINVAL)); } -TEST(Pwritev2Test, TestUnseekableFileValid) { +TEST(Pwritev2Test, UnseekableFileValid) { SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); int pipe_fds[2]; @@ -283,7 +246,7 @@ TEST(Pwritev2Test, TestUnseekableFileValid) { // Calling pwritev2 with a non-negative offset calls pwritev. Calling pwritev // with an unseekable file is not allowed. A pipe is used for an unseekable // file. -TEST(Pwritev2Test, TestUnseekableFileInValid) { +TEST(Pwritev2Test, UnseekableFileInvalid) { SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); int pipe_fds[2]; @@ -302,7 +265,7 @@ TEST(Pwritev2Test, TestUnseekableFileInValid) { EXPECT_THAT(close(pipe_fds[1]), SyscallSucceeds()); } -TEST(Pwritev2Test, TestReadOnlyFile) { +TEST(Pwritev2Test, ReadOnlyFile) { SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( @@ -321,7 +284,7 @@ TEST(Pwritev2Test, TestReadOnlyFile) { } // This test calls pwritev2 with an invalid flag. -TEST(Pwritev2Test, TestInvalidFlag) { +TEST(Pwritev2Test, InvalidFlag) { SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( diff --git a/test/syscalls/linux/raw_socket.cc b/test/syscalls/linux/raw_socket.cc new file mode 100644 index 000000000..8d6e5c913 --- /dev/null +++ b/test/syscalls/linux/raw_socket.cc @@ -0,0 +1,869 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <linux/capability.h> +#ifndef __fuchsia__ +#include <linux/filter.h> +#endif // __fuchsia__ +#include <netinet/in.h> +#include <netinet/ip.h> +#include <netinet/ip6.h> +#include <netinet/ip_icmp.h> +#include <poll.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +#include <algorithm> + +#include "gtest/gtest.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/syscalls/linux/unix_domain_socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" + +// Note: in order to run these tests, /proc/sys/net/ipv4/ping_group_range will +// need to be configured to let the superuser create ping sockets (see icmp(7)). + +namespace gvisor { +namespace testing { + +namespace { + +// Fixture for tests parameterized by protocol. +class RawSocketTest : public ::testing::TestWithParam<std::tuple<int, int>> { + protected: + // Creates a socket to be used in tests. + void SetUp() override; + + // Closes the socket created by SetUp(). + void TearDown() override; + + // Sends buf via s_. + void SendBuf(const char* buf, int buf_len); + + // Reads from s_ into recv_buf. + void ReceiveBuf(char* recv_buf, size_t recv_buf_len); + + void ReceiveBufFrom(int sock, char* recv_buf, size_t recv_buf_len); + + int Protocol() { return std::get<0>(GetParam()); } + + int Family() { return std::get<1>(GetParam()); } + + socklen_t AddrLen() { + if (Family() == AF_INET) { + return sizeof(sockaddr_in); + } + return sizeof(sockaddr_in6); + } + + int HdrLen() { + if (Family() == AF_INET) { + return sizeof(struct iphdr); + } + // IPv6 raw sockets don't include the header. + return 0; + } + + // The socket used for both reading and writing. + int s_; + + // The loopback address. + struct sockaddr_storage addr_; +}; + +void RawSocketTest::SetUp() { + if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { + ASSERT_THAT(socket(Family(), SOCK_RAW, Protocol()), + SyscallFailsWithErrno(EPERM)); + GTEST_SKIP(); + } + + ASSERT_THAT(s_ = socket(Family(), SOCK_RAW, Protocol()), SyscallSucceeds()); + + addr_ = {}; + + // We don't set ports because raw sockets don't have a notion of ports. + if (Family() == AF_INET) { + struct sockaddr_in* sin = reinterpret_cast<struct sockaddr_in*>(&addr_); + sin->sin_family = AF_INET; + sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK); + } else { + struct sockaddr_in6* sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr_); + sin6->sin6_family = AF_INET6; + sin6->sin6_addr = in6addr_loopback; + } +} + +void RawSocketTest::TearDown() { + // TearDown will be run even if we skip the test. + if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { + EXPECT_THAT(close(s_), SyscallSucceeds()); + } +} + +// We should be able to create multiple raw sockets for the same protocol. +// BasicRawSocket::Setup creates the first one, so we only have to create one +// more here. +TEST_P(RawSocketTest, MultipleCreation) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int s2; + ASSERT_THAT(s2 = socket(Family(), SOCK_RAW, Protocol()), SyscallSucceeds()); + + ASSERT_THAT(close(s2), SyscallSucceeds()); +} + +// Test that shutting down an unconnected socket fails. +TEST_P(RawSocketTest, FailShutdownWithoutConnect) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT(shutdown(s_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); + ASSERT_THAT(shutdown(s_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); +} + +// Shutdown is a no-op for raw sockets (and datagram sockets in general). +TEST_P(RawSocketTest, ShutdownWriteNoop) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT( + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), + SyscallSucceeds()); + ASSERT_THAT(shutdown(s_, SHUT_WR), SyscallSucceeds()); + + // Arbitrary. + constexpr char kBuf[] = "noop"; + ASSERT_THAT(RetryEINTR(write)(s_, kBuf, sizeof(kBuf)), + SyscallSucceedsWithValue(sizeof(kBuf))); +} + +// Shutdown is a no-op for raw sockets (and datagram sockets in general). +TEST_P(RawSocketTest, ShutdownReadNoop) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT( + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), + SyscallSucceeds()); + ASSERT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); + + // Arbitrary. + constexpr char kBuf[] = "gdg"; + ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); + + std::vector<char> c(sizeof(kBuf) + HdrLen()); + ASSERT_THAT(read(s_, c.data(), c.size()), SyscallSucceedsWithValue(c.size())); +} + +// Test that listen() fails. +TEST_P(RawSocketTest, FailListen) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT(listen(s_, 1), SyscallFailsWithErrno(ENOTSUP)); +} + +// Test that accept() fails. +TEST_P(RawSocketTest, FailAccept) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + struct sockaddr saddr; + socklen_t addrlen; + ASSERT_THAT(accept(s_, &saddr, &addrlen), SyscallFailsWithErrno(ENOTSUP)); +} + +// Test that getpeername() returns nothing before connect(). +TEST_P(RawSocketTest, FailGetPeerNameBeforeConnect) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + struct sockaddr saddr; + socklen_t addrlen = sizeof(saddr); + ASSERT_THAT(getpeername(s_, &saddr, &addrlen), + SyscallFailsWithErrno(ENOTCONN)); +} + +// Test that getpeername() returns something after connect(). +TEST_P(RawSocketTest, GetPeerName) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT( + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), + SyscallSucceeds()); + struct sockaddr saddr; + socklen_t addrlen = sizeof(saddr); + ASSERT_THAT(getpeername(s_, &saddr, &addrlen), + SyscallFailsWithErrno(ENOTCONN)); + ASSERT_GT(addrlen, 0); +} + +// Test that the socket is writable immediately. +TEST_P(RawSocketTest, PollWritableImmediately) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + struct pollfd pfd = {}; + pfd.fd = s_; + pfd.events = POLLOUT; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 10000), SyscallSucceedsWithValue(1)); +} + +// Test that the socket isn't readable before receiving anything. +TEST_P(RawSocketTest, PollNotReadableInitially) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // Try to receive data with MSG_DONTWAIT, which returns immediately if there's + // nothing to be read. + char buf[117]; + ASSERT_THAT(RetryEINTR(recv)(s_, buf, sizeof(buf), MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); +} + +// Test that the socket becomes readable once something is written to it. +TEST_P(RawSocketTest, PollTriggeredOnWrite) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // Write something so that there's data to be read. + // Arbitrary. + constexpr char kBuf[] = "JP5"; + ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); + + struct pollfd pfd = {}; + pfd.fd = s_; + pfd.events = POLLIN; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 10000), SyscallSucceedsWithValue(1)); +} + +// Test that we can connect() to a valid IP (loopback). +TEST_P(RawSocketTest, ConnectToLoopback) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT( + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), + SyscallSucceeds()); +} + +// Test that calling send() without connect() fails. +TEST_P(RawSocketTest, SendWithoutConnectFails) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // Arbitrary. + constexpr char kBuf[] = "Endgame was good"; + ASSERT_THAT(send(s_, kBuf, sizeof(kBuf), 0), + SyscallFailsWithErrno(EDESTADDRREQ)); +} + +// Wildcard Bind. +TEST_P(RawSocketTest, BindToWildcard) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + struct sockaddr_storage addr; + addr = {}; + + // We don't set ports because raw sockets don't have a notion of ports. + if (Family() == AF_INET) { + struct sockaddr_in* sin = reinterpret_cast<struct sockaddr_in*>(&addr); + sin->sin_family = AF_INET; + sin->sin_addr.s_addr = htonl(INADDR_ANY); + } else { + struct sockaddr_in6* sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); + sin6->sin6_family = AF_INET6; + sin6->sin6_addr = in6addr_any; + } + + ASSERT_THAT(bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), + SyscallSucceeds()); +} + +// Bind to localhost. +TEST_P(RawSocketTest, BindToLocalhost) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT( + bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), + SyscallSucceeds()); +} + +// Bind to a different address. +TEST_P(RawSocketTest, BindToInvalid) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + struct sockaddr_storage bind_addr = addr_; + if (Family() == AF_INET) { + struct sockaddr_in* sin = reinterpret_cast<struct sockaddr_in*>(&bind_addr); + sin->sin_addr = {1}; // 1.0.0.0 - An address that we can't bind to. + } else { + struct sockaddr_in6* sin6 = + reinterpret_cast<struct sockaddr_in6*>(&bind_addr); + memset(&sin6->sin6_addr.s6_addr, 0, sizeof(sin6->sin6_addr.s6_addr)); + sin6->sin6_addr.s6_addr[0] = 1; // 1: - An address that we can't bind to. + } + ASSERT_THAT(bind(s_, reinterpret_cast<struct sockaddr*>(&bind_addr), + AddrLen()), SyscallFailsWithErrno(EADDRNOTAVAIL)); +} + +// Send and receive an packet. +TEST_P(RawSocketTest, SendAndReceive) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // Arbitrary. + constexpr char kBuf[] = "TB12"; + ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); + + // Receive the packet and make sure it's identical. + std::vector<char> recv_buf(sizeof(kBuf) + HdrLen()); + ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size())); + EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), kBuf, sizeof(kBuf)), 0); +} + +// We should be able to create multiple raw sockets for the same protocol and +// receive the same packet on both. +TEST_P(RawSocketTest, MultipleSocketReceive) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int s2; + ASSERT_THAT(s2 = socket(Family(), SOCK_RAW, Protocol()), SyscallSucceeds()); + + // Arbitrary. + constexpr char kBuf[] = "TB10"; + ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); + + // Receive it on socket 1. + std::vector<char> recv_buf1(sizeof(kBuf) + HdrLen()); + ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf1.data(), recv_buf1.size())); + + // Receive it on socket 2. + std::vector<char> recv_buf2(sizeof(kBuf) + HdrLen()); + ASSERT_NO_FATAL_FAILURE(ReceiveBufFrom(s2, recv_buf2.data(), + recv_buf2.size())); + + EXPECT_EQ(memcmp(recv_buf1.data() + HdrLen(), + recv_buf2.data() + HdrLen(), sizeof(kBuf)), + 0); + + ASSERT_THAT(close(s2), SyscallSucceeds()); +} + +// Test that connect sends packets to the right place. +TEST_P(RawSocketTest, SendAndReceiveViaConnect) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT( + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), + SyscallSucceeds()); + + // Arbitrary. + constexpr char kBuf[] = "JH4"; + ASSERT_THAT(send(s_, kBuf, sizeof(kBuf), 0), + SyscallSucceedsWithValue(sizeof(kBuf))); + + // Receive the packet and make sure it's identical. + std::vector<char> recv_buf(sizeof(kBuf) + HdrLen()); + ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size())); + EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), kBuf, sizeof(kBuf)), 0); +} + +// Bind to localhost, then send and receive packets. +TEST_P(RawSocketTest, BindSendAndReceive) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT( + bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), + SyscallSucceeds()); + + // Arbitrary. + constexpr char kBuf[] = "DR16"; + ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); + + // Receive the packet and make sure it's identical. + std::vector<char> recv_buf(sizeof(kBuf) + HdrLen()); + ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size())); + EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), kBuf, sizeof(kBuf)), 0); +} + +// Bind and connect to localhost and send/receive packets. +TEST_P(RawSocketTest, BindConnectSendAndReceive) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT( + bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), + SyscallSucceeds()); + ASSERT_THAT( + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), + SyscallSucceeds()); + + // Arbitrary. + constexpr char kBuf[] = "DG88"; + ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); + + // Receive the packet and make sure it's identical. + std::vector<char> recv_buf(sizeof(kBuf) + HdrLen()); + ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size())); + EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), kBuf, sizeof(kBuf)), 0); +} + +// Check that setting SO_RCVBUF below min is clamped to the minimum +// receive buffer size. +TEST_P(RawSocketTest, SetSocketRecvBufBelowMin) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // Discover minimum receive buf size by trying to set it to zero. + // See: + // https://github.com/torvalds/linux/blob/a5dc8300df75e8b8384b4c82225f1e4a0b4d9b55/net/core/sock.c#L820 + constexpr int kRcvBufSz = 0; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)), + SyscallSucceeds()); + + int min = 0; + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len), + SyscallSucceeds()); + + // Linux doubles the value so let's use a value that when doubled will still + // be smaller than min. + int below_min = min / 2 - 1; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &below_min, sizeof(below_min)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &val, &val_len), + SyscallSucceeds()); + + ASSERT_EQ(min, val); +} + +// Check that setting SO_RCVBUF above max is clamped to the maximum +// receive buffer size. +TEST_P(RawSocketTest, SetSocketRecvBufAboveMax) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // Discover max buf size by trying to set the largest possible buffer size. + constexpr int kRcvBufSz = 0xffffffff; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)), + SyscallSucceeds()); + + int max = 0; + socklen_t max_len = sizeof(max); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &max, &max_len), + SyscallSucceeds()); + + int above_max = max + 1; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &above_max, sizeof(above_max)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &val, &val_len), + SyscallSucceeds()); + ASSERT_EQ(max, val); +} + +// Check that setting SO_RCVBUF min <= kRcvBufSz <= max is honored. +TEST_P(RawSocketTest, SetSocketRecvBuf) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int max = 0; + int min = 0; + { + // Discover max buf size by trying to set a really large buffer size. + constexpr int kRcvBufSz = 0xffffffff; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)), + SyscallSucceeds()); + + max = 0; + socklen_t max_len = sizeof(max); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &max, &max_len), + SyscallSucceeds()); + } + + { + // Discover minimum buffer size by trying to set a zero size receive buffer + // size. + // See: + // https://github.com/torvalds/linux/blob/a5dc8300df75e8b8384b4c82225f1e4a0b4d9b55/net/core/sock.c#L820 + constexpr int kRcvBufSz = 0; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)), + SyscallSucceeds()); + + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len), + SyscallSucceeds()); + } + + int quarter_sz = min + (max - min) / 4; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &quarter_sz, sizeof(quarter_sz)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &val, &val_len), + SyscallSucceeds()); + + // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF. + // TODO(gvisor.dev/issue/2926): Remove when Netstack matches linux behavior. + if (!IsRunningOnGvisor()) { + quarter_sz *= 2; + } + ASSERT_EQ(quarter_sz, val); +} + +// Check that setting SO_SNDBUF below min is clamped to the minimum +// receive buffer size. +TEST_P(RawSocketTest, SetSocketSendBufBelowMin) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // Discover minimum buffer size by trying to set it to zero. + constexpr int kSndBufSz = 0; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)), + SyscallSucceeds()); + + int min = 0; + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &min, &min_len), + SyscallSucceeds()); + + // Linux doubles the value so let's use a value that when doubled will still + // be smaller than min. + int below_min = min / 2 - 1; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &below_min, sizeof(below_min)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len), + SyscallSucceeds()); + + ASSERT_EQ(min, val); +} + +// Check that setting SO_SNDBUF above max is clamped to the maximum +// send buffer size. +TEST_P(RawSocketTest, SetSocketSendBufAboveMax) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // Discover maximum buffer size by trying to set it to a large value. + constexpr int kSndBufSz = 0xffffffff; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)), + SyscallSucceeds()); + + int max = 0; + socklen_t max_len = sizeof(max); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &max, &max_len), + SyscallSucceeds()); + + int above_max = max + 1; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &above_max, sizeof(above_max)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len), + SyscallSucceeds()); + ASSERT_EQ(max, val); +} + +// Check that setting SO_SNDBUF min <= kSndBufSz <= max is honored. +TEST_P(RawSocketTest, SetSocketSendBuf) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int max = 0; + int min = 0; + { + // Discover maximum buffer size by trying to set it to a large value. + constexpr int kSndBufSz = 0xffffffff; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)), + SyscallSucceeds()); + + max = 0; + socklen_t max_len = sizeof(max); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &max, &max_len), + SyscallSucceeds()); + } + + { + // Discover minimum buffer size by trying to set it to zero. + constexpr int kSndBufSz = 0; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)), + SyscallSucceeds()); + + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &min, &min_len), + SyscallSucceeds()); + } + + int quarter_sz = min + (max - min) / 4; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &quarter_sz, sizeof(quarter_sz)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len), + SyscallSucceeds()); + + // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF. + // TODO(gvisor.dev/issue/2926): Remove the gvisor special casing when Netstack + // matches linux behavior. + if (!IsRunningOnGvisor()) { + quarter_sz *= 2; + } + + ASSERT_EQ(quarter_sz, val); +} + +// Test that receive buffer limits are not enforced when the recv buffer is +// empty. +TEST_P(RawSocketTest, RecvBufLimitsEmptyRecvBuffer) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT( + bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), + SyscallSucceeds()); + ASSERT_THAT( + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), + SyscallSucceeds()); + + int min = 0; + { + // Discover minimum buffer size by trying to set it to zero. + constexpr int kRcvBufSz = 0; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)), + SyscallSucceeds()); + + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len), + SyscallSucceeds()); + } + + { + // Send data of size min and verify that it's received. + std::vector<char> buf(min); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size())); + + // Receive the packet and make sure it's identical. + std::vector<char> recv_buf(buf.size() + HdrLen()); + ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size())); + EXPECT_EQ( + memcmp(recv_buf.data() + HdrLen(), buf.data(), buf.size()), + 0); + } + + { + // Send data of size min + 1 and verify that its received. Both linux and + // Netstack accept a dgram that exceeds rcvBuf limits if the receive buffer + // is currently empty. + std::vector<char> buf(min + 1); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size())); + // Receive the packet and make sure it's identical. + std::vector<char> recv_buf(buf.size() + HdrLen()); + ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size())); + EXPECT_EQ( + memcmp(recv_buf.data() + HdrLen(), buf.data(), buf.size()), + 0); + } +} + +TEST_P(RawSocketTest, RecvBufLimits) { + // TCP stack generates RSTs for unknown endpoints and it complicates the test + // as we have to deal with the RST packets as well. For testing the raw socket + // endpoints buffer limit enforcement we can just test for UDP. + // + // We don't use SKIP_IF here because root_test_runner explicitly fails if a + // test is skipped. + if (Protocol() == IPPROTO_TCP) { + return; + } + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT( + bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), + SyscallSucceeds()); + ASSERT_THAT( + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), + SyscallSucceeds()); + + int min = 0; + { + // Discover minimum buffer size by trying to set it to zero. + constexpr int kRcvBufSz = 0; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)), + SyscallSucceeds()); + + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len), + SyscallSucceeds()); + } + + // Now set the limit to min * 2. + int new_rcv_buf_sz = min * 4; + if (!IsRunningOnGvisor()) { + // Linux doubles the value specified so just set to min. + new_rcv_buf_sz = min * 2; + } + + ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &new_rcv_buf_sz, + sizeof(new_rcv_buf_sz)), + SyscallSucceeds()); + int rcv_buf_sz = 0; + { + socklen_t rcv_buf_len = sizeof(rcv_buf_sz); + ASSERT_THAT( + getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &rcv_buf_sz, &rcv_buf_len), + SyscallSucceeds()); + } + + // Set a receive timeout so that we don't block forever on reads if the test + // fails. + struct timeval tv { + .tv_sec = 1, .tv_usec = 0, + }; + ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)), + SyscallSucceeds()); + + { + std::vector<char> buf(min); + RandomizeBuffer(buf.data(), buf.size()); + + ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size())); + ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size())); + ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size())); + ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size())); + int sent = 4; + if (IsRunningOnGvisor()) { + // Linux seems to drop the 4th packet even though technically it should + // fit in the receive buffer. + ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size())); + sent++; + } + + // Verify that the expected number of packets are available to be read. + for (int i = 0; i < sent - 1; i++) { + // Receive the packet and make sure it's identical. + std::vector<char> recv_buf(buf.size() + HdrLen()); + ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size())); + EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), buf.data(), + buf.size()), + 0); + } + + // Assert that the last packet is dropped because the receive buffer should + // be full after the first four packets. + std::vector<char> recv_buf(buf.size() + HdrLen()); + struct iovec iov = {}; + iov.iov_base = static_cast<void*>(const_cast<char*>(recv_buf.data())); + iov.iov_len = buf.size(); + struct msghdr msg = {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = NULL; + msg.msg_controllen = 0; + msg.msg_flags = 0; + ASSERT_THAT(RetryEINTR(recvmsg)(s_, &msg, MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); + } +} + +void RawSocketTest::SendBuf(const char* buf, int buf_len) { + // It's safe to use const_cast here because sendmsg won't modify the iovec or + // address. + struct iovec iov = {}; + iov.iov_base = static_cast<void*>(const_cast<char*>(buf)); + iov.iov_len = static_cast<size_t>(buf_len); + struct msghdr msg = {}; + msg.msg_name = static_cast<void*>(&addr_); + msg.msg_namelen = AddrLen(); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = NULL; + msg.msg_controllen = 0; + msg.msg_flags = 0; + ASSERT_THAT(sendmsg(s_, &msg, 0), SyscallSucceedsWithValue(buf_len)); +} + +void RawSocketTest::ReceiveBuf(char* recv_buf, size_t recv_buf_len) { + ASSERT_NO_FATAL_FAILURE(ReceiveBufFrom(s_, recv_buf, recv_buf_len)); +} + +void RawSocketTest::ReceiveBufFrom(int sock, char* recv_buf, + size_t recv_buf_len) { + ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(sock, recv_buf, recv_buf_len)); +} + +#ifndef __fuchsia__ + +TEST_P(RawSocketTest, SetSocketDetachFilterNoInstalledFilter) { + // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. + if (IsRunningOnGvisor()) { + constexpr int val = 0; + ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallSucceeds()); + return; + } + + constexpr int val = 0; + ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallFailsWithErrno(ENOENT)); +} + +TEST_P(RawSocketTest, GetSocketDetachFilter) { + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len), + SyscallFailsWithErrno(ENOPROTOOPT)); +} + +#endif // __fuchsia__ + +// AF_INET6+SOCK_RAW+IPPROTO_RAW sockets can be created, but not written to. +TEST(RawSocketTest, IPv6ProtoRaw) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET6, SOCK_RAW, IPPROTO_RAW), + SyscallSucceeds()); + + // Verify that writing yields EINVAL. + char buf[] = "This is such a weird little edge case"; + struct sockaddr_in6 sin6 = {}; + sin6.sin6_family = AF_INET6; + sin6.sin6_addr = in6addr_loopback; + ASSERT_THAT(sendto(sock, buf, sizeof(buf), 0 /* flags */, + reinterpret_cast<struct sockaddr*>(&sin6), sizeof(sin6)), + SyscallFailsWithErrno(EINVAL)); +} + +INSTANTIATE_TEST_SUITE_P( + AllInetTests, RawSocketTest, + ::testing::Combine(::testing::Values(IPPROTO_TCP, IPPROTO_UDP), + ::testing::Values(AF_INET, AF_INET6))); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/raw_socket_hdrincl.cc b/test/syscalls/linux/raw_socket_hdrincl.cc index 0a27506aa..2f25aceb2 100644 --- a/test/syscalls/linux/raw_socket_hdrincl.cc +++ b/test/syscalls/linux/raw_socket_hdrincl.cc @@ -167,7 +167,7 @@ TEST_F(RawHDRINCL, NotReadable) { // nothing to be read. char buf[117]; ASSERT_THAT(RetryEINTR(recv)(socket_, buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EINVAL)); + SyscallFailsWithErrno(EAGAIN)); } // Test that we can connect() to a valid IP (loopback). @@ -178,6 +178,9 @@ TEST_F(RawHDRINCL, ConnectToLoopback) { } TEST_F(RawHDRINCL, SendWithoutConnectSucceeds) { + // FIXME(gvisor.dev/issue/3159): Test currently flaky. + SKIP_IF(true); + struct iphdr hdr = LoopbackHeader(); ASSERT_THAT(send(socket_, &hdr, sizeof(hdr), 0), SyscallSucceedsWithValue(sizeof(hdr))); @@ -273,14 +276,17 @@ TEST_F(RawHDRINCL, SendAndReceive) { // The network stack should have set the source address. EXPECT_EQ(src.sin_family, AF_INET); EXPECT_EQ(absl::gbswap_32(src.sin_addr.s_addr), INADDR_LOOPBACK); - // The packet ID should be 0, as the packet is less than 68 bytes. - struct iphdr iphdr = {}; - memcpy(&iphdr, recv_buf, sizeof(iphdr)); - EXPECT_EQ(iphdr.id, 0); + // The packet ID should not be 0, as the packet has DF=0. + struct iphdr* iphdr = reinterpret_cast<struct iphdr*>(recv_buf); + EXPECT_NE(iphdr->id, 0); } -// Send and receive a packet with nonzero IP ID. -TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) { +// Send and receive a packet where the sendto address is not the same as the +// provided destination. +TEST_F(RawHDRINCL, SendAndReceiveDifferentAddress) { + // FIXME(gvisor.dev/issue/3160): Test currently flaky. + SKIP_IF(true); + int port = 40000; if (!IsRunningOnGvisor()) { port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE( @@ -292,19 +298,24 @@ TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) { FileDescriptor udp_sock = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP)); - // Construct a packet with an IP header, UDP header, and payload. Make the - // payload large enough to force an IP ID to be assigned. - constexpr char kPayload[128] = {}; + // Construct a packet with an IP header, UDP header, and payload. + constexpr char kPayload[] = "toto"; char packet[sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kPayload)]; ASSERT_TRUE( FillPacket(packet, sizeof(packet), port, kPayload, sizeof(kPayload))); + // Overwrite the IP destination address with an IP we can't get to. + struct iphdr iphdr = {}; + memcpy(&iphdr, packet, sizeof(iphdr)); + iphdr.daddr = 42; + memcpy(packet, &iphdr, sizeof(iphdr)); socklen_t addrlen = sizeof(addr_); ASSERT_NO_FATAL_FAILURE(sendto(socket_, &packet, sizeof(packet), 0, reinterpret_cast<struct sockaddr*>(&addr_), addrlen)); - // Receive the payload. + // Receive the payload, since sendto should replace the bad destination with + // localhost. char recv_buf[sizeof(packet)]; struct sockaddr_in src; socklen_t src_size = sizeof(src); @@ -318,47 +329,58 @@ TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) { // The network stack should have set the source address. EXPECT_EQ(src.sin_family, AF_INET); EXPECT_EQ(absl::gbswap_32(src.sin_addr.s_addr), INADDR_LOOPBACK); - // The packet ID should not be 0, as the packet was more than 68 bytes. - struct iphdr* iphdr = reinterpret_cast<struct iphdr*>(recv_buf); - EXPECT_NE(iphdr->id, 0); + // The packet ID should not be 0, as the packet has DF=0. + struct iphdr recv_iphdr = {}; + memcpy(&recv_iphdr, recv_buf, sizeof(recv_iphdr)); + EXPECT_NE(recv_iphdr.id, 0); + // The destination address should be localhost, not the bad IP we set + // initially. + EXPECT_EQ(absl::gbswap_32(recv_iphdr.daddr), INADDR_LOOPBACK); } -// Send and receive a packet where the sendto address is not the same as the -// provided destination. -TEST_F(RawHDRINCL, SendAndReceiveDifferentAddress) { +// Send and receive a packet w/ the IP_HDRINCL option set. +TEST_F(RawHDRINCL, SendAndReceiveIPHdrIncl) { int port = 40000; if (!IsRunningOnGvisor()) { port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE( PortAvailable(0, AddressFamily::kIpv4, SocketType::kUdp, false))); } - // IPPROTO_RAW sockets are write-only. We'll have to open another socket to - // read what we write. - FileDescriptor udp_sock = + FileDescriptor recv_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP)); + + FileDescriptor send_sock = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP)); + // Enable IP_HDRINCL option so that we can build and send w/ an IP + // header. + constexpr int kSockOptOn = 1; + ASSERT_THAT(setsockopt(send_sock.get(), SOL_IP, IP_HDRINCL, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + // This is not strictly required but we do it to make sure that setting + // IP_HDRINCL on a non IPPROTO_RAW socket does not prevent it from receiving + // packets. + ASSERT_THAT(setsockopt(recv_sock.get(), SOL_IP, IP_HDRINCL, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + // Construct a packet with an IP header, UDP header, and payload. constexpr char kPayload[] = "toto"; char packet[sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kPayload)]; ASSERT_TRUE( FillPacket(packet, sizeof(packet), port, kPayload, sizeof(kPayload))); - // Overwrite the IP destination address with an IP we can't get to. - struct iphdr iphdr = {}; - memcpy(&iphdr, packet, sizeof(iphdr)); - iphdr.daddr = 42; - memcpy(packet, &iphdr, sizeof(iphdr)); socklen_t addrlen = sizeof(addr_); - ASSERT_NO_FATAL_FAILURE(sendto(socket_, &packet, sizeof(packet), 0, + ASSERT_NO_FATAL_FAILURE(sendto(send_sock.get(), &packet, sizeof(packet), 0, reinterpret_cast<struct sockaddr*>(&addr_), addrlen)); - // Receive the payload, since sendto should replace the bad destination with - // localhost. + // Receive the payload. char recv_buf[sizeof(packet)]; struct sockaddr_in src; socklen_t src_size = sizeof(src); - ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), 0, + ASSERT_THAT(recvfrom(recv_sock.get(), recv_buf, sizeof(recv_buf), 0, reinterpret_cast<struct sockaddr*>(&src), &src_size), SyscallSucceedsWithValue(sizeof(packet))); EXPECT_EQ( @@ -368,13 +390,20 @@ TEST_F(RawHDRINCL, SendAndReceiveDifferentAddress) { // The network stack should have set the source address. EXPECT_EQ(src.sin_family, AF_INET); EXPECT_EQ(absl::gbswap_32(src.sin_addr.s_addr), INADDR_LOOPBACK); - // The packet ID should be 0, as the packet is less than 68 bytes. - struct iphdr recv_iphdr = {}; - memcpy(&recv_iphdr, recv_buf, sizeof(recv_iphdr)); - EXPECT_EQ(recv_iphdr.id, 0); - // The destination address should be localhost, not the bad IP we set - // initially. - EXPECT_EQ(absl::gbswap_32(recv_iphdr.daddr), INADDR_LOOPBACK); + struct iphdr iphdr = {}; + memcpy(&iphdr, recv_buf, sizeof(iphdr)); + EXPECT_NE(iphdr.id, 0); + + // Also verify that the packet we just sent was not delivered to the + // IPPROTO_RAW socket. + { + char recv_buf[sizeof(packet)]; + struct sockaddr_in src; + socklen_t src_size = sizeof(src); + ASSERT_THAT(recvfrom(socket_, recv_buf, sizeof(recv_buf), MSG_DONTWAIT, + reinterpret_cast<struct sockaddr*>(&src), &src_size), + SyscallFailsWithErrno(EAGAIN)); + } } } // namespace diff --git a/test/syscalls/linux/raw_socket_icmp.cc b/test/syscalls/linux/raw_socket_icmp.cc index 8bcaba6f1..3de898df7 100644 --- a/test/syscalls/linux/raw_socket_icmp.cc +++ b/test/syscalls/linux/raw_socket_icmp.cc @@ -129,7 +129,7 @@ TEST_F(RawSocketICMPTest, SendAndReceiveBadChecksum) { EXPECT_THAT(RetryEINTR(recv)(s_, recv_buf, sizeof(recv_buf), MSG_DONTWAIT), SyscallFailsWithErrno(EAGAIN)); } -// + // Send and receive an ICMP packet. TEST_F(RawSocketICMPTest, SendAndReceive) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); diff --git a/test/syscalls/linux/raw_socket_ipv4.cc b/test/syscalls/linux/raw_socket_ipv4.cc deleted file mode 100644 index cde2f07c9..000000000 --- a/test/syscalls/linux/raw_socket_ipv4.cc +++ /dev/null @@ -1,392 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <linux/capability.h> -#include <netinet/in.h> -#include <netinet/ip.h> -#include <netinet/ip_icmp.h> -#include <poll.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <unistd.h> - -#include <algorithm> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/capability_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/test_util.h" - -// Note: in order to run these tests, /proc/sys/net/ipv4/ping_group_range will -// need to be configured to let the superuser create ping sockets (see icmp(7)). - -namespace gvisor { -namespace testing { - -namespace { - -// Fixture for tests parameterized by protocol. -class RawSocketTest : public ::testing::TestWithParam<int> { - protected: - // Creates a socket to be used in tests. - void SetUp() override; - - // Closes the socket created by SetUp(). - void TearDown() override; - - // Sends buf via s_. - void SendBuf(const char* buf, int buf_len); - - // Sends buf to the provided address via the provided socket. - void SendBufTo(int sock, const struct sockaddr_in& addr, const char* buf, - int buf_len); - - // Reads from s_ into recv_buf. - void ReceiveBuf(char* recv_buf, size_t recv_buf_len); - - int Protocol() { return GetParam(); } - - // The socket used for both reading and writing. - int s_; - - // The loopback address. - struct sockaddr_in addr_; -}; - -void RawSocketTest::SetUp() { - if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { - ASSERT_THAT(socket(AF_INET, SOCK_RAW, Protocol()), - SyscallFailsWithErrno(EPERM)); - GTEST_SKIP(); - } - - ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, Protocol()), SyscallSucceeds()); - - addr_ = {}; - - // We don't set ports because raw sockets don't have a notion of ports. - addr_.sin_addr.s_addr = htonl(INADDR_LOOPBACK); - addr_.sin_family = AF_INET; -} - -void RawSocketTest::TearDown() { - // TearDown will be run even if we skip the test. - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { - EXPECT_THAT(close(s_), SyscallSucceeds()); - } -} - -// We should be able to create multiple raw sockets for the same protocol. -// BasicRawSocket::Setup creates the first one, so we only have to create one -// more here. -TEST_P(RawSocketTest, MultipleCreation) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - int s2; - ASSERT_THAT(s2 = socket(AF_INET, SOCK_RAW, Protocol()), SyscallSucceeds()); - - ASSERT_THAT(close(s2), SyscallSucceeds()); -} - -// Test that shutting down an unconnected socket fails. -TEST_P(RawSocketTest, FailShutdownWithoutConnect) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - ASSERT_THAT(shutdown(s_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); - ASSERT_THAT(shutdown(s_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); -} - -// Shutdown is a no-op for raw sockets (and datagram sockets in general). -TEST_P(RawSocketTest, ShutdownWriteNoop) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), - SyscallSucceeds()); - ASSERT_THAT(shutdown(s_, SHUT_WR), SyscallSucceeds()); - - // Arbitrary. - constexpr char kBuf[] = "noop"; - ASSERT_THAT(RetryEINTR(write)(s_, kBuf, sizeof(kBuf)), - SyscallSucceedsWithValue(sizeof(kBuf))); -} - -// Shutdown is a no-op for raw sockets (and datagram sockets in general). -TEST_P(RawSocketTest, ShutdownReadNoop) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), - SyscallSucceeds()); - ASSERT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); - - // Arbitrary. - constexpr char kBuf[] = "gdg"; - ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); - - constexpr size_t kReadSize = sizeof(kBuf) + sizeof(struct iphdr); - char c[kReadSize]; - ASSERT_THAT(read(s_, &c, sizeof(c)), SyscallSucceedsWithValue(kReadSize)); -} - -// Test that listen() fails. -TEST_P(RawSocketTest, FailListen) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - ASSERT_THAT(listen(s_, 1), SyscallFailsWithErrno(ENOTSUP)); -} - -// Test that accept() fails. -TEST_P(RawSocketTest, FailAccept) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - struct sockaddr saddr; - socklen_t addrlen; - ASSERT_THAT(accept(s_, &saddr, &addrlen), SyscallFailsWithErrno(ENOTSUP)); -} - -// Test that getpeername() returns nothing before connect(). -TEST_P(RawSocketTest, FailGetPeerNameBeforeConnect) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - struct sockaddr saddr; - socklen_t addrlen = sizeof(saddr); - ASSERT_THAT(getpeername(s_, &saddr, &addrlen), - SyscallFailsWithErrno(ENOTCONN)); -} - -// Test that getpeername() returns something after connect(). -TEST_P(RawSocketTest, GetPeerName) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), - SyscallSucceeds()); - struct sockaddr saddr; - socklen_t addrlen = sizeof(saddr); - ASSERT_THAT(getpeername(s_, &saddr, &addrlen), - SyscallFailsWithErrno(ENOTCONN)); - ASSERT_GT(addrlen, 0); -} - -// Test that the socket is writable immediately. -TEST_P(RawSocketTest, PollWritableImmediately) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - struct pollfd pfd = {}; - pfd.fd = s_; - pfd.events = POLLOUT; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 10000), SyscallSucceedsWithValue(1)); -} - -// Test that the socket isn't readable before receiving anything. -TEST_P(RawSocketTest, PollNotReadableInitially) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - // Try to receive data with MSG_DONTWAIT, which returns immediately if there's - // nothing to be read. - char buf[117]; - ASSERT_THAT(RetryEINTR(recv)(s_, buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); -} - -// Test that the socket becomes readable once something is written to it. -TEST_P(RawSocketTest, PollTriggeredOnWrite) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - // Write something so that there's data to be read. - // Arbitrary. - constexpr char kBuf[] = "JP5"; - ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); - - struct pollfd pfd = {}; - pfd.fd = s_; - pfd.events = POLLIN; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 10000), SyscallSucceedsWithValue(1)); -} - -// Test that we can connect() to a valid IP (loopback). -TEST_P(RawSocketTest, ConnectToLoopback) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), - SyscallSucceeds()); -} - -// Test that calling send() without connect() fails. -TEST_P(RawSocketTest, SendWithoutConnectFails) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - // Arbitrary. - constexpr char kBuf[] = "Endgame was good"; - ASSERT_THAT(send(s_, kBuf, sizeof(kBuf), 0), - SyscallFailsWithErrno(EDESTADDRREQ)); -} - -// Bind to localhost. -TEST_P(RawSocketTest, BindToLocalhost) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - ASSERT_THAT( - bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), - SyscallSucceeds()); -} - -// Bind to a different address. -TEST_P(RawSocketTest, BindToInvalid) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - struct sockaddr_in bind_addr = {}; - bind_addr.sin_family = AF_INET; - bind_addr.sin_addr = {1}; // 1.0.0.0 - An address that we can't bind to. - ASSERT_THAT(bind(s_, reinterpret_cast<struct sockaddr*>(&bind_addr), - sizeof(bind_addr)), - SyscallFailsWithErrno(EADDRNOTAVAIL)); -} - -// Send and receive an packet. -TEST_P(RawSocketTest, SendAndReceive) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - // Arbitrary. - constexpr char kBuf[] = "TB12"; - ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); - - // Receive the packet and make sure it's identical. - char recv_buf[sizeof(kBuf) + sizeof(struct iphdr)]; - ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf, sizeof(recv_buf))); - EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), kBuf, sizeof(kBuf)), 0); -} - -// We should be able to create multiple raw sockets for the same protocol and -// receive the same packet on both. -TEST_P(RawSocketTest, MultipleSocketReceive) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - int s2; - ASSERT_THAT(s2 = socket(AF_INET, SOCK_RAW, Protocol()), SyscallSucceeds()); - - // Arbitrary. - constexpr char kBuf[] = "TB10"; - ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); - - // Receive it on socket 1. - char recv_buf1[sizeof(kBuf) + sizeof(struct iphdr)]; - ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf1, sizeof(recv_buf1))); - - // Receive it on socket 2. - char recv_buf2[sizeof(kBuf) + sizeof(struct iphdr)]; - ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s2, recv_buf2, sizeof(recv_buf2))); - - EXPECT_EQ(memcmp(recv_buf1 + sizeof(struct iphdr), - recv_buf2 + sizeof(struct iphdr), sizeof(kBuf)), - 0); - - ASSERT_THAT(close(s2), SyscallSucceeds()); -} - -// Test that connect sends packets to the right place. -TEST_P(RawSocketTest, SendAndReceiveViaConnect) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), - SyscallSucceeds()); - - // Arbitrary. - constexpr char kBuf[] = "JH4"; - ASSERT_THAT(send(s_, kBuf, sizeof(kBuf), 0), - SyscallSucceedsWithValue(sizeof(kBuf))); - - // Receive the packet and make sure it's identical. - char recv_buf[sizeof(kBuf) + sizeof(struct iphdr)]; - ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf, sizeof(recv_buf))); - EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), kBuf, sizeof(kBuf)), 0); -} - -// Bind to localhost, then send and receive packets. -TEST_P(RawSocketTest, BindSendAndReceive) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - ASSERT_THAT( - bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), - SyscallSucceeds()); - - // Arbitrary. - constexpr char kBuf[] = "DR16"; - ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); - - // Receive the packet and make sure it's identical. - char recv_buf[sizeof(kBuf) + sizeof(struct iphdr)]; - ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf, sizeof(recv_buf))); - EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), kBuf, sizeof(kBuf)), 0); -} - -// Bind and connect to localhost and send/receive packets. -TEST_P(RawSocketTest, BindConnectSendAndReceive) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - ASSERT_THAT( - bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), - SyscallSucceeds()); - ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), - SyscallSucceeds()); - - // Arbitrary. - constexpr char kBuf[] = "DG88"; - ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); - - // Receive the packet and make sure it's identical. - char recv_buf[sizeof(kBuf) + sizeof(struct iphdr)]; - ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf, sizeof(recv_buf))); - EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), kBuf, sizeof(kBuf)), 0); -} - -void RawSocketTest::SendBuf(const char* buf, int buf_len) { - ASSERT_NO_FATAL_FAILURE(SendBufTo(s_, addr_, buf, buf_len)); -} - -void RawSocketTest::SendBufTo(int sock, const struct sockaddr_in& addr, - const char* buf, int buf_len) { - // It's safe to use const_cast here because sendmsg won't modify the iovec or - // address. - struct iovec iov = {}; - iov.iov_base = static_cast<void*>(const_cast<char*>(buf)); - iov.iov_len = static_cast<size_t>(buf_len); - struct msghdr msg = {}; - msg.msg_name = static_cast<void*>(const_cast<struct sockaddr_in*>(&addr)); - msg.msg_namelen = sizeof(addr); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - msg.msg_control = NULL; - msg.msg_controllen = 0; - msg.msg_flags = 0; - ASSERT_THAT(sendmsg(sock, &msg, 0), SyscallSucceedsWithValue(buf_len)); -} - -void RawSocketTest::ReceiveBuf(char* recv_buf, size_t recv_buf_len) { - ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s_, recv_buf, recv_buf_len)); -} - -INSTANTIATE_TEST_SUITE_P(AllInetTests, RawSocketTest, - ::testing::Values(IPPROTO_TCP, IPPROTO_UDP)); - -} // namespace - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/read.cc b/test/syscalls/linux/read.cc index 4430fa3c2..2633ba31b 100644 --- a/test/syscalls/linux/read.cc +++ b/test/syscalls/linux/read.cc @@ -14,6 +14,7 @@ #include <fcntl.h> #include <unistd.h> + #include <vector> #include "gtest/gtest.h" diff --git a/test/syscalls/linux/readv.cc b/test/syscalls/linux/readv.cc index 4069cbc7e..baaf9f757 100644 --- a/test/syscalls/linux/readv.cc +++ b/test/syscalls/linux/readv.cc @@ -254,7 +254,9 @@ TEST_F(ReadvTest, IovecOutsideTaskAddressRangeInNonemptyArray) { // This test depends on the maximum extent of a single readv() syscall, so // we can't tolerate interruption from saving. TEST(ReadvTestNoFixture, TruncatedAtMax_NoRandomSave) { - // Ensure that we won't be interrupted by ITIMER_PROF. + // Ensure that we won't be interrupted by ITIMER_PROF. This is particularly + // important in environments where automated profiling tools may start + // ITIMER_PROF automatically. struct itimerval itv = {}; auto const cleanup_itimer = ASSERT_NO_ERRNO_AND_VALUE(ScopedItimer(ITIMER_PROF, itv)); diff --git a/test/syscalls/linux/readv_common.cc b/test/syscalls/linux/readv_common.cc index 9658f7d42..2694dc64f 100644 --- a/test/syscalls/linux/readv_common.cc +++ b/test/syscalls/linux/readv_common.cc @@ -19,12 +19,53 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "test/syscalls/linux/file_base.h" #include "test/util/test_util.h" namespace gvisor { namespace testing { +// MatchesStringLength checks that a tuple argument of (struct iovec *, int) +// corresponding to an iovec array and its length, contains data that matches +// the string length strlen. +MATCHER_P(MatchesStringLength, strlen, "") { + struct iovec* iovs = arg.first; + int niov = arg.second; + int offset = 0; + for (int i = 0; i < niov; i++) { + offset += iovs[i].iov_len; + } + if (offset != static_cast<int>(strlen)) { + *result_listener << offset; + return false; + } + return true; +} + +// MatchesStringValue checks that a tuple argument of (struct iovec *, int) +// corresponding to an iovec array and its length, contains data that matches +// the string value str. +MATCHER_P(MatchesStringValue, str, "") { + struct iovec* iovs = arg.first; + int len = strlen(str); + int niov = arg.second; + int offset = 0; + for (int i = 0; i < niov; i++) { + struct iovec iov = iovs[i]; + if (len < offset) { + *result_listener << "strlen " << len << " < offset " << offset; + return false; + } + if (strncmp(static_cast<char*>(iov.iov_base), &str[offset], iov.iov_len)) { + absl::string_view iovec_string(static_cast<char*>(iov.iov_base), + iov.iov_len); + *result_listener << iovec_string << " @offset " << offset; + return false; + } + offset += iov.iov_len; + } + return true; +} + extern const char kReadvTestData[] = "127.0.0.1 localhost" "" @@ -113,7 +154,7 @@ void ReadBuffersOverlapping(int fd) { char* expected_ptr = expected.data(); memcpy(expected_ptr, &kReadvTestData[overlap_bytes], overlap_bytes); memcpy(&expected_ptr[overlap_bytes], &kReadvTestData[overlap_bytes], - kReadvTestDataSize); + kReadvTestDataSize - overlap_bytes); struct iovec iovs[2]; iovs[0].iov_base = buffer.data(); diff --git a/test/syscalls/linux/readv_socket.cc b/test/syscalls/linux/readv_socket.cc index 9b6972201..dd6fb7008 100644 --- a/test/syscalls/linux/readv_socket.cc +++ b/test/syscalls/linux/readv_socket.cc @@ -19,7 +19,6 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "test/syscalls/linux/file_base.h" #include "test/syscalls/linux/readv_common.h" #include "test/util/test_util.h" @@ -28,9 +27,30 @@ namespace testing { namespace { -class ReadvSocketTest : public SocketTest { +class ReadvSocketTest : public ::testing::Test { + public: void SetUp() override { - SocketTest::SetUp(); + test_unix_stream_socket_[0] = -1; + test_unix_stream_socket_[1] = -1; + test_unix_dgram_socket_[0] = -1; + test_unix_dgram_socket_[1] = -1; + test_unix_seqpacket_socket_[0] = -1; + test_unix_seqpacket_socket_[1] = -1; + + ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, test_unix_stream_socket_), + SyscallSucceeds()); + ASSERT_THAT(fcntl(test_unix_stream_socket_[0], F_SETFL, O_NONBLOCK), + SyscallSucceeds()); + ASSERT_THAT(socketpair(AF_UNIX, SOCK_DGRAM, 0, test_unix_dgram_socket_), + SyscallSucceeds()); + ASSERT_THAT(fcntl(test_unix_dgram_socket_[0], F_SETFL, O_NONBLOCK), + SyscallSucceeds()); + ASSERT_THAT( + socketpair(AF_UNIX, SOCK_SEQPACKET, 0, test_unix_seqpacket_socket_), + SyscallSucceeds()); + ASSERT_THAT(fcntl(test_unix_seqpacket_socket_[0], F_SETFL, O_NONBLOCK), + SyscallSucceeds()); + ASSERT_THAT( write(test_unix_stream_socket_[1], kReadvTestData, kReadvTestDataSize), SyscallSucceedsWithValue(kReadvTestDataSize)); @@ -40,11 +60,22 @@ class ReadvSocketTest : public SocketTest { ASSERT_THAT(write(test_unix_seqpacket_socket_[1], kReadvTestData, kReadvTestDataSize), SyscallSucceedsWithValue(kReadvTestDataSize)); - // FIXME(b/69821513): Enable when possible. - // ASSERT_THAT(write(test_tcp_socket_[1], kReadvTestData, - // kReadvTestDataSize), - // SyscallSucceedsWithValue(kReadvTestDataSize)); } + + void TearDown() override { + close(test_unix_stream_socket_[0]); + close(test_unix_stream_socket_[1]); + + close(test_unix_dgram_socket_[0]); + close(test_unix_dgram_socket_[1]); + + close(test_unix_seqpacket_socket_[0]); + close(test_unix_seqpacket_socket_[1]); + } + + int test_unix_stream_socket_[2]; + int test_unix_dgram_socket_[2]; + int test_unix_seqpacket_socket_[2]; }; TEST_F(ReadvSocketTest, ReadOneBufferPerByte_StreamSocket) { diff --git a/test/syscalls/linux/rename.cc b/test/syscalls/linux/rename.cc index 5b474ff32..833c0dc4f 100644 --- a/test/syscalls/linux/rename.cc +++ b/test/syscalls/linux/rename.cc @@ -14,6 +14,7 @@ #include <fcntl.h> #include <stdio.h> + #include <string> #include "gtest/gtest.h" diff --git a/test/syscalls/linux/rseq.cc b/test/syscalls/linux/rseq.cc new file mode 100644 index 000000000..4bfb1ff56 --- /dev/null +++ b/test/syscalls/linux/rseq.cc @@ -0,0 +1,198 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <signal.h> +#include <sys/syscall.h> +#include <sys/types.h> +#include <sys/wait.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "test/syscalls/linux/rseq/test.h" +#include "test/syscalls/linux/rseq/uapi.h" +#include "test/util/logging.h" +#include "test/util/multiprocess_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Syscall test for rseq (restartable sequences). +// +// We must be very careful about how these tests are written. Each thread may +// only have one struct rseq registration, which may be done automatically at +// thread start (as of 2019-11-13, glibc does *not* support rseq and thus does +// not do so, but other libraries do). +// +// Testing of rseq is thus done primarily in a child process with no +// registration. This means exec'ing a nostdlib binary, as rseq registration can +// only be cleared by execve (or knowing the old rseq address), and glibc (based +// on the current unmerged patches) register rseq before calling main()). + +int RSeq(struct rseq* rseq, uint32_t rseq_len, int flags, uint32_t sig) { + return syscall(kRseqSyscall, rseq, rseq_len, flags, sig); +} + +// Returns true if this kernel supports the rseq syscall. +PosixErrorOr<bool> RSeqSupported() { + // We have to be careful here, there are three possible cases: + // + // 1. rseq is not supported -> ENOSYS + // 2. rseq is supported and not registered -> success, but we should + // unregister. + // 3. rseq is supported and registered -> EINVAL (most likely). + + // The only validation done on new registrations is that rseq is aligned and + // writable. + rseq rseq = {}; + int ret = RSeq(&rseq, sizeof(rseq), 0, 0); + if (ret == 0) { + // Successfully registered, rseq is supported. Unregister. + ret = RSeq(&rseq, sizeof(rseq), kRseqFlagUnregister, 0); + if (ret != 0) { + return PosixError(errno); + } + return true; + } + + switch (errno) { + case ENOSYS: + // Not supported. + return false; + case EINVAL: + // Supported, but already registered. EINVAL returned because we provided + // a different address. + return true; + default: + // Unknown error. + return PosixError(errno); + } +} + +constexpr char kRseqBinary[] = "test/syscalls/linux/rseq/rseq"; + +void RunChildTest(std::string test_case, int want_status) { + std::string path = RunfilePath(kRseqBinary); + + pid_t child_pid = -1; + int execve_errno = 0; + auto cleanup = ASSERT_NO_ERRNO_AND_VALUE( + ForkAndExec(path, {path, test_case}, {}, &child_pid, &execve_errno)); + + ASSERT_GT(child_pid, 0); + ASSERT_EQ(execve_errno, 0); + + int status = 0; + ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); + ASSERT_EQ(status, want_status); +} + +// Test that rseq must be aligned. +TEST(RseqTest, Unaligned) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestUnaligned, 0); +} + +// Sanity test that registration works. +TEST(RseqTest, Register) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestRegister, 0); +} + +// Registration can't be done twice. +TEST(RseqTest, DoubleRegister) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestDoubleRegister, 0); +} + +// Registration can be done again after unregister. +TEST(RseqTest, RegisterUnregister) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestRegisterUnregister, 0); +} + +// The pointer to rseq must match on register/unregister. +TEST(RseqTest, UnregisterDifferentPtr) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestUnregisterDifferentPtr, 0); +} + +// The signature must match on register/unregister. +TEST(RseqTest, UnregisterDifferentSignature) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestUnregisterDifferentSignature, 0); +} + +// The CPU ID is initialized. +TEST(RseqTest, CPU) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestCPU, 0); +} + +// Critical section is eventually aborted. +TEST(RseqTest, Abort) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestAbort, 0); +} + +// Abort may be before the critical section. +TEST(RseqTest, AbortBefore) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestAbortBefore, 0); +} + +// Signature must match. +TEST(RseqTest, AbortSignature) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestAbortSignature, SIGSEGV); +} + +// Abort must not be in the critical section. +TEST(RseqTest, AbortPreCommit) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestAbortPreCommit, SIGSEGV); +} + +// rseq.rseq_cs is cleared on abort. +TEST(RseqTest, AbortClearsCS) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestAbortClearsCS, 0); +} + +// rseq.rseq_cs is cleared on abort outside of critical section. +TEST(RseqTest, InvalidAbortClearsCS) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported())); + + RunChildTest(kRseqTestInvalidAbortClearsCS, 0); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/rseq/BUILD b/test/syscalls/linux/rseq/BUILD new file mode 100644 index 000000000..853258b04 --- /dev/null +++ b/test/syscalls/linux/rseq/BUILD @@ -0,0 +1,61 @@ +# This package contains a standalone rseq test binary. This binary must not +# depend on libc, which might use rseq itself. + +load("//tools:defs.bzl", "cc_flags_supplier", "cc_library", "cc_toolchain", "select_arch") + +package(licenses = ["notice"]) + +genrule( + name = "rseq_binary", + srcs = [ + "critical.h", + "critical_amd64.S", + "critical_arm64.S", + "rseq.cc", + "syscalls.h", + "start_amd64.S", + "start_arm64.S", + "test.h", + "types.h", + "uapi.h", + ], + outs = ["rseq"], + cmd = "$(CC) " + + "$(CC_FLAGS) " + + "-I. " + + "-Wall " + + "-Werror " + + "-O2 " + + "-std=c++17 " + + "-static " + + "-nostdlib " + + "-ffreestanding " + + "-o " + + "$(location rseq) " + + select_arch( + amd64 = "$(location critical_amd64.S) $(location start_amd64.S) ", + arm64 = "$(location critical_arm64.S) $(location start_arm64.S) ", + no_match_error = "unsupported architecture", + ) + + "$(location rseq.cc)", + toolchains = [ + cc_toolchain, + ":no_pie_cc_flags", + ], + visibility = ["//:sandbox"], +) + +cc_flags_supplier( + name = "no_pie_cc_flags", + features = ["-pie"], +) + +cc_library( + name = "lib", + testonly = 1, + hdrs = [ + "test.h", + "uapi.h", + ], + visibility = ["//:sandbox"], +) diff --git a/test/syscalls/linux/rseq/critical.h b/test/syscalls/linux/rseq/critical.h new file mode 100644 index 000000000..ac987a25e --- /dev/null +++ b/test/syscalls/linux/rseq/critical.h @@ -0,0 +1,39 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_CRITICAL_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_CRITICAL_H_ + +#include "test/syscalls/linux/rseq/types.h" +#include "test/syscalls/linux/rseq/uapi.h" + +constexpr uint32_t kRseqSignature = 0x90909090; + +extern "C" { + +extern void rseq_loop(struct rseq* r, struct rseq_cs* cs); +extern void* rseq_loop_early_abort; +extern void* rseq_loop_start; +extern void* rseq_loop_pre_commit; +extern void* rseq_loop_post_commit; +extern void* rseq_loop_abort; + +extern int rseq_getpid(struct rseq* r, struct rseq_cs* cs); +extern void* rseq_getpid_start; +extern void* rseq_getpid_post_commit; +extern void* rseq_getpid_abort; + +} // extern "C" + +#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_CRITICAL_H_ diff --git a/test/syscalls/linux/rseq/critical_amd64.S b/test/syscalls/linux/rseq/critical_amd64.S new file mode 100644 index 000000000..8c0687e6d --- /dev/null +++ b/test/syscalls/linux/rseq/critical_amd64.S @@ -0,0 +1,66 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Restartable sequences critical sections. + +// Loops continuously until aborted. +// +// void rseq_loop(struct rseq* r, struct rseq_cs* cs) + + .text + .globl rseq_loop + .type rseq_loop, @function + +rseq_loop: + jmp begin + + // Abort block before the critical section. + // Abort signature is 4 nops for simplicity. + .byte 0x90, 0x90, 0x90, 0x90 + .globl rseq_loop_early_abort +rseq_loop_early_abort: + ret + +begin: + // r->rseq_cs = cs + movq %rsi, 8(%rdi) + + // N.B. rseq_cs will be cleared by any preempt, even outside the critical + // section. Thus it must be set in or immediately before the critical section + // to ensure it is not cleared before the section begins. + .globl rseq_loop_start +rseq_loop_start: + jmp rseq_loop_start + + // "Pre-commit": extra instructions inside the critical section. These are + // used as the abort point in TestAbortPreCommit, which is not valid. + .globl rseq_loop_pre_commit +rseq_loop_pre_commit: + // Extra abort signature + nop for TestAbortPostCommit. + .byte 0x90, 0x90, 0x90, 0x90 + nop + + // "Post-commit": never reached in this case. + .globl rseq_loop_post_commit +rseq_loop_post_commit: + + // Abort signature is 4 nops for simplicity. + .byte 0x90, 0x90, 0x90, 0x90 + + .globl rseq_loop_abort +rseq_loop_abort: + ret + + .size rseq_loop,.-rseq_loop + .section .note.GNU-stack,"",@progbits diff --git a/test/syscalls/linux/rseq/critical_arm64.S b/test/syscalls/linux/rseq/critical_arm64.S new file mode 100644 index 000000000..bfe7e8307 --- /dev/null +++ b/test/syscalls/linux/rseq/critical_arm64.S @@ -0,0 +1,66 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Restartable sequences critical sections. + +// Loops continuously until aborted. +// +// void rseq_loop(struct rseq* r, struct rseq_cs* cs) + + .text + .globl rseq_loop + .type rseq_loop, @function + +rseq_loop: + b begin + + // Abort block before the critical section. + // Abort signature. + .byte 0x90, 0x90, 0x90, 0x90 + .globl rseq_loop_early_abort +rseq_loop_early_abort: + ret + +begin: + // r->rseq_cs = cs + str x1, [x0, #8] + + // N.B. rseq_cs will be cleared by any preempt, even outside the critical + // section. Thus it must be set in or immediately before the critical section + // to ensure it is not cleared before the section begins. + .globl rseq_loop_start +rseq_loop_start: + b rseq_loop_start + + // "Pre-commit": extra instructions inside the critical section. These are + // used as the abort point in TestAbortPreCommit, which is not valid. + .globl rseq_loop_pre_commit +rseq_loop_pre_commit: + // Extra abort signature + nop for TestAbortPostCommit. + .byte 0x90, 0x90, 0x90, 0x90 + nop + + // "Post-commit": never reached in this case. + .globl rseq_loop_post_commit +rseq_loop_post_commit: + + // Abort signature. + .byte 0x90, 0x90, 0x90, 0x90 + + .globl rseq_loop_abort +rseq_loop_abort: + ret + + .size rseq_loop,.-rseq_loop + .section .note.GNU-stack,"",@progbits diff --git a/test/syscalls/linux/rseq/rseq.cc b/test/syscalls/linux/rseq/rseq.cc new file mode 100644 index 000000000..f036db26d --- /dev/null +++ b/test/syscalls/linux/rseq/rseq.cc @@ -0,0 +1,366 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/rseq/critical.h" +#include "test/syscalls/linux/rseq/syscalls.h" +#include "test/syscalls/linux/rseq/test.h" +#include "test/syscalls/linux/rseq/types.h" +#include "test/syscalls/linux/rseq/uapi.h" + +namespace gvisor { +namespace testing { + +extern "C" int main(int argc, char** argv, char** envp); + +// Standalone initialization before calling main(). +extern "C" void __init(uintptr_t* sp) { + int argc = sp[0]; + char** argv = reinterpret_cast<char**>(&sp[1]); + char** envp = &argv[argc + 1]; + + // Call main() and exit. + sys_exit_group(main(argc, argv, envp)); + + // sys_exit_group does not return +} + +int strcmp(const char* s1, const char* s2) { + const unsigned char* p1 = reinterpret_cast<const unsigned char*>(s1); + const unsigned char* p2 = reinterpret_cast<const unsigned char*>(s2); + + while (*p1 == *p2) { + if (!*p1) { + return 0; + } + ++p1; + ++p2; + } + return static_cast<int>(*p1) - static_cast<int>(*p2); +} + +int sys_rseq(struct rseq* rseq, uint32_t rseq_len, int flags, uint32_t sig) { + return raw_syscall(kRseqSyscall, rseq, rseq_len, flags, sig); +} + +// Test that rseq must be aligned. +int TestUnaligned() { + constexpr uintptr_t kRequiredAlignment = alignof(rseq); + + char buf[2 * kRequiredAlignment] = {}; + uintptr_t ptr = reinterpret_cast<uintptr_t>(&buf[0]); + if ((ptr & (kRequiredAlignment - 1)) == 0) { + // buf is already aligned. Misalign it. + ptr++; + } + + int ret = sys_rseq(reinterpret_cast<rseq*>(ptr), sizeof(rseq), 0, 0); + if (sys_errno(ret) != EINVAL) { + return 1; + } + return 0; +} + +// Sanity test that registration works. +int TestRegister() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + return 1; + } + return 0; +}; + +// Registration can't be done twice. +int TestDoubleRegister() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + return 1; + } + + if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != EBUSY) { + return 1; + } + + return 0; +}; + +// Registration can be done again after unregister. +int TestRegisterUnregister() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + return 1; + } + + if (int ret = sys_rseq(&r, sizeof(r), kRseqFlagUnregister, 0); + sys_errno(ret) != 0) { + return 1; + } + + if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + return 1; + } + + return 0; +}; + +// The pointer to rseq must match on register/unregister. +int TestUnregisterDifferentPtr() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + return 1; + } + + struct rseq r2 = {}; + if (int ret = sys_rseq(&r2, sizeof(r2), kRseqFlagUnregister, 0); + sys_errno(ret) != EINVAL) { + return 1; + } + + return 0; +}; + +// The signature must match on register/unregister. +int TestUnregisterDifferentSignature() { + constexpr int kSignature = 0; + + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, kSignature); sys_errno(ret) != 0) { + return 1; + } + + if (int ret = sys_rseq(&r, sizeof(r), kRseqFlagUnregister, kSignature + 1); + sys_errno(ret) != EPERM) { + return 1; + } + + return 0; +}; + +// The CPU ID is initialized. +int TestCPU() { + struct rseq r = {}; + r.cpu_id = kRseqCPUIDUninitialized; + + if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + return 1; + } + + if (__atomic_load_n(&r.cpu_id, __ATOMIC_RELAXED) < 0) { + return 1; + } + if (__atomic_load_n(&r.cpu_id_start, __ATOMIC_RELAXED) < 0) { + return 1; + } + + return 0; +}; + +// Critical section is eventually aborted. +int TestAbort() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); + sys_errno(ret) != 0) { + return 1; + } + + struct rseq_cs cs = {}; + cs.version = 0; + cs.flags = 0; + cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) - + reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_abort); + + // Loops until abort. If this returns then abort occurred. + rseq_loop(&r, &cs); + + return 0; +}; + +// Abort may be before the critical section. +int TestAbortBefore() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); + sys_errno(ret) != 0) { + return 1; + } + + struct rseq_cs cs = {}; + cs.version = 0; + cs.flags = 0; + cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) - + reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_early_abort); + + // Loops until abort. If this returns then abort occurred. + rseq_loop(&r, &cs); + + return 0; +}; + +// Signature must match. +int TestAbortSignature() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature + 1); + sys_errno(ret) != 0) { + return 1; + } + + struct rseq_cs cs = {}; + cs.version = 0; + cs.flags = 0; + cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) - + reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_abort); + + // Loops until abort. This should SIGSEGV on abort. + rseq_loop(&r, &cs); + + return 1; +}; + +// Abort must not be in the critical section. +int TestAbortPreCommit() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature + 1); + sys_errno(ret) != 0) { + return 1; + } + + struct rseq_cs cs = {}; + cs.version = 0; + cs.flags = 0; + cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) - + reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_pre_commit); + + // Loops until abort. This should SIGSEGV on abort. + rseq_loop(&r, &cs); + + return 1; +}; + +// rseq.rseq_cs is cleared on abort. +int TestAbortClearsCS() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); + sys_errno(ret) != 0) { + return 1; + } + + struct rseq_cs cs = {}; + cs.version = 0; + cs.flags = 0; + cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) - + reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_abort); + + // Loops until abort. If this returns then abort occurred. + rseq_loop(&r, &cs); + + if (__atomic_load_n(&r.rseq_cs, __ATOMIC_RELAXED)) { + return 1; + } + + return 0; +}; + +// rseq.rseq_cs is cleared on abort outside of critical section. +int TestInvalidAbortClearsCS() { + struct rseq r = {}; + if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); + sys_errno(ret) != 0) { + return 1; + } + + struct rseq_cs cs = {}; + cs.version = 0; + cs.flags = 0; + cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) - + reinterpret_cast<uint64_t>(&rseq_loop_start); + cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_abort); + + __atomic_store_n(&r.rseq_cs, &cs, __ATOMIC_RELAXED); + + // When the next abort condition occurs, the kernel will clear cs once it + // determines we aren't in the critical section. + while (1) { + if (!__atomic_load_n(&r.rseq_cs, __ATOMIC_RELAXED)) { + break; + } + } + + return 0; +}; + +// Exit codes: +// 0 - Pass +// 1 - Fail +// 2 - Missing argument +// 3 - Unknown test case +extern "C" int main(int argc, char** argv, char** envp) { + if (argc != 2) { + // Usage: rseq <test case> + return 2; + } + + if (strcmp(argv[1], kRseqTestUnaligned) == 0) { + return TestUnaligned(); + } + if (strcmp(argv[1], kRseqTestRegister) == 0) { + return TestRegister(); + } + if (strcmp(argv[1], kRseqTestDoubleRegister) == 0) { + return TestDoubleRegister(); + } + if (strcmp(argv[1], kRseqTestRegisterUnregister) == 0) { + return TestRegisterUnregister(); + } + if (strcmp(argv[1], kRseqTestUnregisterDifferentPtr) == 0) { + return TestUnregisterDifferentPtr(); + } + if (strcmp(argv[1], kRseqTestUnregisterDifferentSignature) == 0) { + return TestUnregisterDifferentSignature(); + } + if (strcmp(argv[1], kRseqTestCPU) == 0) { + return TestCPU(); + } + if (strcmp(argv[1], kRseqTestAbort) == 0) { + return TestAbort(); + } + if (strcmp(argv[1], kRseqTestAbortBefore) == 0) { + return TestAbortBefore(); + } + if (strcmp(argv[1], kRseqTestAbortSignature) == 0) { + return TestAbortSignature(); + } + if (strcmp(argv[1], kRseqTestAbortPreCommit) == 0) { + return TestAbortPreCommit(); + } + if (strcmp(argv[1], kRseqTestAbortClearsCS) == 0) { + return TestAbortClearsCS(); + } + if (strcmp(argv[1], kRseqTestInvalidAbortClearsCS) == 0) { + return TestInvalidAbortClearsCS(); + } + + return 3; +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/rseq/start_amd64.S b/test/syscalls/linux/rseq/start_amd64.S new file mode 100644 index 000000000..b9611b276 --- /dev/null +++ b/test/syscalls/linux/rseq/start_amd64.S @@ -0,0 +1,45 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + + .text + .align 4 + .type _start,@function + .globl _start + +_start: + movq %rsp,%rdi + call __init + hlt + + .size _start,.-_start + .section .note.GNU-stack,"",@progbits + + .text + .globl raw_syscall + .type raw_syscall, @function + +raw_syscall: + mov %rdi,%rax // syscall # + mov %rsi,%rdi // arg0 + mov %rdx,%rsi // arg1 + mov %rcx,%rdx // arg2 + mov %r8,%r10 // arg3 (goes in r10 instead of rcx for system calls) + mov %r9,%r8 // arg4 + mov 0x8(%rsp),%r9 // arg5 + syscall + ret + + .size raw_syscall,.-raw_syscall + .section .note.GNU-stack,"",@progbits diff --git a/test/syscalls/linux/rseq/start_arm64.S b/test/syscalls/linux/rseq/start_arm64.S new file mode 100644 index 000000000..693c1c6eb --- /dev/null +++ b/test/syscalls/linux/rseq/start_arm64.S @@ -0,0 +1,45 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + + .text + .align 4 + .type _start,@function + .globl _start + +_start: + mov x29, sp + bl __init + wfi + + .size _start,.-_start + .section .note.GNU-stack,"",@progbits + + .text + .globl raw_syscall + .type raw_syscall, @function + +raw_syscall: + mov x8,x0 // syscall # + mov x0,x1 // arg0 + mov x1,x2 // arg1 + mov x2,x3 // arg2 + mov x3,x4 // arg3 + mov x4,x5 // arg4 + mov x5,x6 // arg5 + svc #0 + ret + + .size raw_syscall,.-raw_syscall + .section .note.GNU-stack,"",@progbits diff --git a/test/syscalls/linux/rseq/syscalls.h b/test/syscalls/linux/rseq/syscalls.h new file mode 100644 index 000000000..c4118e6c5 --- /dev/null +++ b/test/syscalls/linux/rseq/syscalls.h @@ -0,0 +1,69 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_SYSCALLS_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_SYSCALLS_H_ + +#include "test/syscalls/linux/rseq/types.h" + +// Syscall numbers. +#if defined(__x86_64__) +constexpr int kGetpid = 39; +constexpr int kExitGroup = 231; +#elif defined(__aarch64__) +constexpr int kGetpid = 172; +constexpr int kExitGroup = 94; +#else +#error "Unknown architecture" +#endif + +namespace gvisor { +namespace testing { + +// Standalone system call interfaces. +// Note that these are all "raw" system call interfaces which encode +// errors by setting the return value to a small negative number. +// Use sys_errno() to check system call return values for errors. + +// Maximum Linux error number. +constexpr int kMaxErrno = 4095; + +// Errno values. +#define EPERM 1 +#define EFAULT 14 +#define EBUSY 16 +#define EINVAL 22 + +// Get the error number from a raw system call return value. +// Returns a positive error number or 0 if there was no error. +static inline int sys_errno(uintptr_t rval) { + if (rval >= static_cast<uintptr_t>(-kMaxErrno)) { + return -static_cast<int>(rval); + } + return 0; +} + +extern "C" uintptr_t raw_syscall(int number, ...); + +static inline void sys_exit_group(int status) { + raw_syscall(kExitGroup, status); +} +static inline int sys_getpid() { + return static_cast<int>(raw_syscall(kGetpid)); +} + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_SYSCALLS_H_ diff --git a/test/syscalls/linux/rseq/test.h b/test/syscalls/linux/rseq/test.h new file mode 100644 index 000000000..3b7bb74b1 --- /dev/null +++ b/test/syscalls/linux/rseq/test.h @@ -0,0 +1,43 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TEST_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TEST_H_ + +namespace gvisor { +namespace testing { + +// Test cases supported by rseq binary. + +inline constexpr char kRseqTestUnaligned[] = "unaligned"; +inline constexpr char kRseqTestRegister[] = "register"; +inline constexpr char kRseqTestDoubleRegister[] = "double-register"; +inline constexpr char kRseqTestRegisterUnregister[] = "register-unregister"; +inline constexpr char kRseqTestUnregisterDifferentPtr[] = + "unregister-different-ptr"; +inline constexpr char kRseqTestUnregisterDifferentSignature[] = + "unregister-different-signature"; +inline constexpr char kRseqTestCPU[] = "cpu"; +inline constexpr char kRseqTestAbort[] = "abort"; +inline constexpr char kRseqTestAbortBefore[] = "abort-before"; +inline constexpr char kRseqTestAbortSignature[] = "abort-signature"; +inline constexpr char kRseqTestAbortPreCommit[] = "abort-precommit"; +inline constexpr char kRseqTestAbortClearsCS[] = "abort-clears-cs"; +inline constexpr char kRseqTestInvalidAbortClearsCS[] = + "invalid-abort-clears-cs"; + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TEST_H_ diff --git a/test/syscalls/linux/rseq/types.h b/test/syscalls/linux/rseq/types.h new file mode 100644 index 000000000..b6afe9817 --- /dev/null +++ b/test/syscalls/linux/rseq/types.h @@ -0,0 +1,31 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TYPES_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TYPES_H_ + +using size_t = __SIZE_TYPE__; +using uintptr_t = __UINTPTR_TYPE__; + +using uint8_t = __UINT8_TYPE__; +using uint16_t = __UINT16_TYPE__; +using uint32_t = __UINT32_TYPE__; +using uint64_t = __UINT64_TYPE__; + +using int8_t = __INT8_TYPE__; +using int16_t = __INT16_TYPE__; +using int32_t = __INT32_TYPE__; +using int64_t = __INT64_TYPE__; + +#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TYPES_H_ diff --git a/test/syscalls/linux/rseq/uapi.h b/test/syscalls/linux/rseq/uapi.h new file mode 100644 index 000000000..d3e60d0a4 --- /dev/null +++ b/test/syscalls/linux/rseq/uapi.h @@ -0,0 +1,51 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_ + +#include <stdint.h> + +// User-kernel ABI for restartable sequences. + +// Syscall numbers. +#if defined(__x86_64__) +constexpr int kRseqSyscall = 334; +#elif defined(__aarch64__) +constexpr int kRseqSyscall = 293; +#else +#error "Unknown architecture" +#endif // __x86_64__ + +struct rseq_cs { + uint32_t version; + uint32_t flags; + uint64_t start_ip; + uint64_t post_commit_offset; + uint64_t abort_ip; +} __attribute__((aligned(4 * sizeof(uint64_t)))); + +// N.B. alignment is enforced by the kernel. +struct rseq { + uint32_t cpu_id_start; + uint32_t cpu_id; + struct rseq_cs* rseq_cs; + uint32_t flags; +} __attribute__((aligned(4 * sizeof(uint64_t)))); + +constexpr int kRseqFlagUnregister = 1 << 0; + +constexpr int kRseqCPUIDUninitialized = -1; + +#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_ diff --git a/test/syscalls/linux/rtsignal.cc b/test/syscalls/linux/rtsignal.cc index 81d193ffd..ed27e2566 100644 --- a/test/syscalls/linux/rtsignal.cc +++ b/test/syscalls/linux/rtsignal.cc @@ -167,6 +167,5 @@ int main(int argc, char** argv) { TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); gvisor::testing::TestInit(&argc, &argv); - - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/seccomp.cc b/test/syscalls/linux/seccomp.cc index e77586852..ce88d90dd 100644 --- a/test/syscalls/linux/seccomp.cc +++ b/test/syscalls/linux/seccomp.cc @@ -25,6 +25,7 @@ #include <time.h> #include <ucontext.h> #include <unistd.h> + #include <atomic> #include "gmock/gmock.h" @@ -48,7 +49,12 @@ namespace testing { namespace { // A syscall not implemented by Linux that we don't expect to be called. +#ifdef __x86_64__ constexpr uint32_t kFilteredSyscall = SYS_vserver; +#elif __aarch64__ +// Use the last of arch_specific_syscalls which are not implemented on arm64. +constexpr uint32_t kFilteredSyscall = __NR_arch_specific_syscall + 15; +#endif // Applies a seccomp-bpf filter that returns `filtered_result` for // `sysno` and allows all other syscalls. Async-signal-safe. @@ -64,20 +70,27 @@ void ApplySeccompFilter(uint32_t sysno, uint32_t filtered_result, MaybeSave(); struct sock_filter filter[] = { - // A = seccomp_data.arch - BPF_STMT(BPF_LD | BPF_ABS | BPF_W, 4), - // if (A != AUDIT_ARCH_X86_64) goto kill - BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, AUDIT_ARCH_X86_64, 0, 4), - // A = seccomp_data.nr - BPF_STMT(BPF_LD | BPF_ABS | BPF_W, 0), - // if (A != sysno) goto allow - BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, sysno, 0, 1), - // return filtered_result - BPF_STMT(BPF_RET | BPF_K, filtered_result), - // allow: return SECCOMP_RET_ALLOW - BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_ALLOW), - // kill: return SECCOMP_RET_KILL - BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_KILL), + // A = seccomp_data.arch + BPF_STMT(BPF_LD | BPF_ABS | BPF_W, 4), +#if defined(__x86_64__) + // if (A != AUDIT_ARCH_X86_64) goto kill + BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, AUDIT_ARCH_X86_64, 0, 4), +#elif defined(__aarch64__) + // if (A != AUDIT_ARCH_AARCH64) goto kill + BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, AUDIT_ARCH_AARCH64, 0, 4), +#else +#error "Unknown architecture" +#endif + // A = seccomp_data.nr + BPF_STMT(BPF_LD | BPF_ABS | BPF_W, 0), + // if (A != sysno) goto allow + BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, sysno, 0, 1), + // return filtered_result + BPF_STMT(BPF_RET | BPF_K, filtered_result), + // allow: return SECCOMP_RET_ALLOW + BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_ALLOW), + // kill: return SECCOMP_RET_KILL + BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_KILL), }; struct sock_fprog prog; prog.len = ABSL_ARRAYSIZE(filter); @@ -112,7 +125,8 @@ TEST(SeccompTest, RetKillCausesDeathBySIGSYS) { pid_t const pid = fork(); if (pid == 0) { // Register a signal handler for SIGSYS that we don't expect to be invoked. - RegisterSignalHandler(SIGSYS, +[](int, siginfo_t*, void*) { _exit(1); }); + RegisterSignalHandler( + SIGSYS, +[](int, siginfo_t*, void*) { _exit(1); }); ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_KILL); syscall(kFilteredSyscall); TEST_CHECK_MSG(false, "Survived invocation of test syscall"); @@ -131,7 +145,8 @@ TEST(SeccompTest, RetKillOnlyKillsOneThread) { pid_t const pid = fork(); if (pid == 0) { // Register a signal handler for SIGSYS that we don't expect to be invoked. - RegisterSignalHandler(SIGSYS, +[](int, siginfo_t*, void*) { _exit(1); }); + RegisterSignalHandler( + SIGSYS, +[](int, siginfo_t*, void*) { _exit(1); }); ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_KILL); // Pass CLONE_VFORK to block the original thread in the child process until // the clone thread exits with SIGSYS. @@ -171,9 +186,12 @@ TEST(SeccompTest, RetTrapCausesSIGSYS) { TEST_CHECK(info->si_errno == kTrapValue); TEST_CHECK(info->si_call_addr != nullptr); TEST_CHECK(info->si_syscall == kFilteredSyscall); -#ifdef __x86_64__ +#if defined(__x86_64__) TEST_CHECK(info->si_arch == AUDIT_ARCH_X86_64); TEST_CHECK(uc->uc_mcontext.gregs[REG_RAX] == kFilteredSyscall); +#elif defined(__aarch64__) + TEST_CHECK(info->si_arch == AUDIT_ARCH_AARCH64); + TEST_CHECK(uc->uc_mcontext.regs[8] == kFilteredSyscall); #endif // defined(__x86_64__) _exit(0); }); @@ -345,7 +363,8 @@ TEST(SeccompTest, LeastPermissiveFilterReturnValueApplies) { // one that causes the kill that should be ignored. pid_t const pid = fork(); if (pid == 0) { - RegisterSignalHandler(SIGSYS, +[](int, siginfo_t*, void*) { _exit(1); }); + RegisterSignalHandler( + SIGSYS, +[](int, siginfo_t*, void*) { _exit(1); }); ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_TRACE); ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_KILL); ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_ERRNO | ENOTNAM); @@ -402,5 +421,5 @@ int main(int argc, char** argv) { } gvisor::testing::TestInit(&argc, &argv); - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/select.cc b/test/syscalls/linux/select.cc index e06a2666d..be2364fb8 100644 --- a/test/syscalls/linux/select.cc +++ b/test/syscalls/linux/select.cc @@ -16,6 +16,7 @@ #include <sys/resource.h> #include <sys/select.h> #include <sys/time.h> + #include <climits> #include <csignal> #include <cstdio> @@ -145,7 +146,7 @@ TEST_F(SelectTest, IgnoreBitsAboveNfds) { // This test illustrates Linux's behavior of 'select' calls passing after // setrlimit RLIMIT_NOFILE is called. In particular, versions of sshd rely on -// this behavior. +// this behavior. See b/122318458. TEST_F(SelectTest, SetrlimitCallNOFILE) { fd_set read_set; FD_ZERO(&read_set); diff --git a/test/syscalls/linux/semaphore.cc b/test/syscalls/linux/semaphore.cc index 40c57f543..e9b131ca9 100644 --- a/test/syscalls/linux/semaphore.cc +++ b/test/syscalls/linux/semaphore.cc @@ -447,9 +447,8 @@ TEST(SemaphoreTest, SemCtlGetPidFork) { const pid_t child_pid = fork(); if (child_pid == 0) { - ASSERT_THAT(semctl(sem.get(), 0, SETVAL, 1), SyscallSucceeds()); - ASSERT_THAT(semctl(sem.get(), 0, GETPID), - SyscallSucceedsWithValue(getpid())); + TEST_PCHECK(semctl(sem.get(), 0, SETVAL, 1) == 0); + TEST_PCHECK(semctl(sem.get(), 0, GETPID) == getpid()); _exit(0); } diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc index 580ab5193..64123e904 100644 --- a/test/syscalls/linux/sendfile.cc +++ b/test/syscalls/linux/sendfile.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <fcntl.h> +#include <linux/unistd.h> #include <sys/eventfd.h> #include <sys/sendfile.h> #include <unistd.h> @@ -70,6 +71,28 @@ TEST(SendFileTest, InvalidOffset) { SyscallFailsWithErrno(EINVAL)); } +int memfd_create(const std::string& name, unsigned int flags) { + return syscall(__NR_memfd_create, name.c_str(), flags); +} + +TEST(SendFileTest, Overflow) { + // Create input file. + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor inf = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); + + // Open the output file. + int fd; + EXPECT_THAT(fd = memfd_create("overflow", 0), SyscallSucceeds()); + const FileDescriptor outf(fd); + + // out_offset + kSize overflows INT64_MAX. + loff_t out_offset = 0x7ffffffffffffffeull; + constexpr int kSize = 3; + EXPECT_THAT(sendfile(outf.get(), inf.get(), &out_offset, kSize), + SyscallFailsWithErrno(EINVAL)); +} + TEST(SendFileTest, SendTrivially) { // Create temp files. constexpr char kData[] = "To be, or not to be, that is the question:"; @@ -530,6 +553,34 @@ TEST(SendFileTest, SendToSpecialFile) { SyscallSucceedsWithValue(kSize & (~7))); } +TEST(SendFileTest, SendFileToPipe) { + // Create temp file. + constexpr char kData[] = "<insert-quote-here>"; + constexpr int kDataSize = sizeof(kData) - 1; + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); + const FileDescriptor inf = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); + + // Create a pipe for sending to a pipe. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + // Expect to read up to the given size. + std::vector<char> buf(kDataSize); + ScopedThread t([&]() { + absl::SleepFor(absl::Milliseconds(100)); + ASSERT_THAT(read(rfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(kDataSize)); + }); + + // Send with twice the size of the file, which should hit EOF. + EXPECT_THAT(sendfile(wfd.get(), inf.get(), nullptr, kDataSize * 2), + SyscallSucceedsWithValue(kDataSize)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/sendfile_socket.cc b/test/syscalls/linux/sendfile_socket.cc index 3331288b7..c101fe9d2 100644 --- a/test/syscalls/linux/sendfile_socket.cc +++ b/test/syscalls/linux/sendfile_socket.cc @@ -23,6 +23,7 @@ #include "gtest/gtest.h" #include "absl/strings/string_view.h" +#include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/file_descriptor.h" #include "test/util/temp_path.h" @@ -35,61 +36,39 @@ namespace { class SendFileTest : public ::testing::TestWithParam<int> { protected: - PosixErrorOr<std::tuple<int, int>> Sockets() { + PosixErrorOr<std::unique_ptr<SocketPair>> Sockets(int type) { // Bind a server socket. int family = GetParam(); - struct sockaddr server_addr = {}; switch (family) { case AF_INET: { - struct sockaddr_in *server_addr_in = - reinterpret_cast<struct sockaddr_in *>(&server_addr); - server_addr_in->sin_family = family; - server_addr_in->sin_addr.s_addr = INADDR_ANY; - break; + if (type == SOCK_STREAM) { + return SocketPairKind{ + "TCP", AF_INET, type, 0, + TCPAcceptBindSocketPairCreator(AF_INET, type, 0, false)} + .Create(); + } else { + return SocketPairKind{ + "UDP", AF_INET, type, 0, + UDPBidirectionalBindSocketPairCreator(AF_INET, type, 0, false)} + .Create(); + } } case AF_UNIX: { - struct sockaddr_un *server_addr_un = - reinterpret_cast<struct sockaddr_un *>(&server_addr); - server_addr_un->sun_family = family; - server_addr_un->sun_path[0] = '\0'; - break; + if (type == SOCK_STREAM) { + return SocketPairKind{ + "UNIX", AF_UNIX, type, 0, + FilesystemAcceptBindSocketPairCreator(AF_UNIX, type, 0)} + .Create(); + } else { + return SocketPairKind{ + "UNIX", AF_UNIX, type, 0, + FilesystemBidirectionalBindSocketPairCreator(AF_UNIX, type, 0)} + .Create(); + } } default: return PosixError(EINVAL); } - int server = socket(family, SOCK_STREAM, 0); - if (bind(server, &server_addr, sizeof(server_addr)) < 0) { - return PosixError(errno); - } - if (listen(server, 1) < 0) { - close(server); - return PosixError(errno); - } - - // Fetch the address; both are anonymous. - socklen_t length = sizeof(server_addr); - if (getsockname(server, &server_addr, &length) < 0) { - close(server); - return PosixError(errno); - } - - // Connect the client. - int client = socket(family, SOCK_STREAM, 0); - if (connect(client, &server_addr, length) < 0) { - close(server); - close(client); - return PosixError(errno); - } - - // Accept on the server. - int server_client = accept(server, nullptr, 0); - if (server_client < 0) { - close(server); - close(client); - return PosixError(errno); - } - close(server); - return std::make_tuple(client, server_client); } }; @@ -106,9 +85,7 @@ TEST_P(SendFileTest, SendMultiple) { const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); // Create sockets. - std::tuple<int, int> fds = ASSERT_NO_ERRNO_AND_VALUE(Sockets()); - const FileDescriptor server(std::get<0>(fds)); - FileDescriptor client(std::get<1>(fds)); // non-const, reset is used. + auto socks = ASSERT_NO_ERRNO_AND_VALUE(Sockets(SOCK_STREAM)); // Thread that reads data from socket and dumps to a file. ScopedThread th([&] { @@ -118,7 +95,7 @@ TEST_P(SendFileTest, SendMultiple) { // Read until socket is closed. char buf[10240]; for (int cnt = 0;; cnt++) { - int r = RetryEINTR(read)(server.get(), buf, sizeof(buf)); + int r = RetryEINTR(read)(socks->first_fd(), buf, sizeof(buf)); // We cannot afford to save on every read() call. if (cnt % 1000 == 0) { ASSERT_THAT(r, SyscallSucceeds()); @@ -149,10 +126,10 @@ TEST_P(SendFileTest, SendMultiple) { for (size_t sent = 0; sent < data.size(); cnt++) { const size_t remain = data.size() - sent; std::cout << "sendfile, size=" << data.size() << ", sent=" << sent - << ", remain=" << remain; + << ", remain=" << remain << std::endl; // Send data and verify that sendfile returns the correct value. - int res = sendfile(client.get(), inf.get(), nullptr, remain); + int res = sendfile(socks->second_fd(), inf.get(), nullptr, remain); // We cannot afford to save on every sendfile() call. if (cnt % 120 == 0) { MaybeSave(); @@ -169,7 +146,7 @@ TEST_P(SendFileTest, SendMultiple) { } // Close socket to stop thread. - client.reset(); + close(socks->release_second_fd()); th.Join(); // Verify that the output file has the correct data. @@ -183,9 +160,7 @@ TEST_P(SendFileTest, SendMultiple) { TEST_P(SendFileTest, Shutdown) { // Create a socket. - std::tuple<int, int> fds = ASSERT_NO_ERRNO_AND_VALUE(Sockets()); - const FileDescriptor client(std::get<0>(fds)); - FileDescriptor server(std::get<1>(fds)); // non-const, reset below. + auto socks = ASSERT_NO_ERRNO_AND_VALUE(Sockets(SOCK_STREAM)); // If this is a TCP socket, then turn off linger. if (GetParam() == AF_INET) { @@ -193,7 +168,7 @@ TEST_P(SendFileTest, Shutdown) { sl.l_onoff = 1; sl.l_linger = 0; ASSERT_THAT( - setsockopt(server.get(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + setsockopt(socks->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), SyscallSucceeds()); } @@ -212,12 +187,12 @@ TEST_P(SendFileTest, Shutdown) { ScopedThread t([&]() { size_t done = 0; while (done < data.size()) { - int n = RetryEINTR(read)(server.get(), data.data(), data.size()); + int n = RetryEINTR(read)(socks->first_fd(), data.data(), data.size()); ASSERT_THAT(n, SyscallSucceeds()); done += n; } // Close the server side socket. - server.reset(); + close(socks->release_first_fd()); }); // Continuously stream from the file to the socket. Note we do not assert @@ -225,7 +200,7 @@ TEST_P(SendFileTest, Shutdown) { // data is written. Eventually, we should get a connection reset error. while (1) { off_t offset = 0; // Always read from the start. - int n = sendfile(client.get(), inf.get(), &offset, data.size()); + int n = sendfile(socks->second_fd(), inf.get(), &offset, data.size()); EXPECT_THAT(n, AnyOf(SyscallFailsWithErrno(ECONNRESET), SyscallFailsWithErrno(EPIPE), SyscallSucceeds())); if (n <= 0) { @@ -234,6 +209,20 @@ TEST_P(SendFileTest, Shutdown) { } } +TEST_P(SendFileTest, SendpageFromEmptyFileToUDP) { + auto socks = ASSERT_NO_ERRNO_AND_VALUE(Sockets(SOCK_DGRAM)); + + TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); + + // The value to the count argument has to be so that it is impossible to + // allocate a buffer of this size. In Linux, sendfile transfer at most + // 0x7ffff000 (MAX_RW_COUNT) bytes. + EXPECT_THAT(sendfile(socks->first_fd(), fd.get(), 0x0, 0x8000000000004), + SyscallSucceedsWithValue(0)); +} + INSTANTIATE_TEST_SUITE_P(AddressFamily, SendFileTest, ::testing::Values(AF_UNIX, AF_INET)); diff --git a/test/syscalls/linux/shm.cc b/test/syscalls/linux/shm.cc index eb7a3966f..c7fdbb924 100644 --- a/test/syscalls/linux/shm.cc +++ b/test/syscalls/linux/shm.cc @@ -13,7 +13,6 @@ // limitations under the License. #include <stdio.h> - #include <sys/ipc.h> #include <sys/mman.h> #include <sys/shm.h> @@ -474,7 +473,7 @@ TEST(ShmTest, PartialUnmap) { } // Check that sentry does not panic when asked for a zero-length private shm -// segment. +// segment. Regression test for b/110694797. TEST(ShmTest, GracefullyFailOnZeroLenSegmentCreation) { EXPECT_THAT(Shmget(IPC_PRIVATE, 0, 0), PosixErrorIs(EINVAL, _)); } diff --git a/test/syscalls/linux/sigaction.cc b/test/syscalls/linux/sigaction.cc index 9a53fd3e0..9d9dd57a8 100644 --- a/test/syscalls/linux/sigaction.cc +++ b/test/syscalls/linux/sigaction.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <signal.h> +#include <sys/syscall.h> #include "gtest/gtest.h" #include "test/util/test_util.h" @@ -23,45 +24,53 @@ namespace testing { namespace { TEST(SigactionTest, GetLessThanOrEqualToZeroFails) { - struct sigaction act; - memset(&act, 0, sizeof(act)); - ASSERT_THAT(sigaction(-1, NULL, &act), SyscallFailsWithErrno(EINVAL)); - ASSERT_THAT(sigaction(0, NULL, &act), SyscallFailsWithErrno(EINVAL)); + struct sigaction act = {}; + ASSERT_THAT(sigaction(-1, nullptr, &act), SyscallFailsWithErrno(EINVAL)); + ASSERT_THAT(sigaction(0, nullptr, &act), SyscallFailsWithErrno(EINVAL)); } TEST(SigactionTest, SetLessThanOrEqualToZeroFails) { - struct sigaction act; - memset(&act, 0, sizeof(act)); - ASSERT_THAT(sigaction(0, &act, NULL), SyscallFailsWithErrno(EINVAL)); - ASSERT_THAT(sigaction(0, &act, NULL), SyscallFailsWithErrno(EINVAL)); + struct sigaction act = {}; + ASSERT_THAT(sigaction(0, &act, nullptr), SyscallFailsWithErrno(EINVAL)); + ASSERT_THAT(sigaction(0, &act, nullptr), SyscallFailsWithErrno(EINVAL)); } TEST(SigactionTest, GetGreaterThanMaxFails) { - struct sigaction act; - memset(&act, 0, sizeof(act)); - ASSERT_THAT(sigaction(SIGRTMAX + 1, NULL, &act), + struct sigaction act = {}; + ASSERT_THAT(sigaction(SIGRTMAX + 1, nullptr, &act), SyscallFailsWithErrno(EINVAL)); } TEST(SigactionTest, SetGreaterThanMaxFails) { - struct sigaction act; - memset(&act, 0, sizeof(act)); - ASSERT_THAT(sigaction(SIGRTMAX + 1, &act, NULL), + struct sigaction act = {}; + ASSERT_THAT(sigaction(SIGRTMAX + 1, &act, nullptr), SyscallFailsWithErrno(EINVAL)); } TEST(SigactionTest, SetSigkillFails) { - struct sigaction act; - memset(&act, 0, sizeof(act)); - ASSERT_THAT(sigaction(SIGKILL, NULL, &act), SyscallSucceeds()); - ASSERT_THAT(sigaction(SIGKILL, &act, NULL), SyscallFailsWithErrno(EINVAL)); + struct sigaction act = {}; + ASSERT_THAT(sigaction(SIGKILL, nullptr, &act), SyscallSucceeds()); + ASSERT_THAT(sigaction(SIGKILL, &act, nullptr), SyscallFailsWithErrno(EINVAL)); } TEST(SigactionTest, SetSigstopFails) { - struct sigaction act; - memset(&act, 0, sizeof(act)); - ASSERT_THAT(sigaction(SIGSTOP, NULL, &act), SyscallSucceeds()); - ASSERT_THAT(sigaction(SIGSTOP, &act, NULL), SyscallFailsWithErrno(EINVAL)); + struct sigaction act = {}; + ASSERT_THAT(sigaction(SIGSTOP, nullptr, &act), SyscallSucceeds()); + ASSERT_THAT(sigaction(SIGSTOP, &act, nullptr), SyscallFailsWithErrno(EINVAL)); +} + +TEST(SigactionTest, BadSigsetFails) { + constexpr size_t kWrongSigSetSize = 43; + + struct sigaction act = {}; + + // The syscall itself (rather than the libc wrapper) takes the sigset_t size. + ASSERT_THAT( + syscall(SYS_rt_sigaction, SIGTERM, nullptr, &act, kWrongSigSetSize), + SyscallFailsWithErrno(EINVAL)); + ASSERT_THAT( + syscall(SYS_rt_sigaction, SIGTERM, &act, nullptr, kWrongSigSetSize), + SyscallFailsWithErrno(EINVAL)); } } // namespace diff --git a/test/syscalls/linux/sigaltstack.cc b/test/syscalls/linux/sigaltstack.cc index 6fd3989a4..24e7c4960 100644 --- a/test/syscalls/linux/sigaltstack.cc +++ b/test/syscalls/linux/sigaltstack.cc @@ -95,13 +95,7 @@ TEST(SigaltstackTest, ResetByExecve) { auto const cleanup_sigstack = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaltstack(stack)); - std::string full_path; - char* test_src = getenv("TEST_SRCDIR"); - if (test_src) { - full_path = JoinPath(test_src, "../../linux/sigaltstack_check"); - } - - ASSERT_FALSE(full_path.empty()); + std::string full_path = RunfilePath("test/syscalls/linux/sigaltstack_check"); pid_t child_pid = -1; int execve_errno = 0; @@ -120,7 +114,7 @@ TEST(SigaltstackTest, ResetByExecve) { volatile bool badhandler_on_sigaltstack = true; // Set by the handler. char* volatile badhandler_low_water_mark = nullptr; // Set by the handler. -volatile uint8_t badhandler_recursive_faults = 0; // Consumed by the handler. +volatile uint8_t badhandler_recursive_faults = 0; // Consumed by the handler. void badhandler(int sig, siginfo_t* siginfo, void* arg) { char stack_var = 0; @@ -174,8 +168,8 @@ TEST(SigaltstackTest, WalksOffBottom) { // Trigger a single fault. badhandler_low_water_mark = - static_cast<char*>(stack.ss_sp) + SIGSTKSZ; // Expected top. - badhandler_recursive_faults = 0; // Disable refault. + static_cast<char*>(stack.ss_sp) + SIGSTKSZ; // Expected top. + badhandler_recursive_faults = 0; // Disable refault. Fault(); EXPECT_TRUE(badhandler_on_sigaltstack); EXPECT_THAT(sigaltstack(nullptr, &stack), SyscallSucceeds()); diff --git a/test/syscalls/linux/sigiret.cc b/test/syscalls/linux/sigiret.cc index a47c781ea..6227774a4 100644 --- a/test/syscalls/linux/sigiret.cc +++ b/test/syscalls/linux/sigiret.cc @@ -78,8 +78,8 @@ TEST(SigIretTest, CheckRcxR11) { "1: pause; cmpl $0, %[gotvtalrm]; je 1b;" // while (!gotvtalrm); "movq %%rcx, %[rcx];" // rcx = %rcx "movq %%r11, %[r11];" // r11 = %r11 - : [ready] "=m"(ready), [rcx] "+m"(rcx), [r11] "+m"(r11) - : [gotvtalrm] "m"(gotvtalrm) + : [ ready ] "=m"(ready), [ rcx ] "+m"(rcx), [ r11 ] "+m"(r11) + : [ gotvtalrm ] "m"(gotvtalrm) : "cc", "memory", "rcx", "r11"); // If sigreturn(2) returns via 'sysret' then %rcx and %r11 will be @@ -132,6 +132,5 @@ int main(int argc, char** argv) { TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); gvisor::testing::TestInit(&argc, &argv); - - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/signalfd.cc b/test/syscalls/linux/signalfd.cc index 09ecad34a..389e5fca2 100644 --- a/test/syscalls/linux/signalfd.cc +++ b/test/syscalls/linux/signalfd.cc @@ -39,6 +39,7 @@ namespace testing { namespace { constexpr int kSigno = SIGUSR1; +constexpr int kSignoMax = 64; // SIGRTMAX constexpr int kSignoAlt = SIGUSR2; // Returns a new signalfd. @@ -51,41 +52,45 @@ inline PosixErrorOr<FileDescriptor> NewSignalFD(sigset_t* mask, int flags = 0) { return FileDescriptor(fd); } -TEST(Signalfd, Basic) { +class SignalfdTest : public ::testing::TestWithParam<int> {}; + +TEST_P(SignalfdTest, Basic) { + int signo = GetParam(); // Create the signalfd. sigset_t mask; sigemptyset(&mask); - sigaddset(&mask, kSigno); + sigaddset(&mask, signo); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0)); // Deliver the blocked signal. const auto scoped_sigmask = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno)); - ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds()); + ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, signo)); + ASSERT_THAT(tgkill(getpid(), gettid(), signo), SyscallSucceeds()); // We should now read the signal. struct signalfd_siginfo rbuf; ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), SyscallSucceedsWithValue(sizeof(rbuf))); - EXPECT_EQ(rbuf.ssi_signo, kSigno); + EXPECT_EQ(rbuf.ssi_signo, signo); } -TEST(Signalfd, MaskWorks) { +TEST_P(SignalfdTest, MaskWorks) { + int signo = GetParam(); // Create two signalfds with different masks. sigset_t mask1, mask2; sigemptyset(&mask1); sigemptyset(&mask2); - sigaddset(&mask1, kSigno); + sigaddset(&mask1, signo); sigaddset(&mask2, kSignoAlt); FileDescriptor fd1 = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask1, 0)); FileDescriptor fd2 = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask2, 0)); // Deliver the two signals. const auto scoped_sigmask1 = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno)); + ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, signo)); const auto scoped_sigmask2 = ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSignoAlt)); - ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds()); + ASSERT_THAT(tgkill(getpid(), gettid(), signo), SyscallSucceeds()); ASSERT_THAT(tgkill(getpid(), gettid(), kSignoAlt), SyscallSucceeds()); // We should see the signals on the appropriate signalfds. @@ -98,7 +103,7 @@ TEST(Signalfd, MaskWorks) { EXPECT_EQ(rbuf2.ssi_signo, kSignoAlt); ASSERT_THAT(read(fd1.get(), &rbuf1, sizeof(rbuf1)), SyscallSucceedsWithValue(sizeof(rbuf1))); - EXPECT_EQ(rbuf1.ssi_signo, kSigno); + EXPECT_EQ(rbuf1.ssi_signo, signo); } TEST(Signalfd, Cloexec) { @@ -111,11 +116,12 @@ TEST(Signalfd, Cloexec) { EXPECT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC)); } -TEST(Signalfd, Blocking) { +TEST_P(SignalfdTest, Blocking) { + int signo = GetParam(); // Create the signalfd in blocking mode. sigset_t mask; sigemptyset(&mask); - sigaddset(&mask, kSigno); + sigaddset(&mask, signo); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0)); // Shared tid variable. @@ -136,7 +142,7 @@ TEST(Signalfd, Blocking) { struct signalfd_siginfo rbuf; ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), SyscallSucceedsWithValue(sizeof(rbuf))); - EXPECT_EQ(rbuf.ssi_signo, kSigno); + EXPECT_EQ(rbuf.ssi_signo, signo); }); // Wait until blocked. @@ -149,20 +155,21 @@ TEST(Signalfd, Blocking) { // // See gvisor.dev/issue/139. if (IsRunningOnGvisor()) { - ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds()); + ASSERT_THAT(tgkill(getpid(), gettid(), signo), SyscallSucceeds()); } else { - ASSERT_THAT(tgkill(getpid(), tid, kSigno), SyscallSucceeds()); + ASSERT_THAT(tgkill(getpid(), tid, signo), SyscallSucceeds()); } // Ensure that it was received. t.Join(); } -TEST(Signalfd, ThreadGroup) { +TEST_P(SignalfdTest, ThreadGroup) { + int signo = GetParam(); // Create the signalfd in blocking mode. sigset_t mask; sigemptyset(&mask); - sigaddset(&mask, kSigno); + sigaddset(&mask, signo); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0)); // Shared variable. @@ -176,7 +183,7 @@ TEST(Signalfd, ThreadGroup) { struct signalfd_siginfo rbuf; ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), SyscallSucceedsWithValue(sizeof(rbuf))); - EXPECT_EQ(rbuf.ssi_signo, kSigno); + EXPECT_EQ(rbuf.ssi_signo, signo); // Wait for the other thread. absl::MutexLock ml(&mu); @@ -185,7 +192,7 @@ TEST(Signalfd, ThreadGroup) { }); // Deliver the signal to the threadgroup. - ASSERT_THAT(kill(getpid(), kSigno), SyscallSucceeds()); + ASSERT_THAT(kill(getpid(), signo), SyscallSucceeds()); // Wait for the first thread to process. { @@ -194,13 +201,13 @@ TEST(Signalfd, ThreadGroup) { } // Deliver to the thread group again (other thread still exists). - ASSERT_THAT(kill(getpid(), kSigno), SyscallSucceeds()); + ASSERT_THAT(kill(getpid(), signo), SyscallSucceeds()); // Ensure that we can also receive it. struct signalfd_siginfo rbuf; ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), SyscallSucceedsWithValue(sizeof(rbuf))); - EXPECT_EQ(rbuf.ssi_signo, kSigno); + EXPECT_EQ(rbuf.ssi_signo, signo); // Mark the test as done. { @@ -212,11 +219,12 @@ TEST(Signalfd, ThreadGroup) { t.Join(); } -TEST(Signalfd, Nonblock) { +TEST_P(SignalfdTest, Nonblock) { + int signo = GetParam(); // Create the signalfd in non-blocking mode. sigset_t mask; sigemptyset(&mask); - sigaddset(&mask, kSigno); + sigaddset(&mask, signo); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_NONBLOCK)); @@ -227,20 +235,21 @@ TEST(Signalfd, Nonblock) { // Block and deliver the signal. const auto scoped_sigmask = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno)); - ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds()); + ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, signo)); + ASSERT_THAT(tgkill(getpid(), gettid(), signo), SyscallSucceeds()); // Ensure that a read actually works. ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), SyscallSucceedsWithValue(sizeof(rbuf))); - EXPECT_EQ(rbuf.ssi_signo, kSigno); + EXPECT_EQ(rbuf.ssi_signo, signo); // Should block again. EXPECT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), SyscallFailsWithErrno(EWOULDBLOCK)); } -TEST(Signalfd, SetMask) { +TEST_P(SignalfdTest, SetMask) { + int signo = GetParam(); // Create the signalfd matching nothing. sigset_t mask; sigemptyset(&mask); @@ -249,8 +258,8 @@ TEST(Signalfd, SetMask) { // Block and deliver a signal. const auto scoped_sigmask = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno)); - ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds()); + ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, signo)); + ASSERT_THAT(tgkill(getpid(), gettid(), signo), SyscallSucceeds()); // We should have nothing. struct signalfd_siginfo rbuf; @@ -258,29 +267,30 @@ TEST(Signalfd, SetMask) { SyscallFailsWithErrno(EWOULDBLOCK)); // Change the signal mask. - sigaddset(&mask, kSigno); + sigaddset(&mask, signo); ASSERT_THAT(signalfd(fd.get(), &mask, 0), SyscallSucceeds()); // We should now have the signal. ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), SyscallSucceedsWithValue(sizeof(rbuf))); - EXPECT_EQ(rbuf.ssi_signo, kSigno); + EXPECT_EQ(rbuf.ssi_signo, signo); } -TEST(Signalfd, Poll) { +TEST_P(SignalfdTest, Poll) { + int signo = GetParam(); // Create the signalfd. sigset_t mask; sigemptyset(&mask); - sigaddset(&mask, kSigno); + sigaddset(&mask, signo); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0)); // Block the signal, and start a thread to deliver it. const auto scoped_sigmask = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno)); + ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, signo)); pid_t orig_tid = gettid(); ScopedThread t([&] { absl::SleepFor(absl::Seconds(5)); - ASSERT_THAT(tgkill(getpid(), orig_tid, kSigno), SyscallSucceeds()); + ASSERT_THAT(tgkill(getpid(), orig_tid, signo), SyscallSucceeds()); }); // Start polling for the signal. We expect that it is not available at the @@ -297,19 +307,18 @@ TEST(Signalfd, Poll) { SyscallSucceedsWithValue(sizeof(rbuf))); } -TEST(Signalfd, KillStillKills) { - sigset_t mask; - sigemptyset(&mask); - sigaddset(&mask, SIGKILL); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_CLOEXEC)); - - // Just because there is a signalfd, we shouldn't see any change in behavior - // for unblockable signals. It's easier to test this with SIGKILL. - const auto scoped_sigmask = - ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, SIGKILL)); - EXPECT_EXIT(tgkill(getpid(), gettid(), SIGKILL), KilledBySignal(SIGKILL), ""); +std::string PrintSigno(::testing::TestParamInfo<int> info) { + switch (info.param) { + case kSigno: + return "kSigno"; + case kSignoMax: + return "kSignoMax"; + default: + return absl::StrCat(info.param); + } } +INSTANTIATE_TEST_SUITE_P(Signalfd, SignalfdTest, + ::testing::Values(kSigno, kSignoMax), PrintSigno); TEST(Signalfd, Ppoll) { sigset_t mask; @@ -328,6 +337,20 @@ TEST(Signalfd, Ppoll) { SyscallSucceedsWithValue(0)); } +TEST(Signalfd, KillStillKills) { + sigset_t mask; + sigemptyset(&mask); + sigaddset(&mask, SIGKILL); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_CLOEXEC)); + + // Just because there is a signalfd, we shouldn't see any change in behavior + // for unblockable signals. It's easier to test this with SIGKILL. + const auto scoped_sigmask = + ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, SIGKILL)); + EXPECT_EXIT(tgkill(getpid(), gettid(), SIGKILL), KilledBySignal(SIGKILL), ""); +} + } // namespace } // namespace testing @@ -340,10 +363,11 @@ int main(int argc, char** argv) { sigset_t set; sigemptyset(&set); sigaddset(&set, gvisor::testing::kSigno); + sigaddset(&set, gvisor::testing::kSignoMax); sigaddset(&set, gvisor::testing::kSignoAlt); TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); gvisor::testing::TestInit(&argc, &argv); - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/sigprocmask.cc b/test/syscalls/linux/sigprocmask.cc index 654c6a47f..a603fc1d1 100644 --- a/test/syscalls/linux/sigprocmask.cc +++ b/test/syscalls/linux/sigprocmask.cc @@ -237,7 +237,7 @@ TEST_F(SigProcMaskTest, SignalHandler) { } // Check that sigprocmask correctly handles aliasing of the set and oldset -// pointers. +// pointers. Regression test for b/30502311. TEST_F(SigProcMaskTest, AliasedSets) { sigset_t mask; diff --git a/test/syscalls/linux/sigstop.cc b/test/syscalls/linux/sigstop.cc index 7db57d968..b2fcedd62 100644 --- a/test/syscalls/linux/sigstop.cc +++ b/test/syscalls/linux/sigstop.cc @@ -147,5 +147,5 @@ int main(int argc, char** argv) { return 1; } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/sigtimedwait.cc b/test/syscalls/linux/sigtimedwait.cc index 1e5bf5942..4f8afff15 100644 --- a/test/syscalls/linux/sigtimedwait.cc +++ b/test/syscalls/linux/sigtimedwait.cc @@ -319,6 +319,5 @@ int main(int argc, char** argv) { TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); gvisor::testing::TestInit(&argc, &argv); - - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/socket.cc b/test/syscalls/linux/socket.cc index 3a07ac8d2..c20cd3fcc 100644 --- a/test/syscalls/linux/socket.cc +++ b/test/syscalls/linux/socket.cc @@ -13,11 +13,14 @@ // limitations under the License. #include <sys/socket.h> +#include <sys/stat.h> +#include <sys/types.h> #include <unistd.h> #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/file_descriptor.h" +#include "test/util/temp_umask.h" #include "test/util/test_util.h" namespace gvisor { @@ -58,12 +61,45 @@ TEST(SocketTest, ProtocolInet) { } } +TEST(SocketTest, UnixSocketStat) { + SKIP_IF(IsRunningWithVFS1()); + + FileDescriptor bound = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, PF_UNIX)); + + // The permissions of the file created with bind(2) should be defined by the + // permissions of the bound socket and the umask. + mode_t sock_perm = 0765, mask = 0123; + ASSERT_THAT(fchmod(bound.get(), sock_perm), SyscallSucceeds()); + TempUmask m(mask); + + struct sockaddr_un addr = + ASSERT_NO_ERRNO_AND_VALUE(UniqueUnixAddr(/*abstract=*/false, AF_UNIX)); + ASSERT_THAT(bind(bound.get(), reinterpret_cast<struct sockaddr*>(&addr), + sizeof(addr)), + SyscallSucceeds()); + + struct stat statbuf = {}; + ASSERT_THAT(stat(addr.sun_path, &statbuf), SyscallSucceeds()); + + // Mode should be S_IFSOCK. + EXPECT_EQ(statbuf.st_mode, S_IFSOCK | sock_perm & ~mask); + + // Timestamps should be equal and non-zero. + // TODO(b/158882152): Sockets currently don't implement timestamps. + if (!IsRunningOnGvisor()) { + EXPECT_NE(statbuf.st_atime, 0); + EXPECT_EQ(statbuf.st_atime, statbuf.st_mtime); + EXPECT_EQ(statbuf.st_atime, statbuf.st_ctime); + } +} + using SocketOpenTest = ::testing::TestWithParam<int>; // UDS cannot be opened. TEST_P(SocketOpenTest, Unix) { // FIXME(b/142001530): Open incorrectly succeeds on gVisor. - SKIP_IF(IsRunningOnGvisor()); + SKIP_IF(IsRunningWithVFS1()); FileDescriptor bound = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, PF_UNIX)); diff --git a/test/syscalls/linux/socket_abstract.cc b/test/syscalls/linux/socket_abstract.cc index 715d87b76..00999f192 100644 --- a/test/syscalls/linux/socket_abstract.cc +++ b/test/syscalls/linux/socket_abstract.cc @@ -23,6 +23,7 @@ namespace gvisor { namespace testing { +namespace { std::vector<SocketPairKind> GetSocketPairs() { return ApplyVec<SocketPairKind>( @@ -43,5 +44,6 @@ INSTANTIATE_TEST_SUITE_P( AbstractUnixSockets, UnixSocketPairCmsgTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +} // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_bind_to_device_distribution.cc b/test/syscalls/linux/socket_bind_to_device_distribution.cc index 5767181a1..5ed57625c 100644 --- a/test/syscalls/linux/socket_bind_to_device_distribution.cc +++ b/test/syscalls/linux/socket_bind_to_device_distribution.cc @@ -183,7 +183,14 @@ TEST_P(BindToDeviceDistributionTest, Tcp) { } // Receive some data from a socket to be sure that the connect() // system call has been completed on another side. - int data; + // Do a short read and then close the socket to trigger a RST. This + // ensures that both ends of the connection are cleaned up and no + // goroutines hang around in TIME-WAIT. We do this so that this test + // does not timeout under gotsan runs where lots of goroutines can + // cause the test to use absurd amounts of memory. + // + // See: https://tools.ietf.org/html/rfc2525#page-50 section 2.17 + uint16_t data; EXPECT_THAT( RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0), SyscallSucceedsWithValue(sizeof(data))); @@ -198,15 +205,29 @@ TEST_P(BindToDeviceDistributionTest, Tcp) { } for (int i = 0; i < kConnectAttempts; i++) { - FileDescriptor const fd = ASSERT_NO_ERRNO_AND_VALUE( + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); ASSERT_THAT( RetryEINTR(connect)(fd.get(), reinterpret_cast<sockaddr*>(&conn_addr), connector.addr_len), SyscallSucceeds()); + // Do two separate sends to ensure two segments are received. This is + // required for netstack where read is incorrectly assuming a whole + // segment is read when endpoint.Read() is called which is technically + // incorrect as the syscall that invoked endpoint.Read() may only + // consume it partially. This results in a case where a close() of + // such a socket does not trigger a RST in netstack due to the + // endpoint assuming that the endpoint has no unread data. EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0), SyscallSucceedsWithValue(sizeof(i))); + + // TODO(gvisor.dev/issue/1449): Remove this block once netstack correctly + // generates a RST. + if (IsRunningOnGvisor()) { + EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0), + SyscallSucceedsWithValue(sizeof(i))); + } } // Join threads to be sure that all connections have been counted. diff --git a/test/syscalls/linux/socket_bind_to_device_sequence.cc b/test/syscalls/linux/socket_bind_to_device_sequence.cc index e4641c62e..d3cc71dbf 100644 --- a/test/syscalls/linux/socket_bind_to_device_sequence.cc +++ b/test/syscalls/linux/socket_bind_to_device_sequence.cc @@ -33,6 +33,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/container/node_hash_map.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_bind_to_device_util.h" #include "test/syscalls/linux/socket_test_util.h" @@ -66,7 +67,7 @@ class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> { // Gets a device by device_id. If the device_id has been seen before, returns // the previously returned device. If not, finds or creates a new device. // Returns an empty string on failure. - void GetDevice(int device_id, string *device_name) { + void GetDevice(int device_id, string* device_name) { auto device = devices_.find(device_id); if (device != devices_.end()) { *device_name = device->second; @@ -97,12 +98,22 @@ class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> { sockets_to_close_.erase(socket_id); } - // Bind a socket with the reuse option and bind_to_device options. Checks + // SetDevice changes the bind_to_device option. It does not bind or re-bind. + void SetDevice(int socket_id, int device_id) { + auto socket_fd = sockets_to_close_[socket_id]->get(); + string device_name; + ASSERT_NO_FATAL_FAILURE(GetDevice(device_id, &device_name)); + EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE, + device_name.c_str(), device_name.size() + 1), + SyscallSucceedsWithValue(0)); + } + + // Bind a socket with the reuse options and bind_to_device options. Checks // that all steps succeed and that the bind command's error matches want. // Sets the socket_id to uniquely identify the socket bound if it is not // nullptr. - void BindSocket(bool reuse, int device_id = 0, int want = 0, - int *socket_id = nullptr) { + void BindSocket(bool reuse_port, bool reuse_addr, int device_id = 0, + int want = 0, int* socket_id = nullptr) { next_socket_id_++; sockets_to_close_[next_socket_id_] = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto socket_fd = sockets_to_close_[next_socket_id_]->get(); @@ -110,13 +121,20 @@ class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> { *socket_id = next_socket_id_; } - // If reuse is indicated, do that. - if (reuse) { + // If reuse_port is indicated, do that. + if (reuse_port) { EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceedsWithValue(0)); } + // If reuse_addr is indicated, do that. + if (reuse_addr) { + EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + } + // If the device is non-zero, bind to that device. if (device_id != 0) { string device_name; @@ -137,12 +155,12 @@ class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> { addr.sin_port = port_; if (want == 0) { ASSERT_THAT( - bind(socket_fd, reinterpret_cast<const struct sockaddr *>(&addr), + bind(socket_fd, reinterpret_cast<const struct sockaddr*>(&addr), sizeof(addr)), SyscallSucceeds()); } else { ASSERT_THAT( - bind(socket_fd, reinterpret_cast<const struct sockaddr *>(&addr), + bind(socket_fd, reinterpret_cast<const struct sockaddr*>(&addr), sizeof(addr)), SyscallFailsWithErrno(want)); } @@ -152,7 +170,7 @@ class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> { // remember it for future commands. socklen_t addr_size = sizeof(addr); ASSERT_THAT( - getsockname(socket_fd, reinterpret_cast<struct sockaddr *>(&addr), + getsockname(socket_fd, reinterpret_cast<struct sockaddr*>(&addr), &addr_size), SyscallSucceeds()); port_ = addr.sin_port; @@ -162,7 +180,7 @@ class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> { private: SocketKind socket_factory_; // devices maps from the device id in the test case to the name of the device. - std::unordered_map<int, string> devices_; + absl::node_hash_map<int, string> devices_; // These are the tunnels that were created for the test and will be destroyed // by the destructor. vector<std::unique_ptr<Tunnel>> tunnels_; @@ -175,136 +193,316 @@ class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> { in_port_t port_ = 0; // sockets_to_close_ is a map from action index to the socket that was // created. - std::unordered_map<int, - std::unique_ptr<gvisor::testing::FileDescriptor>> + absl::node_hash_map<int, + std::unique_ptr<gvisor::testing::FileDescriptor>> sockets_to_close_; int next_socket_id_ = 0; }; TEST_P(BindToDeviceSequenceTest, BindTwiceWithDeviceFails) { - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 3)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 3, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ false, /* reuse_addr */ false, /* bind_to_device */ 3)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 3, EADDRINUSE)); } TEST_P(BindToDeviceSequenceTest, BindToDevice) { - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 1)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 2)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ false, /* reuse_addr */ false, /* bind_to_device */ 1)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ false, /* reuse_addr */ false, /* bind_to_device */ 2)); } TEST_P(BindToDeviceSequenceTest, BindToDeviceAndThenWithoutDevice) { - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 123)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); } TEST_P(BindToDeviceSequenceTest, BindWithoutDevice) { - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ false)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 123, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); } TEST_P(BindToDeviceSequenceTest, BindWithDevice) { - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 123, 0)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 123, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 456, 0)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 789, 0)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 123, 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 456, 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 789, 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); } TEST_P(BindToDeviceSequenceTest, BindWithReuse) { - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 123)); ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true, /* bind_to_device */ 0)); + BindSocket(/* reusePort */ true, /* reuse_addr */ false)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, + /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 0)); } TEST_P(BindToDeviceSequenceTest, BindingWithReuseAndDevice) { - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 123)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 123)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 456)); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 789)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 999, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 456)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse_port */ true, /* reuse_addr */ false)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 789)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 999, EADDRINUSE)); } TEST_P(BindToDeviceSequenceTest, MixingReuseAndNotReuseByBindingToDevice) { - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 123, 0)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 456, 0)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 789, 0)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 999, 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 123, 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 456, 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 789, 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 999, 0)); } TEST_P(BindToDeviceSequenceTest, CannotBindTo0AfterMixingReuseAndNotReuse) { - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 123)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 456)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 456)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); } TEST_P(BindToDeviceSequenceTest, BindAndRelease) { - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123)); int to_release; - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 0, 0, &to_release)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 345, EADDRINUSE)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 789)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, 0, &to_release)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 345, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 789)); // Release the bind to device 0 and try again. ASSERT_NO_FATAL_FAILURE(ReleaseSocket(to_release)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 345)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 345)); } TEST_P(BindToDeviceSequenceTest, BindTwiceWithReuseOnce) { - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ false, /* bind_to_device */ 123)); - ASSERT_NO_FATAL_FAILURE( - BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindWithReuseAddr) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reusePort */ false, /* reuse_addr */ true)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ false, /* reuse_addr */ true, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ false, /* reuse_addr */ true, /* bind_to_device */ 0)); +} + +TEST_P(BindToDeviceSequenceTest, + CannotBindTo0AfterMixingReuseAddrAndNotReuseAddr) { + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 456)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ true, + /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindReuseAddrReusePortThenReusePort) { + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ true, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ true, + /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindReuseAddrReusePortThenReuseAddr) { + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ true, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ true, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindDoubleReuseAddrReusePortThenReusePort) { + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ true, /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ true, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ true, + /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindDoubleReuseAddrReusePortThenReuseAddr) { + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ true, /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ true, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ true, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindReusePortThenReuseAddrReusePort) { + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ true, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ true, + /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindReuseAddrThenReuseAddr) { + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ false, /* reuse_addr */ true, /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, + BindReuseAddrThenReuseAddrReusePortThenReuseAddr) { + // The behavior described in this test seems like a Linux bug. It doesn't + // make any sense and it is unlikely that any applications rely on it. + // + // Both SO_REUSEADDR and SO_REUSEPORT allow binding multiple UDP sockets to + // the same address and deliver each packet to exactly one of the bound + // sockets. If both are enabled, one of the strategies is selected to route + // packets. The strategy is selected dynamically based on the settings of the + // currently bound sockets. Usually, the strategy is selected based on the + // common setting (SO_REUSEADDR or SO_REUSEPORT) amongst the sockets, but for + // some reason, Linux allows binding sets of sockets with no overlapping + // settings in some situations. In this case, it is not obvious which strategy + // would be selected as the configured setting is a contradiction. + SKIP_IF(IsRunningOnGvisor()); + + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ false, /* reuse_addr */ true, /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ true, + /* bind_to_device */ 0)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, + /* reuse_addr */ false, + /* bind_to_device */ 0)); +} + +// Repro test for gvisor.dev/issue/1217. Not replicated in ports_test.go as this +// test is different from the others and wouldn't fit well there. +TEST_P(BindToDeviceSequenceTest, BindAndReleaseDifferentDevice) { + int to_release; + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 3, 0, &to_release)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, + /* reuse_addr */ false, + /* bind_to_device */ 3, EADDRINUSE)); + // Change the device. Since the socket was already bound, this should have no + // effect. + SetDevice(to_release, 2); + // Release the bind to device 3 and try again. + ASSERT_NO_FATAL_FAILURE(ReleaseSocket(to_release)); + ASSERT_NO_FATAL_FAILURE(BindSocket( + /* reuse_port */ false, /* reuse_addr */ false, /* bind_to_device */ 3)); } INSTANTIATE_TEST_SUITE_P(BindToDeviceTest, BindToDeviceSequenceTest, diff --git a/test/syscalls/linux/socket_blocking.cc b/test/syscalls/linux/socket_blocking.cc index d7ce57566..7e88aa2d9 100644 --- a/test/syscalls/linux/socket_blocking.cc +++ b/test/syscalls/linux/socket_blocking.cc @@ -17,6 +17,7 @@ #include <sys/socket.h> #include <sys/types.h> #include <sys/un.h> + #include <cstdio> #include "gtest/gtest.h" diff --git a/test/syscalls/linux/socket_capability.cc b/test/syscalls/linux/socket_capability.cc new file mode 100644 index 000000000..84b5b2b21 --- /dev/null +++ b/test/syscalls/linux/socket_capability.cc @@ -0,0 +1,61 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Subset of socket tests that need Linux-specific headers (compared to POSIX +// headers). + +#include "gtest/gtest.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +TEST(SocketTest, UnixConnectNeedsWritePerm) { + SKIP_IF(IsRunningWithVFS1()); + + FileDescriptor bound = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, PF_UNIX)); + + struct sockaddr_un addr = + ASSERT_NO_ERRNO_AND_VALUE(UniqueUnixAddr(/*abstract=*/false, AF_UNIX)); + ASSERT_THAT(bind(bound.get(), reinterpret_cast<struct sockaddr*>(&addr), + sizeof(addr)), + SyscallSucceeds()); + ASSERT_THAT(listen(bound.get(), 1), SyscallSucceeds()); + + // Drop capabilites that allow us to override permision checks. Otherwise if + // the test is run as root, the connect below will bypass permission checks + // and succeed unexpectedly. + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + + // Connect should fail without write perms. + ASSERT_THAT(chmod(addr.sun_path, 0500), SyscallSucceeds()); + FileDescriptor client = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, PF_UNIX)); + ASSERT_THAT(connect(client.get(), reinterpret_cast<struct sockaddr*>(&addr), + sizeof(addr)), + SyscallFailsWithErrno(EACCES)); + + // Connect should succeed with write perms. + ASSERT_THAT(chmod(addr.sun_path, 0200), SyscallSucceeds()); + EXPECT_THAT(connect(client.get(), reinterpret_cast<struct sockaddr*>(&addr), + sizeof(addr)), + SyscallSucceeds()); +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_filesystem.cc b/test/syscalls/linux/socket_filesystem.cc index 74e262959..287359363 100644 --- a/test/syscalls/linux/socket_filesystem.cc +++ b/test/syscalls/linux/socket_filesystem.cc @@ -23,6 +23,7 @@ namespace gvisor { namespace testing { +namespace { std::vector<SocketPairKind> GetSocketPairs() { return ApplyVec<SocketPairKind>( @@ -43,5 +44,6 @@ INSTANTIATE_TEST_SUITE_P( FilesystemUnixSockets, UnixSocketPairCmsgTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +} // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc index e8f24a59e..a6182f0ac 100644 --- a/test/syscalls/linux/socket_generic.cc +++ b/test/syscalls/linux/socket_generic.cc @@ -447,6 +447,62 @@ TEST_P(AllSocketPairTest, RecvTimeoutRecvmsgSucceeds) { SyscallFailsWithErrno(EAGAIN)); } +TEST_P(AllSocketPairTest, SendTimeoutDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + timeval actual_tv = {.tv_sec = -1, .tv_usec = -1}; + socklen_t len = sizeof(actual_tv); + EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, + &actual_tv, &len), + SyscallSucceeds()); + EXPECT_EQ(actual_tv.tv_sec, 0); + EXPECT_EQ(actual_tv.tv_usec, 0); +} + +TEST_P(AllSocketPairTest, SetGetSendTimeout) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // tv_usec should be a multiple of 4000 to work on most systems. + timeval tv = {.tv_sec = 89, .tv_usec = 42000}; + EXPECT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)), + SyscallSucceeds()); + + timeval actual_tv = {}; + socklen_t len = sizeof(actual_tv); + EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, + &actual_tv, &len), + SyscallSucceeds()); + EXPECT_EQ(actual_tv.tv_sec, tv.tv_sec); + EXPECT_EQ(actual_tv.tv_usec, tv.tv_usec); +} + +TEST_P(AllSocketPairTest, SetGetSendTimeoutLargerArg) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + struct timeval_with_extra { + struct timeval tv; + int64_t extra_data; + } ABSL_ATTRIBUTE_PACKED; + + // tv_usec should be a multiple of 4000 to work on most systems. + timeval_with_extra tv_extra = { + .tv = {.tv_sec = 0, .tv_usec = 124000}, + }; + + EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, + &tv_extra, sizeof(tv_extra)), + SyscallSucceeds()); + + timeval_with_extra actual_tv = {}; + socklen_t len = sizeof(actual_tv); + EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, + &actual_tv, &len), + SyscallSucceeds()); + EXPECT_EQ(actual_tv.tv.tv_sec, tv_extra.tv.tv_sec); + EXPECT_EQ(actual_tv.tv.tv_usec, tv_extra.tv.tv_usec); +} + TEST_P(AllSocketPairTest, SendTimeoutAllowsWrite) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -491,18 +547,36 @@ TEST_P(AllSocketPairTest, SendTimeoutAllowsSendmsg) { ASSERT_NO_FATAL_FAILURE(SendNullCmsg(sockets->first_fd(), buf, sizeof(buf))); } -TEST_P(AllSocketPairTest, SoRcvTimeoIsSet) { +TEST_P(AllSocketPairTest, RecvTimeoutDefault) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - struct timeval tv { - .tv_sec = 0, .tv_usec = 35 - }; + timeval actual_tv = {.tv_sec = -1, .tv_usec = -1}; + socklen_t len = sizeof(actual_tv); + EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, + &actual_tv, &len), + SyscallSucceeds()); + EXPECT_EQ(actual_tv.tv_sec, 0); + EXPECT_EQ(actual_tv.tv_usec, 0); +} + +TEST_P(AllSocketPairTest, SetGetRecvTimeout) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + timeval tv = {.tv_sec = 123, .tv_usec = 456000}; EXPECT_THAT( setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)), SyscallSucceeds()); + + timeval actual_tv = {}; + socklen_t len = sizeof(actual_tv); + EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, + &actual_tv, &len), + SyscallSucceeds()); + EXPECT_EQ(actual_tv.tv_sec, 123); + EXPECT_EQ(actual_tv.tv_usec, 456000); } -TEST_P(AllSocketPairTest, SoRcvTimeoIsSetLargerArg) { +TEST_P(AllSocketPairTest, SetGetRecvTimeoutLargerArg) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); struct timeval_with_extra { @@ -510,13 +584,21 @@ TEST_P(AllSocketPairTest, SoRcvTimeoIsSetLargerArg) { int64_t extra_data; } ABSL_ATTRIBUTE_PACKED; - timeval_with_extra tv_extra; - tv_extra.tv.tv_sec = 0; - tv_extra.tv.tv_usec = 25; + timeval_with_extra tv_extra = { + .tv = {.tv_sec = 0, .tv_usec = 432000}, + }; EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv_extra, sizeof(tv_extra)), SyscallSucceeds()); + + timeval_with_extra actual_tv = {}; + socklen_t len = sizeof(actual_tv); + EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, + &actual_tv, &len), + SyscallSucceeds()); + EXPECT_EQ(actual_tv.tv.tv_sec, 0); + EXPECT_EQ(actual_tv.tv.tv_usec, 432000); } TEST_P(AllSocketPairTest, RecvTimeoutRecvmsgOneSecondSucceeds) { diff --git a/test/syscalls/linux/socket_generic_stress.cc b/test/syscalls/linux/socket_generic_stress.cc new file mode 100644 index 000000000..19239e9e9 --- /dev/null +++ b/test/syscalls/linux/socket_generic_stress.cc @@ -0,0 +1,130 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <poll.h> +#include <stdio.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/un.h> + +#include "gtest/gtest.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +// Test fixture for tests that apply to pairs of connected sockets. +using ConnectStressTest = SocketPairTest; + +TEST_P(ConnectStressTest, Reset65kTimes) { + for (int i = 0; i < 1 << 16; ++i) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Send some data to ensure that the connection gets reset and the port gets + // released immediately. This avoids either end entering TIME-WAIT. + char sent_data[100] = {}; + ASSERT_THAT(write(sockets->first_fd(), sent_data, sizeof(sent_data)), + SyscallSucceedsWithValue(sizeof(sent_data))); + // Poll the other FD to make sure that the data is in the receive buffer + // before closing it to ensure a RST is triggered. + const int kTimeout = 10000; + struct pollfd pfd = { + .fd = sockets->second_fd(), + .events = POLL_IN, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + } +} + +INSTANTIATE_TEST_SUITE_P( + AllConnectedSockets, ConnectStressTest, + ::testing::Values(IPv6UDPBidirectionalBindSocketPair(0), + IPv4UDPBidirectionalBindSocketPair(0), + DualStackUDPBidirectionalBindSocketPair(0), + + // Without REUSEADDR, we get port exhaustion on Linux. + SetSockOpt(SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn)(IPv6TCPAcceptBindSocketPair(0)), + SetSockOpt(SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn)(IPv4TCPAcceptBindSocketPair(0)), + SetSockOpt(SOL_SOCKET, SO_REUSEADDR, &kSockOptOn)( + DualStackTCPAcceptBindSocketPair(0)))); + +// Test fixture for tests that apply to pairs of connected sockets created with +// a persistent listener (if applicable). +using PersistentListenerConnectStressTest = SocketPairTest; + +TEST_P(PersistentListenerConnectStressTest, 65kTimesShutdownCloseFirst) { + for (int i = 0; i < 1 << 16; ++i) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RDWR), SyscallSucceeds()); + if (GetParam().type == SOCK_STREAM) { + // Poll the other FD to make sure that we see the FIN from the other + // side before closing the second_fd. This ensures that the first_fd + // enters TIME-WAIT and not second_fd. + const int kTimeout = 10000; + struct pollfd pfd = { + .fd = sockets->second_fd(), + .events = POLL_IN, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + } + ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_RDWR), SyscallSucceeds()); + } +} + +TEST_P(PersistentListenerConnectStressTest, 65kTimesShutdownCloseSecond) { + for (int i = 0; i < 1 << 16; ++i) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_RDWR), SyscallSucceeds()); + if (GetParam().type == SOCK_STREAM) { + // Poll the other FD to make sure that we see the FIN from the other + // side before closing the first_fd. This ensures that the second_fd + // enters TIME-WAIT and not first_fd. + const int kTimeout = 10000; + struct pollfd pfd = { + .fd = sockets->first_fd(), + .events = POLL_IN, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + } + ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RDWR), SyscallSucceeds()); + } +} + +TEST_P(PersistentListenerConnectStressTest, 65kTimesClose) { + for (int i = 0; i < 1 << 16; ++i) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + } +} + +INSTANTIATE_TEST_SUITE_P( + AllConnectedSockets, PersistentListenerConnectStressTest, + ::testing::Values( + IPv6UDPBidirectionalBindSocketPair(0), + IPv4UDPBidirectionalBindSocketPair(0), + DualStackUDPBidirectionalBindSocketPair(0), + + // Without REUSEADDR, we get port exhaustion on Linux. + SetSockOpt(SOL_SOCKET, SO_REUSEADDR, &kSockOptOn)( + IPv6TCPAcceptBindPersistentListenerSocketPair(0)), + SetSockOpt(SOL_SOCKET, SO_REUSEADDR, &kSockOptOn)( + IPv4TCPAcceptBindPersistentListenerSocketPair(0)), + SetSockOpt(SOL_SOCKET, SO_REUSEADDR, &kSockOptOn)( + DualStackTCPAcceptBindPersistentListenerSocketPair(0)))); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 322ee07ad..c3b42682f 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -14,6 +14,7 @@ #include <arpa/inet.h> #include <netinet/in.h> +#include <netinet/tcp.h> #include <poll.h> #include <string.h> #include <sys/socket.h> @@ -30,7 +31,9 @@ #include "gtest/gtest.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/time/clock.h" #include "absl/time/time.h" +#include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" @@ -43,6 +46,8 @@ namespace testing { namespace { +using ::testing::Gt; + PosixErrorOr<uint16_t> AddrPort(int family, sockaddr_storage const& addr) { switch (family) { case AF_INET: @@ -99,19 +104,172 @@ TEST(BadSocketPairArgs, ValidateErrForBadCallsToSocketPair) { SyscallFailsWithErrno(EAFNOSUPPORT)); } -TEST_P(SocketInetLoopbackTest, TCP) { - auto const& param = GetParam(); +enum class Operation { + Bind, + Connect, + SendTo, +}; - TestAddress const& listener = param.listener; - TestAddress const& connector = param.connector; +std::string OperationToString(Operation operation) { + switch (operation) { + case Operation::Bind: + return "Bind"; + case Operation::Connect: + return "Connect"; + case Operation::SendTo: + return "SendTo"; + } +} + +using OperationSequence = std::vector<Operation>; + +using DualStackSocketTest = + ::testing::TestWithParam<std::tuple<TestAddress, OperationSequence>>; + +TEST_P(DualStackSocketTest, AddressOperations) { + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_DGRAM, 0)); + + const TestAddress& addr = std::get<0>(GetParam()); + const OperationSequence& operations = std::get<1>(GetParam()); + + auto addr_in = reinterpret_cast<const sockaddr*>(&addr.addr); + + // sockets may only be bound once. Both `connect` and `sendto` cause a socket + // to be bound. + bool bound = false; + for (const Operation& operation : operations) { + bool sockname = false; + bool peername = false; + switch (operation) { + case Operation::Bind: { + ASSERT_NO_ERRNO(SetAddrPort( + addr.family(), const_cast<sockaddr_storage*>(&addr.addr), 0)); + + int bind_ret = bind(fd.get(), addr_in, addr.addr_len); + + // Dual stack sockets may only be bound to AF_INET6. + if (!bound && addr.family() == AF_INET6) { + EXPECT_THAT(bind_ret, SyscallSucceeds()); + bound = true; + + sockname = true; + } else { + EXPECT_THAT(bind_ret, SyscallFailsWithErrno(EINVAL)); + } + break; + } + case Operation::Connect: { + ASSERT_NO_ERRNO(SetAddrPort( + addr.family(), const_cast<sockaddr_storage*>(&addr.addr), 1337)); + + EXPECT_THAT(RetryEINTR(connect)(fd.get(), addr_in, addr.addr_len), + SyscallSucceeds()) + << GetAddrStr(addr_in); + bound = true; + + sockname = true; + peername = true; + + break; + } + case Operation::SendTo: { + const char payload[] = "hello"; + ASSERT_NO_ERRNO(SetAddrPort( + addr.family(), const_cast<sockaddr_storage*>(&addr.addr), 1337)); + + ssize_t sendto_ret = sendto(fd.get(), &payload, sizeof(payload), 0, + addr_in, addr.addr_len); + + EXPECT_THAT(sendto_ret, SyscallSucceedsWithValue(sizeof(payload))); + sockname = !bound; + bound = true; + break; + } + } + + if (sockname) { + sockaddr_storage sock_addr; + socklen_t addrlen = sizeof(sock_addr); + ASSERT_THAT(getsockname(fd.get(), reinterpret_cast<sockaddr*>(&sock_addr), + &addrlen), + SyscallSucceeds()); + ASSERT_EQ(addrlen, sizeof(struct sockaddr_in6)); + + auto sock_addr_in6 = reinterpret_cast<const sockaddr_in6*>(&sock_addr); + + if (operation == Operation::SendTo) { + EXPECT_EQ(sock_addr_in6->sin6_family, AF_INET6); + EXPECT_TRUE(IN6_IS_ADDR_UNSPECIFIED(sock_addr_in6->sin6_addr.s6_addr32)) + << OperationToString(operation) << " getsocknam=" + << GetAddrStr(reinterpret_cast<sockaddr*>(&sock_addr)); + + EXPECT_NE(sock_addr_in6->sin6_port, 0); + } else if (IN6_IS_ADDR_V4MAPPED( + reinterpret_cast<const sockaddr_in6*>(addr_in) + ->sin6_addr.s6_addr32)) { + EXPECT_TRUE(IN6_IS_ADDR_V4MAPPED(sock_addr_in6->sin6_addr.s6_addr32)) + << OperationToString(operation) << " getsocknam=" + << GetAddrStr(reinterpret_cast<sockaddr*>(&sock_addr)); + } + } + + if (peername) { + sockaddr_storage peer_addr; + socklen_t addrlen = sizeof(peer_addr); + ASSERT_THAT(getpeername(fd.get(), reinterpret_cast<sockaddr*>(&peer_addr), + &addrlen), + SyscallSucceeds()); + ASSERT_EQ(addrlen, sizeof(struct sockaddr_in6)); + + if (addr.family() == AF_INET || + IN6_IS_ADDR_V4MAPPED(reinterpret_cast<const sockaddr_in6*>(addr_in) + ->sin6_addr.s6_addr32)) { + EXPECT_TRUE(IN6_IS_ADDR_V4MAPPED( + reinterpret_cast<const sockaddr_in6*>(&peer_addr) + ->sin6_addr.s6_addr32)) + << OperationToString(operation) << " getpeername=" + << GetAddrStr(reinterpret_cast<sockaddr*>(&peer_addr)); + } + } + } +} + +// TODO(gvisor.dev/issue/1556): uncomment V4MappedAny. +INSTANTIATE_TEST_SUITE_P( + All, DualStackSocketTest, + ::testing::Combine( + ::testing::Values(V4Any(), V4Loopback(), /*V4MappedAny(),*/ + V4MappedLoopback(), V6Any(), V6Loopback()), + ::testing::ValuesIn<OperationSequence>( + {{Operation::Bind, Operation::Connect, Operation::SendTo}, + {Operation::Bind, Operation::SendTo, Operation::Connect}, + {Operation::Connect, Operation::Bind, Operation::SendTo}, + {Operation::Connect, Operation::SendTo, Operation::Bind}, + {Operation::SendTo, Operation::Bind, Operation::Connect}, + {Operation::SendTo, Operation::Connect, Operation::Bind}})), + [](::testing::TestParamInfo< + std::tuple<TestAddress, OperationSequence>> const& info) { + const TestAddress& addr = std::get<0>(info.param); + const OperationSequence& operations = std::get<1>(info.param); + std::string s = addr.description; + for (const Operation& operation : operations) { + absl::StrAppend(&s, OperationToString(operation)); + } + return s; + }); +void tcpSimpleConnectTest(TestAddress const& listener, + TestAddress const& connector, bool unbound) { // Create the listening socket. const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); + if (!unbound) { + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + } ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); // Get the port bound by the listening socket. @@ -145,12 +303,31 @@ TEST_P(SocketInetLoopbackTest, TCP) { ASSERT_THAT(shutdown(conn_fd.get(), SHUT_RDWR), SyscallSucceeds()); } -TEST_P(SocketInetLoopbackTest, TCPListenClose) { +TEST_P(SocketInetLoopbackTest, TCP) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + tcpSimpleConnectTest(listener, connector, true); +} + +TEST_P(SocketInetLoopbackTest, TCPListenUnbound) { auto const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; + tcpSimpleConnectTest(listener, connector, false); +} + +TEST_P(SocketInetLoopbackTest, TCPListenShutdownListen) { + const auto& param = GetParam(); + + const TestAddress& listener = param.listener; + const TestAddress& connector = param.connector; + + constexpr int kBacklog = 5; + // Create the listening socket. FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); @@ -158,7 +335,52 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) { ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len), SyscallSucceeds()); - ASSERT_THAT(listen(listen_fd.get(), 1001), SyscallSucceeds()); + + ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); + ASSERT_THAT(shutdown(listen_fd.get(), SHUT_RD), SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + const uint16_t port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + + for (int i = 0; i < kBacklog; i++) { + auto client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(RetryEINTR(connect)(client.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + } + for (int i = 0; i < kBacklog; i++) { + ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), SyscallSucceeds()); + } +} + +TEST_P(SocketInetLoopbackTest, TCPListenShutdown) { + auto const& param = GetParam(); + + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + constexpr int kBacklog = 2; + constexpr int kFDs = kBacklog + 1; + + // Create the listening socket. + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); // Get the port bound by the listening socket. socklen_t addrlen = listener.addr_len; @@ -168,42 +390,169 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) { uint16_t const port = ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - DisableSave ds; // Too many system calls. sockaddr_storage conn_addr = connector.addr; ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - constexpr int kFDs = 2048; - constexpr int kThreadCount = 4; - constexpr int kFDsPerThread = kFDs / kThreadCount; - FileDescriptor clients[kFDs]; - std::unique_ptr<ScopedThread> threads[kThreadCount]; + + // Shutdown the write of the listener, expect to not have any effect. + ASSERT_THAT(shutdown(listen_fd.get(), SHUT_WR), SyscallSucceeds()); + for (int i = 0; i < kFDs; i++) { - clients[i] = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + auto client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(RetryEINTR(connect)(client.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), SyscallSucceeds()); } - for (int i = 0; i < kThreadCount; i++) { - threads[i] = absl::make_unique<ScopedThread>([&connector, &conn_addr, - &clients, i]() { - for (int j = 0; j < kFDsPerThread; j++) { - int k = i * kFDsPerThread + j; - int ret = - connect(clients[k].get(), reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len); - if (ret != 0) { - EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); - } - } - }); + + // Shutdown the read of the listener, expect to fail subsequent + // server accepts, binds and client connects. + ASSERT_THAT(shutdown(listen_fd.get(), SHUT_RD), SyscallSucceeds()); + + ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), + SyscallFailsWithErrno(EINVAL)); + + // Check that shutdown did not release the port. + FileDescriptor new_listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT( + bind(new_listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); + + // Check that subsequent connection attempts receive a RST. + auto client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + for (int i = 0; i < kFDs; i++) { + auto client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(RetryEINTR(connect)(client.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallFailsWithErrno(ECONNREFUSED)); } - for (int i = 0; i < kThreadCount; i++) { - threads[i]->Join(); +} + +TEST_P(SocketInetLoopbackTest, TCPListenClose) { + auto const& param = GetParam(); + + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + constexpr int kAcceptCount = 2; + constexpr int kBacklog = kAcceptCount + 2; + constexpr int kFDs = kBacklog * 3; + + // Create the listening socket. + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + std::vector<FileDescriptor> clients; + for (int i = 0; i < kFDs; i++) { + auto client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len); + if (ret != 0) { + EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); + } + clients.push_back(std::move(client)); } - for (int i = 0; i < 32; i++) { + for (int i = 0; i < kAcceptCount; i++) { auto accepted = ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); } - // TODO(b/138400178): Fix cooperative S/R failure when ds.reset() is invoked - // before function end. - // ds.reset() +} + +void TestListenWhileConnect(const TestParam& param, + void (*stopListen)(FileDescriptor&)) { + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + constexpr int kBacklog = 2; + constexpr int kClients = kBacklog + 1; + + // Create the listening socket. + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + std::vector<FileDescriptor> clients; + for (int i = 0; i < kClients; i++) { + FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len); + if (ret != 0) { + EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); + clients.push_back(std::move(client)); + } + } + + stopListen(listen_fd); + + for (auto& client : clients) { + const int kTimeout = 10000; + struct pollfd pfd = { + .fd = client.get(), + .events = POLLIN, + }; + // When the listening socket is closed, then we expect the remote to reset + // the connection. + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + ASSERT_EQ(pfd.revents, POLLIN | POLLHUP | POLLERR); + char c; + // Subsequent read can fail with: + // ECONNRESET: If the client connection was established and was reset by the + // remote. + // ECONNREFUSED: If the client connection failed to be established. + ASSERT_THAT(read(client.get(), &c, sizeof(c)), + AnyOf(SyscallFailsWithErrno(ECONNRESET), + SyscallFailsWithErrno(ECONNREFUSED))); + } +} + +TEST_P(SocketInetLoopbackTest, TCPListenCloseWhileConnect) { + TestListenWhileConnect(GetParam(), [](FileDescriptor& f) { + ASSERT_THAT(close(f.release()), SyscallSucceeds()); + }); +} + +TEST_P(SocketInetLoopbackTest, TCPListenShutdownWhileConnect) { + TestListenWhileConnect(GetParam(), [](FileDescriptor& f) { + ASSERT_THAT(shutdown(f.get(), SHUT_RD), SyscallSucceeds()); + }); } TEST_P(SocketInetLoopbackTest, TCPbacklog) { @@ -266,6 +615,649 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog) { } } +// TCPFinWait2Test creates a pair of connected sockets then closes one end to +// trigger FIN_WAIT2 state for the closed endpoint. Then it binds the same local +// IP/port on a new socket and tries to connect. The connect should fail w/ +// an EADDRINUSE. Then we wait till the FIN_WAIT2 timeout is over and try the +// connect again with a new socket and this time it should succeed. +// +// TCP timers are not S/R today, this can cause this test to be flaky when run +// under random S/R due to timer being reset on a restore. +TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + // Lower FIN_WAIT2 state to 5 seconds for test. + constexpr int kTCPLingerTimeout = 5; + EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2, + &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), + SyscallSucceedsWithValue(0)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // Get the address/port bound by the connecting socket. + sockaddr_storage conn_bound_addr; + socklen_t conn_addrlen = connector.addr_len; + ASSERT_THAT( + getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + &conn_addrlen), + SyscallSucceeds()); + + // close the connecting FD to trigger FIN_WAIT2 on the connected fd. + conn_fd.reset(); + + // Now bind and connect a new socket. + const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + // Disable cooperative saves after this point. As a save between the first + // bind/connect and the second one can cause the linger timeout timer to + // be restarted causing the final bind/connect to fail. + DisableSave ds; + + ASSERT_THAT(bind(conn_fd2.get(), + reinterpret_cast<sockaddr*>(&conn_bound_addr), conn_addrlen), + SyscallFailsWithErrno(EADDRINUSE)); + + // Sleep for a little over the linger timeout to reduce flakiness in + // save/restore tests. + absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 2)); + + ds.reset(); + + if (!IsRunningOnGvisor()) { + ASSERT_THAT( + bind(conn_fd2.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + conn_addrlen), + SyscallSucceeds()); + } + ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + conn_addrlen), + SyscallSucceeds()); +} + +// TCPLinger2TimeoutAfterClose creates a pair of connected sockets +// then closes one end to trigger FIN_WAIT2 state for the closed endpont. +// It then sleeps for the TCP_LINGER2 timeout and verifies that bind/ +// connecting the same address succeeds. +// +// TCP timers are not S/R today, this can cause this test to be flaky when run +// under random S/R due to timer being reset on a restore. +TEST_P(SocketInetLoopbackTest, TCPLinger2TimeoutAfterClose_NoRandomSave) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // Get the address/port bound by the connecting socket. + sockaddr_storage conn_bound_addr; + socklen_t conn_addrlen = connector.addr_len; + ASSERT_THAT( + getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + &conn_addrlen), + SyscallSucceeds()); + + // Disable cooperative saves after this point as TCP timers are not restored + // across a S/R. + { + DisableSave ds; + constexpr int kTCPLingerTimeout = 5; + EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2, + &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), + SyscallSucceedsWithValue(0)); + + // close the connecting FD to trigger FIN_WAIT2 on the connected fd. + conn_fd.reset(); + + absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1)); + + // ds going out of scope will Re-enable S/R's since at this point the timer + // must have fired and cleaned up the endpoint. + } + + // Now bind and connect a new socket and verify that we can immediately + // rebind the address bound by the conn_fd as it never entered TIME_WAIT. + const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT(bind(conn_fd2.get(), + reinterpret_cast<sockaddr*>(&conn_bound_addr), conn_addrlen), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + conn_addrlen), + SyscallSucceeds()); +} + +// TCPResetAfterClose creates a pair of connected sockets then closes +// one end to trigger FIN_WAIT2 state for the closed endpoint verifies +// that we generate RSTs for any new data after the socket is fully +// closed. +TEST_P(SocketInetLoopbackTest, TCPResetAfterClose) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // close the connecting FD to trigger FIN_WAIT2 on the connected fd. + conn_fd.reset(); + + int data = 1234; + + // Now send data which should trigger a RST as the other end should + // have timed out and closed the socket. + EXPECT_THAT(RetryEINTR(send)(accepted.get(), &data, sizeof(data), 0), + SyscallSucceeds()); + // Sleep for a shortwhile to get a RST back. + absl::SleepFor(absl::Seconds(1)); + + // Try writing again and we should get an EPIPE back. + EXPECT_THAT(RetryEINTR(send)(accepted.get(), &data, sizeof(data), 0), + SyscallFailsWithErrno(EPIPE)); + + // Trying to read should return zero as the other end did send + // us a FIN. We do it twice to verify that the RST does not cause an + // ECONNRESET on the read after EOF has been read by applicaiton. + EXPECT_THAT(RetryEINTR(recv)(accepted.get(), &data, sizeof(data), 0), + SyscallSucceedsWithValue(0)); + EXPECT_THAT(RetryEINTR(recv)(accepted.get(), &data, sizeof(data), 0), + SyscallSucceedsWithValue(0)); +} + +// This test is disabled under random save as the the restore run +// results in the stack.Seed() being different which can cause +// sequence number of final connect to be one that is considered +// old and can cause the test to be flaky. +TEST_P(SocketInetLoopbackTest, TCPPassiveCloseNoTimeWaitTest_NoRandomSave) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + // We disable saves after this point as a S/R causes the netstack seed + // to be regenerated which changes what ports/ISN is picked for a given + // tuple (src ip,src port, dst ip, dst port). This can cause the final + // SYN to use a sequence number that looks like one from the current + // connection in TIME_WAIT and will not be accepted causing the test + // to timeout. + // + // TODO(gvisor.dev/issue/940): S/R portSeed/portHint + DisableSave ds; + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // Get the address/port bound by the connecting socket. + sockaddr_storage conn_bound_addr; + socklen_t conn_addrlen = connector.addr_len; + ASSERT_THAT( + getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + &conn_addrlen), + SyscallSucceeds()); + + // shutdown the accept FD to trigger TIME_WAIT on the accepted socket which + // should cause the conn_fd to follow CLOSE_WAIT->LAST_ACK->CLOSED instead of + // TIME_WAIT. + ASSERT_THAT(shutdown(accepted.get(), SHUT_RDWR), SyscallSucceeds()); + { + const int kTimeout = 10000; + struct pollfd pfd = { + .fd = conn_fd.get(), + .events = POLLIN, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + ASSERT_EQ(pfd.revents, POLLIN); + } + + conn_fd.reset(); + // This sleep is required to give conn_fd time to transition to TIME-WAIT. + absl::SleepFor(absl::Seconds(1)); + + // At this point conn_fd should be the one that moved to CLOSE_WAIT and + // eventually to CLOSED. + + // Now bind and connect a new socket and verify that we can immediately + // rebind the address bound by the conn_fd as it never entered TIME_WAIT. + const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT(bind(conn_fd2.get(), + reinterpret_cast<sockaddr*>(&conn_bound_addr), conn_addrlen), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + conn_addrlen), + SyscallSucceeds()); +} + +TEST_P(SocketInetLoopbackTest, TCPActiveCloseTimeWaitTest_NoRandomSave) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + // We disable saves after this point as a S/R causes the netstack seed + // to be regenerated which changes what ports/ISN is picked for a given + // tuple (src ip,src port, dst ip, dst port). This can cause the final + // SYN to use a sequence number that looks like one from the current + // connection in TIME_WAIT and will not be accepted causing the test + // to timeout. + // + // TODO(gvisor.dev/issue/940): S/R portSeed/portHint + DisableSave ds; + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // Get the address/port bound by the connecting socket. + sockaddr_storage conn_bound_addr; + socklen_t conn_addrlen = connector.addr_len; + ASSERT_THAT( + getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + &conn_addrlen), + SyscallSucceeds()); + + // shutdown the conn FD to trigger TIME_WAIT on the connect socket. + ASSERT_THAT(shutdown(conn_fd.get(), SHUT_RDWR), SyscallSucceeds()); + { + const int kTimeout = 10000; + struct pollfd pfd = { + .fd = accepted.get(), + .events = POLLIN, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + ASSERT_EQ(pfd.revents, POLLIN); + } + ScopedThread t([&]() { + constexpr int kTimeout = 10000; + constexpr int16_t want_events = POLLHUP; + struct pollfd pfd = { + .fd = conn_fd.get(), + .events = want_events, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + }); + + accepted.reset(); + t.Join(); + conn_fd.reset(); + + // Now bind and connect a new socket and verify that we can't immediately + // rebind the address bound by the conn_fd as it is in TIME_WAIT. + conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT(bind(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + conn_addrlen), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + + const uint16_t port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Set the userTimeout on the listening socket. + constexpr int kUserTimeout = 10; + ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kUserTimeout, sizeof(kUserTimeout)), + SyscallSucceeds()); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + // Verify that the accepted socket inherited the user timeout set on + // listening socket. + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(accepted.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kUserTimeout); +} + +// TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not +// saved. Enable S/R once issue is fixed. +TEST_P(SocketInetLoopbackTest, TCPDeferAccept_NoRandomSave) { + // TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not + // saved. Enable S/R issue is fixed. + DisableSave ds; + + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + + const uint16_t port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Set the TCP_DEFER_ACCEPT on the listening socket. + constexpr int kTCPDeferAccept = 3; + ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, + &kTCPDeferAccept, sizeof(kTCPDeferAccept)), + SyscallSucceeds()); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Set the listening socket to nonblock so that we can verify that there is no + // connection in queue despite the connect above succeeding since the peer has + // sent no data and TCP_DEFER_ACCEPT is set on the listening socket. Set the + // FD to O_NONBLOCK. + int opts; + ASSERT_THAT(opts = fcntl(listen_fd.get(), F_GETFL), SyscallSucceeds()); + opts |= O_NONBLOCK; + ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds()); + + ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Set FD back to blocking. + opts &= ~O_NONBLOCK; + ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds()); + + // Now write some data to the socket. + int data = 0; + ASSERT_THAT(RetryEINTR(write)(conn_fd.get(), &data, sizeof(data)), + SyscallSucceedsWithValue(sizeof(data))); + + // This should now cause the connection to complete and be delivered to the + // accept socket. + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // Verify that the accepted socket returns the data written. + int get = -1; + ASSERT_THAT(RetryEINTR(recv)(accepted.get(), &get, sizeof(get), 0), + SyscallSucceedsWithValue(sizeof(get))); + + EXPECT_EQ(get, data); +} + +// TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not +// saved. Enable S/R once issue is fixed. +TEST_P(SocketInetLoopbackTest, TCPDeferAcceptTimeout_NoRandomSave) { + // TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not + // saved. Enable S/R once issue is fixed. + DisableSave ds; + + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + + const uint16_t port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Set the TCP_DEFER_ACCEPT on the listening socket. + constexpr int kTCPDeferAccept = 3; + ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, + &kTCPDeferAccept, sizeof(kTCPDeferAccept)), + SyscallSucceeds()); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Set the listening socket to nonblock so that we can verify that there is no + // connection in queue despite the connect above succeeding since the peer has + // sent no data and TCP_DEFER_ACCEPT is set on the listening socket. Set the + // FD to O_NONBLOCK. + int opts; + ASSERT_THAT(opts = fcntl(listen_fd.get(), F_GETFL), SyscallSucceeds()); + opts |= O_NONBLOCK; + ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds()); + + // Verify that there is no acceptable connection before TCP_DEFER_ACCEPT + // timeout is hit. + absl::SleepFor(absl::Seconds(kTCPDeferAccept - 1)); + ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Set FD back to blocking. + opts &= ~O_NONBLOCK; + ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds()); + + // Now sleep for a little over the TCP_DEFER_ACCEPT duration. When the timeout + // is hit a SYN-ACK should be retransmitted by the listener as a last ditch + // attempt to complete the connection with or without data. + absl::SleepFor(absl::Seconds(2)); + + // Verify that we have a connection that can be accepted even though no + // data was written. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); +} + INSTANTIATE_TEST_SUITE_P( All, SocketInetLoopbackTest, ::testing::Values( @@ -298,7 +1290,9 @@ INSTANTIATE_TEST_SUITE_P( using SocketInetReusePortTest = ::testing::TestWithParam<TestParam>; -TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) { +// TODO(gvisor.dev/issue/940): Remove _NoRandomSave when portHint/stack.Seed is +// saved/restored. +TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) { auto const& param = GetParam(); TestAddress const& listener = param.listener; @@ -306,6 +1300,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) { sockaddr_storage listen_addr = listener.addr; sockaddr_storage conn_addr = connector.addr; constexpr int kThreadCount = 3; + constexpr int kConnectAttempts = 10000; // Create the listening socket. FileDescriptor listener_fds[kThreadCount]; @@ -339,7 +1334,6 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) { ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); } - constexpr int kConnectAttempts = 10000; std::atomic<int> connects_received = ATOMIC_VAR_INIT(0); std::unique_ptr<ScopedThread> listen_thread[kThreadCount]; int accept_counts[kThreadCount] = {}; @@ -357,6 +1351,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) { if (connects_received >= kConnectAttempts) { // Another thread have shutdown our read side causing the // accept to fail. + ASSERT_EQ(errno, EINVAL); break; } ASSERT_NO_ERRNO(fd); @@ -364,7 +1359,14 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) { } // Receive some data from a socket to be sure that the connect() // system call has been completed on another side. - int data; + // Do a short read and then close the socket to trigger a RST. This + // ensures that both ends of the connection are cleaned up and no + // goroutines hang around in TIME-WAIT. We do this so that this test + // does not timeout under gotsan runs where lots of goroutines can + // cause the test to use absurd amounts of memory. + // + // See: https://tools.ietf.org/html/rfc2525#page-50 section 2.17 + uint16_t data; EXPECT_THAT( RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0), SyscallSucceedsWithValue(sizeof(data))); @@ -387,8 +1389,22 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) { connector.addr_len), SyscallSucceeds()); + // Do two separate sends to ensure two segments are received. This is + // required for netstack where read is incorrectly assuming a whole + // segment is read when endpoint.Read() is called which is technically + // incorrect as the syscall that invoked endpoint.Read() may only + // consume it partially. This results in a case where a close() of + // such a socket does not trigger a RST in netstack due to the + // endpoint assuming that the endpoint has no unread data. EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0), SyscallSucceedsWithValue(sizeof(i))); + + // TODO(gvisor.dev/issue/1449): Remove this block once netstack correctly + // generates a RST. + if (IsRunningOnGvisor()) { + EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0), + SyscallSucceedsWithValue(sizeof(i))); + } } }); @@ -403,7 +1419,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) { EquivalentWithin((kConnectAttempts / kThreadCount), 0.10)); } -TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) { +TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread_NoRandomSave) { auto const& param = GetParam(); TestAddress const& listener = param.listener; @@ -516,6 +1532,115 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) { EquivalentWithin((kConnectAttempts / kThreadCount), 0.10)); } +TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort_NoRandomSave) { + auto const& param = GetParam(); + + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + sockaddr_storage listen_addr = listener.addr; + sockaddr_storage conn_addr = connector.addr; + constexpr int kThreadCount = 3; + + // TODO(b/141211329): endpointsByNic.seed has to be saved/restored. + const DisableSave ds141211329; + + // Create listening sockets. + FileDescriptor listener_fds[kThreadCount]; + for (int i = 0; i < kThreadCount; i++) { + listener_fds[i] = + ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0)); + int fd = listener_fds[i].get(); + + ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT( + bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len), + SyscallSucceeds()); + + // On the first bind we need to determine which port was bound. + if (i != 0) { + continue; + } + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT( + getsockname(listener_fds[0].get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port)); + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + } + + constexpr int kConnectAttempts = 10; + FileDescriptor client_fds[kConnectAttempts]; + + // Do the first run without save/restore. + DisableSave ds; + for (int i = 0; i < kConnectAttempts; i++) { + client_fds[i] = + ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0)); + EXPECT_THAT(RetryEINTR(sendto)(client_fds[i].get(), &i, sizeof(i), 0, + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceedsWithValue(sizeof(i))); + } + ds.reset(); + + // Check that a mapping of client and server sockets has + // not been change after save/restore. + for (int i = 0; i < kConnectAttempts; i++) { + EXPECT_THAT(RetryEINTR(sendto)(client_fds[i].get(), &i, sizeof(i), 0, + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceedsWithValue(sizeof(i))); + } + + struct pollfd pollfds[kThreadCount]; + for (int i = 0; i < kThreadCount; i++) { + pollfds[i].fd = listener_fds[i].get(); + pollfds[i].events = POLLIN; + } + + std::map<uint16_t, int> portToFD; + + int received = 0; + while (received < kConnectAttempts * 2) { + ASSERT_THAT(poll(pollfds, kThreadCount, -1), + SyscallSucceedsWithValue(Gt(0))); + + for (int i = 0; i < kThreadCount; i++) { + if ((pollfds[i].revents & POLLIN) == 0) { + continue; + } + + received++; + + const int fd = pollfds[i].fd; + struct sockaddr_storage addr = {}; + socklen_t addrlen = sizeof(addr); + int data; + EXPECT_THAT(RetryEINTR(recvfrom)( + fd, &data, sizeof(data), 0, + reinterpret_cast<struct sockaddr*>(&addr), &addrlen), + SyscallSucceedsWithValue(sizeof(data))); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(connector.family(), addr)); + auto prev_port = portToFD.find(port); + // Check that all packets from one client have been delivered to the + // same server socket. + if (prev_port == portToFD.end()) { + portToFD[port] = fd; + } else { + EXPECT_EQ(portToFD[port], fd); + } + } + } +} + INSTANTIATE_TEST_SUITE_P( All, SocketInetReusePortTest, ::testing::Values( @@ -702,6 +1827,171 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, DualStackV6AnyReservesEverything) { ASSERT_THAT(bind(fd_v4.get(), reinterpret_cast<sockaddr*>(&addr_v4), test_addr_v4.addr_len), SyscallFailsWithErrno(EADDRINUSE)); + + // Verify that binding the v4 any on the same port with a v4 socket + // fails. + TestAddress const& test_addr_v4_any = V4Any(); + sockaddr_storage addr_v4_any = test_addr_v4_any.addr; + ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, port)); + const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE( + Socket(test_addr_v4_any.family(), param.type, 0)); + ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any), + test_addr_v4_any.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(SocketMultiProtocolInetLoopbackTest, + DualStackV6AnyReuseAddrDoesNotReserveV4Any) { + auto const& param = GetParam(); + + // Bind the v6 any on a dual stack socket. + TestAddress const& test_addr_dual = V6Any(); + sockaddr_storage addr_dual = test_addr_dual.addr; + const FileDescriptor fd_dual = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_dual.family(), param.type, 0)); + ASSERT_THAT(setsockopt(fd_dual.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual), + test_addr_dual.addr_len), + SyscallSucceeds()); + + // Get the port that we bound. + socklen_t addrlen = test_addr_dual.addr_len; + ASSERT_THAT(getsockname(fd_dual.get(), + reinterpret_cast<sockaddr*>(&addr_dual), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual)); + + // Verify that binding the v4 any on the same port with a v4 socket succeeds. + TestAddress const& test_addr_v4_any = V4Any(); + sockaddr_storage addr_v4_any = test_addr_v4_any.addr; + ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, port)); + const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE( + Socket(test_addr_v4_any.family(), param.type, 0)); + ASSERT_THAT(setsockopt(fd_v4_any.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any), + test_addr_v4_any.addr_len), + SyscallSucceeds()); +} + +TEST_P(SocketMultiProtocolInetLoopbackTest, + DualStackV6AnyReuseAddrListenReservesV4Any) { + auto const& param = GetParam(); + + // Only TCP sockets are supported. + SKIP_IF((param.type & SOCK_STREAM) == 0); + + // Bind the v6 any on a dual stack socket. + TestAddress const& test_addr_dual = V6Any(); + sockaddr_storage addr_dual = test_addr_dual.addr; + const FileDescriptor fd_dual = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_dual.family(), param.type, 0)); + ASSERT_THAT(setsockopt(fd_dual.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual), + test_addr_dual.addr_len), + SyscallSucceeds()); + + ASSERT_THAT(listen(fd_dual.get(), 5), SyscallSucceeds()); + + // Get the port that we bound. + socklen_t addrlen = test_addr_dual.addr_len; + ASSERT_THAT(getsockname(fd_dual.get(), + reinterpret_cast<sockaddr*>(&addr_dual), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual)); + + // Verify that binding the v4 any on the same port with a v4 socket succeeds. + TestAddress const& test_addr_v4_any = V4Any(); + sockaddr_storage addr_v4_any = test_addr_v4_any.addr; + ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, port)); + const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE( + Socket(test_addr_v4_any.family(), param.type, 0)); + ASSERT_THAT(setsockopt(fd_v4_any.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any), + test_addr_v4_any.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(SocketMultiProtocolInetLoopbackTest, + DualStackV6AnyWithListenReservesEverything) { + auto const& param = GetParam(); + + // Only TCP sockets are supported. + SKIP_IF((param.type & SOCK_STREAM) == 0); + + // Bind the v6 any on a dual stack socket. + TestAddress const& test_addr_dual = V6Any(); + sockaddr_storage addr_dual = test_addr_dual.addr; + const FileDescriptor fd_dual = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_dual.family(), param.type, 0)); + ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual), + test_addr_dual.addr_len), + SyscallSucceeds()); + + ASSERT_THAT(listen(fd_dual.get(), 5), SyscallSucceeds()); + + // Get the port that we bound. + socklen_t addrlen = test_addr_dual.addr_len; + ASSERT_THAT(getsockname(fd_dual.get(), + reinterpret_cast<sockaddr*>(&addr_dual), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual)); + + // Verify that binding the v6 loopback with the same port fails. + TestAddress const& test_addr_v6 = V6Loopback(); + sockaddr_storage addr_v6 = test_addr_v6.addr; + ASSERT_NO_ERRNO(SetAddrPort(test_addr_v6.family(), &addr_v6, port)); + const FileDescriptor fd_v6 = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0)); + ASSERT_THAT(bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6), + test_addr_v6.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); + + // Verify that binding the v4 loopback on the same port with a v6 socket + // fails. + TestAddress const& test_addr_v4_mapped = V4MappedLoopback(); + sockaddr_storage addr_v4_mapped = test_addr_v4_mapped.addr; + ASSERT_NO_ERRNO( + SetAddrPort(test_addr_v4_mapped.family(), &addr_v4_mapped, port)); + const FileDescriptor fd_v4_mapped = ASSERT_NO_ERRNO_AND_VALUE( + Socket(test_addr_v4_mapped.family(), param.type, 0)); + ASSERT_THAT( + bind(fd_v4_mapped.get(), reinterpret_cast<sockaddr*>(&addr_v4_mapped), + test_addr_v4_mapped.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); + + // Verify that binding the v4 loopback on the same port with a v4 socket + // fails. + TestAddress const& test_addr_v4 = V4Loopback(); + sockaddr_storage addr_v4 = test_addr_v4.addr; + ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4.family(), &addr_v4, port)); + const FileDescriptor fd_v4 = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v4.family(), param.type, 0)); + ASSERT_THAT(bind(fd_v4.get(), reinterpret_cast<sockaddr*>(&addr_v4), + test_addr_v4.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); + + // Verify that binding the v4 any on the same port with a v4 socket + // fails. + TestAddress const& test_addr_v4_any = V4Any(); + sockaddr_storage addr_v4_any = test_addr_v4_any.addr; + ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, port)); + const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE( + Socket(test_addr_v4_any.family(), param.type, 0)); + ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any), + test_addr_v4_any.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); } TEST_P(SocketMultiProtocolInetLoopbackTest, V6OnlyV6AnyReservesV6) { @@ -713,10 +2003,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6OnlyV6AnyReservesV6) { sockaddr_storage addr_dual = test_addr_dual.addr; const FileDescriptor fd_dual = ASSERT_NO_ERRNO_AND_VALUE( Socket(test_addr_dual.family(), param.type, 0)); - int one = 1; - EXPECT_THAT( - setsockopt(fd_dual.get(), IPPROTO_IPV6, IPV6_V6ONLY, &one, sizeof(one)), - SyscallSucceeds()); + EXPECT_THAT(setsockopt(fd_dual.get(), IPPROTO_IPV6, IPV6_V6ONLY, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual), test_addr_dual.addr_len), SyscallSucceeds()); @@ -764,9 +2053,6 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6OnlyV6AnyReservesV6) { TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) { auto const& param = GetParam(); - // FIXME(b/114268588) - SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_STREAM); - for (int i = 0; true; i++) { // Bind the v6 loopback on a dual stack socket. TestAddress const& test_addr = V6Loopback(); @@ -792,10 +2078,10 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) { // Connect to bind an ephemeral port. const FileDescriptor connected_fd = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT( - connect(connected_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - bound_addr_len), - SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), + reinterpret_cast<sockaddr*>(&bound_addr), + bound_addr_len), + SyscallSucceeds()); // Get the ephemeral port. sockaddr_storage connected_addr = {}; @@ -829,17 +2115,6 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) { test_addr_v6.addr_len), SyscallFailsWithErrno(EADDRINUSE)); - // Verify that binding the v4 any with the same port fails. - TestAddress const& test_addr_v4_any = V4Any(); - sockaddr_storage addr_v4_any = test_addr_v4_any.addr; - ASSERT_NO_ERRNO( - SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, ephemeral_port)); - const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE( - Socket(test_addr_v4_any.family(), param.type, 0)); - ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any), - test_addr_v4_any.addr_len), - SyscallFailsWithErrno(EADDRINUSE)); - // Verify that we can still bind the v4 loopback on the same port. TestAddress const& test_addr_v4_mapped = V4MappedLoopback(); sockaddr_storage addr_v4_mapped = test_addr_v4_mapped.addr; @@ -862,11 +2137,71 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) { } } -TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) { +TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReservedReuseAddr) { auto const& param = GetParam(); - // FIXME(b/114268588) - SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_STREAM); + // Bind the v6 loopback on a dual stack socket. + TestAddress const& test_addr = V6Loopback(); + sockaddr_storage bound_addr = test_addr.addr; + const FileDescriptor bound_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + test_addr.addr_len), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Listen iff TCP. + if (param.type == SOCK_STREAM) { + ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds()); + } + + // Get the port that we bound. + socklen_t bound_addr_len = test_addr.addr_len; + ASSERT_THAT( + getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + &bound_addr_len), + SyscallSucceeds()); + + // Connect to bind an ephemeral port. + const FileDescriptor connected_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + ASSERT_THAT(setsockopt(connected_fd.get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), + reinterpret_cast<sockaddr*>(&bound_addr), + bound_addr_len), + SyscallSucceeds()); + + // Get the ephemeral port. + sockaddr_storage connected_addr = {}; + socklen_t connected_addr_len = sizeof(connected_addr); + ASSERT_THAT(getsockname(connected_fd.get(), + reinterpret_cast<sockaddr*>(&connected_addr), + &connected_addr_len), + SyscallSucceeds()); + uint16_t const ephemeral_port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr)); + + // Verify that we actually got an ephemeral port. + ASSERT_NE(ephemeral_port, 0); + + // Verify that the ephemeral port is not reserved. + const FileDescriptor checking_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + ASSERT_THAT(setsockopt(checking_fd.get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + EXPECT_THAT( + bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr), + connected_addr_len), + SyscallSucceeds()); +} + +TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) { + auto const& param = GetParam(); for (int i = 0; true; i++) { // Bind the v4 loopback on a dual stack socket. @@ -893,10 +2228,10 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) { // Connect to bind an ephemeral port. const FileDescriptor connected_fd = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT( - connect(connected_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - bound_addr_len), - SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), + reinterpret_cast<sockaddr*>(&bound_addr), + bound_addr_len), + SyscallSucceeds()); // Get the ephemeral port. sockaddr_storage connected_addr = {}; @@ -965,9 +2300,8 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) { // v6-only socket. const FileDescriptor fd_v6_only_any = ASSERT_NO_ERRNO_AND_VALUE( Socket(test_addr_v6_any.family(), param.type, 0)); - int one = 1; EXPECT_THAT(setsockopt(fd_v6_only_any.get(), IPPROTO_IPV6, IPV6_V6ONLY, - &one, sizeof(one)), + &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); ret = bind(fd_v6_only_any.get(), reinterpret_cast<sockaddr*>(&addr_v6_any), @@ -986,11 +2320,73 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) { } } -TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) { +TEST_P(SocketMultiProtocolInetLoopbackTest, + V4MappedEphemeralPortReservedResueAddr) { auto const& param = GetParam(); - // FIXME(b/114268588) - SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_STREAM); + // Bind the v4 loopback on a dual stack socket. + TestAddress const& test_addr = V4MappedLoopback(); + sockaddr_storage bound_addr = test_addr.addr; + const FileDescriptor bound_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + test_addr.addr_len), + SyscallSucceeds()); + + ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Listen iff TCP. + if (param.type == SOCK_STREAM) { + ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds()); + } + + // Get the port that we bound. + socklen_t bound_addr_len = test_addr.addr_len; + ASSERT_THAT( + getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + &bound_addr_len), + SyscallSucceeds()); + + // Connect to bind an ephemeral port. + const FileDescriptor connected_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + ASSERT_THAT(setsockopt(connected_fd.get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), + reinterpret_cast<sockaddr*>(&bound_addr), + bound_addr_len), + SyscallSucceeds()); + + // Get the ephemeral port. + sockaddr_storage connected_addr = {}; + socklen_t connected_addr_len = sizeof(connected_addr); + ASSERT_THAT(getsockname(connected_fd.get(), + reinterpret_cast<sockaddr*>(&connected_addr), + &connected_addr_len), + SyscallSucceeds()); + uint16_t const ephemeral_port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr)); + + // Verify that we actually got an ephemeral port. + ASSERT_NE(ephemeral_port, 0); + + // Verify that the ephemeral port is not reserved. + const FileDescriptor checking_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + ASSERT_THAT(setsockopt(checking_fd.get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + EXPECT_THAT( + bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr), + connected_addr_len), + SyscallSucceeds()); +} + +TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) { + auto const& param = GetParam(); for (int i = 0; true; i++) { // Bind the v4 loopback on a v4 socket. @@ -1017,10 +2413,10 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) { // Connect to bind an ephemeral port. const FileDescriptor connected_fd = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT( - connect(connected_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - bound_addr_len), - SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), + reinterpret_cast<sockaddr*>(&bound_addr), + bound_addr_len), + SyscallSucceeds()); // Get the ephemeral port. sockaddr_storage connected_addr = {}; @@ -1090,9 +2486,8 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) { // v6-only socket. const FileDescriptor fd_v6_only_any = ASSERT_NO_ERRNO_AND_VALUE( Socket(test_addr_v6_any.family(), param.type, 0)); - int one = 1; EXPECT_THAT(setsockopt(fd_v6_only_any.get(), IPPROTO_IPV6, IPV6_V6ONLY, - &one, sizeof(one)), + &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); ret = bind(fd_v6_only_any.get(), reinterpret_cast<sockaddr*>(&addr_v6_any), @@ -1111,6 +2506,73 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) { } } +TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReservedReuseAddr) { + auto const& param = GetParam(); + + // Bind the v4 loopback on a v4 socket. + TestAddress const& test_addr = V4Loopback(); + sockaddr_storage bound_addr = test_addr.addr; + const FileDescriptor bound_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + + ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + test_addr.addr_len), + SyscallSucceeds()); + + // Listen iff TCP. + if (param.type == SOCK_STREAM) { + ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds()); + } + + // Get the port that we bound. + socklen_t bound_addr_len = test_addr.addr_len; + ASSERT_THAT( + getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + &bound_addr_len), + SyscallSucceeds()); + + // Connect to bind an ephemeral port. + const FileDescriptor connected_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + + ASSERT_THAT(setsockopt(connected_fd.get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), + reinterpret_cast<sockaddr*>(&bound_addr), + bound_addr_len), + SyscallSucceeds()); + + // Get the ephemeral port. + sockaddr_storage connected_addr = {}; + socklen_t connected_addr_len = sizeof(connected_addr); + ASSERT_THAT(getsockname(connected_fd.get(), + reinterpret_cast<sockaddr*>(&connected_addr), + &connected_addr_len), + SyscallSucceeds()); + uint16_t const ephemeral_port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr)); + + // Verify that we actually got an ephemeral port. + ASSERT_NE(ephemeral_port, 0); + + // Verify that the ephemeral port is not reserved. + const FileDescriptor checking_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + ASSERT_THAT(setsockopt(checking_fd.get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + EXPECT_THAT( + bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr), + connected_addr_len), + SyscallSucceeds()); +} + TEST_P(SocketMultiProtocolInetLoopbackTest, PortReuseTwoSockets) { auto const& param = GetParam(); TestAddress const& test_addr = V4Loopback(); @@ -1148,7 +2610,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, PortReuseTwoSockets) { setsockopt(fd2, SOL_SOCKET, SO_REUSEPORT, &portreuse2, sizeof(int)), SyscallSucceeds()); - std::cout << portreuse1 << " " << portreuse2; + std::cout << portreuse1 << " " << portreuse2 << std::endl; int ret = bind(fd2, reinterpret_cast<sockaddr*>(&addr), addrlen); // Verify that two sockets can be bound to the same port only if @@ -1197,7 +2659,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, NoReusePortFollowingReusePort) { } INSTANTIATE_TEST_SUITE_P( - AllFamlies, SocketMultiProtocolInetLoopbackTest, + AllFamilies, SocketMultiProtocolInetLoopbackTest, ::testing::Values(ProtocolTestParam{"TCP", SOCK_STREAM}, ProtocolTestParam{"UDP", SOCK_DGRAM}), DescribeProtocolTestParam); diff --git a/test/syscalls/linux/socket_inet_loopback_nogotsan.cc b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc new file mode 100644 index 000000000..791e2bd51 --- /dev/null +++ b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc @@ -0,0 +1,174 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <arpa/inet.h> +#include <netinet/in.h> +#include <netinet/tcp.h> +#include <string.h> + +#include <iostream> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/strings/str_cat.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/posix_error.h" +#include "test/util/save_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +using ::testing::Gt; + +PosixErrorOr<uint16_t> AddrPort(int family, sockaddr_storage const& addr) { + switch (family) { + case AF_INET: + return static_cast<uint16_t>( + reinterpret_cast<sockaddr_in const*>(&addr)->sin_port); + case AF_INET6: + return static_cast<uint16_t>( + reinterpret_cast<sockaddr_in6 const*>(&addr)->sin6_port); + default: + return PosixError(EINVAL, + absl::StrCat("unknown socket family: ", family)); + } +} + +PosixError SetAddrPort(int family, sockaddr_storage* addr, uint16_t port) { + switch (family) { + case AF_INET: + reinterpret_cast<sockaddr_in*>(addr)->sin_port = port; + return NoError(); + case AF_INET6: + reinterpret_cast<sockaddr_in6*>(addr)->sin6_port = port; + return NoError(); + default: + return PosixError(EINVAL, + absl::StrCat("unknown socket family: ", family)); + } +} + +struct TestParam { + TestAddress listener; + TestAddress connector; +}; + +std::string DescribeTestParam(::testing::TestParamInfo<TestParam> const& info) { + return absl::StrCat("Listen", info.param.listener.description, "_Connect", + info.param.connector.description); +} + +using SocketInetLoopbackTest = ::testing::TestWithParam<TestParam>; + +// This test verifies that connect returns EADDRNOTAVAIL if all local ephemeral +// ports are already in use for a given destination ip/port. +// +// We disable S/R because this test creates a large number of sockets. +// +// FIXME(b/162475855): This test is failing reliably. +TEST_P(SocketInetLoopbackTest, DISABLED_TestTCPPortExhaustion_NoRandomSave) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + constexpr int kBacklog = 10; + constexpr int kClients = 65536; + + // Create the listening socket. + auto listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Disable cooperative S/R as we are making too many syscalls. + DisableSave ds; + + // Now we keep opening connections till we run out of local ephemeral ports. + // and assert the error we get back. + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + std::vector<FileDescriptor> clients; + std::vector<FileDescriptor> servers; + + for (int i = 0; i < kClients; i++) { + FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len); + if (ret == 0) { + clients.push_back(std::move(client)); + FileDescriptor server = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + servers.push_back(std::move(server)); + continue; + } + ASSERT_THAT(ret, SyscallFailsWithErrno(EADDRNOTAVAIL)); + break; + } +} + +INSTANTIATE_TEST_SUITE_P( + All, SocketInetLoopbackTest, + ::testing::Values( + // Listeners bound to IPv4 addresses refuse connections using IPv6 + // addresses. + TestParam{V4Any(), V4Any()}, TestParam{V4Any(), V4Loopback()}, + TestParam{V4Any(), V4MappedAny()}, + TestParam{V4Any(), V4MappedLoopback()}, + TestParam{V4Loopback(), V4Any()}, TestParam{V4Loopback(), V4Loopback()}, + TestParam{V4Loopback(), V4MappedLoopback()}, + TestParam{V4MappedAny(), V4Any()}, + TestParam{V4MappedAny(), V4Loopback()}, + TestParam{V4MappedAny(), V4MappedAny()}, + TestParam{V4MappedAny(), V4MappedLoopback()}, + TestParam{V4MappedLoopback(), V4Any()}, + TestParam{V4MappedLoopback(), V4Loopback()}, + TestParam{V4MappedLoopback(), V4MappedLoopback()}, + + // Listeners bound to IN6ADDR_ANY accept all connections. + TestParam{V6Any(), V4Any()}, TestParam{V6Any(), V4Loopback()}, + TestParam{V6Any(), V4MappedAny()}, + TestParam{V6Any(), V4MappedLoopback()}, TestParam{V6Any(), V6Any()}, + TestParam{V6Any(), V6Loopback()}, + + // Listeners bound to IN6ADDR_LOOPBACK refuse connections using IPv4 + // addresses. + TestParam{V6Loopback(), V6Any()}, + TestParam{V6Loopback(), V6Loopback()}), + DescribeTestParam); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_loopback_blocking.cc b/test/syscalls/linux/socket_ip_loopback_blocking.cc index d7fc9715b..fda252dd7 100644 --- a/test/syscalls/linux/socket_ip_loopback_blocking.cc +++ b/test/syscalls/linux/socket_ip_loopback_blocking.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <netinet/tcp.h> + #include <vector> #include "test/syscalls/linux/ip_socket_test_util.h" @@ -22,6 +23,7 @@ namespace gvisor { namespace testing { +namespace { std::vector<SocketPairKind> GetSocketPairs() { return VecCat<SocketPairKind>( @@ -42,5 +44,6 @@ INSTANTIATE_TEST_SUITE_P( BlockingIPSockets, BlockingSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +} // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index 7e0deda05..53c076787 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -24,13 +24,20 @@ #include <sys/un.h> #include "gtest/gtest.h" +#include "absl/memory/memory.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" +#include "test/util/thread_util.h" namespace gvisor { namespace testing { -TEST_P(TCPSocketPairTest, TcpInfoSucceedes) { +using ::testing::AnyOf; +using ::testing::Eq; + +TEST_P(TCPSocketPairTest, TcpInfoSucceeds) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); struct tcp_info opt = {}; @@ -39,7 +46,7 @@ TEST_P(TCPSocketPairTest, TcpInfoSucceedes) { SyscallSucceeds()); } -TEST_P(TCPSocketPairTest, ShortTcpInfoSucceedes) { +TEST_P(TCPSocketPairTest, ShortTcpInfoSucceeds) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); struct tcp_info opt = {}; @@ -48,7 +55,7 @@ TEST_P(TCPSocketPairTest, ShortTcpInfoSucceedes) { SyscallSucceeds()); } -TEST_P(TCPSocketPairTest, ZeroTcpInfoSucceedes) { +TEST_P(TCPSocketPairTest, ZeroTcpInfoSucceeds) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); struct tcp_info opt = {}; @@ -243,6 +250,31 @@ TEST_P(TCPSocketPairTest, ShutdownRdAllowsReadOfReceivedDataBeforeEOF) { SyscallSucceedsWithValue(0)); } +// This test verifies that a shutdown(wr) by the server after sending +// data allows the client to still read() the queued data and a client +// close after sending response allows server to read the incoming +// response. +TEST_P(TCPSocketPairTest, ShutdownWrServerClientClose) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + char buf[10] = {}; + ScopedThread t([&]() { + ASSERT_THAT(RetryEINTR(read)(sockets->first_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(close(sockets->release_first_fd()), + SyscallSucceedsWithValue(0)); + }); + ASSERT_THAT(RetryEINTR(write)(sockets->second_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(RetryEINTR(shutdown)(sockets->second_fd(), SHUT_WR), + SyscallSucceedsWithValue(0)); + t.Join(); + + ASSERT_THAT(RetryEINTR(read)(sockets->second_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); +} + TEST_P(TCPSocketPairTest, ClosedReadNonBlockingSocket) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -495,6 +527,7 @@ TEST_P(TCPSocketPairTest, SetTCPKeepintvlZero) { // Copied from include/net/tcp.h. constexpr int MAX_TCP_KEEPIDLE = 32767; constexpr int MAX_TCP_KEEPINTVL = 32767; +constexpr int MAX_TCP_KEEPCNT = 127; TEST_P(TCPSocketPairTest, SetTCPKeepidleAboveMax) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -546,6 +579,78 @@ TEST_P(TCPSocketPairTest, SetTCPKeepintvlToMax) { EXPECT_EQ(get, MAX_TCP_KEEPINTVL); } +TEST_P(TCPSocketPairTest, TCPKeepcountDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 9); // 9 keepalive probes. +} + +TEST_P(TCPSocketPairTest, SetTCPKeepcountZero) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kZero = 0; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, &kZero, + sizeof(kZero)), + SyscallFailsWithErrno(EINVAL)); +} + +TEST_P(TCPSocketPairTest, SetTCPKeepcountAboveMax) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kAboveMax = MAX_TCP_KEEPCNT + 1; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, + &kAboveMax, sizeof(kAboveMax)), + SyscallFailsWithErrno(EINVAL)); +} + +TEST_P(TCPSocketPairTest, SetTCPKeepcountToMax) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, + &MAX_TCP_KEEPCNT, sizeof(MAX_TCP_KEEPCNT)), + SyscallSucceedsWithValue(0)); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, MAX_TCP_KEEPCNT); +} + +TEST_P(TCPSocketPairTest, SetTCPKeepcountToOne) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int keepaliveCount = 1; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, + &keepaliveCount, sizeof(keepaliveCount)), + SyscallSucceedsWithValue(0)); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, keepaliveCount); +} + +TEST_P(TCPSocketPairTest, SetTCPKeepcountToNegative) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int keepaliveCount = -5; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, + &keepaliveCount, sizeof(keepaliveCount)), + SyscallFailsWithErrno(EINVAL)); +} + TEST_P(TCPSocketPairTest, SetOOBInline) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -696,5 +801,265 @@ TEST_P(TCPSocketPairTest, SetCongestionControlFailsForUnsupported) { EXPECT_EQ(0, memcmp(got_cc, old_cc, sizeof(old_cc))); } +// Linux and Netstack both default to a 60s TCP_LINGER2 timeout. +constexpr int kDefaultTCPLingerTimeout = 60; +// On Linux, the maximum linger2 timeout was changed from 60sec to 120sec. +constexpr int kMaxTCPLingerTimeout = 120; +constexpr int kOldMaxTCPLingerTimeout = 60; + +TEST_P(TCPSocketPairTest, TCPLingerTimeoutDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kDefaultTCPLingerTimeout); +} + +TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutZeroOrLess) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kZero = 0; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &kZero, + sizeof(kZero)), + SyscallSucceedsWithValue(0)); + + constexpr int kNegative = -1234; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, + &kNegative, sizeof(kNegative)), + SyscallSucceedsWithValue(0)); +} + +TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutAboveMax) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Values above the net.ipv4.tcp_fin_timeout are capped to tcp_fin_timeout + // on linux (defaults to 60 seconds on linux). + constexpr int kAboveDefault = kMaxTCPLingerTimeout + 1; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, + &kAboveDefault, sizeof(kAboveDefault)), + SyscallSucceedsWithValue(0)); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + if (IsRunningOnGvisor()) { + EXPECT_EQ(get, kMaxTCPLingerTimeout); + } else { + EXPECT_THAT(get, + AnyOf(Eq(kMaxTCPLingerTimeout), Eq(kOldMaxTCPLingerTimeout))); + } +} + +TEST_P(TCPSocketPairTest, SetTCPLingerTimeout) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Values above the net.ipv4.tcp_fin_timeout are capped to tcp_fin_timeout + // on linux (defaults to 60 seconds on linux). + constexpr int kTCPLingerTimeout = kDefaultTCPLingerTimeout - 1; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, + &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), + SyscallSucceedsWithValue(0)); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kTCPLingerTimeout); +} + +TEST_P(TCPSocketPairTest, TestTCPCloseWithData) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ScopedThread t([&]() { + // Close one end to trigger sending of a FIN. + ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_WR), SyscallSucceeds()); + char buf[3]; + ASSERT_THAT(read(sockets->second_fd(), buf, 3), + SyscallSucceedsWithValue(3)); + absl::SleepFor(absl::Milliseconds(50)); + ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds()); + }); + + absl::SleepFor(absl::Milliseconds(50)); + // Send some data then close. + constexpr char kStr[] = "abc"; + ASSERT_THAT(write(sockets->first_fd(), kStr, 3), SyscallSucceedsWithValue(3)); + t.Join(); + ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); +} + +TEST_P(TCPSocketPairTest, TCPUserTimeoutDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 0); // 0 ms (disabled). +} + +TEST_P(TCPSocketPairTest, SetTCPUserTimeoutZero) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kZero = 0; + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kZero, sizeof(kZero)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 0); // 0 ms (disabled). +} + +TEST_P(TCPSocketPairTest, SetTCPUserTimeoutBelowZero) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kNeg = -10; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kNeg, sizeof(kNeg)), + SyscallFailsWithErrno(EINVAL)); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 0); // 0 ms (disabled). +} + +TEST_P(TCPSocketPairTest, SetTCPUserTimeoutAboveZero) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kAbove = 10; + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kAbove, sizeof(kAbove)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kAbove); +} + +TEST_P(TCPSocketPairTest, SetTCPWindowClampBelowMinRcvBufConnectedSocket) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + // Discover minimum receive buf by setting a really low value + // for the receive buffer. + constexpr int kZero = 0; + EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVBUF, &kZero, + sizeof(kZero)), + SyscallSucceeds()); + + // Now retrieve the minimum value for SO_RCVBUF as the set above should + // have caused SO_RCVBUF for the socket to be set to the minimum. + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVBUF, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + int min_so_rcvbuf = get; + + { + // Setting TCP_WINDOW_CLAMP to zero for a connected socket is not permitted. + constexpr int kZero = 0; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_WINDOW_CLAMP, + &kZero, sizeof(kZero)), + SyscallFailsWithErrno(EINVAL)); + + // Non-zero clamp values below MIN_SO_RCVBUF/2 should result in the clamp + // being set to MIN_SO_RCVBUF/2. + int below_half_min_so_rcvbuf = min_so_rcvbuf / 2 - 1; + EXPECT_THAT( + setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_WINDOW_CLAMP, + &below_half_min_so_rcvbuf, sizeof(below_half_min_so_rcvbuf)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_WINDOW_CLAMP, + &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(min_so_rcvbuf / 2, get); + } +} + +TEST_P(TCPSocketPairTest, IpMulticastTtlDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, + &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_GT(get, 0); +} + +TEST_P(TCPSocketPairTest, IpMulticastLoopDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, + &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 1); +} + +TEST_P(TCPSocketPairTest, TCPResetDuringClose_NoRandomSave) { + DisableSave ds; // Too many syscalls. + constexpr int kThreadCount = 1000; + std::unique_ptr<ScopedThread> instances[kThreadCount]; + for (int i = 0; i < kThreadCount; i++) { + instances[i] = absl::make_unique<ScopedThread>([&]() { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ScopedThread t([&]() { + // Close one end to trigger sending of a FIN. + struct pollfd poll_fd = {sockets->second_fd(), POLLIN | POLLHUP, 0}; + // Wait up to 20 seconds for the data. + constexpr int kPollTimeoutMs = 20000; + ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), + SyscallSucceedsWithValue(1)); + ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds()); + }); + + // Send some data then close. + constexpr char kStr[] = "abc"; + ASSERT_THAT(write(sockets->first_fd(), kStr, 3), + SyscallSucceedsWithValue(3)); + absl::SleepFor(absl::Milliseconds(10)); + ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); + t.Join(); + }); + } + for (int i = 0; i < kThreadCount; i++) { + instances[i]->Join(); + } +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_tcp_generic_loopback.cc b/test/syscalls/linux/socket_ip_tcp_generic_loopback.cc index 0dc274e2d..4e79d21f4 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic_loopback.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic_loopback.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <netinet/tcp.h> + #include <vector> #include "test/syscalls/linux/ip_socket_test_util.h" @@ -22,6 +23,7 @@ namespace gvisor { namespace testing { +namespace { std::vector<SocketPairKind> GetSocketPairs() { return ApplyVecToVec<SocketPairKind>( @@ -38,5 +40,6 @@ INSTANTIATE_TEST_SUITE_P( AllTCPSockets, TCPSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +} // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_tcp_loopback.cc b/test/syscalls/linux/socket_ip_tcp_loopback.cc index 831de53b8..9db3037bc 100644 --- a/test/syscalls/linux/socket_ip_tcp_loopback.cc +++ b/test/syscalls/linux/socket_ip_tcp_loopback.cc @@ -21,6 +21,7 @@ namespace gvisor { namespace testing { +namespace { std::vector<SocketPairKind> GetSocketPairs() { return { @@ -34,5 +35,6 @@ INSTANTIATE_TEST_SUITE_P( AllUnixDomainSockets, AllSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +} // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc b/test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc index cd3ad97d0..f996b93d2 100644 --- a/test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc +++ b/test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <netinet/tcp.h> + #include <vector> #include "test/syscalls/linux/ip_socket_test_util.h" @@ -22,6 +23,7 @@ namespace gvisor { namespace testing { +namespace { std::vector<SocketPairKind> GetSocketPairs() { return ApplyVecToVec<SocketPairKind>( @@ -38,5 +40,6 @@ INSTANTIATE_TEST_SUITE_P( BlockingTCPSockets, BlockingStreamSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +} // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc b/test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc index 1acdecc17..ffa377210 100644 --- a/test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc +++ b/test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <netinet/tcp.h> + #include <vector> #include "test/syscalls/linux/ip_socket_test_util.h" @@ -22,6 +23,7 @@ namespace gvisor { namespace testing { +namespace { std::vector<SocketPairKind> GetSocketPairs() { return ApplyVecToVec<SocketPairKind>( @@ -37,5 +39,6 @@ INSTANTIATE_TEST_SUITE_P( NonBlockingTCPSockets, NonBlockingSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +} // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc index 2a4ed04a5..edb86aded 100644 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_udp_generic.cc @@ -14,6 +14,7 @@ #include "test/syscalls/linux/socket_ip_udp_generic.h" +#include <errno.h> #include <netinet/in.h> #include <netinet/tcp.h> #include <poll.h> @@ -35,7 +36,7 @@ TEST_P(UDPSocketPairTest, MulticastTTLDefault) { int get = -1; socklen_t get_len = sizeof(get); - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, &get, &get_len), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_len, sizeof(get)); @@ -52,7 +53,7 @@ TEST_P(UDPSocketPairTest, SetUDPMulticastTTLMin) { int get = -1; socklen_t get_len = sizeof(get); - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, &get, &get_len), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_len, sizeof(get)); @@ -69,7 +70,7 @@ TEST_P(UDPSocketPairTest, SetUDPMulticastTTLMax) { int get = -1; socklen_t get_len = sizeof(get); - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, &get, &get_len), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_len, sizeof(get)); @@ -91,7 +92,7 @@ TEST_P(UDPSocketPairTest, SetUDPMulticastTTLNegativeOne) { int get = -1; socklen_t get_len = sizeof(get); - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, &get, &get_len), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_len, sizeof(get)); @@ -126,7 +127,7 @@ TEST_P(UDPSocketPairTest, SetUDPMulticastTTLChar) { int get = -1; socklen_t get_len = sizeof(get); - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, &get, &get_len), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_len, sizeof(get)); @@ -147,7 +148,7 @@ TEST_P(UDPSocketPairTest, MulticastLoopDefault) { int get = -1; socklen_t get_len = sizeof(get); - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, &get, &get_len), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_len, sizeof(get)); @@ -163,7 +164,7 @@ TEST_P(UDPSocketPairTest, SetMulticastLoop) { int get = -1; socklen_t get_len = sizeof(get); - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, &get, &get_len), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_len, sizeof(get)); @@ -173,7 +174,7 @@ TEST_P(UDPSocketPairTest, SetMulticastLoop) { &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, &get, &get_len), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_len, sizeof(get)); @@ -192,7 +193,7 @@ TEST_P(UDPSocketPairTest, SetMulticastLoopChar) { int get = -1; socklen_t get_len = sizeof(get); - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, &get, &get_len), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_len, sizeof(get)); @@ -202,12 +203,250 @@ TEST_P(UDPSocketPairTest, SetMulticastLoopChar) { &kSockOptOnChar, sizeof(kSockOptOnChar)), SyscallSucceeds()); - EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, &get, &get_len), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_len, sizeof(get)); EXPECT_EQ(get, kSockOptOn); } +TEST_P(UDPSocketPairTest, ReuseAddrDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); +} + +TEST_P(UDPSocketPairTest, SetReuseAddr) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOn); + + ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOff, sizeof(kSockOptOff)), + SyscallSucceeds()); + + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); +} + +TEST_P(UDPSocketPairTest, ReusePortDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); +} + +TEST_P(UDPSocketPairTest, SetReusePort) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOn); + + ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, + &kSockOptOff, sizeof(kSockOptOff)), + SyscallSucceeds()); + + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); +} + +TEST_P(UDPSocketPairTest, SetReuseAddrReusePort) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOn); + + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOn); +} + +// Test getsockopt for a socket which is not set with IP_PKTINFO option. +TEST_P(UDPSocketPairTest, IPPKTINFODefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_IP, IP_PKTINFO, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); +} + +// Test setsockopt and getsockopt for a socket with IP_PKTINFO option. +TEST_P(UDPSocketPairTest, SetAndGetIPPKTINFO) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int level = SOL_IP; + int type = IP_PKTINFO; + + // Check getsockopt before IP_PKTINFO is set. + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + + ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get, kSockOptOn); + EXPECT_EQ(get_len, sizeof(get)); + + ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &kSockOptOff, + sizeof(kSockOptOff)), + SyscallSucceedsWithValue(0)); + + ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get, kSockOptOff); + EXPECT_EQ(get_len, sizeof(get)); +} + +// Holds TOS or TClass information for IPv4 or IPv6 respectively. +struct RecvTosOption { + int level; + int option; +}; + +RecvTosOption GetRecvTosOption(int domain) { + TEST_CHECK(domain == AF_INET || domain == AF_INET6); + RecvTosOption opt; + switch (domain) { + case AF_INET: + opt.level = IPPROTO_IP; + opt.option = IP_RECVTOS; + break; + case AF_INET6: + opt.level = IPPROTO_IPV6; + opt.option = IPV6_RECVTCLASS; + break; + } + return opt; +} + +// Ensure that Receiving TOS or TCLASS is off by default. +TEST_P(UDPSocketPairTest, RecvTosDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + RecvTosOption t = GetRecvTosOption(GetParam().domain); + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); +} + +// Test that setting and getting IP_RECVTOS or IPV6_RECVTCLASS works as +// expected. +TEST_P(UDPSocketPairTest, SetRecvTos) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + RecvTosOption t = GetRecvTosOption(GetParam().domain); + + ASSERT_THAT(setsockopt(sockets->first_fd(), t.level, t.option, &kSockOptOff, + sizeof(kSockOptOff)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); + + ASSERT_THAT(setsockopt(sockets->first_fd(), t.level, t.option, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + ASSERT_THAT( + getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOn); +} + +// Test that any socket (including IPv6 only) accepts the IPv4 TOS option: this +// mirrors behavior in linux. +TEST_P(UDPSocketPairTest, TOSRecvMismatch) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + RecvTosOption t = GetRecvTosOption(AF_INET); + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT( + getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len), + SyscallSucceedsWithValue(0)); +} + +// Test that an IPv4 socket does not support the IPv6 TClass option. +TEST_P(UDPSocketPairTest, TClassRecvMismatch) { + // This should only test AF_INET sockets for the mismatch behavior. + SKIP_IF(GetParam().domain != AF_INET); + + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IPV6, IPV6_RECVTCLASS, + &get, &get_len), + SyscallFailsWithErrno(EOPNOTSUPP)); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_udp_loopback.cc b/test/syscalls/linux/socket_ip_udp_loopback.cc index 1df74a348..c7fa44884 100644 --- a/test/syscalls/linux/socket_ip_udp_loopback.cc +++ b/test/syscalls/linux/socket_ip_udp_loopback.cc @@ -23,6 +23,7 @@ namespace gvisor { namespace testing { +namespace { std::vector<SocketPairKind> GetSocketPairs() { return { @@ -44,5 +45,6 @@ INSTANTIATE_TEST_SUITE_P( AllUDPSockets, UDPSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +} // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_udp_loopback_blocking.cc b/test/syscalls/linux/socket_ip_udp_loopback_blocking.cc index 1e259efa7..d6925a8df 100644 --- a/test/syscalls/linux/socket_ip_udp_loopback_blocking.cc +++ b/test/syscalls/linux/socket_ip_udp_loopback_blocking.cc @@ -21,6 +21,7 @@ namespace gvisor { namespace testing { +namespace { std::vector<SocketPairKind> GetSocketPairs() { return { @@ -33,5 +34,6 @@ INSTANTIATE_TEST_SUITE_P( BlockingUDPSockets, BlockingNonStreamSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +} // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_udp_loopback_nonblock.cc b/test/syscalls/linux/socket_ip_udp_loopback_nonblock.cc index 74cbd326d..d675eddc6 100644 --- a/test/syscalls/linux/socket_ip_udp_loopback_nonblock.cc +++ b/test/syscalls/linux/socket_ip_udp_loopback_nonblock.cc @@ -21,6 +21,7 @@ namespace gvisor { namespace testing { +namespace { std::vector<SocketPairKind> GetSocketPairs() { return { @@ -33,5 +34,6 @@ INSTANTIATE_TEST_SUITE_P( NonBlockingUDPSockets, NonBlockingSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +} // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_unbound.cc b/test/syscalls/linux/socket_ip_unbound.cc index b02872308..1c7b0cf90 100644 --- a/test/syscalls/linux/socket_ip_unbound.cc +++ b/test/syscalls/linux/socket_ip_unbound.cc @@ -40,7 +40,7 @@ TEST_P(IPUnboundSocketTest, TtlDefault) { socklen_t get_sz = sizeof(get); EXPECT_THAT(getsockopt(socket->get(), IPPROTO_IP, IP_TTL, &get, &get_sz), SyscallSucceedsWithValue(0)); - EXPECT_EQ(get, 64); + EXPECT_TRUE(get == 64 || get == 127); EXPECT_EQ(get_sz, sizeof(get)); } @@ -129,6 +129,7 @@ TEST_P(IPUnboundSocketTest, InvalidNegativeTtl) { struct TOSOption { int level; int option; + int cmsg_level; }; constexpr int INET_ECN_MASK = 3; @@ -139,10 +140,12 @@ static TOSOption GetTOSOption(int domain) { case AF_INET: opt.level = IPPROTO_IP; opt.option = IP_TOS; + opt.cmsg_level = SOL_IP; break; case AF_INET6: opt.level = IPPROTO_IPV6; opt.option = IPV6_TCLASS; + opt.cmsg_level = SOL_IPV6; break; } return opt; @@ -154,7 +157,7 @@ TEST_P(IPUnboundSocketTest, TOSDefault) { int get = -1; socklen_t get_sz = sizeof(get); constexpr int kDefaultTOS = 0; - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_sz, sizeof(get)); EXPECT_EQ(get, kDefaultTOS); @@ -170,7 +173,7 @@ TEST_P(IPUnboundSocketTest, SetTOS) { int get = -1; socklen_t get_sz = sizeof(get); - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_sz, sizeof(get)); EXPECT_EQ(get, set); @@ -185,7 +188,7 @@ TEST_P(IPUnboundSocketTest, ZeroTOS) { SyscallSucceedsWithValue(0)); int get = -1; socklen_t get_sz = sizeof(get); - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_sz, sizeof(get)); EXPECT_EQ(get, set); @@ -207,7 +210,7 @@ TEST_P(IPUnboundSocketTest, InvalidLargeTOS) { } int get = -1; socklen_t get_sz = sizeof(get); - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_sz, sizeof(get)); EXPECT_EQ(get, kDefaultTOS); @@ -226,7 +229,7 @@ TEST_P(IPUnboundSocketTest, CheckSkipECN) { } int get = -1; socklen_t get_sz = sizeof(get); - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_sz, sizeof(get)); EXPECT_EQ(get, expect); @@ -246,7 +249,7 @@ TEST_P(IPUnboundSocketTest, ZeroTOSOptionSize) { } int get = -1; socklen_t get_sz = 0; - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_sz, 0); EXPECT_EQ(get, -1); @@ -273,7 +276,7 @@ TEST_P(IPUnboundSocketTest, SmallTOSOptionSize) { } uint get = -1; socklen_t get_sz = i; - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_sz, expect_sz); // Account for partial copies by getsockopt, retrieve the lower @@ -294,7 +297,7 @@ TEST_P(IPUnboundSocketTest, LargeTOSOptionSize) { // We expect the system call handler to only copy atmost sizeof(int) bytes // as asserted by the check below. Hence, we do not expect the copy to // overflow in getsockopt. - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_sz, sizeof(int)); EXPECT_EQ(get, set); @@ -322,7 +325,7 @@ TEST_P(IPUnboundSocketTest, NegativeTOS) { } int get = -1; socklen_t get_sz = sizeof(get); - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_sz, sizeof(get)); EXPECT_EQ(get, expect); @@ -335,25 +338,118 @@ TEST_P(IPUnboundSocketTest, InvalidNegativeTOS) { TOSOption t = GetTOSOption(GetParam().domain); int expect; if (GetParam().domain == AF_INET) { - EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), + ASSERT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), SyscallSucceedsWithValue(0)); expect = static_cast<uint8_t>(set); if (GetParam().protocol == IPPROTO_TCP) { expect &= ~INET_ECN_MASK; } } else { - EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), + ASSERT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), SyscallFailsWithErrno(EINVAL)); expect = 0; } int get = 0; socklen_t get_sz = sizeof(get); - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_sz, sizeof(get)); EXPECT_EQ(get, expect); } +TEST_P(IPUnboundSocketTest, NullTOS) { + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + TOSOption t = GetTOSOption(GetParam().domain); + int set_sz = sizeof(int); + if (GetParam().domain == AF_INET) { + EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, nullptr, set_sz), + SyscallFailsWithErrno(EFAULT)); + } else { // AF_INET6 + // The AF_INET6 behavior is not yet compatible. gVisor will try to read + // optval from user memory at syscall handler, it needs substantial + // refactoring to implement this behavior just for IPv6. + if (IsRunningOnGvisor()) { + EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, nullptr, set_sz), + SyscallFailsWithErrno(EFAULT)); + } else { + // Linux's IPv6 stack treats nullptr optval as input of 0, so the call + // succeeds. (net/ipv6/ipv6_sockglue.c, do_ipv6_setsockopt()) + // + // Linux's implementation would need fixing as passing a nullptr as optval + // and non-zero optlen may not be valid. + // TODO(b/158666797): Combine the gVisor and linux cases for IPv6. + // Some kernel versions return EFAULT, so we handle both. + EXPECT_THAT( + setsockopt(socket->get(), t.level, t.option, nullptr, set_sz), + AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(0))); + } + } + socklen_t get_sz = sizeof(int); + EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, nullptr, &get_sz), + SyscallFailsWithErrno(EFAULT)); + int get = -1; + EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, nullptr), + SyscallFailsWithErrno(EFAULT)); +} + +TEST_P(IPUnboundSocketTest, InsufficientBufferTOS) { + SKIP_IF(GetParam().protocol == IPPROTO_TCP); + + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + TOSOption t = GetTOSOption(GetParam().domain); + + in_addr addr4; + in6_addr addr6; + ASSERT_THAT(inet_pton(AF_INET, "127.0.0.1", &addr4), ::testing::Eq(1)); + ASSERT_THAT(inet_pton(AF_INET6, "fe80::", &addr6), ::testing::Eq(1)); + + cmsghdr cmsg = {}; + cmsg.cmsg_len = sizeof(cmsg); + cmsg.cmsg_level = t.cmsg_level; + cmsg.cmsg_type = t.option; + + msghdr msg = {}; + msg.msg_control = &cmsg; + msg.msg_controllen = sizeof(cmsg); + if (GetParam().domain == AF_INET) { + msg.msg_name = &addr4; + msg.msg_namelen = sizeof(addr4); + } else { + msg.msg_name = &addr6; + msg.msg_namelen = sizeof(addr6); + } + + EXPECT_THAT(sendmsg(socket->get(), &msg, 0), SyscallFailsWithErrno(EINVAL)); +} + +TEST_P(IPUnboundSocketTest, ReuseAddrDefault) { + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + int get = -1; + socklen_t get_sz = sizeof(get); + ASSERT_THAT( + getsockopt(socket->get(), SOL_SOCKET, SO_REUSEADDR, &get, &get_sz), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get, kSockOptOff); + EXPECT_EQ(get_sz, sizeof(get)); +} + +TEST_P(IPUnboundSocketTest, SetReuseAddr) { + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + ASSERT_THAT(setsockopt(socket->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + + int get = -1; + socklen_t get_sz = sizeof(get); + ASSERT_THAT( + getsockopt(socket->get(), SOL_SOCKET, SO_REUSEADDR, &get, &get_sz), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get, kSockOptOn); + EXPECT_EQ(get_sz, sizeof(get)); +} + INSTANTIATE_TEST_SUITE_P( IPUnboundSockets, IPUnboundSocketTest, ::testing::ValuesIn(VecCat<SocketKind>(VecCat<SocketKind>( diff --git a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc index 3c3712b50..80f12b0a9 100644 --- a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc +++ b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc @@ -18,6 +18,7 @@ #include <sys/socket.h> #include <sys/types.h> #include <sys/un.h> + #include <cstdio> #include <cstring> diff --git a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc index 92f03e045..797c4174e 100644 --- a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc +++ b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc @@ -12,15 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.h" + #include <vector> #include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" namespace gvisor { namespace testing { +namespace { std::vector<SocketKind> GetSockets() { return ApplyVec<SocketKind>( @@ -31,5 +33,7 @@ std::vector<SocketKind> GetSockets() { INSTANTIATE_TEST_SUITE_P(IPv4TCPUnboundSockets, IPv4TCPUnboundExternalNetworkingSocketTest, ::testing::ValuesIn(GetSockets())); + +} // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index b828b6844..bc005e2bb 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -15,12 +15,16 @@ #include "test/syscalls/linux/socket_ipv4_udp_unbound.h" #include <arpa/inet.h> +#include <net/if.h> #include <sys/ioctl.h> #include <sys/socket.h> +#include <sys/types.h> #include <sys/un.h> + #include <cstdio> #include "gtest/gtest.h" +#include "absl/memory/memory.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" @@ -28,49 +32,29 @@ namespace gvisor { namespace testing { -constexpr char kMulticastAddress[] = "224.0.2.1"; -constexpr char kBroadcastAddress[] = "255.255.255.255"; - -TestAddress V4Multicast() { - TestAddress t("V4Multicast"); - t.addr.ss_family = AF_INET; - t.addr_len = sizeof(sockaddr_in); - reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr = - inet_addr(kMulticastAddress); - return t; -} - -TestAddress V4Broadcast() { - TestAddress t("V4Broadcast"); - t.addr.ss_family = AF_INET; - t.addr_len = sizeof(sockaddr_in); - reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr = - inet_addr(kBroadcastAddress); - return t; -} - // Check that packets are not received without a group membership. Default send // interface configured by bind. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNoGroup) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNoGroup) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Bind the first FD to the loopback. This is an alternative to // IP_MULTICAST_IF for setting the default send interface. auto sender_addr = V4Loopback(); EXPECT_THAT( - bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), sender_addr.addr_len), SyscallSucceeds()); // Bind the second FD to the v4 any address. If multicast worked like unicast, // this would ensure that we get the packet. auto receiver_addr = V4Any(); - EXPECT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + EXPECT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -82,33 +66,33 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNoGroup) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - EXPECT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + EXPECT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), SyscallFailsWithErrno(EAGAIN)); } // Check that not setting a default send interface prevents multicast packets // from being sent. Group membership interface configured by address. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackAddrNoDefaultSendIf) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackAddrNoDefaultSendIf) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Bind the second FD to the v4 any address to ensure that we can receive any // unicast packet. auto receiver_addr = V4Any(); - EXPECT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + EXPECT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -118,8 +102,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackAddrNoDefaultSendIf) { ip_mreq group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -128,27 +112,27 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackAddrNoDefaultSendIf) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - EXPECT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallFailsWithErrno(ENETUNREACH)); + EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallFailsWithErrno(ENETUNREACH)); } // Check that not setting a default send interface prevents multicast packets // from being sent. Group membership interface configured by NIC ID. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNicNoDefaultSendIf) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNicNoDefaultSendIf) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Bind the second FD to the v4 any address to ensure that we can receive any // unicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -158,8 +142,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNicNoDefaultSendIf) { ip_mreqn group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -168,35 +152,35 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNicNoDefaultSendIf) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - EXPECT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallFailsWithErrno(ENETUNREACH)); + EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallFailsWithErrno(ENETUNREACH)); } // Check that multicast works when the default send interface is configured by // bind and the group membership is configured by address. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackAddr) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackAddr) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Bind the first FD to the loopback. This is an alternative to // IP_MULTICAST_IF for setting the default send interface. auto sender_addr = V4Loopback(); ASSERT_THAT( - bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), sender_addr.addr_len), SyscallSucceeds()); // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -206,8 +190,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackAddr) { ip_mreq group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -216,43 +200,42 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackAddr) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } // Check that multicast works when the default send interface is configured by // bind and the group membership is configured by NIC ID. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNic) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNic) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Bind the first FD to the loopback. This is an alternative to // IP_MULTICAST_IF for setting the default send interface. auto sender_addr = V4Loopback(); ASSERT_THAT( - bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), sender_addr.addr_len), SyscallSucceeds()); // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -262,8 +245,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNic) { ip_mreqn group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -272,17 +255,15 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNic) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -290,25 +271,26 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNic) { // Check that multicast works when the default send interface is configured by // IP_MULTICAST_IF, the send address is specified in sendto, and the group // membership is configured by address. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddr) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddr) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Set the default send interface. ip_mreq iface = {}; iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallSucceeds()); // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -318,8 +300,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddr) { ip_mreq group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -328,17 +310,15 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddr) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -346,25 +326,26 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddr) { // Check that multicast works when the default send interface is configured by // IP_MULTICAST_IF, the send address is specified in sendto, and the group // membership is configured by NIC ID. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNic) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNic) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Set the default send interface. ip_mreqn iface = {}; iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallSucceeds()); // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -374,8 +355,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNic) { ip_mreqn group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -384,17 +365,15 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNic) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -402,25 +381,26 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNic) { // Check that multicast works when the default send interface is configured by // IP_MULTICAST_IF, the send address is specified in connect, and the group // membership is configured by address. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrConnect) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrConnect) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Set the default send interface. ip_mreq iface = {}; iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallSucceeds()); // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -430,8 +410,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrConnect) { ip_mreq group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -439,22 +419,20 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrConnect) { reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port = reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; ASSERT_THAT( - RetryEINTR(connect)(sockets->first_fd(), + RetryEINTR(connect)(socket1->get(), reinterpret_cast<sockaddr*>(&connect_addr.addr), connect_addr.addr_len), SyscallSucceeds()); char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), send_buf, sizeof(send_buf), 0), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(send)(socket1->get(), send_buf, sizeof(send_buf), 0), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -462,25 +440,26 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrConnect) { // Check that multicast works when the default send interface is configured by // IP_MULTICAST_IF, the send address is specified in connect, and the group // membership is configured by NIC ID. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicConnect) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicConnect) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Set the default send interface. ip_mreqn iface = {}; iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallSucceeds()); // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -490,8 +469,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicConnect) { ip_mreqn group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -499,22 +478,20 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicConnect) { reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port = reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; ASSERT_THAT( - RetryEINTR(connect)(sockets->first_fd(), + RetryEINTR(connect)(socket1->get(), reinterpret_cast<sockaddr*>(&connect_addr.addr), connect_addr.addr_len), SyscallSucceeds()); char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), send_buf, sizeof(send_buf), 0), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(send)(socket1->get(), send_buf, sizeof(send_buf), 0), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -522,25 +499,26 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicConnect) { // Check that multicast works when the default send interface is configured by // IP_MULTICAST_IF, the send address is specified in sendto, and the group // membership is configured by address. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelf) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelf) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Set the default send interface. ip_mreq iface = {}; iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallSucceeds()); // Bind the first FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->first_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->first_fd(), + ASSERT_THAT(getsockname(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -550,8 +528,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelf) { ip_mreq group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -560,17 +538,15 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelf) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(sockets->first_fd(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -578,25 +554,26 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelf) { // Check that multicast works when the default send interface is configured by // IP_MULTICAST_IF, the send address is specified in sendto, and the group // membership is configured by NIC ID. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelf) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelf) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Set the default send interface. ip_mreqn iface = {}; iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallSucceeds()); // Bind the first FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->first_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->first_fd(), + ASSERT_THAT(getsockname(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -606,8 +583,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelf) { ip_mreqn group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -616,17 +593,15 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelf) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(sockets->first_fd(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -634,25 +609,26 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelf) { // Check that multicast works when the default send interface is configured by // IP_MULTICAST_IF, the send address is specified in connect, and the group // membership is configured by address. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelfConnect) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfConnect) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Set the default send interface. ip_mreq iface = {}; iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallSucceeds()); // Bind the first FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->first_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->first_fd(), + ASSERT_THAT(getsockname(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -662,8 +638,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelfConnect) { ip_mreq group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -671,20 +647,19 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelfConnect) { reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port = reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; EXPECT_THAT( - RetryEINTR(connect)(sockets->first_fd(), + RetryEINTR(connect)(socket1->get(), reinterpret_cast<sockaddr*>(&connect_addr.addr), connect_addr.addr_len), SyscallSucceeds()); char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), send_buf, sizeof(send_buf), 0), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(send)(socket1->get(), send_buf, sizeof(send_buf), 0), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT(RetryEINTR(recv)(sockets->first_fd(), recv_buf, sizeof(recv_buf), + EXPECT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), SyscallFailsWithErrno(EAGAIN)); } @@ -692,25 +667,26 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelfConnect) { // Check that multicast works when the default send interface is configured by // IP_MULTICAST_IF, the send address is specified in connect, and the group // membership is configured by NIC ID. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelfConnect) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfConnect) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Set the default send interface. ip_mreqn iface = {}; iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallSucceeds()); // Bind the first FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->first_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->first_fd(), + ASSERT_THAT(getsockname(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -720,8 +696,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelfConnect) { ip_mreqn group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -729,20 +705,19 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelfConnect) { reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port = reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; ASSERT_THAT( - RetryEINTR(connect)(sockets->first_fd(), + RetryEINTR(connect)(socket1->get(), reinterpret_cast<sockaddr*>(&connect_addr.addr), connect_addr.addr_len), SyscallSucceeds()); char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(send)(sockets->first_fd(), send_buf, sizeof(send_buf), 0), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(send)(socket1->get(), send_buf, sizeof(send_buf), 0), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT(RetryEINTR(recv)(sockets->first_fd(), recv_buf, sizeof(recv_buf), + EXPECT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), SyscallFailsWithErrno(EAGAIN)); } @@ -750,29 +725,30 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelfConnect) { // Check that multicast works when the default send interface is configured by // IP_MULTICAST_IF, the send address is specified in sendto, and the group // membership is configured by address. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelfNoLoop) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfNoLoop) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Set the default send interface. ip_mreq iface = {}; iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallSucceeds()); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_LOOP, &kSockOptOff, sizeof(kSockOptOff)), SyscallSucceeds()); // Bind the first FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->first_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->first_fd(), + ASSERT_THAT(getsockname(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -782,8 +758,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelfNoLoop) { ip_mreq group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -792,17 +768,15 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelfNoLoop) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(sockets->first_fd(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -810,29 +784,30 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelfNoLoop) { // Check that multicast works when the default send interface is configured by // IP_MULTICAST_IF, the send address is specified in sendto, and the group // membership is configured by NIC ID. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelfNoLoop) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfNoLoop) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Set the default send interface. ip_mreqn iface = {}; iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallSucceeds()); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_LOOP, &kSockOptOff, sizeof(kSockOptOff)), SyscallSucceeds()); // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->first_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->first_fd(), + ASSERT_THAT(getsockname(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -842,8 +817,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelfNoLoop) { ip_mreqn group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -852,57 +827,57 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelfNoLoop) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(sockets->first_fd(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } // Check that dropping a group membership that does not exist fails. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastInvalidDrop) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastInvalidDrop) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Unregister from a membership that we didn't have. ip_mreq group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group, + sizeof(group)), SyscallFailsWithErrno(EADDRNOTAVAIL)); } // Check that dropping a group membership prevents multicast packets from being // delivered. Default send address configured by bind and group membership // interface configured by address. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastDropAddr) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropAddr) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Bind the first FD to the loopback. This is an alternative to // IP_MULTICAST_IF for setting the default send interface. auto sender_addr = V4Loopback(); EXPECT_THAT( - bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), sender_addr.addr_len), SyscallSucceeds()); // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - EXPECT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + EXPECT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -912,11 +887,11 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastDropAddr) { ip_mreq group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); - EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -925,15 +900,14 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastDropAddr) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - EXPECT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + EXPECT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), SyscallFailsWithErrno(EAGAIN)); } @@ -941,26 +915,27 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastDropAddr) { // Check that dropping a group membership prevents multicast packets from being // delivered. Default send address configured by bind and group membership // interface configured by NIC ID. -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastDropNic) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropNic) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Bind the first FD to the loopback. This is an alternative to // IP_MULTICAST_IF for setting the default send interface. auto sender_addr = V4Loopback(); EXPECT_THAT( - bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), sender_addr.addr_len), SyscallSucceeds()); // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - EXPECT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + EXPECT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -970,11 +945,11 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastDropNic) { ip_mreqn group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); - EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet. @@ -983,50 +958,53 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastDropNic) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - EXPECT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + EXPECT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), SyscallFailsWithErrno(EAGAIN)); } -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfZero) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfZero) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); ip_mreqn iface = {}; - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallSucceeds()); } -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfInvalidNic) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfInvalidNic) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); ip_mreqn iface = {}; iface.imr_ifindex = -1; - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallFailsWithErrno(EADDRNOTAVAIL)); } -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfInvalidAddr) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfInvalidAddr) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); ip_mreq iface = {}; iface.imr_interface.s_addr = inet_addr("255.255.255"); - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallFailsWithErrno(EADDRNOTAVAIL)); } -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetShort) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetShort) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Create a valid full-sized request. ip_mreqn iface = {}; @@ -1034,29 +1012,31 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetShort) { // Send an optlen of 1 to check that optlen is enforced. EXPECT_THAT( - setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &iface, 1), + setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, 1), SyscallFailsWithErrno(EINVAL)); } -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfDefault) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfDefault) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); in_addr get = {}; socklen_t size = sizeof(get); ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), + getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), SyscallSucceeds()); EXPECT_EQ(size, sizeof(get)); EXPECT_EQ(get.s_addr, 0); } -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfDefaultReqn) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfDefaultReqn) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); ip_mreqn get = {}; socklen_t size = sizeof(get); ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), + getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), SyscallSucceeds()); // getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the @@ -1071,19 +1051,20 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfDefaultReqn) { EXPECT_EQ(get.imr_ifindex, 0); } -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetAddrGetReqn) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetAddrGetReqn) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); in_addr set = {}; set.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &set, + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, sizeof(set)), SyscallSucceeds()); ip_mreqn get = {}; socklen_t size = sizeof(get); ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), + getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), SyscallSucceeds()); // getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the @@ -1095,19 +1076,20 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetAddrGetReqn) { EXPECT_EQ(get.imr_ifindex, 0); } -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetReqAddrGetReqn) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetReqAddrGetReqn) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); ip_mreq set = {}; set.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &set, + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, sizeof(set)), SyscallSucceeds()); ip_mreqn get = {}; socklen_t size = sizeof(get); ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), + getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), SyscallSucceeds()); // getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the @@ -1119,19 +1101,20 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetReqAddrGetReqn) { EXPECT_EQ(get.imr_ifindex, 0); } -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetNicGetReqn) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetNicGetReqn) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); ip_mreqn set = {}; set.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &set, + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, sizeof(set)), SyscallSucceeds()); ip_mreqn get = {}; socklen_t size = sizeof(get); ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), + getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), SyscallSucceeds()); EXPECT_EQ(size, sizeof(in_addr)); EXPECT_EQ(get.imr_multiaddr.s_addr, 0); @@ -1139,87 +1122,93 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetNicGetReqn) { EXPECT_EQ(get.imr_ifindex, 0); } -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetAddr) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetAddr) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); in_addr set = {}; set.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &set, + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, sizeof(set)), SyscallSucceeds()); in_addr get = {}; socklen_t size = sizeof(get); ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), + getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), SyscallSucceeds()); EXPECT_EQ(size, sizeof(get)); EXPECT_EQ(get.s_addr, set.s_addr); } -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetReqAddr) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetReqAddr) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); ip_mreq set = {}; set.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &set, + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, sizeof(set)), SyscallSucceeds()); in_addr get = {}; socklen_t size = sizeof(get); ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), + getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), SyscallSucceeds()); EXPECT_EQ(size, sizeof(get)); EXPECT_EQ(get.s_addr, set.imr_interface.s_addr); } -TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetNic) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetNic) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); ip_mreqn set = {}; set.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &set, + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, sizeof(set)), SyscallSucceeds()); in_addr get = {}; socklen_t size = sizeof(get); ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), + getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), SyscallSucceeds()); EXPECT_EQ(size, sizeof(get)); EXPECT_EQ(get.s_addr, 0); } -TEST_P(IPv4UDPUnboundSocketPairTest, TestJoinGroupNoIf) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, TestJoinGroupNoIf) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); ip_mreqn group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallFailsWithErrno(ENODEV)); } -TEST_P(IPv4UDPUnboundSocketPairTest, TestJoinGroupInvalidIf) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, TestJoinGroupInvalidIf) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); ip_mreqn group = {}; group.imr_address.s_addr = inet_addr("255.255.255"); group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallFailsWithErrno(ENODEV)); } // Check that multiple memberships are not allowed on the same socket. -TEST_P(IPv4UDPUnboundSocketPairTest, TestMultipleJoinsOnSingleSocket) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - auto fd = sockets->first_fd(); +TEST_P(IPv4UDPUnboundSocketTest, TestMultipleJoinsOnSingleSocket) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto fd = socket1->get(); ip_mreqn group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); @@ -1234,41 +1223,44 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestMultipleJoinsOnSingleSocket) { } // Check that two sockets can join the same multicast group at the same time. -TEST_P(IPv4UDPUnboundSocketPairTest, TestTwoSocketsJoinSameMulticastGroup) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, TestTwoSocketsJoinSameMulticastGroup) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); ip_mreqn group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); - EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Drop the membership twice on each socket, the second call for each socket // should fail. - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group, + sizeof(group)), SyscallFailsWithErrno(EADDRNOTAVAIL)); - EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); - EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP, - &group, sizeof(group)), + EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group, + sizeof(group)), SyscallFailsWithErrno(EADDRNOTAVAIL)); } // Check that two sockets can join the same multicast group at the same time, // and both will receive data on it. -TEST_P(IPv4UDPUnboundSocketPairTest, TestMcastReceptionOnTwoSockets) { +TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionOnTwoSockets) { std::unique_ptr<SocketPair> socket_pairs[2] = { - ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()), - ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair())}; + absl::make_unique<FDSocketPair>(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), + ASSERT_NO_ERRNO_AND_VALUE(NewSocket())), + absl::make_unique<FDSocketPair>(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), + ASSERT_NO_ERRNO_AND_VALUE(NewSocket()))}; ip_mreq iface = {}, group = {}; iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); @@ -1338,11 +1330,12 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestMcastReceptionOnTwoSockets) { // Check that on two sockets that joined a group and listen on ANY, dropping // memberships one by one will continue to deliver packets to both sockets until // both memberships have been dropped. -TEST_P(IPv4UDPUnboundSocketPairTest, - TestMcastReceptionWhenDroppingMemberships) { +TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionWhenDroppingMemberships) { std::unique_ptr<SocketPair> socket_pairs[2] = { - ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()), - ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair())}; + absl::make_unique<FDSocketPair>(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), + ASSERT_NO_ERRNO_AND_VALUE(NewSocket())), + absl::make_unique<FDSocketPair>(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), + ASSERT_NO_ERRNO_AND_VALUE(NewSocket()))}; ip_mreq iface = {}, group = {}; iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); @@ -1437,18 +1430,19 @@ TEST_P(IPv4UDPUnboundSocketPairTest, // Check that a receiving socket can bind to the multicast address before // joining the group and receive data once the group has been joined. -TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenJoinThenReceive) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenJoinThenReceive) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Bind second socket (receiver) to the multicast address. auto receiver_addr = V4Multicast(); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); // Update receiver_addr with the correct port number. socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -1458,30 +1452,29 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenJoinThenReceive) { ip_mreqn group = {}; group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); - ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, - &group, sizeof(group)), + ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), SyscallSucceeds()); // Send a multicast packet on the first socket out the loopback interface. ip_mreq iface = {}; iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallSucceeds()); auto sendto_addr = V4Multicast(); reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&sendto_addr.addr), - sendto_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); @@ -1489,18 +1482,19 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenJoinThenReceive) { // Check that a receiving socket can bind to the multicast address and won't // receive multicast data if it hasn't joined the group. -TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenNoJoinThenNoReceive) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenNoJoinThenNoReceive) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Bind second socket (receiver) to the multicast address. auto receiver_addr = V4Multicast(); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); // Update receiver_addr with the correct port number. socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -1509,40 +1503,40 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenNoJoinThenNoReceive) { // Send a multicast packet on the first socket out the loopback interface. ip_mreq iface = {}; iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, - &iface, sizeof(iface)), + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), SyscallSucceeds()); auto sendto_addr = V4Multicast(); reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&sendto_addr.addr), - sendto_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we don't receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), SyscallFailsWithErrno(EAGAIN)); } // Check that a socket can bind to a multicast address and still send out // packets. -TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenSend) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenSend) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Bind second socket (receiver) to the ANY address. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -1551,11 +1545,11 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenSend) { // Bind the first socket (sender) to the multicast address. auto sender_addr = V4Multicast(); ASSERT_THAT( - bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), sender_addr.addr_len), SyscallSucceeds()); socklen_t sender_addr_len = sender_addr.addr_len; - ASSERT_THAT(getsockname(sockets->first_fd(), + ASSERT_THAT(getsockname(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), &sender_addr_len), SyscallSucceeds()); @@ -1567,15 +1561,14 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenSend) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&sendto_addr.addr), - sendto_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); @@ -1583,46 +1576,46 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenSend) { // Check that a receiving socket can bind to the broadcast address and receive // broadcast packets. -TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToBcastThenReceive) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenReceive) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Bind second socket (receiver) to the broadcast address. auto receiver_addr = V4Broadcast(); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); // Send a broadcast packet on the first socket out the loopback interface. - EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_BROADCAST, - &kSockOptOn, sizeof(kSockOptOn)), + EXPECT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn, + sizeof(kSockOptOn)), SyscallSucceedsWithValue(0)); // Note: Binding to the loopback interface makes the broadcast go out of it. auto sender_bind_addr = V4Loopback(); - ASSERT_THAT(bind(sockets->first_fd(), - reinterpret_cast<sockaddr*>(&sender_bind_addr.addr), - sender_bind_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_bind_addr.addr), + sender_bind_addr.addr_len), + SyscallSucceeds()); auto sendto_addr = V4Broadcast(); reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&sendto_addr.addr), - sendto_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); @@ -1630,17 +1623,18 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToBcastThenReceive) { // Check that a socket can bind to the broadcast address and still send out // packets. -TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToBcastThenSend) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenSend) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Bind second socket (receiver) to the ANY address. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), + ASSERT_THAT(getsockname(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); @@ -1649,11 +1643,11 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToBcastThenSend) { // Bind the first socket (sender) to the broadcast address. auto sender_addr = V4Broadcast(); ASSERT_THAT( - bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), sender_addr.addr_len), SyscallSucceeds()); socklen_t sender_addr_len = sender_addr.addr_len; - ASSERT_THAT(getsockname(sockets->first_fd(), + ASSERT_THAT(getsockname(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), &sender_addr_len), SyscallSucceeds()); @@ -1665,19 +1659,898 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToBcastThenSend) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT( - RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&sendto_addr.addr), - sendto_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), SyscallSucceedsWithValue(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } +// Check that SO_REUSEADDR always delivers to the most recently bound socket. +// +// FIXME(gvisor.dev/issue/873): Endpoint order is not restored correctly. Enable +// random and co-op save (below) once that is fixed. +TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution_NoRandomSave) { + std::vector<std::unique_ptr<FileDescriptor>> sockets; + sockets.emplace_back(ASSERT_NO_ERRNO_AND_VALUE(NewSocket())); + + ASSERT_THAT(setsockopt(sockets[0]->get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Bind the first socket to the loopback and take note of the selected port. + auto addr = V4Loopback(); + ASSERT_THAT(bind(sockets[0]->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + socklen_t addr_len = addr.addr_len; + ASSERT_THAT(getsockname(sockets[0]->get(), + reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + SyscallSucceeds()); + EXPECT_EQ(addr_len, addr.addr_len); + + constexpr int kMessageSize = 200; + + // FIXME(gvisor.dev/issue/873): Endpoint order is not restored correctly. + const DisableSave ds; + + for (int i = 0; i < 10; i++) { + // Add a new receiver. + sockets.emplace_back(ASSERT_NO_ERRNO_AND_VALUE(NewSocket())); + auto& last = sockets.back(); + ASSERT_THAT(setsockopt(last->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(last->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + + // Send a new message to the SO_REUSEADDR group. We use a new socket each + // time so that a new ephemeral port will be used each time. This ensures + // that we aren't doing REUSEPORT-like hash load blancing. + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + char send_buf[kMessageSize]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + EXPECT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Verify that the most recent socket got the message. We don't expect any + // of the other sockets to have received it, but we will check that later. + char recv_buf[sizeof(send_buf)] = {}; + EXPECT_THAT( + RetryEINTR(recv)(last->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), + SyscallSucceedsWithValue(sizeof(send_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); + } + + // Verify that no other messages were received. + for (auto& socket : sockets) { + char recv_buf[kMessageSize] = {}; + EXPECT_THAT(RetryEINTR(recv)(socket->get(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); + } +} + +TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrThenReusePort) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Bind socket1 with REUSEADDR. + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Bind the first socket to the loopback and take note of the selected port. + auto addr = V4Loopback(); + ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + socklen_t addr_len = addr.addr_len; + ASSERT_THAT(getsockname(socket1->get(), + reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + SyscallSucceeds()); + EXPECT_EQ(addr_len, addr.addr_len); + + // Bind socket2 to the same address as socket1, only with REUSEPORT. + ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(IPv4UDPUnboundSocketTest, BindReusePortThenReuseAddr) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Bind socket1 with REUSEPORT. + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Bind the first socket to the loopback and take note of the selected port. + auto addr = V4Loopback(); + ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + socklen_t addr_len = addr.addr_len; + ASSERT_THAT(getsockname(socket1->get(), + reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + SyscallSucceeds()); + EXPECT_EQ(addr_len, addr.addr_len); + + // Bind socket2 to the same address as socket1, only with REUSEADDR. + ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConvertibleToReusePort) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Bind socket1 with REUSEADDR and REUSEPORT. + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Bind the first socket to the loopback and take note of the selected port. + auto addr = V4Loopback(); + ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + socklen_t addr_len = addr.addr_len; + ASSERT_THAT(getsockname(socket1->get(), + reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + SyscallSucceeds()); + EXPECT_EQ(addr_len, addr.addr_len); + + // Bind socket2 to the same address as socket1, only with REUSEPORT. + ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + + // Bind socket3 to the same address as socket1, only with REUSEADDR. + ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConvertibleToReuseAddr) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Bind socket1 with REUSEADDR and REUSEPORT. + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Bind the first socket to the loopback and take note of the selected port. + auto addr = V4Loopback(); + ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + socklen_t addr_len = addr.addr_len; + ASSERT_THAT(getsockname(socket1->get(), + reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + SyscallSucceeds()); + EXPECT_EQ(addr_len, addr.addr_len); + + // Bind socket2 to the same address as socket1, only with REUSEADDR. + ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + + // Bind socket3 to the same address as socket1, only with REUSEPORT. + ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable1) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Bind socket1 with REUSEADDR and REUSEPORT. + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Bind the first socket to the loopback and take note of the selected port. + auto addr = V4Loopback(); + ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + socklen_t addr_len = addr.addr_len; + ASSERT_THAT(getsockname(socket1->get(), + reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + SyscallSucceeds()); + EXPECT_EQ(addr_len, addr.addr_len); + + // Bind socket2 to the same address as socket1, only with REUSEPORT. + ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + + // Close socket2 to revert to just socket1 with REUSEADDR and REUSEPORT. + socket2->reset(); + + // Bind socket3 to the same address as socket1, only with REUSEADDR. + ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); +} + +TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable2) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Bind socket1 with REUSEADDR and REUSEPORT. + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Bind the first socket to the loopback and take note of the selected port. + auto addr = V4Loopback(); + ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + socklen_t addr_len = addr.addr_len; + ASSERT_THAT(getsockname(socket1->get(), + reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + SyscallSucceeds()); + EXPECT_EQ(addr_len, addr.addr_len); + + // Bind socket2 to the same address as socket1, only with REUSEADDR. + ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + + // Close socket2 to revert to just socket1 with REUSEADDR and REUSEPORT. + socket2->reset(); + + // Bind socket3 to the same address as socket1, only with REUSEPORT. + ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); +} + +TEST_P(IPv4UDPUnboundSocketTest, BindDoubleReuseAddrReusePortThenReusePort) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Bind socket1 with REUSEADDR and REUSEPORT. + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Bind the first socket to the loopback and take note of the selected port. + auto addr = V4Loopback(); + ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + socklen_t addr_len = addr.addr_len; + ASSERT_THAT(getsockname(socket1->get(), + reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + SyscallSucceeds()); + EXPECT_EQ(addr_len, addr.addr_len); + + // Bind socket2 to the same address as socket1, also with REUSEADDR and + // REUSEPORT. + ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + + // Bind socket3 to the same address as socket1, only with REUSEPORT. + ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); +} + +TEST_P(IPv4UDPUnboundSocketTest, BindDoubleReuseAddrReusePortThenReuseAddr) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Bind socket1 with REUSEADDR and REUSEPORT. + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Bind the first socket to the loopback and take note of the selected port. + auto addr = V4Loopback(); + ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + socklen_t addr_len = addr.addr_len; + ASSERT_THAT(getsockname(socket1->get(), + reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + SyscallSucceeds()); + EXPECT_EQ(addr_len, addr.addr_len); + + // Bind socket2 to the same address as socket1, also with REUSEADDR and + // REUSEPORT. + ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + + // Bind socket3 to the same address as socket1, only with REUSEADDR. + ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); +} + +// Check that REUSEPORT takes precedence over REUSEADDR. +TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) { + auto receiver1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto receiver2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + ASSERT_THAT(setsockopt(receiver1->get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(receiver1->get(), SOL_SOCKET, SO_REUSEPORT, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Bind the first socket to the loopback and take note of the selected port. + auto addr = V4Loopback(); + ASSERT_THAT(bind(receiver1->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + socklen_t addr_len = addr.addr_len; + ASSERT_THAT(getsockname(receiver1->get(), + reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + SyscallSucceeds()); + EXPECT_EQ(addr_len, addr.addr_len); + + // Bind receiver2 to the same address as socket1, also with REUSEADDR and + // REUSEPORT. + ASSERT_THAT(setsockopt(receiver2->get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(receiver2->get(), SOL_SOCKET, SO_REUSEPORT, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(receiver2->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + + constexpr int kMessageSize = 10; + + for (int i = 0; i < 100; ++i) { + // Send a new message to the REUSEADDR/REUSEPORT group. We use a new socket + // each time so that a new ephemerial port will be used each time. This + // ensures that we cycle through hashes. + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + char send_buf[kMessageSize] = {}; + EXPECT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + } + + // Check that both receivers got messages. This checks that we are using load + // balancing (REUSEPORT) instead of the most recently bound socket + // (REUSEADDR). + char recv_buf[kMessageSize] = {}; + EXPECT_THAT(RetryEINTR(recv)(receiver1->get(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallSucceedsWithValue(kMessageSize)); + EXPECT_THAT(RetryEINTR(recv)(receiver2->get(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallSucceedsWithValue(kMessageSize)); +} + +// Check that connect returns EADDRNOTAVAIL when out of local ephemeral ports. +// We disable S/R because this test creates a large number of sockets. +TEST_P(IPv4UDPUnboundSocketTest, UDPConnectPortExhaustion_NoRandomSave) { + auto receiver1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + constexpr int kClients = 65536; + // Bind the first socket to the loopback and take note of the selected port. + auto addr = V4Loopback(); + ASSERT_THAT(bind(receiver1->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + socklen_t addr_len = addr.addr_len; + ASSERT_THAT(getsockname(receiver1->get(), + reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + SyscallSucceeds()); + EXPECT_EQ(addr_len, addr.addr_len); + + // Disable cooperative S/R as we are making too many syscalls. + DisableSave ds; + std::vector<std::unique_ptr<FileDescriptor>> sockets; + for (int i = 0; i < kClients; i++) { + auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + int ret = connect(s->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len); + if (ret == 0) { + sockets.push_back(std::move(s)); + continue; + } + ASSERT_THAT(ret, SyscallFailsWithErrno(EAGAIN)); + break; + } +} + +// Test that socket will receive packet info control message. +TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPPKTINFO) { + // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. + SKIP_IF((IsRunningWithHostinet())); + + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto sender_addr = V4Loopback(); + int level = SOL_IP; + int type = IP_PKTINFO; + + ASSERT_THAT( + bind(receiver->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + sender_addr.addr_len), + SyscallSucceeds()); + socklen_t sender_addr_len = sender_addr.addr_len; + ASSERT_THAT(getsockname(receiver->get(), + reinterpret_cast<sockaddr*>(&sender_addr.addr), + &sender_addr_len), + SyscallSucceeds()); + EXPECT_EQ(sender_addr_len, sender_addr.addr_len); + + auto receiver_addr = V4Loopback(); + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&sender_addr.addr)->sin_port; + ASSERT_THAT( + connect(sender->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + + // Allow socket to receive control message. + ASSERT_THAT( + setsockopt(receiver->get(), level, type, &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Prepare message to send. + constexpr size_t kDataLength = 1024; + msghdr sent_msg = {}; + iovec sent_iov = {}; + char sent_data[kDataLength]; + sent_iov.iov_base = sent_data; + sent_iov.iov_len = kDataLength; + sent_msg.msg_iov = &sent_iov; + sent_msg.msg_iovlen = 1; + sent_msg.msg_flags = 0; + + ASSERT_THAT(RetryEINTR(sendmsg)(sender->get(), &sent_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + msghdr received_msg = {}; + iovec received_iov = {}; + char received_data[kDataLength]; + char received_cmsg_buf[CMSG_SPACE(sizeof(in_pktinfo))] = {}; + size_t cmsg_data_len = sizeof(in_pktinfo); + received_iov.iov_base = received_data; + received_iov.iov_len = kDataLength; + received_msg.msg_iov = &received_iov; + received_msg.msg_iovlen = 1; + received_msg.msg_controllen = CMSG_LEN(cmsg_data_len); + received_msg.msg_control = received_cmsg_buf; + + ASSERT_THAT(RetryEINTR(recvmsg)(receiver->get(), &received_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); + EXPECT_EQ(cmsg->cmsg_level, level); + EXPECT_EQ(cmsg->cmsg_type, type); + + // Get loopback index. + ifreq ifr = {}; + absl::SNPrintF(ifr.ifr_name, IFNAMSIZ, "lo"); + ASSERT_THAT(ioctl(sender->get(), SIOCGIFINDEX, &ifr), SyscallSucceeds()); + ASSERT_NE(ifr.ifr_ifindex, 0); + + // Check the data + in_pktinfo received_pktinfo = {}; + memcpy(&received_pktinfo, CMSG_DATA(cmsg), sizeof(in_pktinfo)); + EXPECT_EQ(received_pktinfo.ipi_ifindex, ifr.ifr_ifindex); + EXPECT_EQ(received_pktinfo.ipi_spec_dst.s_addr, htonl(INADDR_LOOPBACK)); + EXPECT_EQ(received_pktinfo.ipi_addr.s_addr, htonl(INADDR_LOOPBACK)); +} + +// Check that setting SO_RCVBUF below min is clamped to the minimum +// receive buffer size. +TEST_P(IPv4UDPUnboundSocketTest, SetSocketRecvBufBelowMin) { + auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Discover minimum buffer size by setting it to zero. + constexpr int kRcvBufSz = 0; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, + sizeof(kRcvBufSz)), + SyscallSucceeds()); + + int min = 0; + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len), + SyscallSucceeds()); + + // Linux doubles the value so let's use a value that when doubled will still + // be smaller than min. + int below_min = min / 2 - 1; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &below_min, + sizeof(below_min)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &val, &val_len), + SyscallSucceeds()); + + ASSERT_EQ(min, val); +} + +// Check that setting SO_RCVBUF above max is clamped to the maximum +// receive buffer size. +TEST_P(IPv4UDPUnboundSocketTest, SetSocketRecvBufAboveMax) { + auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Discover maxmimum buffer size by setting to a really large value. + constexpr int kRcvBufSz = 0xffffffff; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, + sizeof(kRcvBufSz)), + SyscallSucceeds()); + + int max = 0; + socklen_t max_len = sizeof(max); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &max, &max_len), + SyscallSucceeds()); + + int above_max = max + 1; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &above_max, + sizeof(above_max)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &val, &val_len), + SyscallSucceeds()); + ASSERT_EQ(max, val); +} + +// Check that setting SO_RCVBUF min <= rcvBufSz <= max is honored. +TEST_P(IPv4UDPUnboundSocketTest, SetSocketRecvBuf) { + auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + int max = 0; + int min = 0; + { + // Discover maxmimum buffer size by setting to a really large value. + constexpr int kRcvBufSz = 0xffffffff; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, + sizeof(kRcvBufSz)), + SyscallSucceeds()); + + max = 0; + socklen_t max_len = sizeof(max); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &max, &max_len), + SyscallSucceeds()); + } + + { + // Discover minimum buffer size by setting it to zero. + constexpr int kRcvBufSz = 0; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, + sizeof(kRcvBufSz)), + SyscallSucceeds()); + + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len), + SyscallSucceeds()); + } + + int quarter_sz = min + (max - min) / 4; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &quarter_sz, + sizeof(quarter_sz)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &val, &val_len), + SyscallSucceeds()); + + // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF. + if (!IsRunningOnGvisor()) { + quarter_sz *= 2; + } + ASSERT_EQ(quarter_sz, val); +} + +// Check that setting SO_SNDBUF below min is clamped to the minimum +// send buffer size. +TEST_P(IPv4UDPUnboundSocketTest, SetSocketSendBufBelowMin) { + auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Discover minimum buffer size by setting it to zero. + constexpr int kSndBufSz = 0; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &kSndBufSz, + sizeof(kSndBufSz)), + SyscallSucceeds()); + + int min = 0; + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &min, &min_len), + SyscallSucceeds()); + + // Linux doubles the value so let's use a value that when doubled will still + // be smaller than min. + int below_min = min / 2 - 1; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &below_min, + sizeof(below_min)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &val, &val_len), + SyscallSucceeds()); + + ASSERT_EQ(min, val); +} + +// Check that setting SO_SNDBUF above max is clamped to the maximum +// send buffer size. +TEST_P(IPv4UDPUnboundSocketTest, SetSocketSendBufAboveMax) { + auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Discover maxmimum buffer size by setting to a really large value. + constexpr int kSndBufSz = 0xffffffff; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &kSndBufSz, + sizeof(kSndBufSz)), + SyscallSucceeds()); + + int max = 0; + socklen_t max_len = sizeof(max); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &max, &max_len), + SyscallSucceeds()); + + int above_max = max + 1; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &above_max, + sizeof(above_max)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &val, &val_len), + SyscallSucceeds()); + ASSERT_EQ(max, val); +} + +// Check that setting SO_SNDBUF min <= kSndBufSz <= max is honored. +TEST_P(IPv4UDPUnboundSocketTest, SetSocketSendBuf) { + auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + int max = 0; + int min = 0; + { + // Discover maxmimum buffer size by setting to a really large value. + constexpr int kSndBufSz = 0xffffffff; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &kSndBufSz, + sizeof(kSndBufSz)), + SyscallSucceeds()); + + max = 0; + socklen_t max_len = sizeof(max); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &max, &max_len), + SyscallSucceeds()); + } + + { + // Discover minimum buffer size by setting it to zero. + constexpr int kSndBufSz = 0; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &kSndBufSz, + sizeof(kSndBufSz)), + SyscallSucceeds()); + + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &min, &min_len), + SyscallSucceeds()); + } + + int quarter_sz = min + (max - min) / 4; + ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &quarter_sz, + sizeof(quarter_sz)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &val, &val_len), + SyscallSucceeds()); + + // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF. + if (!IsRunningOnGvisor()) { + quarter_sz *= 2; + } + + ASSERT_EQ(quarter_sz, val); +} + +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIPPacketInfo) { + auto sender_socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto receiver_socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Bind the first FD to the loopback. This is an alternative to + // IP_MULTICAST_IF for setting the default send interface. + auto sender_addr = V4Loopback(); + ASSERT_THAT( + bind(sender_socket->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + sender_addr.addr_len), + SyscallSucceeds()); + + // Bind the second FD to the v4 any address to ensure that we can receive the + // multicast packet. + auto receiver_addr = V4Any(); + ASSERT_THAT(bind(receiver_socket->get(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(receiver_socket->get(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Register to receive multicast packets. + ip_mreqn group = {}; + group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); + group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); + ASSERT_THAT(setsockopt(receiver_socket->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, + &group, sizeof(group)), + SyscallSucceeds()); + + // Register to receive IP packet info. + const int one = 1; + ASSERT_THAT(setsockopt(receiver_socket->get(), IPPROTO_IP, IP_PKTINFO, &one, + sizeof(one)), + SyscallSucceeds()); + + // Send a multicast packet. + auto send_addr = V4Multicast(); + reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sender_socket->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we received the multicast packet. + msghdr recv_msg = {}; + iovec recv_iov = {}; + char recv_buf[sizeof(send_buf)]; + char recv_cmsg_buf[CMSG_SPACE(sizeof(in_pktinfo))] = {}; + size_t cmsg_data_len = sizeof(in_pktinfo); + recv_iov.iov_base = recv_buf; + recv_iov.iov_len = sizeof(recv_buf); + recv_msg.msg_iov = &recv_iov; + recv_msg.msg_iovlen = 1; + recv_msg.msg_controllen = CMSG_LEN(cmsg_data_len); + recv_msg.msg_control = recv_cmsg_buf; + ASSERT_THAT(RetryEINTR(recvmsg)(receiver_socket->get(), &recv_msg, 0), + SyscallSucceedsWithValue(sizeof(send_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); + + // Check the IP_PKTINFO control message. + cmsghdr* cmsg = CMSG_FIRSTHDR(&recv_msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); + EXPECT_EQ(cmsg->cmsg_level, IPPROTO_IP); + EXPECT_EQ(cmsg->cmsg_type, IP_PKTINFO); + + // Get loopback index. + ifreq ifr = {}; + absl::SNPrintF(ifr.ifr_name, IFNAMSIZ, "lo"); + ASSERT_THAT(ioctl(receiver_socket->get(), SIOCGIFINDEX, &ifr), + SyscallSucceeds()); + ASSERT_NE(ifr.ifr_ifindex, 0); + + in_pktinfo received_pktinfo = {}; + memcpy(&received_pktinfo, CMSG_DATA(cmsg), sizeof(in_pktinfo)); + EXPECT_EQ(received_pktinfo.ipi_ifindex, ifr.ifr_ifindex); + if (IsRunningOnGvisor()) { + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + EXPECT_EQ(received_pktinfo.ipi_spec_dst.s_addr, group.imr_multiaddr.s_addr); + } else { + EXPECT_EQ(received_pktinfo.ipi_spec_dst.s_addr, htonl(INADDR_LOOPBACK)); + } + EXPECT_EQ(received_pktinfo.ipi_addr.s_addr, group.imr_multiaddr.s_addr); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.h b/test/syscalls/linux/socket_ipv4_udp_unbound.h index 8e07bfbbf..f64c57645 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.h +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.h @@ -20,8 +20,8 @@ namespace gvisor { namespace testing { -// Test fixture for tests that apply to pairs of IPv4 UDP sockets. -using IPv4UDPUnboundSocketPairTest = SocketPairTest; +// Test fixture for tests that apply to IPv4 UDP sockets. +using IPv4UDPUnboundSocketTest = SimpleSocketTest; } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc index 98ae414f3..b206137eb 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc @@ -41,63 +41,39 @@ TestAddress V4EmptyAddress() { return t; } -constexpr char kMulticastAddress[] = "224.0.2.1"; - -TestAddress V4Multicast() { - TestAddress t("V4Multicast"); - t.addr.ss_family = AF_INET; - t.addr_len = sizeof(sockaddr_in); - reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr = - inet_addr(kMulticastAddress); - return t; -} - -TestAddress V4Broadcast() { - TestAddress t("V4Broadcast"); - t.addr.ss_family = AF_INET; - t.addr_len = sizeof(sockaddr_in); - reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr = - htonl(INADDR_BROADCAST); - return t; -} - void IPv4UDPUnboundExternalNetworkingSocketTest::SetUp() { - got_if_infos_ = false; + // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its + // IPv4 address on eth0. + found_net_interfaces_ = false; // Get interface list. - std::vector<std::string> if_names; ASSERT_NO_ERRNO(if_helper_.Load()); - if_names = if_helper_.InterfaceList(AF_INET); + std::vector<std::string> if_names = if_helper_.InterfaceList(AF_INET); if (if_names.size() != 2) { return; } // Figure out which interface is where. - int lo = 0, eth = 1; - if (if_names[lo] != "lo") { - lo = 1; - eth = 0; - } - - if (if_names[lo] != "lo") { - return; - } - - lo_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(if_names[lo])); - lo_if_addr_ = if_helper_.GetAddr(AF_INET, if_names[lo]); - if (lo_if_addr_ == nullptr) { + std::string lo = if_names[0]; + std::string eth = if_names[1]; + if (lo != "lo") std::swap(lo, eth); + if (lo != "lo") return; + + lo_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(lo)); + auto lo_if_addr = if_helper_.GetAddr(AF_INET, lo); + if (lo_if_addr == nullptr) { return; } - lo_if_sin_addr_ = reinterpret_cast<sockaddr_in*>(lo_if_addr_)->sin_addr; + lo_if_addr_ = *reinterpret_cast<const sockaddr_in*>(lo_if_addr); - eth_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(if_names[eth])); - eth_if_addr_ = if_helper_.GetAddr(AF_INET, if_names[eth]); - if (eth_if_addr_ == nullptr) { + eth_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(eth)); + auto eth_if_addr = if_helper_.GetAddr(AF_INET, eth); + if (eth_if_addr == nullptr) { return; } - eth_if_sin_addr_ = reinterpret_cast<sockaddr_in*>(eth_if_addr_)->sin_addr; + eth_if_addr_ = *reinterpret_cast<const sockaddr_in*>(eth_if_addr); - got_if_infos_ = true; + found_net_interfaces_ = true; } // Verifies that a newly instantiated UDP socket does not have the @@ -136,6 +112,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, SetUDPBroadcast) { // the destination port number. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, UDPBroadcastReceivedOnExpectedPort) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto rcvr1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto rcvr2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -211,9 +188,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // not a unicast address. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, UDPBroadcastReceivedOnExpectedAddresses) { - // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its - // IPv4 address on eth0. - SKIP_IF(!got_if_infos_); + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto rcvr1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -262,7 +237,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // Bind the non-receiving socket to the unicast ethernet address. auto norecv_addr = rcv1_addr; reinterpret_cast<sockaddr_in*>(&norecv_addr.addr)->sin_addr = - eth_if_sin_addr_; + eth_if_addr_.sin_addr; ASSERT_THAT(bind(norcv->get(), reinterpret_cast<sockaddr*>(&norecv_addr.addr), norecv_addr.addr_len), SyscallSucceedsWithValue(0)); @@ -298,6 +273,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // (UDPBroadcastSendRecvOnSocketBoundToAny). TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, UDPBroadcastSendRecvOnSocketBoundToBroadcast) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Enable SO_BROADCAST. @@ -339,6 +315,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // (UDPBroadcastSendRecvOnSocketBoundToBroadcast). TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, UDPBroadcastSendRecvOnSocketBoundToAny) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Enable SO_BROADCAST. @@ -377,6 +354,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // Verifies that a UDP broadcast fails to send on a socket with SO_BROADCAST // disabled. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendBroadcast) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Broadcast a test message without having enabled SO_BROADCAST on the sending @@ -427,6 +405,8 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // multicast on gVisor. SKIP_IF(IsRunningOnGvisor()); + SKIP_IF(!found_net_interfaces_); + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto bind_addr = V4Any(); @@ -461,6 +441,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // Check that multicast packets will be delivered to the sending socket without // setting an interface. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastSelf) { + SKIP_IF(!found_net_interfaces_); auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto bind_addr = V4Any(); @@ -504,6 +485,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastSelf) { // set interface and IP_MULTICAST_LOOP disabled. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastSelfLoopOff) { + SKIP_IF(!found_net_interfaces_); auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto bind_addr = V4Any(); @@ -554,6 +536,8 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastNoGroup) { // multicast on gVisor. SKIP_IF(IsRunningOnGvisor()); + SKIP_IF(!found_net_interfaces_); + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -592,6 +576,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastNoGroup) { // Check that multicast packets will be delivered to another socket without // setting an interface. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticast) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -639,6 +624,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticast) { // set interface and IP_MULTICAST_LOOP disabled on the sending socket. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastSenderNoLoop) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -690,6 +676,8 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // setting an interface and IP_MULTICAST_LOOP disabled on the receiving socket. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastReceiverNoLoop) { + SKIP_IF(!found_net_interfaces_); + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -742,6 +730,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // and both will receive data on it when bound to the ANY address. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastToTwoBoundToAny) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); std::unique_ptr<FileDescriptor> receivers[2] = { ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), @@ -808,6 +797,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // and both will receive data on it when bound to the multicast address. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastToTwoBoundToMulticastAddress) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); std::unique_ptr<FileDescriptor> receivers[2] = { ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), @@ -877,6 +867,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // multicast address, both will receive data. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastToTwoBoundToAnyAndMulticastAddress) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); std::unique_ptr<FileDescriptor> receivers[2] = { ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), @@ -950,6 +941,8 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // is not a multicast address. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, IpMulticastLoopbackFromAddr) { + SKIP_IF(!found_net_interfaces_); + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -1017,9 +1010,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // interface, a multicast packet sent out uses the latter as its source address. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, IpMulticastLoopbackIfNicAndAddr) { - // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its - // IPv4 address on eth0. - SKIP_IF(!got_if_infos_); + SKIP_IF(!found_net_interfaces_); // Create receiver, bind to ANY and join the multicast group. auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -1048,7 +1039,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); ip_mreqn iface = {}; iface.imr_ifindex = lo_if_idx_; - iface.imr_address = eth_if_sin_addr_; + iface.imr_address = eth_if_addr_.sin_addr; ASSERT_THAT(setsockopt(sender->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, sizeof(iface)), SyscallSucceeds()); @@ -1078,16 +1069,14 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, SKIP_IF(IsRunningOnGvisor()); // Verify the received source address. - EXPECT_EQ(eth_if_sin_addr_.s_addr, src_addr_in->sin_addr.s_addr); + EXPECT_EQ(eth_if_addr_.sin_addr.s_addr, src_addr_in->sin_addr.s_addr); } // Check that when we are bound to one interface we can set IP_MULTICAST_IF to // another interface. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, IpMulticastLoopbackBindToOneIfSetMcastIfToAnother) { - // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its - // IPv4 address on eth0. - SKIP_IF(!got_if_infos_); + SKIP_IF(!found_net_interfaces_); // FIXME (b/137790511): When bound to one interface it is not possible to set // IP_MULTICAST_IF to a different interface. @@ -1095,7 +1084,8 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // Create sender and bind to eth interface. auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - ASSERT_THAT(bind(sender->get(), eth_if_addr_, sizeof(sockaddr_in)), + ASSERT_THAT(bind(sender->get(), reinterpret_cast<sockaddr*>(ð_if_addr_), + sizeof(eth_if_addr_)), SyscallSucceeds()); // Run through all possible combinations of index and address for @@ -1105,9 +1095,9 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, struct in_addr imr_address; } test_data[] = { {lo_if_idx_, {}}, - {0, lo_if_sin_addr_}, - {lo_if_idx_, lo_if_sin_addr_}, - {lo_if_idx_, eth_if_sin_addr_}, + {0, lo_if_addr_.sin_addr}, + {lo_if_idx_, lo_if_addr_.sin_addr}, + {lo_if_idx_, eth_if_addr_.sin_addr}, }; for (auto t : test_data) { ip_mreqn iface = {}; diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h index bec2e96ee..0e9e70e8e 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h @@ -29,17 +29,15 @@ class IPv4UDPUnboundExternalNetworkingSocketTest : public SimpleSocketTest { IfAddrHelper if_helper_; - // got_if_infos_ is set to false if SetUp() could not obtain all interface - // infos that we need. - bool got_if_infos_; + // found_net_interfaces_ is set to false if SetUp() could not obtain + // all interface infos that we need. + bool found_net_interfaces_; // Interface infos. int lo_if_idx_; int eth_if_idx_; - sockaddr* lo_if_addr_; - sockaddr* eth_if_addr_; - in_addr lo_if_sin_addr_; - in_addr eth_if_sin_addr_; + sockaddr_in lo_if_addr_; + sockaddr_in eth_if_addr_; }; } // namespace testing diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc index 9d4e1ab97..f6e64c157 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc @@ -12,15 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h" + #include <vector> #include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" namespace gvisor { namespace testing { +namespace { std::vector<SocketKind> GetSockets() { return ApplyVec<SocketKind>( @@ -31,5 +33,7 @@ std::vector<SocketKind> GetSockets() { INSTANTIATE_TEST_SUITE_P(IPv4UDPUnboundSockets, IPv4UDPUnboundExternalNetworkingSocketTest, ::testing::ValuesIn(GetSockets())); + +} // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc index cb0105471..f121c044d 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc @@ -22,14 +22,11 @@ namespace gvisor { namespace testing { -std::vector<SocketPairKind> GetSocketPairs() { - return ApplyVec<SocketPairKind>( - IPv4UDPUnboundSocketPair, - AllBitwiseCombinations(List<int>{0, SOCK_NONBLOCK})); -} - -INSTANTIATE_TEST_SUITE_P(IPv4UDPSockets, IPv4UDPUnboundSocketPairTest, - ::testing::ValuesIn(GetSocketPairs())); +INSTANTIATE_TEST_SUITE_P( + IPv4UDPSockets, IPv4UDPUnboundSocketTest, + ::testing::ValuesIn(ApplyVec<SocketKind>(IPv4UDPUnboundSocket, + AllBitwiseCombinations(List<int>{ + 0, SOCK_NONBLOCK})))); } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_netdevice.cc b/test/syscalls/linux/socket_netdevice.cc index 765f8e0e4..5f8d7f981 100644 --- a/test/syscalls/linux/socket_netdevice.cc +++ b/test/syscalls/linux/socket_netdevice.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <linux/ethtool.h> #include <linux/netlink.h> #include <linux/rtnetlink.h> #include <linux/sockios.h> @@ -49,6 +50,7 @@ TEST(NetdeviceTest, Loopback) { // Check that the loopback is zero hardware address. ASSERT_THAT(ioctl(sock.get(), SIOCGIFHWADDR, &ifr), SyscallSucceeds()); + EXPECT_EQ(ifr.ifr_hwaddr.sa_family, ARPHRD_LOOPBACK); EXPECT_EQ(ifr.ifr_hwaddr.sa_data[0], 0); EXPECT_EQ(ifr.ifr_hwaddr.sa_data[1], 0); EXPECT_EQ(ifr.ifr_hwaddr.sa_data[2], 0); @@ -68,7 +70,8 @@ TEST(NetdeviceTest, Netmask) { // Use a netlink socket to get the netmask, which we'll then compare to the // netmask obtained via ioctl. - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); struct request { @@ -90,7 +93,7 @@ TEST(NetdeviceTest, Netmask) { int prefixlen = -1; ASSERT_NO_ERRNO(NetlinkRequestResponse( fd, &req, sizeof(req), - [&](const struct nlmsghdr *hdr) { + [&](const struct nlmsghdr* hdr) { EXPECT_THAT(hdr->nlmsg_type, AnyOf(Eq(RTM_NEWADDR), Eq(NLMSG_DONE))); EXPECT_TRUE((hdr->nlmsg_flags & NLM_F_MULTI) == NLM_F_MULTI) @@ -106,8 +109,8 @@ TEST(NetdeviceTest, Netmask) { // RTM_NEWADDR contains at least the header and ifaddrmsg. EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct ifaddrmsg)); - struct ifaddrmsg *ifaddrmsg = - reinterpret_cast<struct ifaddrmsg *>(NLMSG_DATA(hdr)); + struct ifaddrmsg* ifaddrmsg = + reinterpret_cast<struct ifaddrmsg*>(NLMSG_DATA(hdr)); if (ifaddrmsg->ifa_index == static_cast<uint32_t>(ifr.ifr_ifindex) && ifaddrmsg->ifa_family == AF_INET) { prefixlen = ifaddrmsg->ifa_prefixlen; @@ -126,8 +129,8 @@ TEST(NetdeviceTest, Netmask) { snprintf(ifr.ifr_name, IFNAMSIZ, "lo"); ASSERT_THAT(ioctl(sock.get(), SIOCGIFNETMASK, &ifr), SyscallSucceeds()); EXPECT_EQ(ifr.ifr_netmask.sa_family, AF_INET); - struct sockaddr_in *sin = - reinterpret_cast<struct sockaddr_in *>(&ifr.ifr_netmask); + struct sockaddr_in* sin = + reinterpret_cast<struct sockaddr_in*>(&ifr.ifr_netmask); EXPECT_EQ(sin->sin_addr.s_addr, mask); } @@ -177,6 +180,27 @@ TEST(NetdeviceTest, InterfaceMTU) { EXPECT_GT(ifr.ifr_mtu, 0); } +TEST(NetdeviceTest, EthtoolGetTSInfo) { + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + + struct ethtool_ts_info tsi = {}; + tsi.cmd = ETHTOOL_GET_TS_INFO; // Get NIC's Timestamping capabilities. + + // Prepare the request. + struct ifreq ifr = {}; + snprintf(ifr.ifr_name, IFNAMSIZ, "lo"); + ifr.ifr_data = (void*)&tsi; + + // Check that SIOCGIFMTU returns a nonzero MTU. + if (IsRunningOnGvisor()) { + ASSERT_THAT(ioctl(sock.get(), SIOCETHTOOL, &ifr), + SyscallFailsWithErrno(EOPNOTSUPP)); + return; + } + ASSERT_THAT(ioctl(sock.get(), SIOCETHTOOL, &ifr), SyscallSucceeds()); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_netlink.cc b/test/syscalls/linux/socket_netlink.cc new file mode 100644 index 000000000..4ec0fd4fa --- /dev/null +++ b/test/syscalls/linux/socket_netlink.cc @@ -0,0 +1,153 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <linux/netlink.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" + +// Tests for all netlink socket protocols. + +namespace gvisor { +namespace testing { + +namespace { + +// NetlinkTest parameter is the protocol to test. +using NetlinkTest = ::testing::TestWithParam<int>; + +// Netlink sockets must be SOCK_DGRAM or SOCK_RAW. +TEST_P(NetlinkTest, Types) { + const int protocol = GetParam(); + + EXPECT_THAT(socket(AF_NETLINK, SOCK_STREAM, protocol), + SyscallFailsWithErrno(ESOCKTNOSUPPORT)); + EXPECT_THAT(socket(AF_NETLINK, SOCK_SEQPACKET, protocol), + SyscallFailsWithErrno(ESOCKTNOSUPPORT)); + EXPECT_THAT(socket(AF_NETLINK, SOCK_RDM, protocol), + SyscallFailsWithErrno(ESOCKTNOSUPPORT)); + EXPECT_THAT(socket(AF_NETLINK, SOCK_DCCP, protocol), + SyscallFailsWithErrno(ESOCKTNOSUPPORT)); + EXPECT_THAT(socket(AF_NETLINK, SOCK_PACKET, protocol), + SyscallFailsWithErrno(ESOCKTNOSUPPORT)); + + int fd; + EXPECT_THAT(fd = socket(AF_NETLINK, SOCK_DGRAM, protocol), SyscallSucceeds()); + EXPECT_THAT(close(fd), SyscallSucceeds()); + + EXPECT_THAT(fd = socket(AF_NETLINK, SOCK_RAW, protocol), SyscallSucceeds()); + EXPECT_THAT(close(fd), SyscallSucceeds()); +} + +TEST_P(NetlinkTest, AutomaticPort) { + const int protocol = GetParam(); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, protocol)); + + struct sockaddr_nl addr = {}; + addr.nl_family = AF_NETLINK; + + EXPECT_THAT( + bind(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), + SyscallSucceeds()); + + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), + &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, sizeof(addr)); + // This is the only netlink socket in the process, so it should get the PID as + // the port id. + // + // N.B. Another process could theoretically have explicitly reserved our pid + // as a port ID, but that is very unlikely. + EXPECT_EQ(addr.nl_pid, getpid()); +} + +// Calling connect automatically binds to an automatic port. +TEST_P(NetlinkTest, ConnectBinds) { + const int protocol = GetParam(); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, protocol)); + + struct sockaddr_nl addr = {}; + addr.nl_family = AF_NETLINK; + + EXPECT_THAT(connect(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), + sizeof(addr)), + SyscallSucceeds()); + + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), + &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, sizeof(addr)); + + // Each test is running in a pid namespace, so another process can explicitly + // reserve our pid as a port ID. In this case, a negative portid value will be + // set. + if (static_cast<pid_t>(addr.nl_pid) > 0) { + EXPECT_EQ(addr.nl_pid, getpid()); + } + + memset(&addr, 0, sizeof(addr)); + addr.nl_family = AF_NETLINK; + + // Connecting again is allowed, but keeps the same port. + EXPECT_THAT(connect(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), + sizeof(addr)), + SyscallSucceeds()); + + addrlen = sizeof(addr); + EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), + &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, sizeof(addr)); + EXPECT_EQ(addr.nl_pid, getpid()); +} + +TEST_P(NetlinkTest, GetPeerName) { + const int protocol = GetParam(); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, protocol)); + + struct sockaddr_nl addr = {}; + socklen_t addrlen = sizeof(addr); + + EXPECT_THAT(getpeername(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), + &addrlen), + SyscallSucceeds()); + + EXPECT_EQ(addrlen, sizeof(addr)); + EXPECT_EQ(addr.nl_family, AF_NETLINK); + // Peer is the kernel if we didn't connect elsewhere. + EXPECT_EQ(addr.nl_pid, 0); +} + +INSTANTIATE_TEST_SUITE_P(ProtocolTest, NetlinkTest, + ::testing::Values(NETLINK_ROUTE, + NETLINK_KOBJECT_UEVENT)); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc index dd4a11655..b3fcf8e7c 100644 --- a/test/syscalls/linux/socket_netlink_route.cc +++ b/test/syscalls/linux/socket_netlink_route.cc @@ -14,6 +14,7 @@ #include <arpa/inet.h> #include <ifaddrs.h> +#include <linux/if.h> #include <linux/netlink.h> #include <linux/rtnetlink.h> #include <sys/socket.h> @@ -25,8 +26,10 @@ #include "gtest/gtest.h" #include "absl/strings/str_format.h" +#include "test/syscalls/linux/socket_netlink_route_util.h" #include "test/syscalls/linux/socket_netlink_util.h" #include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" #include "test/util/cleanup.h" #include "test/util/file_descriptor.h" #include "test/util/test_util.h" @@ -38,115 +41,12 @@ namespace testing { namespace { +constexpr uint32_t kSeq = 12345; + using ::testing::AnyOf; using ::testing::Eq; -// Netlink sockets must be SOCK_DGRAM or SOCK_RAW. -TEST(NetlinkRouteTest, Types) { - EXPECT_THAT(socket(AF_NETLINK, SOCK_STREAM, NETLINK_ROUTE), - SyscallFailsWithErrno(ESOCKTNOSUPPORT)); - EXPECT_THAT(socket(AF_NETLINK, SOCK_SEQPACKET, NETLINK_ROUTE), - SyscallFailsWithErrno(ESOCKTNOSUPPORT)); - EXPECT_THAT(socket(AF_NETLINK, SOCK_RDM, NETLINK_ROUTE), - SyscallFailsWithErrno(ESOCKTNOSUPPORT)); - EXPECT_THAT(socket(AF_NETLINK, SOCK_DCCP, NETLINK_ROUTE), - SyscallFailsWithErrno(ESOCKTNOSUPPORT)); - EXPECT_THAT(socket(AF_NETLINK, SOCK_PACKET, NETLINK_ROUTE), - SyscallFailsWithErrno(ESOCKTNOSUPPORT)); - - int fd; - EXPECT_THAT(fd = socket(AF_NETLINK, SOCK_DGRAM, NETLINK_ROUTE), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); - - EXPECT_THAT(fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); -} - -TEST(NetlinkRouteTest, AutomaticPort) { - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE)); - - struct sockaddr_nl addr = {}; - addr.nl_family = AF_NETLINK; - - EXPECT_THAT( - bind(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), - SyscallSucceeds()); - - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), - &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, sizeof(addr)); - // This is the only netlink socket in the process, so it should get the PID as - // the port id. - // - // N.B. Another process could theoretically have explicitly reserved our pid - // as a port ID, but that is very unlikely. - EXPECT_EQ(addr.nl_pid, getpid()); -} - -// Calling connect automatically binds to an automatic port. -TEST(NetlinkRouteTest, ConnectBinds) { - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE)); - - struct sockaddr_nl addr = {}; - addr.nl_family = AF_NETLINK; - - EXPECT_THAT(connect(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), - sizeof(addr)), - SyscallSucceeds()); - - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), - &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, sizeof(addr)); - - // Each test is running in a pid namespace, so another process can explicitly - // reserve our pid as a port ID. In this case, a negative portid value will be - // set. - if (static_cast<pid_t>(addr.nl_pid) > 0) { - EXPECT_EQ(addr.nl_pid, getpid()); - } - - memset(&addr, 0, sizeof(addr)); - addr.nl_family = AF_NETLINK; - - // Connecting again is allowed, but keeps the same port. - EXPECT_THAT(connect(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), - sizeof(addr)), - SyscallSucceeds()); - - addrlen = sizeof(addr); - EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), - &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, sizeof(addr)); - EXPECT_EQ(addr.nl_pid, getpid()); -} - -TEST(NetlinkRouteTest, GetPeerName) { - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE)); - - struct sockaddr_nl addr = {}; - socklen_t addrlen = sizeof(addr); - - EXPECT_THAT(getpeername(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), - &addrlen), - SyscallSucceeds()); - - EXPECT_EQ(addrlen, sizeof(addr)); - EXPECT_EQ(addr.nl_family, AF_NETLINK); - // Peer is the kernel if we didn't connect elsewhere. - EXPECT_EQ(addr.nl_pid, 0); -} - -// Parameters for GetSockOpt test. They are: +// Parameters for SockOptTest. They are: // 0: Socket option to query. // 1: A predicate to run on the returned sockopt value. Should return true if // the value is considered ok. @@ -195,7 +95,8 @@ INSTANTIATE_TEST_SUITE_P( std::make_tuple(SO_DOMAIN, IsEqual(AF_NETLINK), absl::StrFormat("AF_NETLINK (%d)", AF_NETLINK)), std::make_tuple(SO_PROTOCOL, IsEqual(NETLINK_ROUTE), - absl::StrFormat("NETLINK_ROUTE (%d)", NETLINK_ROUTE)))); + absl::StrFormat("NETLINK_ROUTE (%d)", NETLINK_ROUTE)), + std::make_tuple(SO_PASSCRED, IsEqual(0), "0"))); // Validates the reponses to RTM_GETLINK + NLM_F_DUMP. void CheckGetLinkResponse(const struct nlmsghdr* hdr, int seq, int port) { @@ -218,55 +119,170 @@ void CheckGetLinkResponse(const struct nlmsghdr* hdr, int seq, int port) { } TEST(NetlinkRouteTest, GetLinkDump) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); + // Loopback is common among all tests, check that it's found. + bool loopbackFound = false; + ASSERT_NO_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) { + CheckGetLinkResponse(hdr, kSeq, port); + if (hdr->nlmsg_type != RTM_NEWLINK) { + return; + } + ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct ifinfomsg))); + const struct ifinfomsg* msg = + reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr)); + std::cout << "Found interface idx=" << msg->ifi_index + << ", type=" << std::hex << msg->ifi_type << std::endl; + if (msg->ifi_type == ARPHRD_LOOPBACK) { + loopbackFound = true; + EXPECT_NE(msg->ifi_flags & IFF_LOOPBACK, 0); + } + })); + EXPECT_TRUE(loopbackFound); +} + +// CheckLinkMsg checks a netlink message against an expected link. +void CheckLinkMsg(const struct nlmsghdr* hdr, const Link& link) { + ASSERT_THAT(hdr->nlmsg_type, Eq(RTM_NEWLINK)); + ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct ifinfomsg))); + const struct ifinfomsg* msg = + reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr)); + EXPECT_EQ(msg->ifi_index, link.index); + + const struct rtattr* rta = FindRtAttr(hdr, msg, IFLA_IFNAME); + EXPECT_NE(nullptr, rta) << "IFLA_IFNAME not found in message."; + if (rta != nullptr) { + std::string name(reinterpret_cast<const char*>(RTA_DATA(rta))); + EXPECT_EQ(name, link.name); + } +} + +TEST(NetlinkRouteTest, GetLinkByIndex) { + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + struct request { struct nlmsghdr hdr; struct ifinfomsg ifm; }; - constexpr uint32_t kSeq = 12345; - struct request req = {}; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETLINK; - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_flags = NLM_F_REQUEST; req.hdr.nlmsg_seq = kSeq; req.ifm.ifi_family = AF_UNSPEC; + req.ifm.ifi_index = loopback_link.index; - // Loopback is common among all tests, check that it's found. - bool loopbackFound = false; + bool found = false; ASSERT_NO_ERRNO(NetlinkRequestResponse( fd, &req, sizeof(req), [&](const struct nlmsghdr* hdr) { - CheckGetLinkResponse(hdr, kSeq, port); - if (hdr->nlmsg_type != RTM_NEWLINK) { - return; - } - ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct ifinfomsg))); - const struct ifinfomsg* msg = - reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr)); - std::cout << "Found interface idx=" << msg->ifi_index - << ", type=" << std::hex << msg->ifi_type; - if (msg->ifi_type == ARPHRD_LOOPBACK) { - loopbackFound = true; - EXPECT_NE(msg->ifi_flags & IFF_LOOPBACK, 0); - } + CheckLinkMsg(hdr, loopback_link); + found = true; }, false)); - EXPECT_TRUE(loopbackFound); + EXPECT_TRUE(found) << "Netlink response does not contain any links."; } -TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); +TEST(NetlinkRouteTest, GetLinkByName) { + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); struct request { struct nlmsghdr hdr; struct ifinfomsg ifm; + struct rtattr rtattr; + char ifname[IFNAMSIZ]; + char pad[NLMSG_ALIGNTO + RTA_ALIGNTO]; }; - constexpr uint32_t kSeq = 12345; + struct request req = {}; + req.hdr.nlmsg_type = RTM_GETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST; + req.hdr.nlmsg_seq = kSeq; + req.ifm.ifi_family = AF_UNSPEC; + req.rtattr.rta_type = IFLA_IFNAME; + req.rtattr.rta_len = RTA_LENGTH(loopback_link.name.size() + 1); + strncpy(req.ifname, loopback_link.name.c_str(), sizeof(req.ifname)); + req.hdr.nlmsg_len = + NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len); + + bool found = false; + ASSERT_NO_ERRNO(NetlinkRequestResponse( + fd, &req, sizeof(req), + [&](const struct nlmsghdr* hdr) { + CheckLinkMsg(hdr, loopback_link); + found = true; + }, + false)); + EXPECT_TRUE(found) << "Netlink response does not contain any links."; +} + +TEST(NetlinkRouteTest, GetLinkByIndexNotFound) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + }; + + struct request req = {}; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST; + req.hdr.nlmsg_seq = kSeq; + req.ifm.ifi_family = AF_UNSPEC; + req.ifm.ifi_index = 1234590; + + EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)), + PosixErrorIs(ENODEV, ::testing::_)); +} + +TEST(NetlinkRouteTest, GetLinkByNameNotFound) { + const std::string name = "nodevice?!"; + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + struct rtattr rtattr; + char ifname[IFNAMSIZ]; + char pad[NLMSG_ALIGNTO + RTA_ALIGNTO]; + }; + + struct request req = {}; + req.hdr.nlmsg_type = RTM_GETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST; + req.hdr.nlmsg_seq = kSeq; + req.ifm.ifi_family = AF_UNSPEC; + req.rtattr.rta_type = IFLA_IFNAME; + req.rtattr.rta_len = RTA_LENGTH(name.size() + 1); + strncpy(req.ifname, name.c_str(), sizeof(req.ifname)); + req.hdr.nlmsg_len = + NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len); + + EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)), + PosixErrorIs(ENODEV, ::testing::_)); +} + +TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + }; struct request req = {}; req.hdr.nlmsg_len = sizeof(req); @@ -277,30 +293,19 @@ TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) { req.hdr.nlmsg_seq = kSeq; req.ifm.ifi_family = AF_UNSPEC; - ASSERT_NO_ERRNO(NetlinkRequestResponse( - fd, &req, sizeof(req), - [&](const struct nlmsghdr* hdr) { - EXPECT_THAT(hdr->nlmsg_type, Eq(NLMSG_ERROR)); - EXPECT_EQ(hdr->nlmsg_seq, kSeq); - EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct nlmsgerr)); - - const struct nlmsgerr* msg = - reinterpret_cast<const struct nlmsgerr*>(NLMSG_DATA(hdr)); - EXPECT_EQ(msg->error, -EOPNOTSUPP); - }, - true)); + EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)), + PosixErrorIs(EOPNOTSUPP, ::testing::_)); } TEST(NetlinkRouteTest, MsgHdrMsgTrunc) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); struct request { struct nlmsghdr hdr; struct ifinfomsg ifm; }; - constexpr uint32_t kSeq = 12345; - struct request req = {}; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETLINK; @@ -331,15 +336,14 @@ TEST(NetlinkRouteTest, MsgHdrMsgTrunc) { } TEST(NetlinkRouteTest, MsgTruncMsgHdrMsgTrunc) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); struct request { struct nlmsghdr hdr; struct ifinfomsg ifm; }; - constexpr uint32_t kSeq = 12345; - struct request req = {}; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETLINK; @@ -372,7 +376,8 @@ TEST(NetlinkRouteTest, MsgTruncMsgHdrMsgTrunc) { } TEST(NetlinkRouteTest, ControlMessageIgnored) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); struct request { @@ -381,8 +386,6 @@ TEST(NetlinkRouteTest, ControlMessageIgnored) { struct ifinfomsg ifm; }; - constexpr uint32_t kSeq = 12345; - struct request req = {}; // This control message is ignored. We still receive a response for the @@ -407,7 +410,8 @@ TEST(NetlinkRouteTest, ControlMessageIgnored) { } TEST(NetlinkRouteTest, GetAddrDump) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); struct request { @@ -415,8 +419,6 @@ TEST(NetlinkRouteTest, GetAddrDump) { struct rtgenmsg rgm; }; - constexpr uint32_t kSeq = 12345; - struct request req; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETADDR; @@ -465,9 +467,59 @@ TEST(NetlinkRouteTest, LookupAll) { ASSERT_GT(count, 0); } +TEST(NetlinkRouteTest, AddAddr) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifaddrmsg ifa; + struct rtattr rtattr; + struct in_addr addr; + char pad[NLMSG_ALIGNTO + RTA_ALIGNTO]; + }; + + struct request req = {}; + req.hdr.nlmsg_type = RTM_NEWADDR; + req.hdr.nlmsg_seq = kSeq; + req.ifa.ifa_family = AF_INET; + req.ifa.ifa_prefixlen = 24; + req.ifa.ifa_flags = 0; + req.ifa.ifa_scope = 0; + req.ifa.ifa_index = loopback_link.index; + req.rtattr.rta_type = IFA_LOCAL; + req.rtattr.rta_len = RTA_LENGTH(sizeof(req.addr)); + inet_pton(AF_INET, "10.0.0.1", &req.addr); + req.hdr.nlmsg_len = + NLMSG_LENGTH(sizeof(req.ifa)) + NLMSG_ALIGN(req.rtattr.rta_len); + + // Create should succeed, as no such address in kernel. + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_ACK; + EXPECT_NO_ERRNO( + NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len)); + + // Replace an existing address should succeed. + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_REPLACE | NLM_F_ACK; + req.hdr.nlmsg_seq++; + EXPECT_NO_ERRNO( + NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len)); + + // Create exclusive should fail, as we created the address above. + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL | NLM_F_ACK; + req.hdr.nlmsg_seq++; + EXPECT_THAT( + NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len), + PosixErrorIs(EEXIST, ::testing::_)); +} + // GetRouteDump tests a RTM_GETROUTE + NLM_F_DUMP request. TEST(NetlinkRouteTest, GetRouteDump) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); struct request { @@ -475,8 +527,6 @@ TEST(NetlinkRouteTest, GetRouteDump) { struct rtmsg rtm; }; - constexpr uint32_t kSeq = 12345; - struct request req = {}; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETROUTE; @@ -527,7 +577,10 @@ TEST(NetlinkRouteTest, GetRouteDump) { std::cout << std::endl; - if (msg->rtm_table == RT_TABLE_MAIN) { + // If the test is running in a new network namespace, it will have only + // the local route table. + if (msg->rtm_table == RT_TABLE_MAIN || + (!IsRunningOnGvisor() && msg->rtm_table == RT_TABLE_LOCAL)) { routeFound = true; dstFound = rtDstFound && dstFound; } @@ -539,19 +592,102 @@ TEST(NetlinkRouteTest, GetRouteDump) { EXPECT_TRUE(dstFound); } +// GetRouteRequest tests a RTM_GETROUTE request with RTM_F_LOOKUP_TABLE flag. +TEST(NetlinkRouteTest, GetRouteRequest) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); + + struct request { + struct nlmsghdr hdr; + struct rtmsg rtm; + struct nlattr nla; + struct in_addr sin_addr; + }; + + constexpr uint32_t kSeq = 12345; + + struct request req = {}; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETROUTE; + req.hdr.nlmsg_flags = NLM_F_REQUEST; + req.hdr.nlmsg_seq = kSeq; + + req.rtm.rtm_family = AF_INET; + req.rtm.rtm_dst_len = 32; + req.rtm.rtm_src_len = 0; + req.rtm.rtm_tos = 0; + req.rtm.rtm_table = RT_TABLE_UNSPEC; + req.rtm.rtm_protocol = RTPROT_UNSPEC; + req.rtm.rtm_scope = RT_SCOPE_UNIVERSE; + req.rtm.rtm_type = RTN_UNSPEC; + req.rtm.rtm_flags = RTM_F_LOOKUP_TABLE; + + req.nla.nla_len = 8; + req.nla.nla_type = RTA_DST; + inet_aton("127.0.0.2", &req.sin_addr); + + bool rtDstFound = false; + ASSERT_NO_ERRNO(NetlinkRequestResponseSingle( + fd, &req, sizeof(req), [&](const struct nlmsghdr* hdr) { + // Validate the reponse to RTM_GETROUTE request with RTM_F_LOOKUP_TABLE + // flag. + EXPECT_THAT(hdr->nlmsg_type, RTM_NEWROUTE); + + EXPECT_TRUE(hdr->nlmsg_flags == 0) << std::hex << hdr->nlmsg_flags; + + EXPECT_EQ(hdr->nlmsg_seq, kSeq); + EXPECT_EQ(hdr->nlmsg_pid, port); + + // RTM_NEWROUTE contains at least the header and rtmsg. + ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct rtmsg))); + const struct rtmsg* msg = + reinterpret_cast<const struct rtmsg*>(NLMSG_DATA(hdr)); + + // NOTE: rtmsg fields are char fields. + std::cout << "Found route table=" << static_cast<int>(msg->rtm_table) + << ", protocol=" << static_cast<int>(msg->rtm_protocol) + << ", scope=" << static_cast<int>(msg->rtm_scope) + << ", type=" << static_cast<int>(msg->rtm_type); + + EXPECT_EQ(msg->rtm_family, AF_INET); + EXPECT_EQ(msg->rtm_dst_len, 32); + EXPECT_TRUE((msg->rtm_flags & RTM_F_CLONED) == RTM_F_CLONED) + << std::hex << msg->rtm_flags; + + int len = RTM_PAYLOAD(hdr); + std::cout << ", len=" << len; + for (struct rtattr* attr = RTM_RTA(msg); RTA_OK(attr, len); + attr = RTA_NEXT(attr, len)) { + if (attr->rta_type == RTA_DST) { + char address[INET_ADDRSTRLEN] = {}; + inet_ntop(AF_INET, RTA_DATA(attr), address, sizeof(address)); + std::cout << ", dst=" << address; + rtDstFound = true; + } else if (attr->rta_type == RTA_OIF) { + const char* oif = reinterpret_cast<const char*>(RTA_DATA(attr)); + std::cout << ", oif=" << oif; + } + } + + std::cout << std::endl; + })); + // Found RTA_DST for RTM_F_LOOKUP_TABLE. + EXPECT_TRUE(rtDstFound); +} + // RecvmsgTrunc tests the recvmsg MSG_TRUNC flag with zero length output // buffer. MSG_TRUNC with a zero length buffer should consume subsequent // messages off the socket. TEST(NetlinkRouteTest, RecvmsgTrunc) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); struct request { struct nlmsghdr hdr; struct rtgenmsg rgm; }; - constexpr uint32_t kSeq = 12345; - struct request req; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETADDR; @@ -619,15 +755,14 @@ TEST(NetlinkRouteTest, RecvmsgTrunc) { // it, so a properly sized buffer can be allocated to store the message. This // test tests that scenario. TEST(NetlinkRouteTest, RecvmsgTruncPeek) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); struct request { struct nlmsghdr hdr; struct rtgenmsg rgm; }; - constexpr uint32_t kSeq = 12345; - struct request req; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETADDR; @@ -692,6 +827,111 @@ TEST(NetlinkRouteTest, RecvmsgTruncPeek) { } while (type != NLMSG_DONE && type != NLMSG_ERROR); } +// No SCM_CREDENTIALS are received without SO_PASSCRED set. +TEST(NetlinkRouteTest, NoPasscredNoCreds) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + ASSERT_THAT(setsockopt(fd.get(), SOL_SOCKET, SO_PASSCRED, &kSockOptOff, + sizeof(kSockOptOff)), + SyscallSucceeds()); + + struct request { + struct nlmsghdr hdr; + struct rtgenmsg rgm; + }; + + struct request req; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETADDR; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_seq = kSeq; + req.rgm.rtgen_family = AF_UNSPEC; + + struct iovec iov = {}; + iov.iov_base = &req; + iov.iov_len = sizeof(req); + + struct msghdr msg = {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds()); + + iov.iov_base = NULL; + iov.iov_len = 0; + + char control[CMSG_SPACE(sizeof(struct ucred))] = {}; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + // Note: This test assumes at least one message is returned by the + // RTM_GETADDR request. + ASSERT_THAT(RetryEINTR(recvmsg)(fd.get(), &msg, 0), SyscallSucceeds()); + + // No control messages. + EXPECT_EQ(CMSG_FIRSTHDR(&msg), nullptr); +} + +// SCM_CREDENTIALS are received with SO_PASSCRED set. +TEST(NetlinkRouteTest, PasscredCreds) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + ASSERT_THAT(setsockopt(fd.get(), SOL_SOCKET, SO_PASSCRED, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + struct request { + struct nlmsghdr hdr; + struct rtgenmsg rgm; + }; + + struct request req; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETADDR; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_seq = kSeq; + req.rgm.rtgen_family = AF_UNSPEC; + + struct iovec iov = {}; + iov.iov_base = &req; + iov.iov_len = sizeof(req); + + struct msghdr msg = {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds()); + + iov.iov_base = NULL; + iov.iov_len = 0; + + char control[CMSG_SPACE(sizeof(struct ucred))] = {}; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + // Note: This test assumes at least one message is returned by the + // RTM_GETADDR request. + ASSERT_THAT(RetryEINTR(recvmsg)(fd.get(), &msg, 0), SyscallSucceeds()); + + struct ucred creds; + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(creds))); + ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); + ASSERT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); + + memcpy(&creds, CMSG_DATA(cmsg), sizeof(creds)); + + // The peer is the kernel, which is "PID" 0. + EXPECT_EQ(creds.pid, 0); + // The kernel identifies as root. Also allow nobody in case this test is + // running in a userns without root mapped. + EXPECT_THAT(creds.uid, AnyOf(Eq(0), Eq(65534))); + EXPECT_THAT(creds.gid, AnyOf(Eq(0), Eq(65534))); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc new file mode 100644 index 000000000..bde1dbb4d --- /dev/null +++ b/test/syscalls/linux/socket_netlink_route_util.cc @@ -0,0 +1,162 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/socket_netlink_route_util.h" + +#include <linux/if.h> +#include <linux/netlink.h> +#include <linux/rtnetlink.h> + +#include "test/syscalls/linux/socket_netlink_util.h" + +namespace gvisor { +namespace testing { +namespace { + +constexpr uint32_t kSeq = 12345; + +} // namespace + +PosixError DumpLinks( + const FileDescriptor& fd, uint32_t seq, + const std::function<void(const struct nlmsghdr* hdr)>& fn) { + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + }; + + struct request req = {}; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_seq = seq; + req.ifm.ifi_family = AF_UNSPEC; + + return NetlinkRequestResponse(fd, &req, sizeof(req), fn, false); +} + +PosixErrorOr<std::vector<Link>> DumpLinks() { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + std::vector<Link> links; + RETURN_IF_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) { + if (hdr->nlmsg_type != RTM_NEWLINK || + hdr->nlmsg_len < NLMSG_SPACE(sizeof(struct ifinfomsg))) { + return; + } + const struct ifinfomsg* msg = + reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr)); + const auto* rta = FindRtAttr(hdr, msg, IFLA_IFNAME); + if (rta == nullptr) { + // Ignore links that do not have a name. + return; + } + + links.emplace_back(); + links.back().index = msg->ifi_index; + links.back().type = msg->ifi_type; + links.back().name = + std::string(reinterpret_cast<const char*>(RTA_DATA(rta))); + })); + return links; +} + +PosixErrorOr<Link> LoopbackLink() { + ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks()); + for (const auto& link : links) { + if (link.type == ARPHRD_LOOPBACK) { + return link; + } + } + return PosixError(ENOENT, "loopback link not found"); +} + +PosixError LinkAddLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifaddrmsg ifaddr; + char attrbuf[512]; + }; + + struct request req = {}; + req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifaddr)); + req.hdr.nlmsg_type = RTM_NEWADDR; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + req.hdr.nlmsg_seq = kSeq; + req.ifaddr.ifa_index = index; + req.ifaddr.ifa_family = family; + req.ifaddr.ifa_prefixlen = prefixlen; + + struct rtattr* rta = reinterpret_cast<struct rtattr*>( + reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len)); + rta->rta_type = IFA_LOCAL; + rta->rta_len = RTA_LENGTH(addrlen); + req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen); + memcpy(RTA_DATA(rta), addr, addrlen); + + return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); +} + +PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifinfo; + char pad[NLMSG_ALIGNTO]; + }; + + struct request req = {}; + req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifinfo)); + req.hdr.nlmsg_type = RTM_NEWLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + req.hdr.nlmsg_seq = kSeq; + req.ifinfo.ifi_index = index; + req.ifinfo.ifi_flags = flags; + req.ifinfo.ifi_change = change; + + return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); +} + +PosixError LinkSetMacAddr(int index, const void* addr, int addrlen) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifinfo; + char attrbuf[512]; + }; + + struct request req = {}; + req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifinfo)); + req.hdr.nlmsg_type = RTM_NEWLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + req.hdr.nlmsg_seq = kSeq; + req.ifinfo.ifi_index = index; + + struct rtattr* rta = reinterpret_cast<struct rtattr*>( + reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len)); + rta->rta_type = IFLA_ADDRESS; + rta->rta_len = RTA_LENGTH(addrlen); + req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen); + memcpy(RTA_DATA(rta), addr, addrlen); + + return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h new file mode 100644 index 000000000..149c4a7f6 --- /dev/null +++ b/test/syscalls/linux/socket_netlink_route_util.h @@ -0,0 +1,55 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_ + +#include <linux/netlink.h> +#include <linux/rtnetlink.h> + +#include <vector> + +#include "test/syscalls/linux/socket_netlink_util.h" + +namespace gvisor { +namespace testing { + +struct Link { + int index; + int16_t type; + std::string name; +}; + +PosixError DumpLinks(const FileDescriptor& fd, uint32_t seq, + const std::function<void(const struct nlmsghdr* hdr)>& fn); + +PosixErrorOr<std::vector<Link>> DumpLinks(); + +// Returns the loopback link on the system. ENOENT if not found. +PosixErrorOr<Link> LoopbackLink(); + +// LinkAddLocalAddr sets IFA_LOCAL attribute on the interface. +PosixError LinkAddLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen); + +// LinkChangeFlags changes interface flags. E.g. IFF_UP. +PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change); + +// LinkSetMacAddr sets IFLA_ADDRESS attribute of the interface. +PosixError LinkSetMacAddr(int index, const void* addr, int addrlen); + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_ diff --git a/test/syscalls/linux/socket_netlink_uevent.cc b/test/syscalls/linux/socket_netlink_uevent.cc new file mode 100644 index 000000000..da425bed4 --- /dev/null +++ b/test/syscalls/linux/socket_netlink_uevent.cc @@ -0,0 +1,83 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <linux/filter.h> +#include <linux/netlink.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "test/syscalls/linux/socket_netlink_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" + +// Tests for NETLINK_KOBJECT_UEVENT sockets. +// +// gVisor never sends any messages on these sockets, so we don't test the events +// themselves. + +namespace gvisor { +namespace testing { + +namespace { + +// SO_PASSCRED can be enabled. Since no messages are sent in gVisor, we don't +// actually test receiving credentials. +TEST(NetlinkUeventTest, PassCred) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_KOBJECT_UEVENT)); + + EXPECT_THAT(setsockopt(fd.get(), SOL_SOCKET, SO_PASSCRED, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); +} + +// SO_DETACH_FILTER fails without a filter already installed. +TEST(NetlinkUeventTest, DetachNoFilter) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_KOBJECT_UEVENT)); + + int opt; + EXPECT_THAT( + setsockopt(fd.get(), SOL_SOCKET, SO_DETACH_FILTER, &opt, sizeof(opt)), + SyscallFailsWithErrno(ENOENT)); +} + +// We can attach a BPF filter. +TEST(NetlinkUeventTest, AttachFilter) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_KOBJECT_UEVENT)); + + // Minimal BPF program: a single ret. + struct sock_filter filter = {0x6, 0, 0, 0}; + struct sock_fprog prog = {}; + prog.len = 1; + prog.filter = &filter; + + EXPECT_THAT( + setsockopt(fd.get(), SOL_SOCKET, SO_ATTACH_FILTER, &prog, sizeof(prog)), + SyscallSucceeds()); + + int opt; + EXPECT_THAT( + setsockopt(fd.get(), SOL_SOCKET, SO_DETACH_FILTER, &opt, sizeof(opt)), + SyscallSucceeds()); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_util.cc b/test/syscalls/linux/socket_netlink_util.cc index fcb8f8a88..952eecfe8 100644 --- a/test/syscalls/linux/socket_netlink_util.cc +++ b/test/syscalls/linux/socket_netlink_util.cc @@ -12,24 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include <sys/socket.h> +#include "test/syscalls/linux/socket_netlink_util.h" #include <linux/if_arp.h> #include <linux/netlink.h> #include <linux/rtnetlink.h> +#include <sys/socket.h> #include <vector> #include "absl/strings/str_cat.h" -#include "test/syscalls/linux/socket_netlink_util.h" #include "test/syscalls/linux/socket_test_util.h" namespace gvisor { namespace testing { -PosixErrorOr<FileDescriptor> NetlinkBoundSocket() { +PosixErrorOr<FileDescriptor> NetlinkBoundSocket(int protocol) { FileDescriptor fd; - ASSIGN_OR_RETURN_ERRNO(fd, Socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE)); + ASSIGN_OR_RETURN_ERRNO(fd, Socket(AF_NETLINK, SOCK_RAW, protocol)); struct sockaddr_nl addr = {}; addr.nl_family = AF_NETLINK; @@ -72,9 +72,10 @@ PosixError NetlinkRequestResponse( iov.iov_base = buf.data(); iov.iov_len = buf.size(); - // Response is a series of NLM_F_MULTI messages, ending with a NLMSG_DONE - // message. + // If NLM_F_MULTI is set, response is a series of messages that ends with a + // NLMSG_DONE message. int type = -1; + int flags = 0; do { int len; RETURN_ERROR_IF_SYSCALL_FAIL(len = RetryEINTR(recvmsg)(fd.get(), &msg, 0)); @@ -90,6 +91,7 @@ PosixError NetlinkRequestResponse( for (struct nlmsghdr* hdr = reinterpret_cast<struct nlmsghdr*>(buf.data()); NLMSG_OK(hdr, len); hdr = NLMSG_NEXT(hdr, len)) { fn(hdr); + flags = hdr->nlmsg_flags; type = hdr->nlmsg_type; // Done should include an integer payload for dump_done_errno. // See net/netlink/af_netlink.c:netlink_dump @@ -99,15 +101,87 @@ PosixError NetlinkRequestResponse( EXPECT_GE(hdr->nlmsg_len, NLMSG_LENGTH(sizeof(int))); } } - } while (type != NLMSG_DONE && type != NLMSG_ERROR); + } while ((flags & NLM_F_MULTI) && type != NLMSG_DONE && type != NLMSG_ERROR); if (expect_nlmsgerr) { EXPECT_EQ(type, NLMSG_ERROR); - } else { + } else if (flags & NLM_F_MULTI) { EXPECT_EQ(type, NLMSG_DONE); } return NoError(); } +PosixError NetlinkRequestResponseSingle( + const FileDescriptor& fd, void* request, size_t len, + const std::function<void(const struct nlmsghdr* hdr)>& fn) { + struct iovec iov = {}; + iov.iov_base = request; + iov.iov_len = len; + + struct msghdr msg = {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + // No destination required; it defaults to pid 0, the kernel. + + RETURN_ERROR_IF_SYSCALL_FAIL(RetryEINTR(sendmsg)(fd.get(), &msg, 0)); + + constexpr size_t kBufferSize = 4096; + std::vector<char> buf(kBufferSize); + iov.iov_base = buf.data(); + iov.iov_len = buf.size(); + + int ret; + RETURN_ERROR_IF_SYSCALL_FAIL(ret = RetryEINTR(recvmsg)(fd.get(), &msg, 0)); + + // We don't bother with the complexity of dealing with truncated messages. + // We must allocate a large enough buffer up front. + if ((msg.msg_flags & MSG_TRUNC) == MSG_TRUNC) { + return PosixError( + EIO, + absl::StrCat("Received truncated message with flags: ", msg.msg_flags)); + } + + for (struct nlmsghdr* hdr = reinterpret_cast<struct nlmsghdr*>(buf.data()); + NLMSG_OK(hdr, ret); hdr = NLMSG_NEXT(hdr, ret)) { + fn(hdr); + } + + return NoError(); +} + +PosixError NetlinkRequestAckOrError(const FileDescriptor& fd, uint32_t seq, + void* request, size_t len) { + // Dummy negative number for no error message received. + // We won't get a negative error number so there will be no confusion. + int err = -42; + RETURN_IF_ERRNO(NetlinkRequestResponse( + fd, request, len, + [&](const struct nlmsghdr* hdr) { + EXPECT_EQ(NLMSG_ERROR, hdr->nlmsg_type); + EXPECT_EQ(hdr->nlmsg_seq, seq); + EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct nlmsgerr)); + + const struct nlmsgerr* msg = + reinterpret_cast<const struct nlmsgerr*>(NLMSG_DATA(hdr)); + err = -msg->error; + }, + true)); + return PosixError(err); +} + +const struct rtattr* FindRtAttr(const struct nlmsghdr* hdr, + const struct ifinfomsg* msg, int16_t attr) { + const int ifi_space = NLMSG_SPACE(sizeof(*msg)); + int attrlen = hdr->nlmsg_len - ifi_space; + const struct rtattr* rta = reinterpret_cast<const struct rtattr*>( + reinterpret_cast<const uint8_t*>(hdr) + NLMSG_ALIGN(ifi_space)); + for (; RTA_OK(rta, attrlen); rta = RTA_NEXT(rta, attrlen)) { + if (rta->rta_type == attr) { + return rta; + } + } + return nullptr; +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_util.h b/test/syscalls/linux/socket_netlink_util.h index db8639a2f..e13ead406 100644 --- a/test/syscalls/linux/socket_netlink_util.h +++ b/test/syscalls/linux/socket_netlink_util.h @@ -15,6 +15,8 @@ #ifndef GVISOR_TEST_SYSCALLS_SOCKET_NETLINK_UTIL_H_ #define GVISOR_TEST_SYSCALLS_SOCKET_NETLINK_UTIL_H_ +#include <sys/socket.h> +// socket.h has to be included before if_arp.h. #include <linux/if_arp.h> #include <linux/netlink.h> #include <linux/rtnetlink.h> @@ -25,18 +27,35 @@ namespace gvisor { namespace testing { -// Returns a bound NETLINK_ROUTE socket. -PosixErrorOr<FileDescriptor> NetlinkBoundSocket(); +// Returns a bound netlink socket. +PosixErrorOr<FileDescriptor> NetlinkBoundSocket(int protocol); // Returns the port ID of the passed socket. PosixErrorOr<uint32_t> NetlinkPortID(int fd); -// Send the passed request and call fn will all response netlink messages. +// Send the passed request and call fn on all response netlink messages. +// +// To be used on requests with NLM_F_MULTI reponses. PosixError NetlinkRequestResponse( const FileDescriptor& fd, void* request, size_t len, const std::function<void(const struct nlmsghdr* hdr)>& fn, bool expect_nlmsgerr); +// Send the passed request and call fn on all response netlink messages. +// +// To be used on requests without NLM_F_MULTI reponses. +PosixError NetlinkRequestResponseSingle( + const FileDescriptor& fd, void* request, size_t len, + const std::function<void(const struct nlmsghdr* hdr)>& fn); + +// Send the passed request then expect and return an ack or error. +PosixError NetlinkRequestAckOrError(const FileDescriptor& fd, uint32_t seq, + void* request, size_t len); + +// Find rtnetlink attribute in message. +const struct rtattr* FindRtAttr(const struct nlmsghdr* hdr, + const struct ifinfomsg* msg, int16_t attr); + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_non_stream.cc b/test/syscalls/linux/socket_non_stream.cc index d91c5ed39..c61817f14 100644 --- a/test/syscalls/linux/socket_non_stream.cc +++ b/test/syscalls/linux/socket_non_stream.cc @@ -113,7 +113,7 @@ TEST_P(NonStreamSocketPairTest, RecvmsgMsghdrFlagMsgTrunc) { EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(received_data))); // Check that msghdr flags were updated. - EXPECT_EQ(msg.msg_flags, MSG_TRUNC); + EXPECT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC); } // Stream sockets allow data sent with multiple sends to be peeked at in a @@ -193,7 +193,7 @@ TEST_P(NonStreamSocketPairTest, MsgTruncTruncationRecvmsgMsghdrFlagMsgTrunc) { EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(received_data))); // Check that msghdr flags were updated. - EXPECT_EQ(msg.msg_flags, MSG_TRUNC); + EXPECT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC); } TEST_P(NonStreamSocketPairTest, MsgTruncSameSize) { @@ -224,5 +224,114 @@ TEST_P(NonStreamSocketPairTest, MsgTruncNotFull) { EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); } +// This test tests reading from a socket with MSG_TRUNC and a zero length +// receive buffer. The user should be able to get the message length. +TEST_P(NonStreamSocketPairTest, RecvmsgMsgTruncZeroLen) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[10]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + ASSERT_THAT( + RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), + SyscallSucceedsWithValue(sizeof(sent_data))); + + // The receive buffer is of zero length. + char received_data[0] = {}; + + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + struct msghdr msg = {}; + msg.msg_flags = -1; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + // The syscall succeeds returning the full size of the message on the socket. + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_TRUNC), + SyscallSucceedsWithValue(sizeof(sent_data))); + + // Check that MSG_TRUNC is set on msghdr flags. + EXPECT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC); +} + +// This test tests reading from a socket with MSG_TRUNC | MSG_PEEK and a zero +// length receive buffer. The user should be able to get the message length +// without reading data off the socket. +TEST_P(NonStreamSocketPairTest, RecvmsgMsgTruncMsgPeekZeroLen) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[10]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + ASSERT_THAT( + RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), + SyscallSucceedsWithValue(sizeof(sent_data))); + + // The receive buffer is of zero length. + char peek_data[0] = {}; + + struct iovec peek_iov; + peek_iov.iov_base = peek_data; + peek_iov.iov_len = sizeof(peek_data); + struct msghdr peek_msg = {}; + peek_msg.msg_flags = -1; + peek_msg.msg_iov = &peek_iov; + peek_msg.msg_iovlen = 1; + + // The syscall succeeds returning the full size of the message on the socket. + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &peek_msg, + MSG_TRUNC | MSG_PEEK), + SyscallSucceedsWithValue(sizeof(sent_data))); + + // Check that MSG_TRUNC is set on msghdr flags because the receive buffer is + // smaller than the message size. + EXPECT_EQ(peek_msg.msg_flags & MSG_TRUNC, MSG_TRUNC); + + char received_data[sizeof(sent_data)] = {}; + + struct iovec received_iov; + received_iov.iov_base = received_data; + received_iov.iov_len = sizeof(received_data); + struct msghdr received_msg = {}; + received_msg.msg_flags = -1; + received_msg.msg_iov = &received_iov; + received_msg.msg_iovlen = 1; + + // Next we can read the actual data. + ASSERT_THAT( + RetryEINTR(recvmsg)(sockets->second_fd(), &received_msg, MSG_TRUNC), + SyscallSucceedsWithValue(sizeof(sent_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + // Check that MSG_TRUNC is not set on msghdr flags because we read the whole + // message. + EXPECT_EQ(received_msg.msg_flags & MSG_TRUNC, 0); +} + +// This test tests reading from a socket with MSG_TRUNC | MSG_PEEK and a zero +// length receive buffer and MSG_DONTWAIT. The user should be able to get an +// EAGAIN or EWOULDBLOCK error response. +TEST_P(NonStreamSocketPairTest, RecvmsgTruncPeekDontwaitZeroLen) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // NOTE: We don't send any data on the socket. + + // The receive buffer is of zero length. + char peek_data[0] = {}; + + struct iovec peek_iov; + peek_iov.iov_base = peek_data; + peek_iov.iov_len = sizeof(peek_data); + struct msghdr peek_msg = {}; + peek_msg.msg_flags = -1; + peek_msg.msg_iov = &peek_iov; + peek_msg.msg_iovlen = 1; + + // recvmsg fails with EAGAIN because no data is available on the socket. + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &peek_msg, + MSG_TRUNC | MSG_PEEK | MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_non_stream_blocking.cc b/test/syscalls/linux/socket_non_stream_blocking.cc index 62d87c1af..b052f6e61 100644 --- a/test/syscalls/linux/socket_non_stream_blocking.cc +++ b/test/syscalls/linux/socket_non_stream_blocking.cc @@ -25,6 +25,7 @@ #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" +#include "test/util/thread_util.h" namespace gvisor { namespace testing { @@ -44,5 +45,41 @@ TEST_P(BlockingNonStreamSocketPairTest, RecvLessThanBufferWaitAll) { SyscallSucceedsWithValue(sizeof(sent_data))); } +// This test tests reading from a socket with MSG_TRUNC | MSG_PEEK and a zero +// length receive buffer and MSG_DONTWAIT. The recvmsg call should block on +// reading the data. +TEST_P(BlockingNonStreamSocketPairTest, + RecvmsgTruncPeekDontwaitZeroLenBlocking) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // NOTE: We don't initially send any data on the socket. + const int data_size = 10; + char sent_data[data_size]; + RandomizeBuffer(sent_data, data_size); + + // The receive buffer is of zero length. + char peek_data[0] = {}; + + struct iovec peek_iov; + peek_iov.iov_base = peek_data; + peek_iov.iov_len = sizeof(peek_data); + struct msghdr peek_msg = {}; + peek_msg.msg_flags = -1; + peek_msg.msg_iov = &peek_iov; + peek_msg.msg_iovlen = 1; + + ScopedThread t([&]() { + // The syscall succeeds returning the full size of the message on the + // socket. This should block until there is data on the socket. + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &peek_msg, + MSG_TRUNC | MSG_PEEK), + SyscallSucceedsWithValue(data_size)); + }); + + absl::SleepFor(absl::Seconds(1)); + ASSERT_THAT(RetryEINTR(send)(sockets->first_fd(), sent_data, data_size, 0), + SyscallSucceedsWithValue(data_size)); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_stream.cc b/test/syscalls/linux/socket_stream.cc index 346443f96..6522b2e01 100644 --- a/test/syscalls/linux/socket_stream.cc +++ b/test/syscalls/linux/socket_stream.cc @@ -104,7 +104,60 @@ TEST_P(StreamSocketPairTest, RecvmsgMsghdrFlagsNoMsgTrunc) { EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(received_data))); // Check that msghdr flags were cleared (MSG_TRUNC was not set). - EXPECT_EQ(msg.msg_flags, 0); + ASSERT_EQ(msg.msg_flags & MSG_TRUNC, 0); +} + +TEST_P(StreamSocketPairTest, RecvmsgTruncZeroLen) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[10]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + ASSERT_THAT( + RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), + SyscallSucceedsWithValue(sizeof(sent_data))); + + char received_data[0] = {}; + + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + struct msghdr msg = {}; + msg.msg_flags = -1; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_TRUNC), + SyscallSucceedsWithValue(0)); + + // Check that msghdr flags were cleared (MSG_TRUNC was not set). + ASSERT_EQ(msg.msg_flags & MSG_TRUNC, 0); +} + +TEST_P(StreamSocketPairTest, RecvmsgTruncPeekZeroLen) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[10]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + ASSERT_THAT( + RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0), + SyscallSucceedsWithValue(sizeof(sent_data))); + + char received_data[0] = {}; + + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + struct msghdr msg = {}; + msg.msg_flags = -1; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT( + RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_TRUNC | MSG_PEEK), + SyscallSucceedsWithValue(0)); + + // Check that msghdr flags were cleared (MSG_TRUNC was not set). + ASSERT_EQ(msg.msg_flags & MSG_TRUNC, 0); } TEST_P(StreamSocketPairTest, MsgTrunc) { diff --git a/test/syscalls/linux/socket_stream_blocking.cc b/test/syscalls/linux/socket_stream_blocking.cc index e9cc082bf..538ee2268 100644 --- a/test/syscalls/linux/socket_stream_blocking.cc +++ b/test/syscalls/linux/socket_stream_blocking.cc @@ -32,38 +32,38 @@ namespace gvisor { namespace testing { TEST_P(BlockingStreamSocketPairTest, BlockPartialWriteClosed) { - // FIXME(b/35921550): gVisor doesn't support SO_SNDBUF on UDS, nor does it - // enforce any limit; it will write arbitrary amounts of data without - // blocking. - SKIP_IF(IsRunningOnGvisor()); - - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int buffer_size; - socklen_t length = sizeof(buffer_size); - ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF, - &buffer_size, &length), - SyscallSucceeds()); - - int wfd = sockets->first_fd(); - ScopedThread t([wfd, buffer_size]() { - std::vector<char> buf(2 * buffer_size); - // Write more than fits in the buffer. Blocks then returns partial write - // when the other end is closed. The next call returns EPIPE. - // - // N.B. writes occur in chunks, so we may see less than buffer_size from - // the first call. - ASSERT_THAT(write(wfd, buf.data(), buf.size()), - SyscallSucceedsWithValue(::testing::Gt(0))); - ASSERT_THAT(write(wfd, buf.data(), buf.size()), - ::testing::AnyOf(SyscallFailsWithErrno(EPIPE), - SyscallFailsWithErrno(ECONNRESET))); - }); - - // Leave time for write to become blocked. - absl::SleepFor(absl::Seconds(1)); - - ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds()); + // FIXME(b/35921550): gVisor doesn't support SO_SNDBUF on UDS, nor does it + // enforce any limit; it will write arbitrary amounts of data without + // blocking. + SKIP_IF(IsRunningOnGvisor()); + + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int buffer_size; + socklen_t length = sizeof(buffer_size); + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF, + &buffer_size, &length), + SyscallSucceeds()); + + int wfd = sockets->first_fd(); + ScopedThread t([wfd, buffer_size]() { + std::vector<char> buf(2 * buffer_size); + // Write more than fits in the buffer. Blocks then returns partial write + // when the other end is closed. The next call returns EPIPE. + // + // N.B. writes occur in chunks, so we may see less than buffer_size from + // the first call. + ASSERT_THAT(write(wfd, buf.data(), buf.size()), + SyscallSucceedsWithValue(::testing::Gt(0))); + ASSERT_THAT(write(wfd, buf.data(), buf.size()), + ::testing::AnyOf(SyscallFailsWithErrno(EPIPE), + SyscallFailsWithErrno(ECONNRESET))); + }); + + // Leave time for write to become blocked. + absl::SleepFor(absl::Seconds(1)); + + ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds()); } // Random save may interrupt the call to sendmsg() in SendLargeSendMsg(), diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc index eff7d577e..53b678e94 100644 --- a/test/syscalls/linux/socket_test_util.cc +++ b/test/syscalls/linux/socket_test_util.cc @@ -18,10 +18,13 @@ #include <poll.h> #include <sys/socket.h> +#include <memory> + #include "gtest/gtest.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/time/clock.h" +#include "absl/types/optional.h" #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" #include "test/util/temp_path.h" @@ -109,7 +112,10 @@ Creator<SocketPair> AcceptBindSocketPairCreator(bool abstract, int domain, MaybeSave(); // Unlinked path. } - return absl::make_unique<AddrFDSocketPair>(connected, accepted, bind_addr, + // accepted is before connected to destruct connected before accepted. + // Destructors for nonstatic member objects are called in the reverse order + // in which they appear in the class declaration. + return absl::make_unique<AddrFDSocketPair>(accepted, connected, bind_addr, extra_addr); }; } @@ -311,11 +317,16 @@ PosixErrorOr<T> BindIP(int fd, bool dual_stack) { } template <typename T> -PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> CreateTCPAcceptBindSocketPair( - int bound, int connected, int type, bool dual_stack) { - ASSIGN_OR_RETURN_ERRNO(T bind_addr, BindIP<T>(bound, dual_stack)); - RETURN_ERROR_IF_SYSCALL_FAIL(listen(bound, /* backlog = */ 5)); +PosixErrorOr<T> TCPBindAndListen(int fd, bool dual_stack) { + ASSIGN_OR_RETURN_ERRNO(T addr, BindIP<T>(fd, dual_stack)); + RETURN_ERROR_IF_SYSCALL_FAIL(listen(fd, /* backlog = */ 5)); + return addr; +} +template <typename T> +PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> +CreateTCPConnectAcceptSocketPair(int bound, int connected, int type, + bool dual_stack, T bind_addr) { int connect_result = 0; RETURN_ERROR_IF_SYSCALL_FAIL( (connect_result = RetryEINTR(connect)( @@ -353,19 +364,25 @@ PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> CreateTCPAcceptBindSocketPair( } MaybeSave(); // Successful accept. - // FIXME(b/110484944) - if (connect_result == -1) { - absl::SleepFor(absl::Seconds(1)); - } + T extra_addr = {}; + LocalhostAddr(&extra_addr, dual_stack); + return absl::make_unique<AddrFDSocketPair>(connected, accepted, bind_addr, + extra_addr); +} + +template <typename T> +PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> CreateTCPAcceptBindSocketPair( + int bound, int connected, int type, bool dual_stack) { + ASSIGN_OR_RETURN_ERRNO(T bind_addr, TCPBindAndListen<T>(bound, dual_stack)); + + auto result = CreateTCPConnectAcceptSocketPair(bound, connected, type, + dual_stack, bind_addr); // Cleanup no longer needed resources. RETURN_ERROR_IF_SYSCALL_FAIL(close(bound)); MaybeSave(); // Successful close. - T extra_addr = {}; - LocalhostAddr(&extra_addr, dual_stack); - return absl::make_unique<AddrFDSocketPair>(connected, accepted, bind_addr, - extra_addr); + return result; } Creator<SocketPair> TCPAcceptBindSocketPairCreator(int domain, int type, @@ -389,6 +406,63 @@ Creator<SocketPair> TCPAcceptBindSocketPairCreator(int domain, int type, }; } +Creator<SocketPair> TCPAcceptBindPersistentListenerSocketPairCreator( + int domain, int type, int protocol, bool dual_stack) { + // These are lazily initialized below, on the first call to the returned + // lambda. These values are private to each returned lambda, but shared across + // invocations of a specific lambda. + // + // The sharing allows pairs created with the same parameters to share a + // listener. This prevents future connects from failing if the connecting + // socket selects a port which had previously been used by a listening socket + // that still has some connections in TIME-WAIT. + // + // The lazy initialization is to avoid creating sockets during parameter + // enumeration. This is important because parameters are enumerated during the + // build process where networking may not be available. + auto listener = std::make_shared<absl::optional<int>>(absl::optional<int>()); + auto addr4 = std::make_shared<absl::optional<sockaddr_in>>( + absl::optional<sockaddr_in>()); + auto addr6 = std::make_shared<absl::optional<sockaddr_in6>>( + absl::optional<sockaddr_in6>()); + + return [=]() -> PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> { + int connected; + RETURN_ERROR_IF_SYSCALL_FAIL(connected = socket(domain, type, protocol)); + MaybeSave(); // Successful socket creation. + + // Share the listener across invocations. + if (!listener->has_value()) { + int fd = socket(domain, type, protocol); + if (fd < 0) { + return PosixError(errno, absl::StrCat("socket(", domain, ", ", type, + ", ", protocol, ")")); + } + listener->emplace(fd); + MaybeSave(); // Successful socket creation. + } + + // Bind the listener once, but create a new connect/accept pair each + // time. + if (domain == AF_INET) { + if (!addr4->has_value()) { + addr4->emplace( + TCPBindAndListen<sockaddr_in>(listener->value(), dual_stack) + .ValueOrDie()); + } + return CreateTCPConnectAcceptSocketPair(listener->value(), connected, + type, dual_stack, addr4->value()); + } + if (!addr6->has_value()) { + addr6->emplace( + TCPBindAndListen<sockaddr_in6>(listener->value(), dual_stack) + .ValueOrDie()); + } + return CreateTCPConnectAcceptSocketPair(listener->value(), connected, type, + dual_stack, addr6->value()); + }; +} + template <typename T> PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> CreateUDPBoundSocketPair( int sock1, int sock2, int type, bool dual_stack) { @@ -518,8 +592,8 @@ size_t CalculateUnixSockAddrLen(const char* sun_path) { if (sun_path[0] == 0) { return sizeof(sockaddr_un); } - // Filesystem addresses use the address length plus the 2 byte sun_family and - // null terminator. + // Filesystem addresses use the address length plus the 2 byte sun_family + // and null terminator. return strlen(sun_path) + 3; } @@ -726,6 +800,24 @@ TestAddress V4MappedLoopback() { return t; } +TestAddress V4Multicast() { + TestAddress t("V4Multicast"); + t.addr.ss_family = AF_INET; + t.addr_len = sizeof(sockaddr_in); + reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr = + inet_addr(kMulticastAddress); + return t; +} + +TestAddress V4Broadcast() { + TestAddress t("V4Broadcast"); + t.addr.ss_family = AF_INET; + t.addr_len = sizeof(sockaddr_in); + reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr = + htonl(INADDR_BROADCAST); + return t; +} + TestAddress V6Any() { TestAddress t("V6Any"); t.addr.ss_family = AF_INET6; diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h index be38907c2..734b48b96 100644 --- a/test/syscalls/linux/socket_test_util.h +++ b/test/syscalls/linux/socket_test_util.h @@ -114,6 +114,9 @@ class FDSocketPair : public SocketPair { public: FDSocketPair(int first_fd, int second_fd) : first_(first_fd), second_(second_fd) {} + FDSocketPair(std::unique_ptr<FileDescriptor> first_fd, + std::unique_ptr<FileDescriptor> second_fd) + : first_(first_fd->release()), second_(second_fd->release()) {} int first_fd() const override { return first_.get(); } int second_fd() const override { return second_.get(); } @@ -270,6 +273,12 @@ Creator<SocketPair> TCPAcceptBindSocketPairCreator(int domain, int type, int protocol, bool dual_stack); +// TCPAcceptBindPersistentListenerSocketPairCreator is like +// TCPAcceptBindSocketPairCreator, except it uses the same listening socket to +// create all SocketPairs. +Creator<SocketPair> TCPAcceptBindPersistentListenerSocketPairCreator( + int domain, int type, int protocol, bool dual_stack); + // UDPBidirectionalBindSocketPairCreator returns a Creator<SocketPair> that // obtains file descriptors by invoking the bind() and connect() syscalls on UDP // sockets. @@ -475,10 +484,15 @@ struct TestAddress { : description(std::move(description)), addr(), addr_len() {} }; +constexpr char kMulticastAddress[] = "224.0.2.1"; +constexpr char kBroadcastAddress[] = "255.255.255.255"; + TestAddress V4Any(); +TestAddress V4Broadcast(); TestAddress V4Loopback(); TestAddress V4MappedAny(); TestAddress V4MappedLoopback(); +TestAddress V4Multicast(); TestAddress V6Any(); TestAddress V6Loopback(); diff --git a/test/syscalls/linux/socket_unix.cc b/test/syscalls/linux/socket_unix.cc index 8a28202a8..591cab3fd 100644 --- a/test/syscalls/linux/socket_unix.cc +++ b/test/syscalls/linux/socket_unix.cc @@ -65,6 +65,21 @@ TEST_P(UnixSocketPairTest, BindToBadName) { SyscallFailsWithErrno(ENOENT)); } +TEST_P(UnixSocketPairTest, BindToBadFamily) { + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + constexpr char kBadName[] = "/some/path/that/does/not/exist"; + sockaddr_un sockaddr; + sockaddr.sun_family = AF_INET; + memcpy(sockaddr.sun_path, kBadName, sizeof(kBadName)); + + EXPECT_THAT( + bind(pair->first_fd(), reinterpret_cast<struct sockaddr*>(&sockaddr), + sizeof(sockaddr)), + SyscallFailsWithErrno(EINVAL)); +} + TEST_P(UnixSocketPairTest, RecvmmsgTimeoutAfterRecv) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); char sent_data[10]; @@ -241,8 +256,9 @@ TEST_P(UnixSocketPairTest, ShutdownWrite) { } TEST_P(UnixSocketPairTest, SocketReopenFromProcfs) { - // TODO(b/122310852): We should be returning ENXIO and NOT EIO. - SKIP_IF(IsRunningOnGvisor()); + // TODO(gvisor.dev/issue/1624): In VFS1, we return EIO instead of ENXIO (see + // b/122310852). Remove this skip once VFS1 is deleted. + SKIP_IF(IsRunningWithVFS1()); auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); // Opening a socket pair via /proc/self/fd/X is a ENXIO. diff --git a/test/syscalls/linux/socket_unix_abstract_nonblock.cc b/test/syscalls/linux/socket_unix_abstract_nonblock.cc index be31ab2a7..8bef76b67 100644 --- a/test/syscalls/linux/socket_unix_abstract_nonblock.cc +++ b/test/syscalls/linux/socket_unix_abstract_nonblock.cc @@ -21,6 +21,7 @@ namespace gvisor { namespace testing { +namespace { std::vector<SocketPairKind> GetSocketPairs() { return ApplyVec<SocketPairKind>( @@ -33,5 +34,6 @@ INSTANTIATE_TEST_SUITE_P( NonBlockingAbstractUnixSockets, NonBlockingSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +} // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_blocking_local.cc b/test/syscalls/linux/socket_unix_blocking_local.cc index 1994139e6..77cb8c6d6 100644 --- a/test/syscalls/linux/socket_unix_blocking_local.cc +++ b/test/syscalls/linux/socket_unix_blocking_local.cc @@ -12,16 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "test/syscalls/linux/socket_blocking.h" - #include <vector> +#include "test/syscalls/linux/socket_blocking.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" namespace gvisor { namespace testing { +namespace { std::vector<SocketPairKind> GetSocketPairs() { return VecCat<SocketPairKind>( @@ -40,5 +40,6 @@ INSTANTIATE_TEST_SUITE_P( NonBlockingUnixDomainSockets, BlockingSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +} // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_cmsg.cc b/test/syscalls/linux/socket_unix_cmsg.cc index 1159c5229..a16899493 100644 --- a/test/syscalls/linux/socket_unix_cmsg.cc +++ b/test/syscalls/linux/socket_unix_cmsg.cc @@ -149,6 +149,35 @@ TEST_P(UnixSocketPairCmsgTest, BadFDPass) { SyscallFailsWithErrno(EBADF)); } +TEST_P(UnixSocketPairCmsgTest, ShortCmsg) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + int sent_fd = -1; + + struct msghdr msg = {}; + char control[CMSG_SPACE(sizeof(sent_fd))]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + cmsg->cmsg_len = 1; + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SCM_RIGHTS; + memcpy(CMSG_DATA(cmsg), &sent_fd, sizeof(sent_fd)); + + struct iovec iov; + iov.iov_base = sent_data; + iov.iov_len = sizeof(sent_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(sendmsg)(sockets->first_fd(), &msg, 0), + SyscallFailsWithErrno(EINVAL)); +} + // BasicFDPassNoSpace starts off by sending a single FD just like BasicFDPass. // The difference is that when calling recvmsg, no space for FDs is provided, // only space for the cmsg header. diff --git a/test/syscalls/linux/socket_unix_dgram.cc b/test/syscalls/linux/socket_unix_dgram.cc index 3245cf7c9..af0df4fb4 100644 --- a/test/syscalls/linux/socket_unix_dgram.cc +++ b/test/syscalls/linux/socket_unix_dgram.cc @@ -16,6 +16,7 @@ #include <stdio.h> #include <sys/un.h> + #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/socket_unix_dgram_local.cc b/test/syscalls/linux/socket_unix_dgram_local.cc index 9134fcdf7..31d2d5216 100644 --- a/test/syscalls/linux/socket_unix_dgram_local.cc +++ b/test/syscalls/linux/socket_unix_dgram_local.cc @@ -23,6 +23,7 @@ namespace gvisor { namespace testing { +namespace { std::vector<SocketPairKind> GetSocketPairs() { return VecCat<SocketPairKind>(VecCat<SocketPairKind>( @@ -52,5 +53,6 @@ INSTANTIATE_TEST_SUITE_P( DgramUnixSockets, NonStreamSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +} // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_dgram_non_blocking.cc b/test/syscalls/linux/socket_unix_dgram_non_blocking.cc index cd4fba25c..2db8b68d3 100644 --- a/test/syscalls/linux/socket_unix_dgram_non_blocking.cc +++ b/test/syscalls/linux/socket_unix_dgram_non_blocking.cc @@ -14,6 +14,7 @@ #include <stdio.h> #include <sys/un.h> + #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/socket_unix_domain.cc b/test/syscalls/linux/socket_unix_domain.cc index fa3efc7f8..f7dff8b4d 100644 --- a/test/syscalls/linux/socket_unix_domain.cc +++ b/test/syscalls/linux/socket_unix_domain.cc @@ -21,6 +21,7 @@ namespace gvisor { namespace testing { +namespace { std::vector<SocketPairKind> GetSocketPairs() { return ApplyVec<SocketPairKind>( @@ -33,5 +34,6 @@ INSTANTIATE_TEST_SUITE_P( AllUnixDomainSockets, AllSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +} // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_filesystem_nonblock.cc b/test/syscalls/linux/socket_unix_filesystem_nonblock.cc index 8ba7af971..6700b4d90 100644 --- a/test/syscalls/linux/socket_unix_filesystem_nonblock.cc +++ b/test/syscalls/linux/socket_unix_filesystem_nonblock.cc @@ -21,6 +21,7 @@ namespace gvisor { namespace testing { +namespace { std::vector<SocketPairKind> GetSocketPairs() { return ApplyVec<SocketPairKind>( @@ -33,5 +34,6 @@ INSTANTIATE_TEST_SUITE_P( NonBlockingFilesystemUnixSockets, NonBlockingSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +} // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_non_stream.cc b/test/syscalls/linux/socket_unix_non_stream.cc index 276a94eb8..884319e1d 100644 --- a/test/syscalls/linux/socket_unix_non_stream.cc +++ b/test/syscalls/linux/socket_unix_non_stream.cc @@ -109,7 +109,7 @@ PosixErrorOr<std::vector<Mapping>> CreateFragmentedRegion(const int size, } // A contiguous iov that is heavily fragmented in FileMem can still be sent -// successfully. +// successfully. See b/115833655. TEST_P(UnixNonStreamSocketPairTest, FragmentedSendMsg) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -165,7 +165,7 @@ TEST_P(UnixNonStreamSocketPairTest, FragmentedSendMsg) { } // A contiguous iov that is heavily fragmented in FileMem can still be received -// into successfully. +// into successfully. Regression test for b/115833655. TEST_P(UnixNonStreamSocketPairTest, FragmentedRecvMsg) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); diff --git a/test/syscalls/linux/socket_unix_non_stream_blocking_local.cc b/test/syscalls/linux/socket_unix_non_stream_blocking_local.cc index da762cd83..fddcdf1c5 100644 --- a/test/syscalls/linux/socket_unix_non_stream_blocking_local.cc +++ b/test/syscalls/linux/socket_unix_non_stream_blocking_local.cc @@ -12,16 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "test/syscalls/linux/socket_non_stream_blocking.h" - #include <vector> +#include "test/syscalls/linux/socket_non_stream_blocking.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" namespace gvisor { namespace testing { +namespace { std::vector<SocketPairKind> GetSocketPairs() { return VecCat<SocketPairKind>( @@ -37,5 +37,6 @@ INSTANTIATE_TEST_SUITE_P( BlockingNonStreamUnixSockets, BlockingNonStreamSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +} // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_pair.cc b/test/syscalls/linux/socket_unix_pair.cc index 411fb4518..85999db04 100644 --- a/test/syscalls/linux/socket_unix_pair.cc +++ b/test/syscalls/linux/socket_unix_pair.cc @@ -22,6 +22,7 @@ namespace gvisor { namespace testing { +namespace { std::vector<SocketPairKind> GetSocketPairs() { return VecCat<SocketPairKind>(ApplyVec<SocketPairKind>( @@ -38,5 +39,6 @@ INSTANTIATE_TEST_SUITE_P( AllUnixDomainSockets, UnixSocketPairCmsgTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +} // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_pair_nonblock.cc b/test/syscalls/linux/socket_unix_pair_nonblock.cc index 3135d325f..281410a9a 100644 --- a/test/syscalls/linux/socket_unix_pair_nonblock.cc +++ b/test/syscalls/linux/socket_unix_pair_nonblock.cc @@ -21,6 +21,7 @@ namespace gvisor { namespace testing { +namespace { std::vector<SocketPairKind> GetSocketPairs() { return ApplyVec<SocketPairKind>( @@ -33,5 +34,6 @@ INSTANTIATE_TEST_SUITE_P( NonBlockingUnixSockets, NonBlockingSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +} // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_seqpacket.cc b/test/syscalls/linux/socket_unix_seqpacket.cc index 60fa9e38a..6d03df4d9 100644 --- a/test/syscalls/linux/socket_unix_seqpacket.cc +++ b/test/syscalls/linux/socket_unix_seqpacket.cc @@ -16,6 +16,7 @@ #include <stdio.h> #include <sys/un.h> + #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" @@ -42,6 +43,24 @@ TEST_P(SeqpacketUnixSocketPairTest, ReadOneSideClosed) { SyscallSucceedsWithValue(0)); } +TEST_P(SeqpacketUnixSocketPairTest, Sendto) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + struct sockaddr_un addr = {}; + addr.sun_family = AF_UNIX; + constexpr char kPath[] = "\0nonexistent"; + memcpy(addr.sun_path, kPath, sizeof(kPath)); + + constexpr char kStr[] = "abc"; + ASSERT_THAT(sendto(sockets->second_fd(), kStr, 3, 0, (struct sockaddr*)&addr, + sizeof(addr)), + SyscallSucceedsWithValue(3)); + + char data[10] = {}; + ASSERT_THAT(read(sockets->first_fd(), data, sizeof(data)), + SyscallSucceedsWithValue(3)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_unix_seqpacket_local.cc b/test/syscalls/linux/socket_unix_seqpacket_local.cc index dff75a532..69a5f150d 100644 --- a/test/syscalls/linux/socket_unix_seqpacket_local.cc +++ b/test/syscalls/linux/socket_unix_seqpacket_local.cc @@ -23,6 +23,7 @@ namespace gvisor { namespace testing { +namespace { std::vector<SocketPairKind> GetSocketPairs() { return VecCat<SocketPairKind>(VecCat<SocketPairKind>( @@ -52,5 +53,6 @@ INSTANTIATE_TEST_SUITE_P( SeqpacketUnixSockets, UnixNonStreamSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +} // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_stream.cc b/test/syscalls/linux/socket_unix_stream.cc index 563467365..99e77b89e 100644 --- a/test/syscalls/linux/socket_unix_stream.cc +++ b/test/syscalls/linux/socket_unix_stream.cc @@ -89,6 +89,20 @@ TEST_P(StreamUnixSocketPairTest, ReadOneSideClosedWithUnreadData) { SyscallFailsWithErrno(ECONNRESET)); } +TEST_P(StreamUnixSocketPairTest, Sendto) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + struct sockaddr_un addr = {}; + addr.sun_family = AF_UNIX; + constexpr char kPath[] = "\0nonexistent"; + memcpy(addr.sun_path, kPath, sizeof(kPath)); + + constexpr char kStr[] = "abc"; + ASSERT_THAT(sendto(sockets->second_fd(), kStr, 3, 0, (struct sockaddr*)&addr, + sizeof(addr)), + SyscallFailsWithErrno(EISCONN)); +} + INSTANTIATE_TEST_SUITE_P( AllUnixDomainSockets, StreamUnixSocketPairTest, ::testing::ValuesIn(IncludeReversals(VecCat<SocketPairKind>( diff --git a/test/syscalls/linux/socket_unix_stream_blocking_local.cc b/test/syscalls/linux/socket_unix_stream_blocking_local.cc index fa0a9d367..8429bd429 100644 --- a/test/syscalls/linux/socket_unix_stream_blocking_local.cc +++ b/test/syscalls/linux/socket_unix_stream_blocking_local.cc @@ -12,16 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "test/syscalls/linux/socket_stream_blocking.h" - #include <vector> +#include "test/syscalls/linux/socket_stream_blocking.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" namespace gvisor { namespace testing { +namespace { std::vector<SocketPairKind> GetSocketPairs() { return { @@ -35,5 +35,6 @@ INSTANTIATE_TEST_SUITE_P( BlockingStreamUnixSockets, BlockingStreamSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +} // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_stream_local.cc b/test/syscalls/linux/socket_unix_stream_local.cc index 65eef1a81..a7e3449a9 100644 --- a/test/syscalls/linux/socket_unix_stream_local.cc +++ b/test/syscalls/linux/socket_unix_stream_local.cc @@ -21,6 +21,7 @@ namespace gvisor { namespace testing { +namespace { std::vector<SocketPairKind> GetSocketPairs() { return VecCat<SocketPairKind>( @@ -42,5 +43,6 @@ INSTANTIATE_TEST_SUITE_P( StreamUnixSockets, StreamSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +} // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_stream_nonblock_local.cc b/test/syscalls/linux/socket_unix_stream_nonblock_local.cc index ec777c59f..4b763c8e2 100644 --- a/test/syscalls/linux/socket_unix_stream_nonblock_local.cc +++ b/test/syscalls/linux/socket_unix_stream_nonblock_local.cc @@ -11,16 +11,16 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "test/syscalls/linux/socket_stream_nonblock.h" - #include <vector> +#include "test/syscalls/linux/socket_stream_nonblock.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" namespace gvisor { namespace testing { +namespace { std::vector<SocketPairKind> GetSocketPairs() { return { @@ -34,5 +34,6 @@ INSTANTIATE_TEST_SUITE_P( NonBlockingStreamUnixSockets, NonBlockingStreamSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +} // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_unbound_abstract.cc b/test/syscalls/linux/socket_unix_unbound_abstract.cc index 7f5816ace..8b1762000 100644 --- a/test/syscalls/linux/socket_unix_unbound_abstract.cc +++ b/test/syscalls/linux/socket_unix_unbound_abstract.cc @@ -14,6 +14,7 @@ #include <stdio.h> #include <sys/un.h> + #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/socket_unix_unbound_filesystem.cc b/test/syscalls/linux/socket_unix_unbound_filesystem.cc index b14f24086..cab912152 100644 --- a/test/syscalls/linux/socket_unix_unbound_filesystem.cc +++ b/test/syscalls/linux/socket_unix_unbound_filesystem.cc @@ -14,6 +14,7 @@ #include <stdio.h> #include <sys/un.h> + #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/socket_unix_unbound_seqpacket.cc b/test/syscalls/linux/socket_unix_unbound_seqpacket.cc index 50ffa1d04..cb99030f5 100644 --- a/test/syscalls/linux/socket_unix_unbound_seqpacket.cc +++ b/test/syscalls/linux/socket_unix_unbound_seqpacket.cc @@ -14,6 +14,7 @@ #include <stdio.h> #include <sys/un.h> + #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/socket_unix_unbound_stream.cc b/test/syscalls/linux/socket_unix_unbound_stream.cc index 344918c34..f185dded3 100644 --- a/test/syscalls/linux/socket_unix_unbound_stream.cc +++ b/test/syscalls/linux/socket_unix_unbound_stream.cc @@ -14,6 +14,7 @@ #include <stdio.h> #include <sys/un.h> + #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc index 85232cb1f..08fc4b1b7 100644 --- a/test/syscalls/linux/splice.cc +++ b/test/syscalls/linux/splice.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <fcntl.h> +#include <linux/unistd.h> #include <sys/eventfd.h> #include <sys/resource.h> #include <sys/sendfile.h> @@ -60,6 +61,62 @@ TEST(SpliceTest, TwoRegularFiles) { SyscallFailsWithErrno(EINVAL)); } +int memfd_create(const std::string& name, unsigned int flags) { + return syscall(__NR_memfd_create, name.c_str(), flags); +} + +TEST(SpliceTest, NegativeOffset) { + // Create a new pipe. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + // Fill the pipe. + std::vector<char> buf(kPageSize); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(kPageSize)); + + // Open the output file as write only. + int fd; + EXPECT_THAT(fd = memfd_create("negative", 0), SyscallSucceeds()); + const FileDescriptor out_fd(fd); + + loff_t out_offset = 0xffffffffffffffffull; + constexpr int kSize = 2; + EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), &out_offset, kSize, 0), + SyscallFailsWithErrno(EINVAL)); +} + +// Write offset + size overflows int64. +// +// This is a regression test for b/148041624. +TEST(SpliceTest, WriteOverflow) { + // Create a new pipe. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + // Fill the pipe. + std::vector<char> buf(kPageSize); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(kPageSize)); + + // Open the output file. + int fd; + EXPECT_THAT(fd = memfd_create("overflow", 0), SyscallSucceeds()); + const FileDescriptor out_fd(fd); + + // out_offset + kSize overflows INT64_MAX. + loff_t out_offset = 0x7ffffffffffffffeull; + constexpr int kSize = 3; + EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), &out_offset, kSize, 0), + SyscallFailsWithErrno(EINVAL)); +} + TEST(SpliceTest, SamePipe) { // Create a new pipe. int fds[2]; @@ -373,6 +430,55 @@ TEST(SpliceTest, TwoPipes) { EXPECT_EQ(memcmp(rbuf.data(), buf.data(), kPageSize), 0); } +TEST(SpliceTest, TwoPipesCircular) { + // This test deadlocks the sentry on VFS1 because VFS1 splice ordering is + // based on fs.File.UniqueID, which does not prevent circular ordering between + // e.g. inode-level locks taken by fs.FileOperations. + SKIP_IF(IsRunningWithVFS1()); + + // Create two pipes. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor first_rfd(fds[0]); + const FileDescriptor first_wfd(fds[1]); + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor second_rfd(fds[0]); + const FileDescriptor second_wfd(fds[1]); + + // On Linux, each pipe is normally limited to + // include/linux/pipe_fs_i.h:PIPE_DEF_BUFFERS buffers worth of data. + constexpr size_t PIPE_DEF_BUFFERS = 16; + + // Write some data to each pipe. Below we splice 1 byte at a time between + // pipes, which very quickly causes each byte to be stored in a separate + // buffer, so we must ensure that the total amount of data in the system is <= + // PIPE_DEF_BUFFERS bytes. + std::vector<char> buf(PIPE_DEF_BUFFERS / 2); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_THAT(write(first_wfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(buf.size())); + ASSERT_THAT(write(second_wfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(buf.size())); + + // Have another thread splice from the second pipe to the first, while we + // splice from the first to the second. The test passes if this does not + // deadlock. + const int kIterations = 1000; + DisableSave ds; + ScopedThread t([&]() { + for (int i = 0; i < kIterations; i++) { + ASSERT_THAT( + splice(second_rfd.get(), nullptr, first_wfd.get(), nullptr, 1, 0), + SyscallSucceedsWithValue(1)); + } + }); + for (int i = 0; i < kIterations; i++) { + ASSERT_THAT( + splice(first_rfd.get(), nullptr, second_wfd.get(), nullptr, 1, 0), + SyscallSucceedsWithValue(1)); + } +} + TEST(SpliceTest, Blocking) { // Create two new pipes. int first[2], second[2]; diff --git a/test/syscalls/linux/stat.cc b/test/syscalls/linux/stat.cc index 30de2f8ff..2503960f3 100644 --- a/test/syscalls/linux/stat.cc +++ b/test/syscalls/linux/stat.cc @@ -34,6 +34,13 @@ #include "test/util/temp_path.h" #include "test/util/test_util.h" +#ifndef AT_STATX_FORCE_SYNC +#define AT_STATX_FORCE_SYNC 0x2000 +#endif +#ifndef AT_STATX_DONT_SYNC +#define AT_STATX_DONT_SYNC 0x4000 +#endif + namespace gvisor { namespace testing { @@ -557,6 +564,8 @@ TEST(SimpleStatTest, AnonDeviceAllocatesUniqueInodesAcrossSaveRestore) { #ifndef SYS_statx #if defined(__x86_64__) #define SYS_statx 332 +#elif defined(__aarch64__) +#define SYS_statx 291 #else #error "Unknown architecture" #endif @@ -599,13 +608,13 @@ struct kernel_statx { uint64_t __spare2[14]; }; -int statx(int dirfd, const char *pathname, int flags, unsigned int mask, - struct kernel_statx *statxbuf) { +int statx(int dirfd, const char* pathname, int flags, unsigned int mask, + struct kernel_statx* statxbuf) { return syscall(SYS_statx, dirfd, pathname, flags, mask, statxbuf); } TEST_F(StatTest, StatxAbsPath) { - SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, 0) < 0 && + SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 && errno == ENOSYS); struct kernel_statx stx; @@ -615,7 +624,7 @@ TEST_F(StatTest, StatxAbsPath) { } TEST_F(StatTest, StatxRelPathDirFD) { - SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, 0) < 0 && + SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 && errno == ENOSYS); struct kernel_statx stx; @@ -629,7 +638,7 @@ TEST_F(StatTest, StatxRelPathDirFD) { } TEST_F(StatTest, StatxRelPathCwd) { - SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, 0) < 0 && + SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 && errno == ENOSYS); ASSERT_THAT(chdir(GetAbsoluteTestTmpdir().c_str()), SyscallSucceeds()); @@ -641,7 +650,7 @@ TEST_F(StatTest, StatxRelPathCwd) { } TEST_F(StatTest, StatxEmptyPath) { - SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, 0) < 0 && + SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 && errno == ENOSYS); const auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDONLY)); @@ -651,6 +660,60 @@ TEST_F(StatTest, StatxEmptyPath) { EXPECT_TRUE(S_ISREG(stx.stx_mode)); } +TEST_F(StatTest, StatxDoesNotRejectExtraneousMaskBits) { + SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 && + errno == ENOSYS); + + struct kernel_statx stx; + // Set all mask bits except for STATX__RESERVED. + uint mask = 0xffffffff & ~0x80000000; + EXPECT_THAT(statx(-1, test_file_name_.c_str(), 0, mask, &stx), + SyscallSucceeds()); + EXPECT_TRUE(S_ISREG(stx.stx_mode)); +} + +TEST_F(StatTest, StatxRejectsReservedMaskBit) { + SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 && + errno == ENOSYS); + + struct kernel_statx stx; + // Set STATX__RESERVED in the mask. + EXPECT_THAT(statx(-1, test_file_name_.c_str(), 0, 0x80000000, &stx), + SyscallFailsWithErrno(EINVAL)); +} + +TEST_F(StatTest, StatxSymlink) { + SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 && + errno == ENOSYS); + + std::string parent_dir = "/tmp"; + TempPath link = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(parent_dir, test_file_name_)); + std::string p = link.path(); + + struct kernel_statx stx; + EXPECT_THAT(statx(AT_FDCWD, p.c_str(), AT_SYMLINK_NOFOLLOW, STATX_ALL, &stx), + SyscallSucceeds()); + EXPECT_TRUE(S_ISLNK(stx.stx_mode)); + EXPECT_THAT(statx(AT_FDCWD, p.c_str(), 0, STATX_ALL, &stx), + SyscallSucceeds()); + EXPECT_TRUE(S_ISREG(stx.stx_mode)); +} + +TEST_F(StatTest, StatxInvalidFlags) { + SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 && + errno == ENOSYS); + + struct kernel_statx stx; + EXPECT_THAT(statx(AT_FDCWD, test_file_name_.c_str(), 12345, 0, &stx), + SyscallFailsWithErrno(EINVAL)); + + // Sync flags are mutually exclusive. + EXPECT_THAT(statx(AT_FDCWD, test_file_name_.c_str(), + AT_STATX_FORCE_SYNC | AT_STATX_DONT_SYNC, 0, &stx), + SyscallFailsWithErrno(EINVAL)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/sticky.cc b/test/syscalls/linux/sticky.cc index 7e73325bf..4afed6d08 100644 --- a/test/syscalls/linux/sticky.cc +++ b/test/syscalls/linux/sticky.cc @@ -40,10 +40,17 @@ namespace { TEST(StickyTest, StickyBitPermDenied) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID))); - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - EXPECT_THAT(chmod(dir.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds()); - std::string path = JoinPath(dir.path(), "NewDir"); - ASSERT_THAT(mkdir(path.c_str(), 0755), SyscallSucceeds()); + const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + EXPECT_THAT(chmod(parent.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds()); + + // After changing credentials below, we need to use an open fd to make + // modifications in the parent dir, because there is no guarantee that we will + // still have the ability to open it. + const FileDescriptor parent_fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(parent.path(), O_DIRECTORY)); + ASSERT_THAT(openat(parent_fd.get(), "file", O_CREAT), SyscallSucceeds()); + ASSERT_THAT(mkdirat(parent_fd.get(), "dir", 0777), SyscallSucceeds()); + ASSERT_THAT(symlinkat("xyz", parent_fd.get(), "link"), SyscallSucceeds()); // Drop privileges and change IDs only in child thread, or else this parent // thread won't be able to open some log files after the test ends. @@ -61,17 +68,31 @@ TEST(StickyTest, StickyBitPermDenied) { syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid), -1), SyscallSucceeds()); - EXPECT_THAT(rmdir(path.c_str()), SyscallFailsWithErrno(EPERM)); + EXPECT_THAT(renameat(parent_fd.get(), "file", parent_fd.get(), "file2"), + SyscallFailsWithErrno(EPERM)); + EXPECT_THAT(unlinkat(parent_fd.get(), "file", 0), + SyscallFailsWithErrno(EPERM)); + EXPECT_THAT(unlinkat(parent_fd.get(), "dir", AT_REMOVEDIR), + SyscallFailsWithErrno(EPERM)); + EXPECT_THAT(unlinkat(parent_fd.get(), "link", 0), + SyscallFailsWithErrno(EPERM)); }); } TEST(StickyTest, StickyBitSameUID) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID))); - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - EXPECT_THAT(chmod(dir.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds()); - std::string path = JoinPath(dir.path(), "NewDir"); - ASSERT_THAT(mkdir(path.c_str(), 0755), SyscallSucceeds()); + const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + EXPECT_THAT(chmod(parent.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds()); + + // After changing credentials below, we need to use an open fd to make + // modifications in the parent dir, because there is no guarantee that we will + // still have the ability to open it. + const FileDescriptor parent_fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(parent.path(), O_DIRECTORY)); + ASSERT_THAT(openat(parent_fd.get(), "file", O_CREAT), SyscallSucceeds()); + ASSERT_THAT(mkdirat(parent_fd.get(), "dir", 0777), SyscallSucceeds()); + ASSERT_THAT(symlinkat("xyz", parent_fd.get(), "link"), SyscallSucceeds()); // Drop privileges and change IDs only in child thread, or else this parent // thread won't be able to open some log files after the test ends. @@ -87,17 +108,29 @@ TEST(StickyTest, StickyBitSameUID) { SyscallSucceeds()); // We still have the same EUID. - EXPECT_THAT(rmdir(path.c_str()), SyscallSucceeds()); + EXPECT_THAT(renameat(parent_fd.get(), "file", parent_fd.get(), "file2"), + SyscallSucceeds()); + EXPECT_THAT(unlinkat(parent_fd.get(), "file2", 0), SyscallSucceeds()); + EXPECT_THAT(unlinkat(parent_fd.get(), "dir", AT_REMOVEDIR), + SyscallSucceeds()); + EXPECT_THAT(unlinkat(parent_fd.get(), "link", 0), SyscallSucceeds()); }); } TEST(StickyTest, StickyBitCapFOWNER) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID))); - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - EXPECT_THAT(chmod(dir.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds()); - std::string path = JoinPath(dir.path(), "NewDir"); - ASSERT_THAT(mkdir(path.c_str(), 0755), SyscallSucceeds()); + const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + EXPECT_THAT(chmod(parent.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds()); + + // After changing credentials below, we need to use an open fd to make + // modifications in the parent dir, because there is no guarantee that we will + // still have the ability to open it. + const FileDescriptor parent_fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(parent.path(), O_DIRECTORY)); + ASSERT_THAT(openat(parent_fd.get(), "file", O_CREAT), SyscallSucceeds()); + ASSERT_THAT(mkdirat(parent_fd.get(), "dir", 0777), SyscallSucceeds()); + ASSERT_THAT(symlinkat("xyz", parent_fd.get(), "link"), SyscallSucceeds()); // Drop privileges and change IDs only in child thread, or else this parent // thread won't be able to open some log files after the test ends. @@ -114,7 +147,12 @@ TEST(StickyTest, StickyBitCapFOWNER) { SyscallSucceeds()); EXPECT_NO_ERRNO(SetCapability(CAP_FOWNER, true)); - EXPECT_THAT(rmdir(path.c_str()), SyscallSucceeds()); + EXPECT_THAT(renameat(parent_fd.get(), "file", parent_fd.get(), "file2"), + SyscallSucceeds()); + EXPECT_THAT(unlinkat(parent_fd.get(), "file2", 0), SyscallSucceeds()); + EXPECT_THAT(unlinkat(parent_fd.get(), "dir", AT_REMOVEDIR), + SyscallSucceeds()); + EXPECT_THAT(unlinkat(parent_fd.get(), "link", 0), SyscallSucceeds()); }); } } // namespace diff --git a/test/syscalls/linux/symlink.cc b/test/syscalls/linux/symlink.cc index b249ff91f..a17ff62e9 100644 --- a/test/syscalls/linux/symlink.cc +++ b/test/syscalls/linux/symlink.cc @@ -20,6 +20,7 @@ #include <string> #include "gtest/gtest.h" +#include "absl/time/clock.h" #include "test/util/capability_util.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" @@ -38,7 +39,7 @@ mode_t FilePermission(const std::string& path) { } // Test that name collisions are checked on the new link path, not the source -// path. +// path. Regression test for b/31782115. TEST(SymlinkTest, CanCreateSymlinkWithCachedSourceDirent) { const std::string srcname = NewTempAbsPath(); const std::string newname = NewTempAbsPath(); @@ -272,6 +273,30 @@ TEST(SymlinkTest, ChmodSymlink) { EXPECT_EQ(FilePermission(newpath), 0777); } +// Test that following a symlink updates the atime on the symlink. +TEST(SymlinkTest, FollowUpdatesATime) { + const auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const std::string link = NewTempAbsPath(); + EXPECT_THAT(symlink(file.path().c_str(), link.c_str()), SyscallSucceeds()); + + // Lstat the symlink. + struct stat st_before_follow; + ASSERT_THAT(lstat(link.c_str(), &st_before_follow), SyscallSucceeds()); + + // Let the clock advance. + absl::SleepFor(absl::Seconds(1)); + + // Open the file via the symlink. + int fd; + ASSERT_THAT(fd = open(link.c_str(), O_RDWR, 0666), SyscallSucceeds()); + FileDescriptor fd_closer(fd); + + // Lstat the symlink again, and check that atime is updated. + struct stat st_after_follow; + ASSERT_THAT(lstat(link.c_str(), &st_after_follow), SyscallSucceeds()); + EXPECT_LT(st_before_follow.st_atime, st_after_follow.st_atime); +} + class ParamSymlinkTest : public ::testing::TestWithParam<std::string> {}; // Test that creating an existing symlink with creat will create the target. diff --git a/test/syscalls/linux/sync.cc b/test/syscalls/linux/sync.cc index fe479390d..8aa2525a9 100644 --- a/test/syscalls/linux/sync.cc +++ b/test/syscalls/linux/sync.cc @@ -14,10 +14,9 @@ #include <fcntl.h> #include <stdio.h> -#include <unistd.h> - #include <sys/syscall.h> #include <unistd.h> + #include <string> #include "gtest/gtest.h" diff --git a/test/syscalls/linux/sysret.cc b/test/syscalls/linux/sysret.cc index 819fa655a..19ffbd85b 100644 --- a/test/syscalls/linux/sysret.cc +++ b/test/syscalls/linux/sysret.cc @@ -14,6 +14,8 @@ // Tests to verify that the behavior of linux and gvisor matches when // 'sysret' returns to bad (aka non-canonical) %rip or %rsp. + +#include <linux/elf.h> #include <sys/ptrace.h> #include <sys/user.h> @@ -32,6 +34,7 @@ constexpr uint64_t kNonCanonicalRsp = 0xFFFF000000000000; class SysretTest : public ::testing::Test { protected: struct user_regs_struct regs_; + struct iovec iov; pid_t child_; void SetUp() override { @@ -48,10 +51,15 @@ class SysretTest : public ::testing::Test { // Parent. int status; + memset(&iov, 0, sizeof(iov)); ASSERT_THAT(pid, SyscallSucceeds()); // Might still be < 0. ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid)); EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP); - ASSERT_THAT(ptrace(PTRACE_GETREGS, pid, 0, ®s_), SyscallSucceeds()); + + iov.iov_base = ®s_; + iov.iov_len = sizeof(regs_); + ASSERT_THAT(ptrace(PTRACE_GETREGSET, pid, NT_PRSTATUS, &iov), + SyscallSucceeds()); child_ = pid; } @@ -61,13 +69,27 @@ class SysretTest : public ::testing::Test { } void SetRip(uint64_t newrip) { +#if defined(__x86_64__) regs_.rip = newrip; - ASSERT_THAT(ptrace(PTRACE_SETREGS, child_, 0, ®s_), SyscallSucceeds()); +#elif defined(__aarch64__) + regs_.pc = newrip; +#else +#error "Unknown architecture" +#endif + ASSERT_THAT(ptrace(PTRACE_SETREGSET, child_, NT_PRSTATUS, &iov), + SyscallSucceeds()); } void SetRsp(uint64_t newrsp) { +#if defined(__x86_64__) regs_.rsp = newrsp; - ASSERT_THAT(ptrace(PTRACE_SETREGS, child_, 0, ®s_), SyscallSucceeds()); +#elif defined(__aarch64__) + regs_.sp = newrsp; +#else +#error "Unknown architecture" +#endif + ASSERT_THAT(ptrace(PTRACE_SETREGSET, child_, NT_PRSTATUS, &iov), + SyscallSucceeds()); } // Wait waits for the child pid and returns the exit status. @@ -104,8 +126,15 @@ TEST_F(SysretTest, BadRsp) { SetRsp(kNonCanonicalRsp); Detach(); int status = Wait(); +#if defined(__x86_64__) EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGBUS) << "status = " << status; +#elif defined(__aarch64__) + EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGSEGV) + << "status = " << status; +#else +#error "Unknown architecture" +#endif } } // namespace diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index bfa031bce..a6325a761 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -13,6 +13,9 @@ // limitations under the License. #include <fcntl.h> +#ifndef __fuchsia__ +#include <linux/filter.h> +#endif // __fuchsia__ #include <netinet/in.h> #include <netinet/tcp.h> #include <poll.h> @@ -130,6 +133,33 @@ void TcpSocketTest::TearDown() { } } +TEST_P(TcpSocketTest, ConnectOnEstablishedConnection) { + sockaddr_storage addr = + ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); + socklen_t addrlen = sizeof(addr); + + ASSERT_THAT( + connect(s_, reinterpret_cast<const struct sockaddr*>(&addr), addrlen), + SyscallFailsWithErrno(EISCONN)); + ASSERT_THAT( + connect(t_, reinterpret_cast<const struct sockaddr*>(&addr), addrlen), + SyscallFailsWithErrno(EISCONN)); +} + +TEST_P(TcpSocketTest, ShutdownWriteInTimeWait) { + EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds()); + EXPECT_THAT(shutdown(s_, SHUT_RDWR), SyscallSucceeds()); + absl::SleepFor(absl::Seconds(1)); // Wait to enter TIME_WAIT. + EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(TcpSocketTest, ShutdownWriteInFinWait1) { + EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds()); + EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds()); + absl::SleepFor(absl::Seconds(1)); // Wait to enter FIN-WAIT2. + EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds()); +} + TEST_P(TcpSocketTest, DataCoalesced) { char buf[10]; @@ -231,7 +261,8 @@ TEST_P(TcpSocketTest, ZeroWriteAllowed) { } // Test that a non-blocking write with a buffer that is larger than the send -// buffer size will not actually write the whole thing at once. +// buffer size will not actually write the whole thing at once. Regression test +// for b/64438887. TEST_P(TcpSocketTest, NonblockingLargeWrite) { // Set the FD to O_NONBLOCK. int opts; @@ -394,8 +425,15 @@ TEST_P(TcpSocketTest, PollWithFullBufferBlocks) { sizeof(tcp_nodelay_flag)), SyscallSucceeds()); + // Set a 256KB send/receive buffer. + int buf_sz = 1 << 18; + EXPECT_THAT(setsockopt(t_, SOL_SOCKET, SO_RCVBUF, &buf_sz, sizeof(buf_sz)), + SyscallSucceedsWithValue(0)); + EXPECT_THAT(setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &buf_sz, sizeof(buf_sz)), + SyscallSucceedsWithValue(0)); + // Create a large buffer that will be used for sending. - std::vector<char> buf(10 * sendbuf_size_); + std::vector<char> buf(1 << 16); // Write until we receive an error. while (RetryEINTR(send)(s_, buf.data(), buf.size(), 0) != -1) { @@ -405,6 +443,11 @@ TEST_P(TcpSocketTest, PollWithFullBufferBlocks) { } // The last error should have been EWOULDBLOCK. ASSERT_EQ(errno, EWOULDBLOCK); + + // Now polling on the FD with a timeout should return 0 corresponding to no + // FDs ready. + struct pollfd poll_fd = {s_, POLLOUT, 0}; + EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10), SyscallSucceedsWithValue(0)); } TEST_P(TcpSocketTest, MsgTrunc) { @@ -677,6 +720,30 @@ TEST_P(TcpSocketTest, TcpSCMPriority) { ASSERT_EQ(cmsg, nullptr); } +TEST_P(TcpSocketTest, TimeWaitPollHUP) { + shutdown(s_, SHUT_RDWR); + ScopedThread t([&]() { + constexpr int kTimeout = 10000; + constexpr int16_t want_events = POLLHUP; + struct pollfd pfd = { + .fd = s_, + .events = want_events, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + }); + shutdown(t_, SHUT_RDWR); + t.Join(); + // At this point s_ should be in TIME-WAIT and polling for POLLHUP should + // return with 1 FD. + constexpr int kTimeout = 10000; + constexpr int16_t want_events = POLLHUP; + struct pollfd pfd = { + .fd = s_, + .events = want_events, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, TcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); @@ -789,6 +856,20 @@ TEST_P(TcpSocketTest, FullBuffer) { t_ = -1; } +TEST_P(TcpSocketTest, PollAfterShutdown) { + ScopedThread client_thread([this]() { + EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallSucceedsWithValue(0)); + struct pollfd poll_fd = {s_, POLLIN | POLLERR | POLLHUP, 0}; + EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000), + SyscallSucceedsWithValue(1)); + }); + + EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceedsWithValue(0)); + struct pollfd poll_fd = {t_, POLLIN | POLLERR | POLLHUP, 0}; + EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000), + SyscallSucceedsWithValue(1)); +} + TEST_P(SimpleTcpSocketTest, NonBlockingConnectNoListener) { // Initialize address to the loopback one. sockaddr_storage addr = @@ -942,6 +1023,78 @@ TEST_P(SimpleTcpSocketTest, BlockingConnectRefused) { EXPECT_THAT(close(s.release()), SyscallSucceeds()); } +// Test that connecting to a non-listening port and thus receiving a RST is +// handled appropriately by the socket - the port that the socket was bound to +// is released and the expected error is returned. +TEST_P(SimpleTcpSocketTest, CleanupOnConnectionRefused) { + // Create a socket that is known to not be listening. As is it bound but not + // listening, when another socket connects to the port, it will refuse.. + FileDescriptor bound_s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage bound_addr = + ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); + socklen_t bound_addrlen = sizeof(bound_addr); + + ASSERT_THAT( + bind(bound_s.get(), reinterpret_cast<struct sockaddr*>(&bound_addr), + bound_addrlen), + SyscallSucceeds()); + + // Get the addresses the socket is bound to because the port is chosen by the + // stack. + ASSERT_THAT(getsockname(bound_s.get(), + reinterpret_cast<struct sockaddr*>(&bound_addr), + &bound_addrlen), + SyscallSucceeds()); + + // Create, initialize, and bind the socket that is used to test connecting to + // the non-listening port. + FileDescriptor client_s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + // Initialize client address to the loopback one. + sockaddr_storage client_addr = + ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); + socklen_t client_addrlen = sizeof(client_addr); + + ASSERT_THAT( + bind(client_s.get(), reinterpret_cast<struct sockaddr*>(&client_addr), + client_addrlen), + SyscallSucceeds()); + + ASSERT_THAT(getsockname(client_s.get(), + reinterpret_cast<struct sockaddr*>(&client_addr), + &client_addrlen), + SyscallSucceeds()); + + // Now the test: connect to the bound but not listening socket with the + // client socket. The bound socket should return a RST and cause the client + // socket to return an error and clean itself up immediately. + // The error being ECONNREFUSED diverges with RFC 793, page 37, but does what + // Linux does. + ASSERT_THAT(connect(client_s.get(), + reinterpret_cast<const struct sockaddr*>(&bound_addr), + bound_addrlen), + SyscallFailsWithErrno(ECONNREFUSED)); + + FileDescriptor new_s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + // Test binding to the address from the client socket. This should be okay + // if it was dropped correctly. + ASSERT_THAT( + bind(new_s.get(), reinterpret_cast<struct sockaddr*>(&client_addr), + client_addrlen), + SyscallSucceeds()); + + // Attempt #2, with the new socket and reused addr our connect should fail in + // the same way as before, not with an EADDRINUSE. + ASSERT_THAT(connect(client_s.get(), + reinterpret_cast<const struct sockaddr*>(&bound_addr), + bound_addrlen), + SyscallFailsWithErrno(ECONNREFUSED)); +} + // Test that we get an ECONNREFUSED with a nonblocking socket. TEST_P(SimpleTcpSocketTest, NonBlockingConnectRefused) { FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( @@ -1150,6 +1303,346 @@ TEST_P(SimpleTcpSocketTest, SetMaxSegFailsForInvalidMSSValues) { } } +TEST_P(SimpleTcpSocketTest, SetTCPUserTimeout) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + { + constexpr int kTCPUserTimeout = -1; + EXPECT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kTCPUserTimeout, sizeof(kTCPUserTimeout)), + SyscallFailsWithErrno(EINVAL)); + } + + // kTCPUserTimeout is in milliseconds. + constexpr int kTCPUserTimeout = 100; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kTCPUserTimeout, sizeof(kTCPUserTimeout)), + SyscallSucceedsWithValue(0)); + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kTCPUserTimeout); +} + +TEST_P(SimpleTcpSocketTest, SetTCPDeferAcceptNeg) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + // -ve TCP_DEFER_ACCEPT is same as setting it to zero. + constexpr int kNeg = -1; + EXPECT_THAT( + setsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, &kNeg, sizeof(kNeg)), + SyscallSucceeds()); + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 0); +} + +TEST_P(SimpleTcpSocketTest, GetTCPDeferAcceptDefault) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 0); +} + +TEST_P(SimpleTcpSocketTest, SetTCPDeferAcceptGreaterThanZero) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + // kTCPDeferAccept is in seconds. + // NOTE: linux translates seconds to # of retries and back from + // #of retries to seconds. Which means only certain values + // translate back exactly. That's why we use 3 here, a value of + // 5 will result in us getting back 7 instead of 5 in the + // getsockopt. + constexpr int kTCPDeferAccept = 3; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, + &kTCPDeferAccept, sizeof(kTCPDeferAccept)), + SyscallSucceeds()); + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kTCPDeferAccept); +} + +TEST_P(SimpleTcpSocketTest, RecvOnClosedSocket) { + auto s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + char buf[1]; + EXPECT_THAT(recv(s.get(), buf, 0, 0), SyscallFailsWithErrno(ENOTCONN)); + EXPECT_THAT(recv(s.get(), buf, sizeof(buf), 0), + SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(SimpleTcpSocketTest, TCPConnectSoRcvBufRace) { + auto s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(GetParam(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + sockaddr_storage addr = + ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); + socklen_t addrlen = sizeof(addr); + + RetryEINTR(connect)(s.get(), reinterpret_cast<struct sockaddr*>(&addr), + addrlen); + int buf_sz = 1 << 18; + EXPECT_THAT( + setsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &buf_sz, sizeof(buf_sz)), + SyscallSucceedsWithValue(0)); +} + +TEST_P(SimpleTcpSocketTest, SetTCPSynCntLessThanOne) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + int default_syn_cnt = get; + + { + // TCP_SYNCNT less than 1 should be rejected with an EINVAL. + constexpr int kZero = 0; + EXPECT_THAT( + setsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &kZero, sizeof(kZero)), + SyscallFailsWithErrno(EINVAL)); + + // TCP_SYNCNT less than 1 should be rejected with an EINVAL. + constexpr int kNeg = -1; + EXPECT_THAT( + setsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &kNeg, sizeof(kNeg)), + SyscallFailsWithErrno(EINVAL)); + + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(default_syn_cnt, get); + } +} + +TEST_P(SimpleTcpSocketTest, GetTCPSynCntDefault) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + int get = -1; + socklen_t get_len = sizeof(get); + constexpr int kDefaultSynCnt = 6; + + ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kDefaultSynCnt); +} + +TEST_P(SimpleTcpSocketTest, SetTCPSynCntGreaterThanOne) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + constexpr int kTCPSynCnt = 20; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &kTCPSynCnt, + sizeof(kTCPSynCnt)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kTCPSynCnt); +} + +TEST_P(SimpleTcpSocketTest, SetTCPSynCntAboveMax) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + int default_syn_cnt = get; + { + constexpr int kTCPSynCnt = 256; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &kTCPSynCnt, + sizeof(kTCPSynCnt)), + SyscallFailsWithErrno(EINVAL)); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, default_syn_cnt); + } +} + +TEST_P(SimpleTcpSocketTest, SetTCPWindowClampBelowMinRcvBuf) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + // Discover minimum receive buf by setting a really low value + // for the receive buffer. + constexpr int kZero = 0; + EXPECT_THAT(setsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &kZero, sizeof(kZero)), + SyscallSucceeds()); + + // Now retrieve the minimum value for SO_RCVBUF as the set above should + // have caused SO_RCVBUF for the socket to be set to the minimum. + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + int min_so_rcvbuf = get; + + { + // TCP_WINDOW_CLAMP less than min_so_rcvbuf/2 should be set to + // min_so_rcvbuf/2. + int below_half_min_rcvbuf = min_so_rcvbuf / 2 - 1; + EXPECT_THAT( + setsockopt(s.get(), IPPROTO_TCP, TCP_WINDOW_CLAMP, + &below_half_min_rcvbuf, sizeof(below_half_min_rcvbuf)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_WINDOW_CLAMP, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(min_so_rcvbuf / 2, get); + } +} + +TEST_P(SimpleTcpSocketTest, SetTCPWindowClampZeroClosedSocket) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + constexpr int kZero = 0; + ASSERT_THAT( + setsockopt(s.get(), IPPROTO_TCP, TCP_WINDOW_CLAMP, &kZero, sizeof(kZero)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_WINDOW_CLAMP, &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kZero); +} + +TEST_P(SimpleTcpSocketTest, SetTCPWindowClampAboveHalfMinRcvBuf) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + // Discover minimum receive buf by setting a really low value + // for the receive buffer. + constexpr int kZero = 0; + EXPECT_THAT(setsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &kZero, sizeof(kZero)), + SyscallSucceeds()); + + // Now retrieve the minimum value for SO_RCVBUF as the set above should + // have caused SO_RCVBUF for the socket to be set to the minimum. + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + int min_so_rcvbuf = get; + + { + int above_half_min_rcv_buf = min_so_rcvbuf / 2 + 1; + EXPECT_THAT( + setsockopt(s.get(), IPPROTO_TCP, TCP_WINDOW_CLAMP, + &above_half_min_rcv_buf, sizeof(above_half_min_rcv_buf)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_WINDOW_CLAMP, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(above_half_min_rcv_buf, get); + } +} + +#ifndef __fuchsia__ + +// TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. +// gVisor currently silently ignores attaching a filter. +TEST_P(SimpleTcpSocketTest, SetSocketAttachDetachFilter) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + // Program generated using sudo tcpdump -i lo tcp and port 1234 -dd + struct sock_filter code[] = { + {0x28, 0, 0, 0x0000000c}, {0x15, 0, 6, 0x000086dd}, + {0x30, 0, 0, 0x00000014}, {0x15, 0, 15, 0x00000006}, + {0x28, 0, 0, 0x00000036}, {0x15, 12, 0, 0x000004d2}, + {0x28, 0, 0, 0x00000038}, {0x15, 10, 11, 0x000004d2}, + {0x15, 0, 10, 0x00000800}, {0x30, 0, 0, 0x00000017}, + {0x15, 0, 8, 0x00000006}, {0x28, 0, 0, 0x00000014}, + {0x45, 6, 0, 0x00001fff}, {0xb1, 0, 0, 0x0000000e}, + {0x48, 0, 0, 0x0000000e}, {0x15, 2, 0, 0x000004d2}, + {0x48, 0, 0, 0x00000010}, {0x15, 0, 1, 0x000004d2}, + {0x6, 0, 0, 0x00040000}, {0x6, 0, 0, 0x00000000}, + }; + struct sock_fprog bpf = { + .len = ABSL_ARRAYSIZE(code), + .filter = code, + }; + ASSERT_THAT( + setsockopt(s.get(), SOL_SOCKET, SO_ATTACH_FILTER, &bpf, sizeof(bpf)), + SyscallSucceeds()); + + constexpr int val = 0; + ASSERT_THAT( + setsockopt(s.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallSucceeds()); +} + +TEST_P(SimpleTcpSocketTest, SetSocketDetachFilterNoInstalledFilter) { + // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. + SKIP_IF(IsRunningOnGvisor()); + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + constexpr int val = 0; + ASSERT_THAT( + setsockopt(s.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallFailsWithErrno(ENOENT)); +} + +TEST_P(SimpleTcpSocketTest, GetSocketDetachFilter) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len), + SyscallFailsWithErrno(ENOPROTOOPT)); +} + +#endif // __fuchsia__ + INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); diff --git a/test/syscalls/linux/time.cc b/test/syscalls/linux/time.cc index c7eead17e..e75bba669 100644 --- a/test/syscalls/linux/time.cc +++ b/test/syscalls/linux/time.cc @@ -26,6 +26,7 @@ namespace { constexpr long kFudgeSeconds = 5; +#if defined(__x86_64__) || defined(__i386__) // Mimics the time(2) wrapper from glibc prior to 2.15. time_t vsyscall_time(time_t* t) { constexpr uint64_t kVsyscallTimeEntry = 0xffffffffff600400; @@ -62,6 +63,7 @@ TEST(TimeTest, VsyscallTime_InvalidAddressSIGSEGV) { ::testing::KilledBySignal(SIGSEGV), ""); } +// Mimics the gettimeofday(2) wrapper from the Go runtime <= 1.2. int vsyscall_gettimeofday(struct timeval* tv, struct timezone* tz) { constexpr uint64_t kVsyscallGettimeofdayEntry = 0xffffffffff600000; return reinterpret_cast<int (*)(struct timeval*, struct timezone*)>( @@ -97,6 +99,7 @@ TEST(TimeTest, VsyscallGettimeofday_InvalidAddressSIGSEGV) { reinterpret_cast<struct timezone*>(0x1)), ::testing::KilledBySignal(SIGSEGV), ""); } +#endif } // namespace diff --git a/test/syscalls/linux/timerfd.cc b/test/syscalls/linux/timerfd.cc index 86ed87b7c..c4f8fdd7a 100644 --- a/test/syscalls/linux/timerfd.cc +++ b/test/syscalls/linux/timerfd.cc @@ -204,16 +204,33 @@ TEST_P(TimerfdTest, SetAbsoluteTime) { EXPECT_EQ(1, val); } -TEST_P(TimerfdTest, IllegalReadWrite) { +TEST_P(TimerfdTest, IllegalSeek) { + auto const tfd = ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), 0)); + if (!IsRunningWithVFS1()) { + EXPECT_THAT(lseek(tfd.get(), 0, SEEK_SET), SyscallFailsWithErrno(ESPIPE)); + } +} + +TEST_P(TimerfdTest, IllegalPread) { + auto const tfd = ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), 0)); + int val; + EXPECT_THAT(pread(tfd.get(), &val, sizeof(val), 0), + SyscallFailsWithErrno(ESPIPE)); +} + +TEST_P(TimerfdTest, IllegalPwrite) { + auto const tfd = ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), 0)); + EXPECT_THAT(pwrite(tfd.get(), "x", 1, 0), SyscallFailsWithErrno(ESPIPE)); + if (!IsRunningWithVFS1()) { + } +} + +TEST_P(TimerfdTest, IllegalWrite) { auto const tfd = ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), TFD_NONBLOCK)); uint64_t val = 0; - EXPECT_THAT(PreadFd(tfd.get(), &val, sizeof(val), 0), - SyscallFailsWithErrno(ESPIPE)); - EXPECT_THAT(WriteFd(tfd.get(), &val, sizeof(val)), + EXPECT_THAT(write(tfd.get(), &val, sizeof(val)), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(PwriteFd(tfd.get(), &val, sizeof(val), 0), - SyscallFailsWithErrno(ESPIPE)); } std::string PrintClockId(::testing::TestParamInfo<int> info) { diff --git a/test/syscalls/linux/timers.cc b/test/syscalls/linux/timers.cc index 3db18d7ac..4b3c44527 100644 --- a/test/syscalls/linux/timers.cc +++ b/test/syscalls/linux/timers.cc @@ -297,9 +297,13 @@ class IntervalTimer { PosixErrorOr<IntervalTimer> TimerCreate(clockid_t clockid, const struct sigevent& sev) { int timerid; - if (syscall(SYS_timer_create, clockid, &sev, &timerid) < 0) { + int ret = syscall(SYS_timer_create, clockid, &sev, &timerid); + if (ret < 0) { return PosixError(errno, "timer_create"); } + if (ret > 0) { + return PosixError(EINVAL, "timer_create should never return positive"); + } MaybeSave(); return IntervalTimer(timerid); } @@ -317,6 +321,18 @@ TEST(IntervalTimerTest, IsInitiallyStopped) { EXPECT_EQ(0, its.it_value.tv_nsec); } +// Kernel can create multiple timers without issue. +// +// Regression test for gvisor.dev/issue/1738. +TEST(IntervalTimerTest, MultipleTimers) { + struct sigevent sev = {}; + sev.sigev_notify = SIGEV_NONE; + const auto timer1 = + ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev)); + const auto timer2 = + ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev)); +} + TEST(IntervalTimerTest, SingleShotSilent) { struct sigevent sev = {}; sev.sigev_notify = SIGEV_NONE; @@ -642,5 +658,5 @@ int main(int argc, char** argv) { } } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/tkill.cc b/test/syscalls/linux/tkill.cc index bae377c69..8d8ebbb24 100644 --- a/test/syscalls/linux/tkill.cc +++ b/test/syscalls/linux/tkill.cc @@ -54,7 +54,7 @@ void SigHandler(int sig, siginfo_t* info, void* context) { TEST_CHECK(info->si_code == SI_TKILL); } -// Test with a real signal. +// Test with a real signal. Regression test for b/24790092. TEST(TkillTest, ValidTIDAndRealSignal) { struct sigaction sa; sa.sa_sigaction = SigHandler; diff --git a/test/syscalls/linux/truncate.cc b/test/syscalls/linux/truncate.cc index e5cc5d97c..c988c6380 100644 --- a/test/syscalls/linux/truncate.cc +++ b/test/syscalls/linux/truncate.cc @@ -19,6 +19,7 @@ #include <sys/vfs.h> #include <time.h> #include <unistd.h> + #include <iostream> #include <string> diff --git a/test/syscalls/linux/tuntap.cc b/test/syscalls/linux/tuntap.cc new file mode 100644 index 000000000..97d554e72 --- /dev/null +++ b/test/syscalls/linux/tuntap.cc @@ -0,0 +1,422 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <arpa/inet.h> +#include <linux/capability.h> +#include <linux/if_arp.h> +#include <linux/if_ether.h> +#include <linux/if_tun.h> +#include <netinet/ip.h> +#include <netinet/ip_icmp.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_split.h" +#include "test/syscalls/linux/socket_netlink_route_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/fs_util.h" +#include "test/util/posix_error.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { +namespace { + +constexpr int kIPLen = 4; + +constexpr const char kDevNetTun[] = "/dev/net/tun"; +constexpr const char kTapName[] = "tap0"; + +constexpr const uint8_t kMacA[ETH_ALEN] = {0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}; +constexpr const uint8_t kMacB[ETH_ALEN] = {0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB}; + +PosixErrorOr<std::set<std::string>> DumpLinkNames() { + ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks()); + std::set<std::string> names; + for (const auto& link : links) { + names.emplace(link.name); + } + return names; +} + +PosixErrorOr<Link> GetLinkByName(const std::string& name) { + ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks()); + for (const auto& link : links) { + if (link.name == name) { + return link; + } + } + return PosixError(ENOENT, "interface not found"); +} + +struct pihdr { + uint16_t pi_flags; + uint16_t pi_protocol; +} __attribute__((packed)); + +struct ping_pkt { + pihdr pi; + struct ethhdr eth; + struct iphdr ip; + struct icmphdr icmp; + char payload[64]; +} __attribute__((packed)); + +ping_pkt CreatePingPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip, + const uint8_t dstmac[ETH_ALEN], const char* dstip) { + ping_pkt pkt = {}; + + pkt.pi.pi_protocol = htons(ETH_P_IP); + + memcpy(pkt.eth.h_dest, dstmac, sizeof(pkt.eth.h_dest)); + memcpy(pkt.eth.h_source, srcmac, sizeof(pkt.eth.h_source)); + pkt.eth.h_proto = htons(ETH_P_IP); + + pkt.ip.ihl = 5; + pkt.ip.version = 4; + pkt.ip.tos = 0; + pkt.ip.tot_len = htons(sizeof(struct iphdr) + sizeof(struct icmphdr) + + sizeof(pkt.payload)); + pkt.ip.id = 1; + pkt.ip.frag_off = 1 << 6; // Do not fragment + pkt.ip.ttl = 64; + pkt.ip.protocol = IPPROTO_ICMP; + inet_pton(AF_INET, dstip, &pkt.ip.daddr); + inet_pton(AF_INET, srcip, &pkt.ip.saddr); + pkt.ip.check = IPChecksum(pkt.ip); + + pkt.icmp.type = ICMP_ECHO; + pkt.icmp.code = 0; + pkt.icmp.checksum = 0; + pkt.icmp.un.echo.sequence = 1; + pkt.icmp.un.echo.id = 1; + + strncpy(pkt.payload, "abcd", sizeof(pkt.payload)); + pkt.icmp.checksum = ICMPChecksum(pkt.icmp, pkt.payload, sizeof(pkt.payload)); + + return pkt; +} + +struct arp_pkt { + pihdr pi; + struct ethhdr eth; + struct arphdr arp; + uint8_t arp_sha[ETH_ALEN]; + uint8_t arp_spa[kIPLen]; + uint8_t arp_tha[ETH_ALEN]; + uint8_t arp_tpa[kIPLen]; +} __attribute__((packed)); + +std::string CreateArpPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip, + const uint8_t dstmac[ETH_ALEN], const char* dstip) { + std::string buffer; + buffer.resize(sizeof(arp_pkt)); + + arp_pkt* pkt = reinterpret_cast<arp_pkt*>(&buffer[0]); + { + pkt->pi.pi_protocol = htons(ETH_P_ARP); + + memcpy(pkt->eth.h_dest, kMacA, sizeof(pkt->eth.h_dest)); + memcpy(pkt->eth.h_source, kMacB, sizeof(pkt->eth.h_source)); + pkt->eth.h_proto = htons(ETH_P_ARP); + + pkt->arp.ar_hrd = htons(ARPHRD_ETHER); + pkt->arp.ar_pro = htons(ETH_P_IP); + pkt->arp.ar_hln = ETH_ALEN; + pkt->arp.ar_pln = kIPLen; + pkt->arp.ar_op = htons(ARPOP_REPLY); + + memcpy(pkt->arp_sha, srcmac, sizeof(pkt->arp_sha)); + inet_pton(AF_INET, srcip, pkt->arp_spa); + memcpy(pkt->arp_tha, dstmac, sizeof(pkt->arp_tha)); + inet_pton(AF_INET, dstip, pkt->arp_tpa); + } + return buffer; +} + +} // namespace + +TEST(TuntapStaticTest, NetTunExists) { + struct stat statbuf; + ASSERT_THAT(stat(kDevNetTun, &statbuf), SyscallSucceeds()); + // Check that it's a character device with rw-rw-rw- permissions. + EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666); +} + +class TuntapTest : public ::testing::Test { + protected: + void TearDown() override { + if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))) { + // Bring back capability if we had dropped it in test case. + ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, true)); + } + } +}; + +TEST_F(TuntapTest, CreateInterfaceNoCap) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, false)); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + struct ifreq ifr = {}; + ifr.ifr_flags = IFF_TAP; + strncpy(ifr.ifr_name, kTapName, IFNAMSIZ); + + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallFailsWithErrno(EPERM)); +} + +TEST_F(TuntapTest, CreateFixedNameInterface) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + struct ifreq ifr_set = {}; + ifr_set.ifr_flags = IFF_TAP; + strncpy(ifr_set.ifr_name, kTapName, IFNAMSIZ); + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr_set), + SyscallSucceedsWithValue(0)); + + struct ifreq ifr_get = {}; + EXPECT_THAT(ioctl(fd.get(), TUNGETIFF, &ifr_get), + SyscallSucceedsWithValue(0)); + + struct ifreq ifr_expect = ifr_set; + // See __tun_chr_ioctl() in net/drivers/tun.c. + ifr_expect.ifr_flags |= IFF_NOFILTER; + + EXPECT_THAT(DumpLinkNames(), + IsPosixErrorOkAndHolds(::testing::Contains(kTapName))); + EXPECT_THAT(memcmp(&ifr_expect, &ifr_get, sizeof(ifr_get)), ::testing::Eq(0)); +} + +TEST_F(TuntapTest, CreateInterface) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + struct ifreq ifr = {}; + ifr.ifr_flags = IFF_TAP; + // Empty ifr.ifr_name. Let kernel assign. + + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallSucceedsWithValue(0)); + + struct ifreq ifr_get = {}; + EXPECT_THAT(ioctl(fd.get(), TUNGETIFF, &ifr_get), + SyscallSucceedsWithValue(0)); + + std::string ifname = ifr_get.ifr_name; + EXPECT_THAT(ifname, ::testing::StartsWith("tap")); + EXPECT_THAT(DumpLinkNames(), + IsPosixErrorOkAndHolds(::testing::Contains(ifname))); +} + +TEST_F(TuntapTest, InvalidReadWrite) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + char buf[128] = {}; + EXPECT_THAT(read(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EBADFD)); + EXPECT_THAT(write(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EBADFD)); +} + +TEST_F(TuntapTest, WriteToDownDevice) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + // FIXME(b/110961832): gVisor always creates enabled/up'd interfaces. + SKIP_IF(IsRunningOnGvisor()); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + // Device created should be down by default. + struct ifreq ifr = {}; + ifr.ifr_flags = IFF_TAP; + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallSucceedsWithValue(0)); + + char buf[128] = {}; + EXPECT_THAT(write(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EIO)); +} + +PosixErrorOr<FileDescriptor> OpenAndAttachTap( + const std::string& dev_name, const std::string& dev_ipv4_addr) { + // Interface creation. + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, Open(kDevNetTun, O_RDWR)); + + struct ifreq ifr_set = {}; + ifr_set.ifr_flags = IFF_TAP; + strncpy(ifr_set.ifr_name, dev_name.c_str(), IFNAMSIZ); + if (ioctl(fd.get(), TUNSETIFF, &ifr_set) < 0) { + return PosixError(errno); + } + + ASSIGN_OR_RETURN_ERRNO(auto link, GetLinkByName(dev_name)); + + // Interface setup. + struct in_addr addr; + inet_pton(AF_INET, dev_ipv4_addr.c_str(), &addr); + EXPECT_NO_ERRNO(LinkAddLocalAddr(link.index, AF_INET, /*prefixlen=*/24, &addr, + sizeof(addr))); + + if (!IsRunningOnGvisor()) { + // FIXME(b/110961832): gVisor doesn't support setting MAC address on + // interfaces yet. + RETURN_IF_ERRNO(LinkSetMacAddr(link.index, kMacA, sizeof(kMacA))); + + // FIXME(b/110961832): gVisor always creates enabled/up'd interfaces. + RETURN_IF_ERRNO(LinkChangeFlags(link.index, IFF_UP, IFF_UP)); + } + + return fd; +} + +// This test sets up a TAP device and pings kernel by sending ICMP echo request. +// +// It works as the following: +// * Open /dev/net/tun, and create kTapName interface. +// * Use rtnetlink to do initial setup of the interface: +// * Assign IP address 10.0.0.1/24 to kernel. +// * MAC address: kMacA +// * Bring up the interface. +// * Send an ICMP echo reqest (ping) packet from 10.0.0.2 (kMacB) to kernel. +// * Loop to receive packets from TAP device/fd: +// * If packet is an ICMP echo reply, it stops and passes the test. +// * If packet is an ARP request, it responds with canned reply and resends +// the +// ICMP request packet. +TEST_F(TuntapTest, PingKernel) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(OpenAndAttachTap(kTapName, "10.0.0.1")); + ping_pkt ping_req = CreatePingPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1"); + std::string arp_rep = CreateArpPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1"); + + // Send ping, this would trigger an ARP request on Linux. + EXPECT_THAT(write(fd.get(), &ping_req, sizeof(ping_req)), + SyscallSucceedsWithValue(sizeof(ping_req))); + + // Receive loop to process inbound packets. + struct inpkt { + union { + pihdr pi; + ping_pkt ping; + arp_pkt arp; + }; + }; + while (1) { + inpkt r = {}; + int n = read(fd.get(), &r, sizeof(r)); + EXPECT_THAT(n, SyscallSucceeds()); + + if (n < sizeof(pihdr)) { + std::cerr << "Ignored packet, protocol: " << r.pi.pi_protocol + << " len: " << n << std::endl; + continue; + } + + // Process ARP packet. + if (n >= sizeof(arp_pkt) && r.pi.pi_protocol == htons(ETH_P_ARP)) { + // Respond with canned ARP reply. + EXPECT_THAT(write(fd.get(), arp_rep.data(), arp_rep.size()), + SyscallSucceedsWithValue(arp_rep.size())); + // First ping request might have been dropped due to mac address not in + // ARP cache. Send it again. + EXPECT_THAT(write(fd.get(), &ping_req, sizeof(ping_req)), + SyscallSucceedsWithValue(sizeof(ping_req))); + } + + // Process ping response packet. + if (n >= sizeof(ping_pkt) && r.pi.pi_protocol == ping_req.pi.pi_protocol && + r.ping.ip.protocol == ping_req.ip.protocol && + !memcmp(&r.ping.ip.saddr, &ping_req.ip.daddr, kIPLen) && + !memcmp(&r.ping.ip.daddr, &ping_req.ip.saddr, kIPLen) && + r.ping.icmp.type == 0 && r.ping.icmp.code == 0) { + // Ends and passes the test. + break; + } + } +} + +TEST_F(TuntapTest, SendUdpTriggersArpResolution) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(OpenAndAttachTap(kTapName, "10.0.0.1")); + + // Send a UDP packet to remote. + int sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_IP); + ASSERT_THAT(sock, SyscallSucceeds()); + + struct sockaddr_in remote = {}; + remote.sin_family = AF_INET; + remote.sin_port = htons(42); + inet_pton(AF_INET, "10.0.0.2", &remote.sin_addr); + int ret = sendto(sock, "hello", 5, 0, reinterpret_cast<sockaddr*>(&remote), + sizeof(remote)); + ASSERT_THAT(ret, ::testing::AnyOf(SyscallSucceeds(), + SyscallFailsWithErrno(EHOSTDOWN))); + + struct inpkt { + union { + pihdr pi; + arp_pkt arp; + }; + }; + while (1) { + inpkt r = {}; + int n = read(fd.get(), &r, sizeof(r)); + EXPECT_THAT(n, SyscallSucceeds()); + + if (n < sizeof(pihdr)) { + std::cerr << "Ignored packet, protocol: " << r.pi.pi_protocol + << " len: " << n << std::endl; + continue; + } + + if (n >= sizeof(arp_pkt) && r.pi.pi_protocol == htons(ETH_P_ARP)) { + break; + } + } +} + +// Write hang bug found by syskaller: b/155928773 +// https://syzkaller.appspot.com/bug?id=065b893bd8d1d04a4e0a1d53c578537cde1efe99 +TEST_F(TuntapTest, WriteHangBug155928773) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(OpenAndAttachTap(kTapName, "10.0.0.1")); + + int sock = socket(AF_INET, SOCK_DGRAM, 0); + ASSERT_THAT(sock, SyscallSucceeds()); + + struct sockaddr_in remote = {}; + remote.sin_family = AF_INET; + remote.sin_port = htons(42); + inet_pton(AF_INET, "10.0.0.1", &remote.sin_addr); + // Return values do not matter in this test. + connect(sock, reinterpret_cast<struct sockaddr*>(&remote), sizeof(remote)); + write(sock, "hello", 5); +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/tuntap_hostinet.cc b/test/syscalls/linux/tuntap_hostinet.cc new file mode 100644 index 000000000..1513fb9d5 --- /dev/null +++ b/test/syscalls/linux/tuntap_hostinet.cc @@ -0,0 +1,38 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/stat.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +TEST(TuntapHostInetTest, NoNetTun) { + SKIP_IF(!IsRunningOnGvisor()); + SKIP_IF(!IsRunningWithHostinet()); + + struct stat statbuf; + ASSERT_THAT(stat("/dev/net/tun", &statbuf), SyscallFailsWithErrno(ENOENT)); +} + +} // namespace +} // namespace testing + +} // namespace gvisor diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc index 111dbacdf..7a8ac30a4 100644 --- a/test/syscalls/linux/udp_socket.cc +++ b/test/syscalls/linux/udp_socket.cc @@ -12,1332 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include <arpa/inet.h> -#include <fcntl.h> -#include <linux/errqueue.h> -#include <netinet/in.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/types.h> - -#include "gtest/gtest.h" -#include "absl/base/macros.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" +#include "test/syscalls/linux/udp_socket_test_cases.h" namespace gvisor { namespace testing { namespace { -// The initial port to be be used on gvisor. -constexpr int TestPort = 40000; - -// Fixture for tests parameterized by the address family to use (AF_INET and -// AF_INET6) when creating sockets. -class UdpSocketTest : public ::testing::TestWithParam<AddressFamily> { - protected: - // Creates two sockets that will be used by test cases. - void SetUp() override; - - // Closes the sockets created by SetUp(). - void TearDown() override { - EXPECT_THAT(close(s_), SyscallSucceeds()); - EXPECT_THAT(close(t_), SyscallSucceeds()); - - for (size_t i = 0; i < ABSL_ARRAYSIZE(ports_); ++i) { - ASSERT_NO_ERRNO(FreeAvailablePort(ports_[i])); - } - } - - // First UDP socket. - int s_; - - // Second UDP socket. - int t_; - - // The length of the socket address. - socklen_t addrlen_; - - // Initialized address pointing to loopback and port TestPort+i. - struct sockaddr* addr_[3]; - - // Initialize "any" address. - struct sockaddr* anyaddr_; - - // Used ports. - int ports_[3]; - - private: - // Storage for the loopback addresses. - struct sockaddr_storage addr_storage_[3]; - - // Storage for the "any" address. - struct sockaddr_storage anyaddr_storage_; -}; - -// Gets a pointer to the port component of the given address. -uint16_t* Port(struct sockaddr_storage* addr) { - switch (addr->ss_family) { - case AF_INET: { - auto sin = reinterpret_cast<struct sockaddr_in*>(addr); - return &sin->sin_port; - } - case AF_INET6: { - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr); - return &sin6->sin6_port; - } - } - - return nullptr; -} - -void UdpSocketTest::SetUp() { - int type; - if (GetParam() == AddressFamily::kIpv4) { - type = AF_INET; - auto sin = reinterpret_cast<struct sockaddr_in*>(&anyaddr_storage_); - addrlen_ = sizeof(*sin); - sin->sin_addr.s_addr = htonl(INADDR_ANY); - } else { - type = AF_INET6; - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&anyaddr_storage_); - addrlen_ = sizeof(*sin6); - if (GetParam() == AddressFamily::kIpv6) { - sin6->sin6_addr = IN6ADDR_ANY_INIT; - } else { - TestAddress const& v4_mapped_any = V4MappedAny(); - sin6->sin6_addr = - reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr) - ->sin6_addr; - } - } - ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds()); - - ASSERT_THAT(t_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds()); - - memset(&anyaddr_storage_, 0, sizeof(anyaddr_storage_)); - anyaddr_ = reinterpret_cast<struct sockaddr*>(&anyaddr_storage_); - anyaddr_->sa_family = type; - - if (gvisor::testing::IsRunningOnGvisor()) { - for (size_t i = 0; i < ABSL_ARRAYSIZE(ports_); ++i) { - ports_[i] = TestPort + i; - } - } else { - // When not under gvisor, use utility function to pick port. Assert that - // all ports are different. - std::string error; - for (size_t i = 0; i < ABSL_ARRAYSIZE(ports_); ++i) { - // Find an unused port, we specify port 0 to allow the kernel to provide - // the port. - bool unique = true; - do { - ports_[i] = ASSERT_NO_ERRNO_AND_VALUE(PortAvailable( - 0, AddressFamily::kDualStack, SocketType::kUdp, false)); - ASSERT_GT(ports_[i], 0); - for (size_t j = 0; j < i; ++j) { - if (ports_[j] == ports_[i]) { - unique = false; - break; - } - } - } while (!unique); - } - } - - // Initialize the sockaddrs. - for (size_t i = 0; i < ABSL_ARRAYSIZE(addr_); ++i) { - memset(&addr_storage_[i], 0, sizeof(addr_storage_[i])); - - addr_[i] = reinterpret_cast<struct sockaddr*>(&addr_storage_[i]); - addr_[i]->sa_family = type; - - switch (type) { - case AF_INET: { - auto sin = reinterpret_cast<struct sockaddr_in*>(addr_[i]); - sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK); - sin->sin_port = htons(ports_[i]); - break; - } - case AF_INET6: { - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr_[i]); - sin6->sin6_addr = in6addr_loopback; - sin6->sin6_port = htons(ports_[i]); - break; - } - } - } -} - -TEST_P(UdpSocketTest, Creation) { - int type = AF_INET6; - if (GetParam() == AddressFamily::kIpv4) { - type = AF_INET; - } - - int s_; - - ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds()); - EXPECT_THAT(close(s_), SyscallSucceeds()); - - ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, 0), SyscallSucceeds()); - EXPECT_THAT(close(s_), SyscallSucceeds()); - - ASSERT_THAT(s_ = socket(type, SOCK_STREAM, IPPROTO_UDP), SyscallFails()); -} - -TEST_P(UdpSocketTest, Getsockname) { - // Check that we're not bound. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, anyaddr_, addrlen_), 0); - - // Bind, then check that we get the right address. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, addr_[0], addrlen_), 0); -} - -TEST_P(UdpSocketTest, Getpeername) { - // Check that we're not connected. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); - - // Connect, then check that we get the right address. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - addrlen = sizeof(addr); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, addr_[0], addrlen_), 0); -} - -TEST_P(UdpSocketTest, SendNotConnected) { - // Do send & write, they must fail. - char buf[512]; - EXPECT_THAT(send(s_, buf, sizeof(buf), 0), - SyscallFailsWithErrno(EDESTADDRREQ)); - - EXPECT_THAT(write(s_, buf, sizeof(buf)), SyscallFailsWithErrno(EDESTADDRREQ)); - - // Use sendto. - ASSERT_THAT(sendto(s_, buf, sizeof(buf), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Check that we're bound now. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_NE(*Port(&addr), 0); -} - -TEST_P(UdpSocketTest, ConnectBinds) { - // Connect the socket. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Check that we're bound now. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_NE(*Port(&addr), 0); -} - -TEST_P(UdpSocketTest, ReceiveNotBound) { - char buf[512]; - EXPECT_THAT(recv(s_, buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST_P(UdpSocketTest, Bind) { - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Try to bind again. - EXPECT_THAT(bind(s_, addr_[1], addrlen_), SyscallFailsWithErrno(EINVAL)); - - // Check that we're still bound to the original address. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, addr_[0], addrlen_), 0); -} - -TEST_P(UdpSocketTest, BindInUse) { - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Try to bind again. - EXPECT_THAT(bind(t_, addr_[0], addrlen_), SyscallFailsWithErrno(EADDRINUSE)); -} - -TEST_P(UdpSocketTest, ReceiveAfterConnect) { - // Connect s_ to loopback:TestPort, and bind t_ to loopback:TestPort. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(bind(t_, addr_[0], addrlen_), SyscallSucceeds()); - - // Get the address s_ was bound to during connect. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - - // Send from t_ to s_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, - reinterpret_cast<sockaddr*>(&addr), addrlen), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data. - char received[sizeof(buf)]; - EXPECT_THAT(recv(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); -} - -TEST_P(UdpSocketTest, ReceiveAfterDisconnect) { - // Connect s_ to loopback:TestPort, and bind t_ to loopback:TestPort. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(bind(t_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[1], addrlen_), SyscallSucceeds()); - - // Get the address s_ was bound to during connect. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - - for (int i = 0; i < 2; i++) { - // Send from t_ to s_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, - reinterpret_cast<sockaddr*>(&addr), addrlen), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data. - char received[sizeof(buf)]; - EXPECT_THAT(recv(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); - - // Disconnect s_. - struct sockaddr addr = {}; - addr.sa_family = AF_UNSPEC; - ASSERT_THAT(connect(s_, &addr, sizeof(addr.sa_family)), SyscallSucceeds()); - // Connect s_ loopback:TestPort. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - } -} - -TEST_P(UdpSocketTest, Connect) { - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Check that we're connected to the right peer. - struct sockaddr_storage peer; - socklen_t peerlen = sizeof(peer); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallSucceeds()); - EXPECT_EQ(peerlen, addrlen_); - EXPECT_EQ(memcmp(&peer, addr_[0], addrlen_), 0); - - // Try to bind after connect. - EXPECT_THAT(bind(s_, addr_[1], addrlen_), SyscallFailsWithErrno(EINVAL)); - - // Try to connect again. - EXPECT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds()); - - // Check that peer name changed. - peerlen = sizeof(peer); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallSucceeds()); - EXPECT_EQ(peerlen, addrlen_); - EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0); -} - -void ConnectAny(AddressFamily family, int sockfd, uint16_t port) { - struct sockaddr_storage addr = {}; - - // Precondition check. - { - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - if (family == AddressFamily::kIpv4) { - auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY)); - } else { - auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - struct in6_addr any = IN6ADDR_ANY_INIT; - EXPECT_EQ(memcmp(&addr_out->sin6_addr, &any, sizeof(in6_addr)), 0); - } - - { - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); - } - - struct sockaddr_storage baddr = {}; - if (family == AddressFamily::kIpv4) { - auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); - addrlen = sizeof(*addr_in); - addr_in->sin_family = AF_INET; - addr_in->sin_addr.s_addr = htonl(INADDR_ANY); - addr_in->sin_port = port; - } else { - auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); - addrlen = sizeof(*addr_in); - addr_in->sin6_family = AF_INET6; - addr_in->sin6_port = port; - if (family == AddressFamily::kIpv6) { - addr_in->sin6_addr = IN6ADDR_ANY_INIT; - } else { - TestAddress const& v4_mapped_any = V4MappedAny(); - addr_in->sin6_addr = - reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr) - ->sin6_addr; - } - } - - // TODO(b/138658473): gVisor doesn't allow connecting to the zero port. - if (port == 0) { - SKIP_IF(IsRunningOnGvisor()); - } - - ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&baddr), addrlen), - SyscallSucceeds()); - } - - // Postcondition check. - { - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - if (family == AddressFamily::kIpv4) { - auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_LOOPBACK)); - } else { - auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - struct in6_addr loopback; - if (family == AddressFamily::kIpv6) { - loopback = IN6ADDR_LOOPBACK_INIT; - } else { - TestAddress const& v4_mapped_loopback = V4MappedLoopback(); - loopback = reinterpret_cast<const struct sockaddr_in6*>( - &v4_mapped_loopback.addr) - ->sin6_addr; - } - - EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); - } - - addrlen = sizeof(addr); - if (port == 0) { - EXPECT_THAT( - getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); - } else { - EXPECT_THAT( - getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - } - } -} - -TEST_P(UdpSocketTest, ConnectAny) { ConnectAny(GetParam(), s_, 0); } - -TEST_P(UdpSocketTest, ConnectAnyWithPort) { - auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); - ConnectAny(GetParam(), s_, port); -} - -void DisconnectAfterConnectAny(AddressFamily family, int sockfd, int port) { - struct sockaddr_storage addr = {}; - - socklen_t addrlen = sizeof(addr); - struct sockaddr_storage baddr = {}; - if (family == AddressFamily::kIpv4) { - auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); - addrlen = sizeof(*addr_in); - addr_in->sin_family = AF_INET; - addr_in->sin_addr.s_addr = htonl(INADDR_ANY); - addr_in->sin_port = port; - } else { - auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); - addrlen = sizeof(*addr_in); - addr_in->sin6_family = AF_INET6; - addr_in->sin6_port = port; - if (family == AddressFamily::kIpv6) { - addr_in->sin6_addr = IN6ADDR_ANY_INIT; - } else { - TestAddress const& v4_mapped_any = V4MappedAny(); - addr_in->sin6_addr = - reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr) - ->sin6_addr; - } - } - - // TODO(b/138658473): gVisor doesn't allow connecting to the zero port. - if (port == 0) { - SKIP_IF(IsRunningOnGvisor()); - } - - ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&baddr), addrlen), - SyscallSucceeds()); - // Now the socket is bound to the loopback address. - - // Disconnect - addrlen = sizeof(addr); - addr.ss_family = AF_UNSPEC; - ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&addr), addrlen), - SyscallSucceeds()); - - // Check that after disconnect the socket is bound to the ANY address. - EXPECT_THAT(getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - if (family == AddressFamily::kIpv4) { - auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY)); - } else { - auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - struct in6_addr loopback = IN6ADDR_ANY_INIT; - - EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); - } -} - -TEST_P(UdpSocketTest, DisconnectAfterConnectAny) { - DisconnectAfterConnectAny(GetParam(), s_, 0); -} - -TEST_P(UdpSocketTest, DisconnectAfterConnectAnyWithPort) { - auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); - DisconnectAfterConnectAny(GetParam(), s_, port); -} - -TEST_P(UdpSocketTest, DisconnectAfterBind) { - ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds()); - // Connect the socket. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - struct sockaddr_storage addr = {}; - addr.ss_family = AF_UNSPEC; - EXPECT_THAT( - connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)), - SyscallSucceeds()); - - // Check that we're still bound. - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, addr_[1], addrlen_), 0); - - addrlen = sizeof(addr); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); -} - -TEST_P(UdpSocketTest, DisconnectAfterBindToAny) { - struct sockaddr_storage baddr = {}; - socklen_t addrlen; - auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); - if (GetParam() == AddressFamily::kIpv4) { - auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); - addr_in->sin_family = AF_INET; - addr_in->sin_port = port; - addr_in->sin_addr.s_addr = htonl(INADDR_ANY); - } else { - auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); - addr_in->sin6_family = AF_INET6; - addr_in->sin6_port = port; - addr_in->sin6_scope_id = 0; - addr_in->sin6_addr = IN6ADDR_ANY_INIT; - } - ASSERT_THAT(bind(s_, reinterpret_cast<sockaddr*>(&baddr), addrlen_), - SyscallSucceeds()); - // Connect the socket. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - struct sockaddr_storage addr = {}; - addr.ss_family = AF_UNSPEC; - EXPECT_THAT( - connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)), - SyscallSucceeds()); - - // Check that we're still bound. - addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, &baddr, addrlen), 0); - - addrlen = sizeof(addr); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); -} - -TEST_P(UdpSocketTest, Disconnect) { - for (int i = 0; i < 2; i++) { - // Try to connect again. - EXPECT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds()); - - // Check that we're connected to the right peer. - struct sockaddr_storage peer; - socklen_t peerlen = sizeof(peer); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallSucceeds()); - EXPECT_EQ(peerlen, addrlen_); - EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0); - - // Try to disconnect. - struct sockaddr_storage addr = {}; - addr.ss_family = AF_UNSPEC; - EXPECT_THAT( - connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)), - SyscallSucceeds()); - - peerlen = sizeof(peer); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallFailsWithErrno(ENOTCONN)); - - // Check that we're still bound. - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(*Port(&addr), 0); - } -} - -TEST_P(UdpSocketTest, ConnectBadAddress) { - struct sockaddr addr = {}; - addr.sa_family = addr_[0]->sa_family; - ASSERT_THAT(connect(s_, &addr, sizeof(addr.sa_family)), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(UdpSocketTest, SendToAddressOtherThanConnected) { - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Send to a different destination than we're connected to. - char buf[512]; - EXPECT_THAT(sendto(s_, buf, sizeof(buf), 0, addr_[1], addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); -} - -TEST_P(UdpSocketTest, ZerolengthWriteAllowed) { - // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); - - // Bind t_ to loopback:TestPort+1. - ASSERT_THAT(bind(t_, addr_[1], addrlen_), SyscallSucceeds()); - - char buf[3]; - // Send zero length packet from s_ to t_. - ASSERT_THAT(write(s_, buf, 0), SyscallSucceedsWithValue(0)); - // Receive the packet. - char received[3]; - EXPECT_THAT(read(t_, received, sizeof(received)), - SyscallSucceedsWithValue(0)); -} - -TEST_P(UdpSocketTest, ZerolengthWriteAllowedNonBlockRead) { - // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); - - // Bind t_ to loopback:TestPort+1. - ASSERT_THAT(bind(t_, addr_[1], addrlen_), SyscallSucceeds()); - - // Set t_ to non-blocking. - int opts = 0; - ASSERT_THAT(opts = fcntl(t_, F_GETFL), SyscallSucceeds()); - ASSERT_THAT(fcntl(t_, F_SETFL, opts | O_NONBLOCK), SyscallSucceeds()); - - char buf[3]; - // Send zero length packet from s_ to t_. - ASSERT_THAT(write(s_, buf, 0), SyscallSucceedsWithValue(0)); - // Receive the packet. - char received[3]; - EXPECT_THAT(read(t_, received, sizeof(received)), - SyscallSucceedsWithValue(0)); - EXPECT_THAT(read(t_, received, sizeof(received)), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST_P(UdpSocketTest, SendAndReceiveNotConnected) { - // Bind s_ to loopback. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Send some data to s_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data. - char received[sizeof(buf)]; - EXPECT_THAT(recv(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); -} - -TEST_P(UdpSocketTest, SendAndReceiveConnected) { - // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); - - // Bind t_ to loopback:TestPort+1. - ASSERT_THAT(bind(t_, addr_[1], addrlen_), SyscallSucceeds()); - - // Send some data from t_ to s_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data. - char received[sizeof(buf)]; - EXPECT_THAT(recv(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); -} - -TEST_P(UdpSocketTest, ReceiveFromNotConnected) { - // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); - - // Bind t_ to loopback:TestPort+2. - ASSERT_THAT(bind(t_, addr_[2], addrlen_), SyscallSucceeds()); - - // Send some data from t_ to s_. - char buf[512]; - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Check that the data isn't_ received because it was sent from a different - // address than we're connected. - EXPECT_THAT(recv(s_, buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST_P(UdpSocketTest, ReceiveBeforeConnect) { - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Bind t_ to loopback:TestPort+2. - ASSERT_THAT(bind(t_, addr_[2], addrlen_), SyscallSucceeds()); - - // Send some data from t_ to s_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Connect to loopback:TestPort+1. - ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); - - // Receive the data. It works because it was sent before the connect. - char received[sizeof(buf)]; - EXPECT_THAT(recv(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); - - // Send again. This time it should not be received. - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - EXPECT_THAT(recv(s_, buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST_P(UdpSocketTest, ReceiveFrom) { - // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); - - // Bind t_ to loopback:TestPort+1. - ASSERT_THAT(bind(t_, addr_[1], addrlen_), SyscallSucceeds()); - - // Send some data from t_ to s_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data and sender address. - char received[sizeof(buf)]; - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(recvfrom(s_, received, sizeof(received), 0, - reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, addr_[1], addrlen_), 0); -} - -TEST_P(UdpSocketTest, Listen) { - ASSERT_THAT(listen(s_, SOMAXCONN), SyscallFailsWithErrno(EOPNOTSUPP)); -} - -TEST_P(UdpSocketTest, Accept) { - ASSERT_THAT(accept(s_, nullptr, nullptr), SyscallFailsWithErrno(EOPNOTSUPP)); -} - -// This test validates that a read shutdown with pending data allows the read -// to proceed with the data before returning EAGAIN. -TEST_P(UdpSocketTest, ReadShutdownNonblockPendingData) { - char received[512]; - - // Bind t_ to loopback:TestPort+2. - ASSERT_THAT(bind(t_, addr_[2], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[1], addrlen_), SyscallSucceeds()); - - // Connect the socket, then try to shutdown again. - ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds()); - - // Verify that we get EWOULDBLOCK when there is nothing to read. - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - const char* buf = "abc"; - EXPECT_THAT(write(t_, buf, 3), SyscallSucceedsWithValue(3)); - - int opts = 0; - ASSERT_THAT(opts = fcntl(s_, F_GETFL), SyscallSucceeds()); - ASSERT_THAT(fcntl(s_, F_SETFL, opts | O_NONBLOCK), SyscallSucceeds()); - ASSERT_THAT(opts = fcntl(s_, F_GETFL), SyscallSucceeds()); - ASSERT_NE(opts & O_NONBLOCK, 0); - - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); - - // We should get the data even though read has been shutdown. - EXPECT_THAT(recv(s_, received, 2, 0), SyscallSucceedsWithValue(2)); - - // Because we read less than the entire packet length, since it's a packet - // based socket any subsequent reads should return EWOULDBLOCK. - EXPECT_THAT(recv(s_, received, 1, 0), SyscallFailsWithErrno(EWOULDBLOCK)); -} - -// This test is validating that even after a socket is shutdown if it's -// reconnected it will reset the shutdown state. -TEST_P(UdpSocketTest, ReadShutdownSameSocketResetsShutdownState) { - char received[512]; - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); - - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Connect the socket, then try to shutdown again. - ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds()); - - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST_P(UdpSocketTest, ReadShutdown) { - char received[512]; - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); - - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Connect the socket, then try to shutdown again. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); - - EXPECT_THAT(recv(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(0)); -} - -TEST_P(UdpSocketTest, ReadShutdownDifferentThread) { - char received[512]; - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Connect the socket, then shutdown from another thread. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - ScopedThread t([&] { - absl::SleepFor(absl::Milliseconds(200)); - EXPECT_THAT(shutdown(this->s_, SHUT_RD), SyscallSucceeds()); - }); - EXPECT_THAT(RetryEINTR(recv)(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(0)); - t.Join(); - - EXPECT_THAT(RetryEINTR(recv)(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(0)); -} - -TEST_P(UdpSocketTest, WriteShutdown) { - EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallSucceeds()); -} - -TEST_P(UdpSocketTest, SynchronousReceive) { - // Bind s_ to loopback. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Send some data to s_ from another thread. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - // Receive the data prior to actually starting the other thread. - char received[512]; - EXPECT_THAT(RetryEINTR(recv)(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Start the thread. - ScopedThread t([&] { - absl::SleepFor(absl::Milliseconds(200)); - ASSERT_THAT( - sendto(this->t_, buf, sizeof(buf), 0, this->addr_[0], this->addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - }); - - EXPECT_THAT(RetryEINTR(recv)(s_, received, sizeof(received), 0), - SyscallSucceedsWithValue(512)); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); -} - -TEST_P(UdpSocketTest, BoundaryPreserved_SendRecv) { - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Send 3 packets from t_ to s_. - constexpr int psize = 100; - char buf[3 * psize]; - RandomizeBuffer(buf, sizeof(buf)); - - for (int i = 0; i < 3; ++i) { - ASSERT_THAT(sendto(t_, buf + i * psize, psize, 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(psize)); - } - - // Receive the data as 3 separate packets. - char received[6 * psize]; - for (int i = 0; i < 3; ++i) { - EXPECT_THAT(recv(s_, received + i * psize, 3 * psize, 0), - SyscallSucceedsWithValue(psize)); - } - EXPECT_EQ(memcmp(buf, received, 3 * psize), 0); -} - -TEST_P(UdpSocketTest, BoundaryPreserved_WritevReadv) { - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Direct writes from t_ to s_. - ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); - - // Send 2 packets from t_ to s_, where each packet's data consists of 2 - // discontiguous iovecs. - constexpr size_t kPieceSize = 100; - char buf[4 * kPieceSize]; - RandomizeBuffer(buf, sizeof(buf)); - - for (int i = 0; i < 2; i++) { - struct iovec iov[2]; - for (int j = 0; j < 2; j++) { - iov[j].iov_base = reinterpret_cast<void*>( - reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize); - iov[j].iov_len = kPieceSize; - } - ASSERT_THAT(writev(t_, iov, 2), SyscallSucceedsWithValue(2 * kPieceSize)); - } - - // Receive the data as 2 separate packets. - char received[6 * kPieceSize]; - for (int i = 0; i < 2; i++) { - struct iovec iov[3]; - for (int j = 0; j < 3; j++) { - iov[j].iov_base = reinterpret_cast<void*>( - reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize); - iov[j].iov_len = kPieceSize; - } - ASSERT_THAT(readv(s_, iov, 3), SyscallSucceedsWithValue(2 * kPieceSize)); - } - EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0); -} - -TEST_P(UdpSocketTest, BoundaryPreserved_SendMsgRecvMsg) { - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Send 2 packets from t_ to s_, where each packet's data consists of 2 - // discontiguous iovecs. - constexpr size_t kPieceSize = 100; - char buf[4 * kPieceSize]; - RandomizeBuffer(buf, sizeof(buf)); - - for (int i = 0; i < 2; i++) { - struct iovec iov[2]; - for (int j = 0; j < 2; j++) { - iov[j].iov_base = reinterpret_cast<void*>( - reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize); - iov[j].iov_len = kPieceSize; - } - struct msghdr msg = {}; - msg.msg_name = addr_[0]; - msg.msg_namelen = addrlen_; - msg.msg_iov = iov; - msg.msg_iovlen = 2; - ASSERT_THAT(sendmsg(t_, &msg, 0), SyscallSucceedsWithValue(2 * kPieceSize)); - } - - // Receive the data as 2 separate packets. - char received[6 * kPieceSize]; - for (int i = 0; i < 2; i++) { - struct iovec iov[3]; - for (int j = 0; j < 3; j++) { - iov[j].iov_base = reinterpret_cast<void*>( - reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize); - iov[j].iov_len = kPieceSize; - } - struct msghdr msg = {}; - msg.msg_iov = iov; - msg.msg_iovlen = 3; - ASSERT_THAT(recvmsg(s_, &msg, 0), SyscallSucceedsWithValue(2 * kPieceSize)); - } - EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0); -} - -TEST_P(UdpSocketTest, FIONREADShutdown) { - int n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - // A UDP socket must be connected before it can be shutdown. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); -} - -TEST_P(UdpSocketTest, FIONREADWriteShutdown) { - int n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // A UDP socket must be connected before it can be shutdown. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - const char str[] = "abc"; - ASSERT_THAT(send(s_, str, sizeof(str), 0), - SyscallSucceedsWithValue(sizeof(str))); - - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, sizeof(str)); - - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, sizeof(str)); -} - -TEST_P(UdpSocketTest, FIONREAD) { - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Check that the bound socket with an empty buffer reports an empty first - // packet. - int n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - // Send 3 packets from t_ to s_. - constexpr int psize = 100; - char buf[3 * psize]; - RandomizeBuffer(buf, sizeof(buf)); - - for (int i = 0; i < 3; ++i) { - ASSERT_THAT(sendto(t_, buf + i * psize, psize, 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(psize)); - - // Check that regardless of how many packets are in the queue, the size - // reported is that of a single packet. - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, psize); - } -} - -TEST_P(UdpSocketTest, FIONREADZeroLengthPacket) { - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // Check that the bound socket with an empty buffer reports an empty first - // packet. - int n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - // Send 3 packets from t_ to s_. - constexpr int psize = 100; - char buf[3 * psize]; - RandomizeBuffer(buf, sizeof(buf)); - - for (int i = 0; i < 3; ++i) { - ASSERT_THAT(sendto(t_, buf + i * psize, 0, 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(0)); - - // Check that regardless of how many packets are in the queue, the size - // reported is that of a single packet. - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - } -} - -TEST_P(UdpSocketTest, FIONREADZeroLengthWriteShutdown) { - int n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - - // A UDP socket must be connected before it can be shutdown. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - const char str[] = "abc"; - ASSERT_THAT(send(s_, str, 0, 0), SyscallSucceedsWithValue(0)); - - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); -} - -TEST_P(UdpSocketTest, ErrorQueue) { - char cmsgbuf[CMSG_SPACE(sizeof(sock_extended_err))]; - msghdr msg; - memset(&msg, 0, sizeof(msg)); - iovec iov; - memset(&iov, 0, sizeof(iov)); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - msg.msg_control = cmsgbuf; - msg.msg_controllen = sizeof(cmsgbuf); - - // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT. - EXPECT_THAT(RetryEINTR(recvmsg)(s_, &msg, MSG_ERRQUEUE), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST_P(UdpSocketTest, SoTimestampOffByDefault) { - int v = -1; - socklen_t optlen = sizeof(v); - ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_TIMESTAMP, &v, &optlen), - SyscallSucceeds()); - ASSERT_EQ(v, kSockOptOff); - ASSERT_EQ(optlen, sizeof(v)); -} - -TEST_P(UdpSocketTest, SoTimestamp) { - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); - - int v = 1; - ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), - SyscallSucceeds()); - - char buf[3]; - // Send zero length packet from t_ to s_. - ASSERT_THAT(RetryEINTR(write)(t_, buf, 0), SyscallSucceedsWithValue(0)); - - char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))]; - msghdr msg; - memset(&msg, 0, sizeof(msg)); - iovec iov; - memset(&iov, 0, sizeof(iov)); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - msg.msg_control = cmsgbuf; - msg.msg_controllen = sizeof(cmsgbuf); - - ASSERT_THAT(RetryEINTR(recvmsg)(s_, &msg, 0), SyscallSucceedsWithValue(0)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); - ASSERT_EQ(cmsg->cmsg_type, SO_TIMESTAMP); - ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(struct timeval))); - - struct timeval tv = {}; - memcpy(&tv, CMSG_DATA(cmsg), sizeof(struct timeval)); - - ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); - - // There should be nothing to get via ioctl. - ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallFailsWithErrno(ENOENT)); -} - -TEST_P(UdpSocketTest, WriteShutdownNotConnected) { - EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); -} - -TEST_P(UdpSocketTest, TimestampIoctl) { - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); - - char buf[3]; - // Send packet from t_ to s_. - ASSERT_THAT(RetryEINTR(write)(t_, buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - // There should be no control messages. - char recv_buf[sizeof(buf)]; - ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s_, recv_buf, sizeof(recv_buf))); - - // A nonzero timeval should be available via ioctl. - struct timeval tv = {}; - ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallSucceeds()); - ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); -} - -TEST_P(UdpSocketTest, TimetstampIoctlNothingRead) { - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); - - struct timeval tv = {}; - ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallFailsWithErrno(ENOENT)); -} - -// Test that the timestamp accessed via SIOCGSTAMP is still accessible after -// SO_TIMESTAMP is enabled and used to retrieve a timestamp. -TEST_P(UdpSocketTest, TimestampIoctlPersistence) { - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); - - char buf[3]; - // Send packet from t_ to s_. - ASSERT_THAT(RetryEINTR(write)(t_, buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - ASSERT_THAT(RetryEINTR(write)(t_, buf, 0), SyscallSucceedsWithValue(0)); - - // There should be no control messages. - char recv_buf[sizeof(buf)]; - ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s_, recv_buf, sizeof(recv_buf))); - - // A nonzero timeval should be available via ioctl. - struct timeval tv = {}; - ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallSucceeds()); - ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); - - // Enable SO_TIMESTAMP and send a message. - int v = 1; - EXPECT_THAT(setsockopt(s_, SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), - SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(write)(t_, buf, 0), SyscallSucceedsWithValue(0)); - - // There should be a message for SO_TIMESTAMP. - char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))]; - msghdr msg = {}; - iovec iov = {}; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - msg.msg_control = cmsgbuf; - msg.msg_controllen = sizeof(cmsgbuf); - ASSERT_THAT(RetryEINTR(recvmsg)(s_, &msg, 0), SyscallSucceedsWithValue(0)); - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - - // The ioctl should return the exact same values as before. - struct timeval tv2 = {}; - ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv2), SyscallSucceeds()); - ASSERT_EQ(tv.tv_sec, tv2.tv_sec); - ASSERT_EQ(tv.tv_usec, tv2.tv_usec); -} - INSTANTIATE_TEST_SUITE_P(AllInetTests, UdpSocketTest, ::testing::Values(AddressFamily::kIpv4, AddressFamily::kIpv6, diff --git a/test/syscalls/linux/udp_socket_errqueue_test_case.cc b/test/syscalls/linux/udp_socket_errqueue_test_case.cc new file mode 100644 index 000000000..54a0594f7 --- /dev/null +++ b/test/syscalls/linux/udp_socket_errqueue_test_case.cc @@ -0,0 +1,57 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef __fuchsia__ + +#include <arpa/inet.h> +#include <fcntl.h> +#include <linux/errqueue.h> +#include <netinet/in.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> + +#include "gtest/gtest.h" +#include "absl/base/macros.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/syscalls/linux/udp_socket_test_cases.h" +#include "test/syscalls/linux/unix_domain_socket_test_util.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +TEST_P(UdpSocketTest, ErrorQueue) { + char cmsgbuf[CMSG_SPACE(sizeof(sock_extended_err))]; + msghdr msg; + memset(&msg, 0, sizeof(msg)); + iovec iov; + memset(&iov, 0, sizeof(iov)); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = cmsgbuf; + msg.msg_controllen = sizeof(cmsgbuf); + + // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT. + EXPECT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, MSG_ERRQUEUE), + SyscallFailsWithErrno(EAGAIN)); +} + +} // namespace testing +} // namespace gvisor + +#endif // __fuchsia__ diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc new file mode 100644 index 000000000..60c48ed6e --- /dev/null +++ b/test/syscalls/linux/udp_socket_test_cases.cc @@ -0,0 +1,1781 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/udp_socket_test_cases.h" + +#include <arpa/inet.h> +#include <fcntl.h> +#ifndef __fuchsia__ +#include <linux/filter.h> +#endif // __fuchsia__ +#include <netinet/in.h> +#include <poll.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> + +#include "absl/strings/str_format.h" +#ifndef SIOCGSTAMP +#include <linux/sockios.h> +#endif + +#include "gtest/gtest.h" +#include "absl/base/macros.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/syscalls/linux/unix_domain_socket_test_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/posix_error.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +// Gets a pointer to the port component of the given address. +uint16_t* Port(struct sockaddr_storage* addr) { + switch (addr->ss_family) { + case AF_INET: { + auto sin = reinterpret_cast<struct sockaddr_in*>(addr); + return &sin->sin_port; + } + case AF_INET6: { + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr); + return &sin6->sin6_port; + } + } + + return nullptr; +} + +// Sets addr port to "port". +void SetPort(struct sockaddr_storage* addr, uint16_t port) { + switch (addr->ss_family) { + case AF_INET: { + auto sin = reinterpret_cast<struct sockaddr_in*>(addr); + sin->sin_port = port; + break; + } + case AF_INET6: { + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr); + sin6->sin6_port = port; + break; + } + } +} + +void UdpSocketTest::SetUp() { + addrlen_ = GetAddrLength(); + + bind_ = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); + memset(&bind_addr_storage_, 0, sizeof(bind_addr_storage_)); + bind_addr_ = reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); + + sock_ = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); +} + +int UdpSocketTest::GetFamily() { + if (GetParam() == AddressFamily::kIpv4) { + return AF_INET; + } + return AF_INET6; +} + +PosixError UdpSocketTest::BindLoopback() { + bind_addr_storage_ = InetLoopbackAddr(); + struct sockaddr* bind_addr_ = + reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); + return BindSocket(bind_.get(), bind_addr_); +} + +PosixError UdpSocketTest::BindAny() { + bind_addr_storage_ = InetAnyAddr(); + struct sockaddr* bind_addr_ = + reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); + return BindSocket(bind_.get(), bind_addr_); +} + +PosixError UdpSocketTest::BindSocket(int socket, struct sockaddr* addr) { + socklen_t len = sizeof(bind_addr_storage_); + + // Bind, then check that we get the right address. + RETURN_ERROR_IF_SYSCALL_FAIL(bind(socket, addr, addrlen_)); + + RETURN_ERROR_IF_SYSCALL_FAIL(getsockname(socket, addr, &len)); + + if (addrlen_ != len) { + return PosixError( + EINVAL, + absl::StrFormat("getsockname len: %u expected: %u", len, addrlen_)); + } + return PosixError(0); +} + +socklen_t UdpSocketTest::GetAddrLength() { + struct sockaddr_storage addr; + if (GetFamily() == AF_INET) { + auto sin = reinterpret_cast<struct sockaddr_in*>(&addr); + return sizeof(*sin); + } + + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); + return sizeof(*sin6); +} + +sockaddr_storage UdpSocketTest::InetAnyAddr() { + struct sockaddr_storage addr; + memset(&addr, 0, sizeof(addr)); + reinterpret_cast<struct sockaddr*>(&addr)->sa_family = GetFamily(); + + if (GetFamily() == AF_INET) { + auto sin = reinterpret_cast<struct sockaddr_in*>(&addr); + sin->sin_addr.s_addr = htonl(INADDR_ANY); + sin->sin_port = htons(0); + return addr; + } + + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); + sin6->sin6_addr = IN6ADDR_ANY_INIT; + sin6->sin6_port = htons(0); + return addr; +} + +sockaddr_storage UdpSocketTest::InetLoopbackAddr() { + struct sockaddr_storage addr; + memset(&addr, 0, sizeof(addr)); + reinterpret_cast<struct sockaddr*>(&addr)->sa_family = GetFamily(); + + if (GetFamily() == AF_INET) { + auto sin = reinterpret_cast<struct sockaddr_in*>(&addr); + sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK); + sin->sin_port = htons(0); + return addr; + } + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); + sin6->sin6_addr = in6addr_loopback; + sin6->sin6_port = htons(0); + return addr; +} + +void UdpSocketTest::Disconnect(int sockfd) { + sockaddr_storage addr_storage = InetAnyAddr(); + sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + socklen_t addrlen = sizeof(addr_storage); + + addr->sa_family = AF_UNSPEC; + ASSERT_THAT(connect(sockfd, addr, addrlen), SyscallSucceeds()); + + // Check that after disconnect the socket is bound to the ANY address. + EXPECT_THAT(getsockname(sockfd, addr, &addrlen), SyscallSucceeds()); + if (GetParam() == AddressFamily::kIpv4) { + auto addr_out = reinterpret_cast<struct sockaddr_in*>(addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY)); + } else { + auto addr_out = reinterpret_cast<struct sockaddr_in6*>(addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + struct in6_addr loopback = IN6ADDR_ANY_INIT; + + EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); + } +} + +TEST_P(UdpSocketTest, Creation) { + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); + EXPECT_THAT(close(sock.release()), SyscallSucceeds()); + + sock = ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, 0)); + EXPECT_THAT(close(sock.release()), SyscallSucceeds()); + + ASSERT_THAT(socket(GetFamily(), SOCK_STREAM, IPPROTO_UDP), SyscallFails()); +} + +TEST_P(UdpSocketTest, Getsockname) { + // Check that we're not bound. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + struct sockaddr_storage any = InetAnyAddr(); + EXPECT_EQ(memcmp(&addr, reinterpret_cast<struct sockaddr*>(&any), addrlen_), + 0); + + ASSERT_NO_ERRNO(BindLoopback()); + + EXPECT_THAT( + getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); +} + +TEST_P(UdpSocketTest, Getpeername) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Check that we're not connected. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); + + // Connect, then check that we get the right address. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); +} + +TEST_P(UdpSocketTest, SendNotConnected) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Do send & write, they must fail. + char buf[512]; + EXPECT_THAT(send(sock_.get(), buf, sizeof(buf), 0), + SyscallFailsWithErrno(EDESTADDRREQ)); + + EXPECT_THAT(write(sock_.get(), buf, sizeof(buf)), + SyscallFailsWithErrno(EDESTADDRREQ)); + + // Use sendto. + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Check that we're bound now. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_NE(*Port(&addr), 0); +} + +TEST_P(UdpSocketTest, ConnectBinds) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect the socket. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Check that we're bound now. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_NE(*Port(&addr), 0); +} + +TEST_P(UdpSocketTest, ReceiveNotBound) { + char buf[512]; + EXPECT_THAT(recv(sock_.get(), buf, sizeof(buf), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +TEST_P(UdpSocketTest, Bind) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Try to bind again. + EXPECT_THAT(bind(bind_.get(), bind_addr_, addrlen_), + SyscallFailsWithErrno(EINVAL)); + + // Check that we're still bound to the original address. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); +} + +TEST_P(UdpSocketTest, BindInUse) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Try to bind again. + EXPECT_THAT(bind(sock_.get(), bind_addr_, addrlen_), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(UdpSocketTest, ReceiveAfterConnect) { + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Send from sock_ to bind_ + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data. + char received[sizeof(buf)]; + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); +} + +TEST_P(UdpSocketTest, ReceiveAfterDisconnect) { + ASSERT_NO_ERRNO(BindLoopback()); + + for (int i = 0; i < 2; i++) { + // Connet sock_ to bound address. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + + // Send from sock to bind_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + ASSERT_THAT(sendto(bind_.get(), buf, sizeof(buf), 0, + reinterpret_cast<sockaddr*>(&addr), addrlen), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data. + char received[sizeof(buf)]; + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); + + // Disconnect sock_. + struct sockaddr unspec = {}; + unspec.sa_family = AF_UNSPEC; + ASSERT_THAT(connect(sock_.get(), &unspec, sizeof(unspec.sa_family)), + SyscallSucceeds()); + } +} + +TEST_P(UdpSocketTest, Connect) { + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Check that we're connected to the right peer. + struct sockaddr_storage peer; + socklen_t peerlen = sizeof(peer); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), + SyscallSucceeds()); + EXPECT_EQ(peerlen, addrlen_); + EXPECT_EQ(memcmp(&peer, bind_addr_, addrlen_), 0); + + // Try to bind after connect. + struct sockaddr_storage any = InetAnyAddr(); + EXPECT_THAT( + bind(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), + SyscallFailsWithErrno(EINVAL)); + + struct sockaddr_storage bind2_storage = InetLoopbackAddr(); + struct sockaddr* bind2_addr = + reinterpret_cast<struct sockaddr*>(&bind2_storage); + FileDescriptor bind2 = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); + ASSERT_NO_ERRNO(BindSocket(bind2.get(), bind2_addr)); + + // Try to connect again. + EXPECT_THAT(connect(sock_.get(), bind2_addr, addrlen_), SyscallSucceeds()); + + // Check that peer name changed. + peerlen = sizeof(peer); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), + SyscallSucceeds()); + EXPECT_EQ(peerlen, addrlen_); + EXPECT_EQ(memcmp(&peer, bind2_addr, addrlen_), 0); +} + +TEST_P(UdpSocketTest, ConnectAnyZero) { + // TODO(138658473): Enable when we can connect to port 0 with gVisor. + SKIP_IF(IsRunningOnGvisor()); + + struct sockaddr_storage any = InetAnyAddr(); + EXPECT_THAT( + connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), + SyscallSucceeds()); + + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(UdpSocketTest, ConnectAnyWithPort) { + ASSERT_NO_ERRNO(BindAny()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); +} + +TEST_P(UdpSocketTest, DisconnectAfterConnectAny) { + // TODO(138658473): Enable when we can connect to port 0 with gVisor. + SKIP_IF(IsRunningOnGvisor()); + struct sockaddr_storage any = InetAnyAddr(); + EXPECT_THAT( + connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), + SyscallSucceeds()); + + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); + + Disconnect(sock_.get()); +} + +TEST_P(UdpSocketTest, DisconnectAfterConnectAnyWithPort) { + ASSERT_NO_ERRNO(BindAny()); + EXPECT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(*Port(&bind_addr_storage_), *Port(&addr)); + + Disconnect(sock_.get()); +} + +TEST_P(UdpSocketTest, DisconnectAfterBind) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Bind to the next port above bind_. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_NO_ERRNO(BindSocket(sock_.get(), addr)); + + // Connect the socket. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + struct sockaddr_storage unspec = {}; + unspec.ss_family = AF_UNSPEC; + EXPECT_THAT(connect(sock_.get(), reinterpret_cast<sockaddr*>(&unspec), + sizeof(unspec.ss_family)), + SyscallSucceeds()); + + // Check that we're still bound. + socklen_t addrlen = sizeof(unspec); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&unspec), &addrlen), + SyscallSucceeds()); + + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(addr, &unspec, addrlen_), 0); + + addrlen = sizeof(addr); + EXPECT_THAT(getpeername(sock_.get(), addr, &addrlen), + SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(UdpSocketTest, BindToAnyConnnectToLocalhost) { + ASSERT_NO_ERRNO(BindAny()); + + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + socklen_t addrlen = sizeof(addr); + + // Connect the socket. + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + EXPECT_THAT(getsockname(bind_.get(), addr, &addrlen), SyscallSucceeds()); + + // If the socket is bound to ANY and connected to a loopback address, + // getsockname() has to return the loopback address. + if (GetParam() == AddressFamily::kIpv4) { + auto addr_out = reinterpret_cast<struct sockaddr_in*>(addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_LOOPBACK)); + } else { + auto addr_out = reinterpret_cast<struct sockaddr_in6*>(addr); + struct in6_addr loopback = IN6ADDR_LOOPBACK_INIT; + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); + } +} + +TEST_P(UdpSocketTest, DisconnectAfterBindToAny) { + ASSERT_NO_ERRNO(BindLoopback()); + + struct sockaddr_storage any_storage = InetAnyAddr(); + struct sockaddr* any = reinterpret_cast<struct sockaddr*>(&any_storage); + SetPort(&any_storage, *Port(&bind_addr_storage_) + 1); + + ASSERT_NO_ERRNO(BindSocket(sock_.get(), any)); + + // Connect the socket. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + Disconnect(sock_.get()); + + // Check that we're still bound. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, any, addrlen), 0); + + addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(UdpSocketTest, Disconnect) { + ASSERT_NO_ERRNO(BindLoopback()); + + struct sockaddr_storage any_storage = InetAnyAddr(); + struct sockaddr* any = reinterpret_cast<struct sockaddr*>(&any_storage); + SetPort(&any_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_NO_ERRNO(BindSocket(sock_.get(), any)); + + for (int i = 0; i < 2; i++) { + // Try to connect again. + EXPECT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Check that we're connected to the right peer. + struct sockaddr_storage peer; + socklen_t peerlen = sizeof(peer); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), + SyscallSucceeds()); + EXPECT_EQ(peerlen, addrlen_); + EXPECT_EQ(memcmp(&peer, bind_addr_, addrlen_), 0); + + // Try to disconnect. + struct sockaddr_storage addr = {}; + addr.ss_family = AF_UNSPEC; + EXPECT_THAT(connect(sock_.get(), reinterpret_cast<sockaddr*>(&addr), + sizeof(addr.ss_family)), + SyscallSucceeds()); + + peerlen = sizeof(peer); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), + SyscallFailsWithErrno(ENOTCONN)); + + // Check that we're still bound. + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(*Port(&addr), *Port(&any_storage)); + } +} + +TEST_P(UdpSocketTest, ConnectBadAddress) { + struct sockaddr addr = {}; + addr.sa_family = GetFamily(); + ASSERT_THAT(connect(sock_.get(), &addr, sizeof(addr.sa_family)), + SyscallFailsWithErrno(EINVAL)); +} + +TEST_P(UdpSocketTest, SendToAddressOtherThanConnected) { + ASSERT_NO_ERRNO(BindLoopback()); + + struct sockaddr_storage addr_storage = InetAnyAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Send to a different destination than we're connected to. + char buf[512]; + EXPECT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, addr, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); +} + +TEST_P(UdpSocketTest, ZerolengthWriteAllowed) { + // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + // Connect to loopback:bind_addr_+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Bind sock to loopback:bind_addr_+1. + ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); + + char buf[3]; + // Send zero length packet from bind_ to sock_. + ASSERT_THAT(write(bind_.get(), buf, 0), SyscallSucceedsWithValue(0)); + + struct pollfd pfd = {sock_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout*/ 1000), + SyscallSucceedsWithValue(1)); + + // Receive the packet. + char received[3]; + EXPECT_THAT(read(sock_.get(), received, sizeof(received)), + SyscallSucceedsWithValue(0)); +} + +TEST_P(UdpSocketTest, ZerolengthWriteAllowedNonBlockRead) { + // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Bind sock to loopback:bind_addr_port+1. + ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); + + // Set sock to non-blocking. + int opts = 0; + ASSERT_THAT(opts = fcntl(sock_.get(), F_GETFL), SyscallSucceeds()); + ASSERT_THAT(fcntl(sock_.get(), F_SETFL, opts | O_NONBLOCK), + SyscallSucceeds()); + + char buf[3]; + // Send zero length packet from bind_ to sock_. + ASSERT_THAT(write(bind_.get(), buf, 0), SyscallSucceedsWithValue(0)); + + struct pollfd pfd = {sock_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + // Receive the packet. + char received[3]; + EXPECT_THAT(read(sock_.get(), received, sizeof(received)), + SyscallSucceedsWithValue(0)); + EXPECT_THAT(read(sock_.get(), received, sizeof(received)), + SyscallFailsWithErrno(EAGAIN)); +} + +TEST_P(UdpSocketTest, SendAndReceiveNotConnected) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Send some data to bind_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data. + char received[sizeof(buf)]; + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); +} + +TEST_P(UdpSocketTest, SendAndReceiveConnected) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Bind sock to loopback:TestPort+1. + ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); + + // Send some data from sock to bind_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data. + char received[sizeof(buf)]; + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); +} + +TEST_P(UdpSocketTest, ReceiveFromNotConnected) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Bind sock to loopback:bind_addr_port+2. + struct sockaddr_storage addr2_storage = InetLoopbackAddr(); + struct sockaddr* addr2 = reinterpret_cast<struct sockaddr*>(&addr2_storage); + SetPort(&addr2_storage, *Port(&bind_addr_storage_) + 2); + ASSERT_THAT(bind(sock_.get(), addr2, addrlen_), SyscallSucceeds()); + + // Send some data from sock to bind_. + char buf[512]; + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Check that the data isn't received because it was sent from a different + // address than we're connected. + EXPECT_THAT(recv(sock_.get(), buf, sizeof(buf), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +TEST_P(UdpSocketTest, ReceiveBeforeConnect) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Bind sock to loopback:bind_addr_port+2. + struct sockaddr_storage addr2_storage = InetLoopbackAddr(); + struct sockaddr* addr2 = reinterpret_cast<struct sockaddr*>(&addr2_storage); + SetPort(&addr2_storage, *Port(&bind_addr_storage_) + 2); + ASSERT_THAT(bind(sock_.get(), addr2, addrlen_), SyscallSucceeds()); + + // Send some data from sock to bind_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Connect to loopback:TestPort+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Receive the data. It works because it was sent before the connect. + char received[sizeof(buf)]; + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); + + // Send again. This time it should not be received. + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + EXPECT_THAT(recv(bind_.get(), buf, sizeof(buf), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +TEST_P(UdpSocketTest, ReceiveFrom) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Bind sock to loopback:TestPort+1. + ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); + + // Send some data from sock to bind_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data and sender address. + char received[sizeof(buf)]; + struct sockaddr_storage addr2; + socklen_t addr2len = sizeof(addr2); + EXPECT_THAT(recvfrom(bind_.get(), received, sizeof(received), 0, + reinterpret_cast<sockaddr*>(&addr2), &addr2len), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); + EXPECT_EQ(addr2len, addrlen_); + EXPECT_EQ(memcmp(addr, &addr2, addrlen_), 0); +} + +TEST_P(UdpSocketTest, Listen) { + ASSERT_THAT(listen(sock_.get(), SOMAXCONN), + SyscallFailsWithErrno(EOPNOTSUPP)); +} + +TEST_P(UdpSocketTest, Accept) { + ASSERT_THAT(accept(sock_.get(), nullptr, nullptr), + SyscallFailsWithErrno(EOPNOTSUPP)); +} + +// This test validates that a read shutdown with pending data allows the read +// to proceed with the data before returning EAGAIN. +TEST_P(UdpSocketTest, ReadShutdownNonblockPendingData) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Bind to loopback:bind_addr_port+1 and connect to bind_addr_. + ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Verify that we get EWOULDBLOCK when there is nothing to read. + char received[512]; + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + const char* buf = "abc"; + EXPECT_THAT(write(sock_.get(), buf, 3), SyscallSucceedsWithValue(3)); + + int opts = 0; + ASSERT_THAT(opts = fcntl(bind_.get(), F_GETFL), SyscallSucceeds()); + ASSERT_THAT(fcntl(bind_.get(), F_SETFL, opts | O_NONBLOCK), + SyscallSucceeds()); + ASSERT_THAT(opts = fcntl(bind_.get(), F_GETFL), SyscallSucceeds()); + ASSERT_NE(opts & O_NONBLOCK, 0); + + EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds()); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + // We should get the data even though read has been shutdown. + EXPECT_THAT(recv(bind_.get(), received, 2, 0), SyscallSucceedsWithValue(2)); + + // Because we read less than the entire packet length, since it's a packet + // based socket any subsequent reads should return EWOULDBLOCK. + EXPECT_THAT(recv(bind_.get(), received, 1, 0), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +// This test is validating that even after a socket is shutdown if it's +// reconnected it will reset the shutdown state. +TEST_P(UdpSocketTest, ReadShutdownSameSocketResetsShutdownState) { + char received[512]; + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); + + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Connect the socket, then try to shutdown again. + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +TEST_P(UdpSocketTest, ReadShutdown) { + // TODO(gvisor.dev/issue/1202): Calling recv() after shutdown without + // MSG_DONTWAIT blocks indefinitely. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + + char received[512]; + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); + + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Connect the socket, then try to shutdown again. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds()); + + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(0)); +} + +TEST_P(UdpSocketTest, ReadShutdownDifferentThread) { + // TODO(gvisor.dev/issue/1202): Calling recv() after shutdown without + // MSG_DONTWAIT blocks indefinitely. + SKIP_IF(IsRunningWithHostinet()); + ASSERT_NO_ERRNO(BindLoopback()); + + char received[512]; + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Connect the socket, then shutdown from another thread. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + ScopedThread t([&] { + absl::SleepFor(absl::Milliseconds(200)); + EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds()); + }); + EXPECT_THAT(RetryEINTR(recv)(sock_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(0)); + t.Join(); + + EXPECT_THAT(RetryEINTR(recv)(sock_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(0)); +} + +TEST_P(UdpSocketTest, WriteShutdown) { + ASSERT_NO_ERRNO(BindLoopback()); + EXPECT_THAT(shutdown(sock_.get(), SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + EXPECT_THAT(shutdown(sock_.get(), SHUT_WR), SyscallSucceeds()); +} + +TEST_P(UdpSocketTest, SynchronousReceive) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Send some data to bind_ from another thread. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + // Receive the data prior to actually starting the other thread. + char received[512]; + EXPECT_THAT( + RetryEINTR(recv)(bind_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Start the thread. + ScopedThread t([&] { + absl::SleepFor(absl::Milliseconds(200)); + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, this->bind_addr_, + this->addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + }); + + EXPECT_THAT(RetryEINTR(recv)(bind_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(512)); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); +} + +TEST_P(UdpSocketTest, BoundaryPreserved_SendRecv) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Send 3 packets from sock to bind_. + constexpr int psize = 100; + char buf[3 * psize]; + RandomizeBuffer(buf, sizeof(buf)); + + for (int i = 0; i < 3; ++i) { + ASSERT_THAT( + sendto(sock_.get(), buf + i * psize, psize, 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(psize)); + } + + // Receive the data as 3 separate packets. + char received[6 * psize]; + for (int i = 0; i < 3; ++i) { + EXPECT_THAT(recv(bind_.get(), received + i * psize, 3 * psize, 0), + SyscallSucceedsWithValue(psize)); + } + EXPECT_EQ(memcmp(buf, received, 3 * psize), 0); +} + +TEST_P(UdpSocketTest, BoundaryPreserved_WritevReadv) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Direct writes from sock to bind_. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Send 2 packets from sock to bind_, where each packet's data consists of + // 2 discontiguous iovecs. + constexpr size_t kPieceSize = 100; + char buf[4 * kPieceSize]; + RandomizeBuffer(buf, sizeof(buf)); + + for (int i = 0; i < 2; i++) { + struct iovec iov[2]; + for (int j = 0; j < 2; j++) { + iov[j].iov_base = reinterpret_cast<void*>( + reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize); + iov[j].iov_len = kPieceSize; + } + ASSERT_THAT(writev(sock_.get(), iov, 2), + SyscallSucceedsWithValue(2 * kPieceSize)); + } + + // Receive the data as 2 separate packets. + char received[6 * kPieceSize]; + for (int i = 0; i < 2; i++) { + struct iovec iov[3]; + for (int j = 0; j < 3; j++) { + iov[j].iov_base = reinterpret_cast<void*>( + reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize); + iov[j].iov_len = kPieceSize; + } + ASSERT_THAT(readv(bind_.get(), iov, 3), + SyscallSucceedsWithValue(2 * kPieceSize)); + } + EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0); +} + +TEST_P(UdpSocketTest, BoundaryPreserved_SendMsgRecvMsg) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Send 2 packets from sock to bind_, where each packet's data consists of + // 2 discontiguous iovecs. + constexpr size_t kPieceSize = 100; + char buf[4 * kPieceSize]; + RandomizeBuffer(buf, sizeof(buf)); + + for (int i = 0; i < 2; i++) { + struct iovec iov[2]; + for (int j = 0; j < 2; j++) { + iov[j].iov_base = reinterpret_cast<void*>( + reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize); + iov[j].iov_len = kPieceSize; + } + struct msghdr msg = {}; + msg.msg_name = bind_addr_; + msg.msg_namelen = addrlen_; + msg.msg_iov = iov; + msg.msg_iovlen = 2; + ASSERT_THAT(sendmsg(sock_.get(), &msg, 0), + SyscallSucceedsWithValue(2 * kPieceSize)); + } + + // Receive the data as 2 separate packets. + char received[6 * kPieceSize]; + for (int i = 0; i < 2; i++) { + struct iovec iov[3]; + for (int j = 0; j < 3; j++) { + iov[j].iov_base = reinterpret_cast<void*>( + reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize); + iov[j].iov_len = kPieceSize; + } + struct msghdr msg = {}; + msg.msg_iov = iov; + msg.msg_iovlen = 3; + ASSERT_THAT(recvmsg(bind_.get(), &msg, 0), + SyscallSucceedsWithValue(2 * kPieceSize)); + } + EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0); +} + +TEST_P(UdpSocketTest, FIONREADShutdown) { + ASSERT_NO_ERRNO(BindLoopback()); + + int n = -1; + EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + // A UDP socket must be connected before it can be shutdown. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); +} + +TEST_P(UdpSocketTest, FIONREADWriteShutdown) { + int n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + ASSERT_NO_ERRNO(BindLoopback()); + + // A UDP socket must be connected before it can be shutdown. + ASSERT_THAT(connect(bind_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + const char str[] = "abc"; + ASSERT_THAT(send(bind_.get(), str, sizeof(str), 0), + SyscallSucceedsWithValue(sizeof(str))); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, sizeof(str)); + + EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, sizeof(str)); +} + +// NOTE: Do not use `FIONREAD` as test name because it will be replaced by the +// corresponding macro and become `0x541B`. +TEST_P(UdpSocketTest, Fionread) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Check that the bound socket with an empty buffer reports an empty first + // packet. + int n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + // Send 3 packets from sock to bind_. + constexpr int psize = 100; + char buf[3 * psize]; + RandomizeBuffer(buf, sizeof(buf)); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + for (int i = 0; i < 3; ++i) { + ASSERT_THAT( + sendto(sock_.get(), buf + i * psize, psize, 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(psize)); + + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + // Check that regardless of how many packets are in the queue, the size + // reported is that of a single packet. + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, psize); + } +} + +TEST_P(UdpSocketTest, FIONREADZeroLengthPacket) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Check that the bound socket with an empty buffer reports an empty first + // packet. + int n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + // Send 3 packets from sock to bind_. + constexpr int psize = 100; + char buf[3 * psize]; + RandomizeBuffer(buf, sizeof(buf)); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + for (int i = 0; i < 3; ++i) { + ASSERT_THAT( + sendto(sock_.get(), buf + i * psize, 0, 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(0)); + + // TODO(gvisor.dev/issue/2726): sending a zero-length message to a hostinet + // socket does not cause a poll event to be triggered. + if (!IsRunningWithHostinet()) { + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + } + + // Check that regardless of how many packets are in the queue, the size + // reported is that of a single packet. + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + } +} + +TEST_P(UdpSocketTest, FIONREADZeroLengthWriteShutdown) { + int n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + ASSERT_NO_ERRNO(BindLoopback()); + + // A UDP socket must be connected before it can be shutdown. + ASSERT_THAT(connect(bind_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + const char str[] = "abc"; + ASSERT_THAT(send(bind_.get(), str, 0, 0), SyscallSucceedsWithValue(0)); + + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); +} + +TEST_P(UdpSocketTest, SoNoCheckOffByDefault) { + // TODO(gvisor.dev/issue/1202): SO_NO_CHECK socket option not supported by + // hostinet. + SKIP_IF(IsRunningWithHostinet()); + + int v = -1; + socklen_t optlen = sizeof(v); + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen), + SyscallSucceeds()); + ASSERT_EQ(v, kSockOptOff); + ASSERT_EQ(optlen, sizeof(v)); +} + +TEST_P(UdpSocketTest, SoNoCheck) { + // TODO(gvisor.dev/issue/1202): SO_NO_CHECK socket option not supported by + // hostinet. + SKIP_IF(IsRunningWithHostinet()); + + int v = kSockOptOn; + socklen_t optlen = sizeof(v); + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, optlen), + SyscallSucceeds()); + v = -1; + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen), + SyscallSucceeds()); + ASSERT_EQ(v, kSockOptOn); + ASSERT_EQ(optlen, sizeof(v)); + + v = kSockOptOff; + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, optlen), + SyscallSucceeds()); + v = -1; + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen), + SyscallSucceeds()); + ASSERT_EQ(v, kSockOptOff); + ASSERT_EQ(optlen, sizeof(v)); +} + +TEST_P(UdpSocketTest, SoTimestampOffByDefault) { + // TODO(gvisor.dev/issue/1202): SO_TIMESTAMP socket option not supported by + // hostinet. + SKIP_IF(IsRunningWithHostinet()); + + int v = -1; + socklen_t optlen = sizeof(v); + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, &optlen), + SyscallSucceeds()); + ASSERT_EQ(v, kSockOptOff); + ASSERT_EQ(optlen, sizeof(v)); +} + +TEST_P(UdpSocketTest, SoTimestamp) { + // TODO(gvisor.dev/issue/1202): ioctl() and SO_TIMESTAMP socket option are not + // supported by hostinet. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + int v = 1; + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), + SyscallSucceeds()); + + char buf[3]; + // Send zero length packet from sock to bind_. + ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0), + SyscallSucceedsWithValue(0)); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))]; + msghdr msg; + memset(&msg, 0, sizeof(msg)); + iovec iov; + memset(&iov, 0, sizeof(iov)); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = cmsgbuf; + msg.msg_controllen = sizeof(cmsgbuf); + + ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, 0), + SyscallSucceedsWithValue(0)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); + ASSERT_EQ(cmsg->cmsg_type, SO_TIMESTAMP); + ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(struct timeval))); + + struct timeval tv = {}; + memcpy(&tv, CMSG_DATA(cmsg), sizeof(struct timeval)); + + ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); + + // There should be nothing to get via ioctl. + ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), + SyscallFailsWithErrno(ENOENT)); +} + +TEST_P(UdpSocketTest, WriteShutdownNotConnected) { + EXPECT_THAT(shutdown(bind_.get(), SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(UdpSocketTest, TimestampIoctl) { + // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + char buf[3]; + // Send packet from sock to bind_. + ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + // There should be no control messages. + char recv_buf[sizeof(buf)]; + ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(bind_.get(), recv_buf, sizeof(recv_buf))); + + // A nonzero timeval should be available via ioctl. + struct timeval tv = {}; + ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), SyscallSucceeds()); + ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); +} + +TEST_P(UdpSocketTest, TimestampIoctlNothingRead) { + // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + struct timeval tv = {}; + ASSERT_THAT(ioctl(sock_.get(), SIOCGSTAMP, &tv), + SyscallFailsWithErrno(ENOENT)); +} + +// Test that the timestamp accessed via SIOCGSTAMP is still accessible after +// SO_TIMESTAMP is enabled and used to retrieve a timestamp. +TEST_P(UdpSocketTest, TimestampIoctlPersistence) { + // TODO(gvisor.dev/issue/1202): ioctl() and SO_TIMESTAMP socket option are not + // supported by hostinet. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + char buf[3]; + // Send packet from sock to bind_. + ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0), + SyscallSucceedsWithValue(0)); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + // There should be no control messages. + char recv_buf[sizeof(buf)]; + ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(bind_.get(), recv_buf, sizeof(recv_buf))); + + // A nonzero timeval should be available via ioctl. + struct timeval tv = {}; + ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), SyscallSucceeds()); + ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); + + // Enable SO_TIMESTAMP and send a message. + int v = 1; + EXPECT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0), + SyscallSucceedsWithValue(0)); + + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + // There should be a message for SO_TIMESTAMP. + char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))]; + msghdr msg = {}; + iovec iov = {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = cmsgbuf; + msg.msg_controllen = sizeof(cmsgbuf); + ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, 0), + SyscallSucceedsWithValue(0)); + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + + // The ioctl should return the exact same values as before. + struct timeval tv2 = {}; + ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv2), SyscallSucceeds()); + ASSERT_EQ(tv.tv_sec, tv2.tv_sec); + ASSERT_EQ(tv.tv_usec, tv2.tv_usec); +} + +// Test that a socket with IP_TOS or IPV6_TCLASS set will set the TOS byte on +// outgoing packets, and that a receiving socket with IP_RECVTOS or +// IPV6_RECVTCLASS will create the corresponding control message. +TEST_P(UdpSocketTest, SetAndReceiveTOS) { + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Allow socket to receive control message. + int recv_level = SOL_IP; + int recv_type = IP_RECVTOS; + if (GetParam() != AddressFamily::kIpv4) { + recv_level = SOL_IPV6; + recv_type = IPV6_RECVTCLASS; + } + ASSERT_THAT(setsockopt(bind_.get(), recv_level, recv_type, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Set socket TOS. + int sent_level = recv_level; + int sent_type = IP_TOS; + if (sent_level == SOL_IPV6) { + sent_type = IPV6_TCLASS; + } + int sent_tos = IPTOS_LOWDELAY; // Choose some TOS value. + ASSERT_THAT(setsockopt(sock_.get(), sent_level, sent_type, &sent_tos, + sizeof(sent_tos)), + SyscallSucceeds()); + + // Prepare message to send. + constexpr size_t kDataLength = 1024; + struct msghdr sent_msg = {}; + struct iovec sent_iov = {}; + char sent_data[kDataLength]; + sent_iov.iov_base = &sent_data[0]; + sent_iov.iov_len = kDataLength; + sent_msg.msg_iov = &sent_iov; + sent_msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(sendmsg)(sock_.get(), &sent_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + // Receive message. + struct msghdr received_msg = {}; + struct iovec received_iov = {}; + char received_data[kDataLength]; + received_iov.iov_base = &received_data[0]; + received_iov.iov_len = kDataLength; + received_msg.msg_iov = &received_iov; + received_msg.msg_iovlen = 1; + size_t cmsg_data_len = sizeof(int8_t); + if (sent_type == IPV6_TCLASS) { + cmsg_data_len = sizeof(int); + } + std::vector<char> received_cmsgbuf(CMSG_SPACE(cmsg_data_len)); + received_msg.msg_control = &received_cmsgbuf[0]; + received_msg.msg_controllen = received_cmsgbuf.size(); + ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &received_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); + EXPECT_EQ(cmsg->cmsg_level, sent_level); + EXPECT_EQ(cmsg->cmsg_type, sent_type); + int8_t received_tos = 0; + memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos)); + EXPECT_EQ(received_tos, sent_tos); +} + +// Test that sendmsg with IP_TOS and IPV6_TCLASS control messages will set the +// TOS byte on outgoing packets, and that a receiving socket with IP_RECVTOS or +// IPV6_RECVTCLASS will create the corresponding control message. +TEST_P(UdpSocketTest, SendAndReceiveTOS) { + // TODO(b/146661005): Setting TOS via cmsg not supported for netstack. + SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Allow socket to receive control message. + int recv_level = SOL_IP; + int recv_type = IP_RECVTOS; + if (GetParam() != AddressFamily::kIpv4) { + recv_level = SOL_IPV6; + recv_type = IPV6_RECVTCLASS; + } + int recv_opt = kSockOptOn; + ASSERT_THAT(setsockopt(bind_.get(), recv_level, recv_type, &recv_opt, + sizeof(recv_opt)), + SyscallSucceeds()); + + // Prepare message to send. + constexpr size_t kDataLength = 1024; + int sent_level = recv_level; + int sent_type = IP_TOS; + int sent_tos = IPTOS_LOWDELAY; // Choose some TOS value. + + struct msghdr sent_msg = {}; + struct iovec sent_iov = {}; + char sent_data[kDataLength]; + sent_iov.iov_base = &sent_data[0]; + sent_iov.iov_len = kDataLength; + sent_msg.msg_iov = &sent_iov; + sent_msg.msg_iovlen = 1; + size_t cmsg_data_len = sizeof(int8_t); + if (sent_level == SOL_IPV6) { + sent_type = IPV6_TCLASS; + cmsg_data_len = sizeof(int); + } + std::vector<char> sent_cmsgbuf(CMSG_SPACE(cmsg_data_len)); + sent_msg.msg_control = &sent_cmsgbuf[0]; + sent_msg.msg_controllen = CMSG_LEN(cmsg_data_len); + + // Manually add control message. + struct cmsghdr* sent_cmsg = CMSG_FIRSTHDR(&sent_msg); + sent_cmsg->cmsg_len = CMSG_LEN(cmsg_data_len); + sent_cmsg->cmsg_level = sent_level; + sent_cmsg->cmsg_type = sent_type; + *(int8_t*)CMSG_DATA(sent_cmsg) = sent_tos; + + ASSERT_THAT(RetryEINTR(sendmsg)(sock_.get(), &sent_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + // Receive message. + struct msghdr received_msg = {}; + struct iovec received_iov = {}; + char received_data[kDataLength]; + received_iov.iov_base = &received_data[0]; + received_iov.iov_len = kDataLength; + received_msg.msg_iov = &received_iov; + received_msg.msg_iovlen = 1; + std::vector<char> received_cmsgbuf(CMSG_SPACE(cmsg_data_len)); + received_msg.msg_control = &received_cmsgbuf[0]; + received_msg.msg_controllen = CMSG_LEN(cmsg_data_len); + ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &received_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); + EXPECT_EQ(cmsg->cmsg_level, sent_level); + EXPECT_EQ(cmsg->cmsg_type, sent_type); + int8_t received_tos = 0; + memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos)); + EXPECT_EQ(received_tos, sent_tos); +} + +TEST_P(UdpSocketTest, RecvBufLimitsEmptyRcvBuf) { + // Discover minimum buffer size by setting it to zero. + constexpr int kRcvBufSz = 0; + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, + sizeof(kRcvBufSz)), + SyscallSucceeds()); + + int min = 0; + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len), + SyscallSucceeds()); + + // Bind bind_ to loopback. + ASSERT_NO_ERRNO(BindLoopback()); + + { + // Send data of size min and verify that it's received. + std::vector<char> buf(min); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + std::vector<char> received(buf.size()); + EXPECT_THAT( + recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT), + SyscallSucceedsWithValue(received.size())); + } + + { + // Send data of size min + 1 and verify that its received. Both linux and + // Netstack accept a dgram that exceeds rcvBuf limits if the receive buffer + // is currently empty. + std::vector<char> buf(min + 1); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + + std::vector<char> received(buf.size()); + EXPECT_THAT( + recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT), + SyscallSucceedsWithValue(received.size())); + } +} + +// Test that receive buffer limits are enforced. +TEST_P(UdpSocketTest, RecvBufLimits) { + // Bind s_ to loopback. + ASSERT_NO_ERRNO(BindLoopback()); + + int min = 0; + { + // Discover minimum buffer size by trying to set it to zero. + constexpr int kRcvBufSz = 0; + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, + sizeof(kRcvBufSz)), + SyscallSucceeds()); + + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len), + SyscallSucceeds()); + } + + // Now set the limit to min * 4. + int new_rcv_buf_sz = min * 4; + if (!IsRunningOnGvisor() || IsRunningWithHostinet()) { + // Linux doubles the value specified so just set to min * 2. + new_rcv_buf_sz = min * 2; + } + + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &new_rcv_buf_sz, + sizeof(new_rcv_buf_sz)), + SyscallSucceeds()); + int rcv_buf_sz = 0; + { + socklen_t rcv_buf_len = sizeof(rcv_buf_sz); + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &rcv_buf_sz, + &rcv_buf_len), + SyscallSucceeds()); + } + + { + std::vector<char> buf(min); + RandomizeBuffer(buf.data(), buf.size()); + + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + int sent = 4; + if (IsRunningOnGvisor() && !IsRunningWithHostinet()) { + // Linux seems to drop the 4th packet even though technically it should + // fit in the receive buffer. + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + sent++; + } + + for (int i = 0; i < sent - 1; i++) { + // Receive the data. + std::vector<char> received(buf.size()); + EXPECT_THAT( + recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT), + SyscallSucceedsWithValue(received.size())); + EXPECT_EQ(memcmp(buf.data(), received.data(), buf.size()), 0); + } + + // The last receive should fail with EAGAIN as the last packet should have + // been dropped due to lack of space in the receive buffer. + std::vector<char> received(buf.size()); + EXPECT_THAT( + recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); + } +} + +#ifndef __fuchsia__ + +// TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. +// gVisor currently silently ignores attaching a filter. +TEST_P(UdpSocketTest, SetSocketDetachFilter) { + // Program generated using sudo tcpdump -i lo udp and port 1234 -dd + struct sock_filter code[] = { + {0x28, 0, 0, 0x0000000c}, {0x15, 0, 6, 0x000086dd}, + {0x30, 0, 0, 0x00000014}, {0x15, 0, 15, 0x00000011}, + {0x28, 0, 0, 0x00000036}, {0x15, 12, 0, 0x000004d2}, + {0x28, 0, 0, 0x00000038}, {0x15, 10, 11, 0x000004d2}, + {0x15, 0, 10, 0x00000800}, {0x30, 0, 0, 0x00000017}, + {0x15, 0, 8, 0x00000011}, {0x28, 0, 0, 0x00000014}, + {0x45, 6, 0, 0x00001fff}, {0xb1, 0, 0, 0x0000000e}, + {0x48, 0, 0, 0x0000000e}, {0x15, 2, 0, 0x000004d2}, + {0x48, 0, 0, 0x00000010}, {0x15, 0, 1, 0x000004d2}, + {0x6, 0, 0, 0x00040000}, {0x6, 0, 0, 0x00000000}, + }; + struct sock_fprog bpf = { + .len = ABSL_ARRAYSIZE(code), + .filter = code, + }; + ASSERT_THAT( + setsockopt(sock_.get(), SOL_SOCKET, SO_ATTACH_FILTER, &bpf, sizeof(bpf)), + SyscallSucceeds()); + + constexpr int val = 0; + ASSERT_THAT( + setsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallSucceeds()); +} + +TEST_P(UdpSocketTest, SetSocketDetachFilterNoInstalledFilter) { + // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. + SKIP_IF(IsRunningOnGvisor()); + constexpr int val = 0; + ASSERT_THAT( + setsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallFailsWithErrno(ENOENT)); +} + +TEST_P(UdpSocketTest, GetSocketDetachFilter) { + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT( + getsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len), + SyscallFailsWithErrno(ENOPROTOOPT)); +} + +#endif // __fuchsia__ + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/udp_socket_test_cases.h b/test/syscalls/linux/udp_socket_test_cases.h new file mode 100644 index 000000000..f7e25c805 --- /dev/null +++ b/test/syscalls/linux/udp_socket_test_cases.h @@ -0,0 +1,82 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GOLANG_GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_ +#define THIRD_PARTY_GOLANG_GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_ + +#include <sys/socket.h> + +#include "gtest/gtest.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/posix_error.h" + +namespace gvisor { +namespace testing { + +// The initial port to be be used on gvisor. +constexpr int TestPort = 40000; + +// Fixture for tests parameterized by the address family to use (AF_INET and +// AF_INET6) when creating sockets. +class UdpSocketTest + : public ::testing::TestWithParam<gvisor::testing::AddressFamily> { + protected: + // Creates two sockets that will be used by test cases. + void SetUp() override; + + // Binds the socket bind_ to the loopback and updates bind_addr_. + PosixError BindLoopback(); + + // Binds the socket bind_ to Any and updates bind_addr_. + PosixError BindAny(); + + // Binds given socket to address addr and updates. + PosixError BindSocket(int socket, struct sockaddr* addr); + + // Return initialized Any address to port 0. + struct sockaddr_storage InetAnyAddr(); + + // Return initialized Loopback address to port 0. + struct sockaddr_storage InetLoopbackAddr(); + + // Disconnects socket sockfd. + void Disconnect(int sockfd); + + // Get family for the test. + int GetFamily(); + + // Socket used by Bind methods + FileDescriptor bind_; + + // Second socket used for tests. + FileDescriptor sock_; + + // Address for bind_ socket. + struct sockaddr* bind_addr_; + + // Initialized to the length based on GetFamily(). + socklen_t addrlen_; + + // Storage for bind_addr_. + struct sockaddr_storage bind_addr_storage_; + + private: + // Helper to initialize addrlen_ for the test case. + socklen_t GetAddrLength(); +}; +} // namespace testing +} // namespace gvisor + +#endif // THIRD_PARTY_GOLANG_GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_ diff --git a/test/syscalls/linux/uidgid.cc b/test/syscalls/linux/uidgid.cc index 6218fbce1..64d6d0b8f 100644 --- a/test/syscalls/linux/uidgid.cc +++ b/test/syscalls/linux/uidgid.cc @@ -14,6 +14,7 @@ #include <errno.h> #include <grp.h> +#include <sys/resource.h> #include <sys/types.h> #include <unistd.h> @@ -249,6 +250,26 @@ TEST(UidGidRootTest, Setgroups) { SyscallFailsWithErrno(EFAULT)); } +TEST(UidGidRootTest, Setuid_prlimit) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot())); + + // Do seteuid in a separate thread so that after finishing this test, the + // process can still open files the test harness created before starting this + // test. Otherwise, the files are created by root (UID before the test), but + // cannot be opened by the `uid` set below after the test. + ScopedThread([&] { + // Use syscall instead of glibc setuid wrapper because we want this seteuid + // call to only apply to this task. POSIX threads, however, require that all + // threads have the same UIDs, so using the seteuid wrapper sets all + // threads' UID. + EXPECT_THAT(syscall(SYS_setreuid, -1, 65534), SyscallSucceeds()); + + // Despite the UID change, we should be able to get our own limits. + struct rlimit rl = {}; + EXPECT_THAT(prlimit(0, RLIMIT_NOFILE, NULL, &rl), SyscallSucceeds()); + }); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/unix_domain_socket_test_util.cc b/test/syscalls/linux/unix_domain_socket_test_util.cc index 7fb9eed8d..b05ab2900 100644 --- a/test/syscalls/linux/unix_domain_socket_test_util.cc +++ b/test/syscalls/linux/unix_domain_socket_test_util.cc @@ -15,6 +15,7 @@ #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include <sys/un.h> + #include <vector> #include "gtest/gtest.h" diff --git a/test/syscalls/linux/unix_domain_socket_test_util.h b/test/syscalls/linux/unix_domain_socket_test_util.h index 5eca0b7f0..b8073db17 100644 --- a/test/syscalls/linux/unix_domain_socket_test_util.h +++ b/test/syscalls/linux/unix_domain_socket_test_util.h @@ -16,6 +16,7 @@ #define GVISOR_TEST_SYSCALLS_UNIX_DOMAIN_SOCKET_TEST_UTIL_H_ #include <string> + #include "test/syscalls/linux/socket_test_util.h" namespace gvisor { diff --git a/test/syscalls/linux/utimes.cc b/test/syscalls/linux/utimes.cc index 80716859a..e647d2896 100644 --- a/test/syscalls/linux/utimes.cc +++ b/test/syscalls/linux/utimes.cc @@ -20,6 +20,7 @@ #include <time.h> #include <unistd.h> #include <utime.h> + #include <string> #include "absl/time/time.h" @@ -33,17 +34,10 @@ namespace testing { namespace { -// TODO(b/36516566): utimes(nullptr) does not pick the "now" time in the -// application's time domain, so when asserting that times are within a window, -// we expand the window to allow for differences between the time domains. -constexpr absl::Duration kClockSlack = absl::Milliseconds(100); - // TimeBoxed runs fn, setting before and after to (coarse realtime) times // guaranteed* to come before and after fn started and completed, respectively. // // fn may be called more than once if the clock is adjusted. -// -// * See the comment on kClockSlack. gVisor breaks this guarantee. void TimeBoxed(absl::Time* before, absl::Time* after, std::function<void()> const& fn) { do { @@ -54,12 +48,15 @@ void TimeBoxed(absl::Time* before, absl::Time* after, // filesystems set it to 1, so we don't do any truncation. struct timespec ts; EXPECT_THAT(clock_gettime(CLOCK_REALTIME_COARSE, &ts), SyscallSucceeds()); - *before = absl::TimeFromTimespec(ts); + // FIXME(b/132819225): gVisor filesystem timestamps inconsistently use the + // internal or host clock, which may diverge slightly. Allow some slack on + // times to account for the difference. + *before = absl::TimeFromTimespec(ts) - absl::Seconds(1); fn(); EXPECT_THAT(clock_gettime(CLOCK_REALTIME_COARSE, &ts), SyscallSucceeds()); - *after = absl::TimeFromTimespec(ts); + *after = absl::TimeFromTimespec(ts) + absl::Seconds(1); if (*after < *before) { // Clock jumped backwards; retry. @@ -68,23 +65,17 @@ void TimeBoxed(absl::Time* before, absl::Time* after, // which could lead to test failures, but that is very unlikely to happen. continue; } - - if (IsRunningOnGvisor()) { - // See comment on kClockSlack. - *before -= kClockSlack; - *after += kClockSlack; - } } while (*after < *before); } void TestUtimesOnPath(std::string const& path) { struct stat statbuf; - struct timeval times[2] = {{1, 0}, {2, 0}}; + struct timeval times[2] = {{10, 0}, {20, 0}}; EXPECT_THAT(utimes(path.c_str(), times), SyscallSucceeds()); EXPECT_THAT(stat(path.c_str(), &statbuf), SyscallSucceeds()); - EXPECT_EQ(1, statbuf.st_atime); - EXPECT_EQ(2, statbuf.st_mtime); + EXPECT_EQ(10, statbuf.st_atime); + EXPECT_EQ(20, statbuf.st_mtime); absl::Time before; absl::Time after; @@ -115,18 +106,18 @@ TEST(UtimesTest, OnDir) { TEST(UtimesTest, MissingPath) { auto path = NewTempAbsPath(); - struct timeval times[2] = {{1, 0}, {2, 0}}; + struct timeval times[2] = {{10, 0}, {20, 0}}; EXPECT_THAT(utimes(path.c_str(), times), SyscallFailsWithErrno(ENOENT)); } void TestFutimesat(int dirFd, std::string const& path) { struct stat statbuf; - struct timeval times[2] = {{1, 0}, {2, 0}}; + struct timeval times[2] = {{10, 0}, {20, 0}}; EXPECT_THAT(futimesat(dirFd, path.c_str(), times), SyscallSucceeds()); EXPECT_THAT(fstatat(dirFd, path.c_str(), &statbuf, 0), SyscallSucceeds()); - EXPECT_EQ(1, statbuf.st_atime); - EXPECT_EQ(2, statbuf.st_mtime); + EXPECT_EQ(10, statbuf.st_atime); + EXPECT_EQ(20, statbuf.st_mtime); absl::Time before; absl::Time after; @@ -162,12 +153,12 @@ TEST(FutimesatTest, OnRelPath) { TEST(FutimesatTest, InvalidNsec) { auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); struct timeval times[4][2] = {{ - {0, 1}, // Valid + {0, 1}, // Valid {1, static_cast<int64_t>(1e7)} // Invalid }, { {1, static_cast<int64_t>(1e7)}, // Invalid - {0, 1} // Valid + {0, 1} // Valid }, { {0, 1}, // Valid @@ -187,11 +178,11 @@ TEST(FutimesatTest, InvalidNsec) { void TestUtimensat(int dirFd, std::string const& path) { struct stat statbuf; - const struct timespec times[2] = {{1, 0}, {2, 0}}; + const struct timespec times[2] = {{10, 0}, {20, 0}}; EXPECT_THAT(utimensat(dirFd, path.c_str(), times, 0), SyscallSucceeds()); EXPECT_THAT(fstatat(dirFd, path.c_str(), &statbuf, 0), SyscallSucceeds()); - EXPECT_EQ(1, statbuf.st_atime); - EXPECT_EQ(2, statbuf.st_mtime); + EXPECT_EQ(10, statbuf.st_atime); + EXPECT_EQ(20, statbuf.st_mtime); // Test setting with UTIME_NOW and UTIME_OMIT. struct stat statbuf2; @@ -234,10 +225,7 @@ void TestUtimensat(int dirFd, std::string const& path) { EXPECT_GE(mtime3, before); EXPECT_LE(mtime3, after); - if (!IsRunningOnGvisor()) { - // FIXME(b/36516566): Gofers set atime and mtime to different "now" times. - EXPECT_EQ(atime3, mtime3); - } + EXPECT_EQ(atime3, mtime3); } TEST(UtimensatTest, OnAbsPath) { @@ -287,14 +275,15 @@ TEST(UtimeTest, ZeroAtimeandMtime) { TEST(UtimensatTest, InvalidNsec) { auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - struct timespec times[2][2] = {{ - {0, UTIME_OMIT}, // Valid - {2, static_cast<int64_t>(1e10)} // Invalid - }, - { - {2, static_cast<int64_t>(1e10)}, // Invalid - {0, UTIME_OMIT} // Valid - }}; + struct timespec times[2][2] = { + { + {0, UTIME_OMIT}, // Valid + {2, static_cast<int64_t>(1e10)} // Invalid + }, + { + {2, static_cast<int64_t>(1e10)}, // Invalid + {0, UTIME_OMIT} // Valid + }}; for (unsigned int i = 0; i < sizeof(times) / sizeof(times[0]); i++) { std::cout << "test:" << i << "\n"; @@ -315,13 +304,13 @@ TEST(Utimensat, NullPath) { auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDWR)); struct stat statbuf; - const struct timespec times[2] = {{1, 0}, {2, 0}}; + const struct timespec times[2] = {{10, 0}, {20, 0}}; // Call syscall directly. EXPECT_THAT(syscall(SYS_utimensat, fd.get(), NULL, times, 0), SyscallSucceeds()); EXPECT_THAT(fstatat(0, f.path().c_str(), &statbuf, 0), SyscallSucceeds()); - EXPECT_EQ(1, statbuf.st_atime); - EXPECT_EQ(2, statbuf.st_mtime); + EXPECT_EQ(10, statbuf.st_atime); + EXPECT_EQ(20, statbuf.st_mtime); } } // namespace diff --git a/test/syscalls/linux/vdso_clock_gettime.cc b/test/syscalls/linux/vdso_clock_gettime.cc index 40c0014b9..ce1899f45 100644 --- a/test/syscalls/linux/vdso_clock_gettime.cc +++ b/test/syscalls/linux/vdso_clock_gettime.cc @@ -17,6 +17,7 @@ #include <syscall.h> #include <time.h> #include <unistd.h> + #include <map> #include <string> #include <utility> diff --git a/test/syscalls/linux/vfork.cc b/test/syscalls/linux/vfork.cc index 0aaba482d..19d05998e 100644 --- a/test/syscalls/linux/vfork.cc +++ b/test/syscalls/linux/vfork.cc @@ -191,5 +191,5 @@ int main(int argc, char** argv) { return gvisor::testing::RunChild(); } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/vsyscall.cc b/test/syscalls/linux/vsyscall.cc index 2c2303358..ae4377108 100644 --- a/test/syscalls/linux/vsyscall.cc +++ b/test/syscalls/linux/vsyscall.cc @@ -24,6 +24,7 @@ namespace testing { namespace { +#if defined(__x86_64__) || defined(__i386__) time_t vsyscall_time(time_t* t) { constexpr uint64_t kVsyscallTimeEntry = 0xffffffffff600400; return reinterpret_cast<time_t (*)(time_t*)>(kVsyscallTimeEntry)(t); @@ -37,6 +38,7 @@ TEST(VsyscallTest, VsyscallAlwaysAvailableOnGvisor) { time_t t; EXPECT_THAT(vsyscall_time(&t), SyscallSucceeds()); } +#endif } // namespace diff --git a/test/syscalls/linux/write.cc b/test/syscalls/linux/write.cc index 9b219cfd6..39b5b2f56 100644 --- a/test/syscalls/linux/write.cc +++ b/test/syscalls/linux/write.cc @@ -31,14 +31,8 @@ namespace gvisor { namespace testing { namespace { -// This test is currently very rudimentary. -// -// TODO(edahlgren): -// * bad buffer states (EFAULT). -// * bad fds (wrong permission, wrong type of file, EBADF). -// * check offset is incremented. -// * check for EOF. -// * writing to pipes, symlinks, special files. + +// TODO(gvisor.dev/issue/2370): This test is currently very rudimentary. class WriteTest : public ::testing::Test { public: ssize_t WriteBytes(int fd, int bytes) { diff --git a/test/syscalls/linux/xattr.cc b/test/syscalls/linux/xattr.cc new file mode 100644 index 000000000..cbcf08451 --- /dev/null +++ b/test/syscalls/linux/xattr.cc @@ -0,0 +1,610 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <limits.h> +#include <sys/types.h> +#include <sys/xattr.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/container/flat_hash_set.h" +#include "test/syscalls/linux/file_base.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/posix_error.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class XattrTest : public FileTest {}; + +TEST_F(XattrTest, XattrNonexistentFile) { + const char* path = "/does/not/exist"; + const char* name = "user.test"; + EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), + SyscallFailsWithErrno(ENOENT)); + EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENOENT)); + EXPECT_THAT(listxattr(path, nullptr, 0), SyscallFailsWithErrno(ENOENT)); + EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(ENOENT)); +} + +TEST_F(XattrTest, XattrNullName) { + const char* path = test_file_name_.c_str(); + + EXPECT_THAT(setxattr(path, nullptr, nullptr, 0, /*flags=*/0), + SyscallFailsWithErrno(EFAULT)); + EXPECT_THAT(getxattr(path, nullptr, nullptr, 0), + SyscallFailsWithErrno(EFAULT)); + EXPECT_THAT(removexattr(path, nullptr), SyscallFailsWithErrno(EFAULT)); +} + +TEST_F(XattrTest, XattrEmptyName) { + const char* path = test_file_name_.c_str(); + + EXPECT_THAT(setxattr(path, "", nullptr, 0, /*flags=*/0), + SyscallFailsWithErrno(ERANGE)); + EXPECT_THAT(getxattr(path, "", nullptr, 0), SyscallFailsWithErrno(ERANGE)); + EXPECT_THAT(removexattr(path, ""), SyscallFailsWithErrno(ERANGE)); +} + +TEST_F(XattrTest, XattrLargeName) { + const char* path = test_file_name_.c_str(); + std::string name = "user."; + name += std::string(XATTR_NAME_MAX - name.length(), 'a'); + + if (!IsRunningOnGvisor()) { + // In gVisor, access to xattrs is controlled with an explicit list of + // allowed names. This name isn't going to be configured to allow access, so + // don't test it. + EXPECT_THAT(setxattr(path, name.c_str(), nullptr, 0, /*flags=*/0), + SyscallSucceeds()); + EXPECT_THAT(getxattr(path, name.c_str(), nullptr, 0), + SyscallSucceedsWithValue(0)); + } + + name += "a"; + EXPECT_THAT(setxattr(path, name.c_str(), nullptr, 0, /*flags=*/0), + SyscallFailsWithErrno(ERANGE)); + EXPECT_THAT(getxattr(path, name.c_str(), nullptr, 0), + SyscallFailsWithErrno(ERANGE)); + EXPECT_THAT(removexattr(path, name.c_str()), SyscallFailsWithErrno(ERANGE)); +} + +TEST_F(XattrTest, XattrInvalidPrefix) { + const char* path = test_file_name_.c_str(); + std::string name(XATTR_NAME_MAX, 'a'); + EXPECT_THAT(setxattr(path, name.c_str(), nullptr, 0, /*flags=*/0), + SyscallFailsWithErrno(EOPNOTSUPP)); + EXPECT_THAT(getxattr(path, name.c_str(), nullptr, 0), + SyscallFailsWithErrno(EOPNOTSUPP)); + EXPECT_THAT(removexattr(path, name.c_str()), + SyscallFailsWithErrno(EOPNOTSUPP)); +} + +// Do not allow save/restore cycles after making the test file read-only, as +// the restore will fail to open it with r/w permissions. +TEST_F(XattrTest, XattrReadOnly_NoRandomSave) { + // Drop capabilities that allow us to override file and directory permissions. + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + + const char* path = test_file_name_.c_str(); + const char name[] = "user.test"; + char val = 'a'; + size_t size = sizeof(val); + + EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds()); + + DisableSave ds; + ASSERT_NO_ERRNO(testing::Chmod(test_file_name_, S_IRUSR)); + + EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), + SyscallFailsWithErrno(EACCES)); + EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(EACCES)); + + char buf = '-'; + EXPECT_THAT(getxattr(path, name, &buf, size), SyscallSucceedsWithValue(size)); + EXPECT_EQ(buf, val); + + char list[sizeof(name)]; + EXPECT_THAT(listxattr(path, list, sizeof(list)), + SyscallSucceedsWithValue(sizeof(name))); + EXPECT_STREQ(list, name); +} + +// Do not allow save/restore cycles after making the test file write-only, as +// the restore will fail to open it with r/w permissions. +TEST_F(XattrTest, XattrWriteOnly_NoRandomSave) { + // Drop capabilities that allow us to override file and directory permissions. + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + + DisableSave ds; + ASSERT_NO_ERRNO(testing::Chmod(test_file_name_, S_IWUSR)); + + const char* path = test_file_name_.c_str(); + const char name[] = "user.test"; + char val = 'a'; + size_t size = sizeof(val); + + EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds()); + + EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(EACCES)); + + // listxattr will succeed even without read permissions. + char list[sizeof(name)]; + EXPECT_THAT(listxattr(path, list, sizeof(list)), + SyscallSucceedsWithValue(sizeof(name))); + EXPECT_STREQ(list, name); + + EXPECT_THAT(removexattr(path, name), SyscallSucceeds()); +} + +TEST_F(XattrTest, XattrTrustedWithNonadmin) { + // TODO(b/148380782): Support setxattr and getxattr with "trusted" prefix. + SKIP_IF(IsRunningOnGvisor()); + SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + + const char* path = test_file_name_.c_str(); + const char name[] = "trusted.abc"; + EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), + SyscallFailsWithErrno(EPERM)); + EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(EPERM)); + EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA)); +} + +TEST_F(XattrTest, XattrOnDirectory) { + TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const char name[] = "user.test"; + EXPECT_THAT(setxattr(dir.path().c_str(), name, nullptr, 0, /*flags=*/0), + SyscallSucceeds()); + EXPECT_THAT(getxattr(dir.path().c_str(), name, nullptr, 0), + SyscallSucceedsWithValue(0)); + + char list[sizeof(name)]; + EXPECT_THAT(listxattr(dir.path().c_str(), list, sizeof(list)), + SyscallSucceedsWithValue(sizeof(name))); + EXPECT_STREQ(list, name); + + EXPECT_THAT(removexattr(dir.path().c_str(), name), SyscallSucceeds()); +} + +TEST_F(XattrTest, XattrOnSymlink) { + TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + TempPath link = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(dir.path(), test_file_name_)); + const char name[] = "user.test"; + EXPECT_THAT(setxattr(link.path().c_str(), name, nullptr, 0, /*flags=*/0), + SyscallSucceeds()); + EXPECT_THAT(getxattr(link.path().c_str(), name, nullptr, 0), + SyscallSucceedsWithValue(0)); + + char list[sizeof(name)]; + EXPECT_THAT(listxattr(link.path().c_str(), list, sizeof(list)), + SyscallSucceedsWithValue(sizeof(name))); + EXPECT_STREQ(list, name); + + EXPECT_THAT(removexattr(link.path().c_str(), name), SyscallSucceeds()); +} + +TEST_F(XattrTest, XattrOnInvalidFileTypes) { + const char name[] = "user.test"; + + char char_device[] = "/dev/zero"; + EXPECT_THAT(setxattr(char_device, name, nullptr, 0, /*flags=*/0), + SyscallFailsWithErrno(EPERM)); + EXPECT_THAT(getxattr(char_device, name, nullptr, 0), + SyscallFailsWithErrno(ENODATA)); + EXPECT_THAT(listxattr(char_device, nullptr, 0), SyscallSucceedsWithValue(0)); + + // Use tmpfs, where creation of named pipes is supported. + const std::string fifo = NewTempAbsPathInDir("/dev/shm"); + const char* path = fifo.c_str(); + EXPECT_THAT(mknod(path, S_IFIFO | S_IRUSR | S_IWUSR, 0), SyscallSucceeds()); + EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), + SyscallFailsWithErrno(EPERM)); + EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA)); + EXPECT_THAT(listxattr(path, nullptr, 0), SyscallSucceedsWithValue(0)); + EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(EPERM)); +} + +TEST_F(XattrTest, SetxattrSizeSmallerThanValue) { + const char* path = test_file_name_.c_str(); + const char name[] = "user.test"; + std::vector<char> val = {'a', 'a'}; + size_t size = 1; + EXPECT_THAT(setxattr(path, name, val.data(), size, /*flags=*/0), + SyscallSucceeds()); + + std::vector<char> buf = {'-', '-'}; + std::vector<char> expected_buf = {'a', '-'}; + EXPECT_THAT(getxattr(path, name, buf.data(), buf.size()), + SyscallSucceedsWithValue(size)); + EXPECT_EQ(buf, expected_buf); +} + +TEST_F(XattrTest, SetxattrZeroSize) { + const char* path = test_file_name_.c_str(); + const char name[] = "user.test"; + char val = 'a'; + EXPECT_THAT(setxattr(path, name, &val, 0, /*flags=*/0), SyscallSucceeds()); + + char buf = '-'; + EXPECT_THAT(getxattr(path, name, &buf, XATTR_SIZE_MAX), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(buf, '-'); +} + +TEST_F(XattrTest, SetxattrSizeTooLarge) { + const char* path = test_file_name_.c_str(); + const char name[] = "user.test"; + + // Note that each particular fs implementation may stipulate a lower size + // limit, in which case we actually may fail (e.g. error with ENOSPC) for + // some sizes under XATTR_SIZE_MAX. + size_t size = XATTR_SIZE_MAX + 1; + std::vector<char> val(size); + EXPECT_THAT(setxattr(path, name, val.data(), size, /*flags=*/0), + SyscallFailsWithErrno(E2BIG)); + + EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA)); +} + +TEST_F(XattrTest, SetxattrNullValueAndNonzeroSize) { + const char* path = test_file_name_.c_str(); + const char name[] = "user.test"; + EXPECT_THAT(setxattr(path, name, nullptr, 1, /*flags=*/0), + SyscallFailsWithErrno(EFAULT)); + + EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA)); +} + +TEST_F(XattrTest, SetxattrNullValueAndZeroSize) { + const char* path = test_file_name_.c_str(); + const char name[] = "user.test"; + EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds()); + + EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0)); +} + +TEST_F(XattrTest, SetxattrValueTooLargeButOKSize) { + const char* path = test_file_name_.c_str(); + const char name[] = "user.test"; + std::vector<char> val(XATTR_SIZE_MAX + 1); + std::fill(val.begin(), val.end(), 'a'); + size_t size = 1; + EXPECT_THAT(setxattr(path, name, val.data(), size, /*flags=*/0), + SyscallSucceeds()); + + std::vector<char> buf = {'-', '-'}; + std::vector<char> expected_buf = {'a', '-'}; + EXPECT_THAT(getxattr(path, name, buf.data(), size), + SyscallSucceedsWithValue(size)); + EXPECT_EQ(buf, expected_buf); +} + +TEST_F(XattrTest, SetxattrReplaceWithSmaller) { + const char* path = test_file_name_.c_str(); + const char name[] = "user.test"; + std::vector<char> val = {'a', 'a'}; + EXPECT_THAT(setxattr(path, name, val.data(), 2, /*flags=*/0), + SyscallSucceeds()); + EXPECT_THAT(setxattr(path, name, val.data(), 1, /*flags=*/0), + SyscallSucceeds()); + + std::vector<char> buf = {'-', '-'}; + std::vector<char> expected_buf = {'a', '-'}; + EXPECT_THAT(getxattr(path, name, buf.data(), 2), SyscallSucceedsWithValue(1)); + EXPECT_EQ(buf, expected_buf); +} + +TEST_F(XattrTest, SetxattrReplaceWithLarger) { + const char* path = test_file_name_.c_str(); + const char name[] = "user.test"; + std::vector<char> val = {'a', 'a'}; + EXPECT_THAT(setxattr(path, name, val.data(), 1, /*flags=*/0), + SyscallSucceeds()); + EXPECT_THAT(setxattr(path, name, val.data(), 2, /*flags=*/0), + SyscallSucceeds()); + + std::vector<char> buf = {'-', '-'}; + EXPECT_THAT(getxattr(path, name, buf.data(), 2), SyscallSucceedsWithValue(2)); + EXPECT_EQ(buf, val); +} + +TEST_F(XattrTest, SetxattrCreateFlag) { + const char* path = test_file_name_.c_str(); + const char name[] = "user.test"; + EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_CREATE), + SyscallSucceeds()); + EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_CREATE), + SyscallFailsWithErrno(EEXIST)); + + EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0)); +} + +TEST_F(XattrTest, SetxattrReplaceFlag) { + const char* path = test_file_name_.c_str(); + const char name[] = "user.test"; + EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_REPLACE), + SyscallFailsWithErrno(ENODATA)); + EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds()); + EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_REPLACE), + SyscallSucceeds()); + + EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0)); +} + +TEST_F(XattrTest, SetxattrInvalidFlags) { + const char* path = test_file_name_.c_str(); + int invalid_flags = 0xff; + EXPECT_THAT(setxattr(path, nullptr, nullptr, 0, invalid_flags), + SyscallFailsWithErrno(EINVAL)); +} + +TEST_F(XattrTest, Getxattr) { + const char* path = test_file_name_.c_str(); + const char name[] = "user.test"; + int val = 1234; + size_t size = sizeof(val); + EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds()); + + int buf = 0; + EXPECT_THAT(getxattr(path, name, &buf, size), SyscallSucceedsWithValue(size)); + EXPECT_EQ(buf, val); +} + +TEST_F(XattrTest, GetxattrSizeSmallerThanValue) { + const char* path = test_file_name_.c_str(); + const char name[] = "user.test"; + std::vector<char> val = {'a', 'a'}; + size_t size = val.size(); + EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds()); + + char buf = '-'; + EXPECT_THAT(getxattr(path, name, &buf, 1), SyscallFailsWithErrno(ERANGE)); + EXPECT_EQ(buf, '-'); +} + +TEST_F(XattrTest, GetxattrSizeLargerThanValue) { + const char* path = test_file_name_.c_str(); + const char name[] = "user.test"; + char val = 'a'; + EXPECT_THAT(setxattr(path, name, &val, 1, /*flags=*/0), SyscallSucceeds()); + + std::vector<char> buf(XATTR_SIZE_MAX); + std::fill(buf.begin(), buf.end(), '-'); + std::vector<char> expected_buf = buf; + expected_buf[0] = 'a'; + EXPECT_THAT(getxattr(path, name, buf.data(), buf.size()), + SyscallSucceedsWithValue(1)); + EXPECT_EQ(buf, expected_buf); +} + +TEST_F(XattrTest, GetxattrZeroSize) { + const char* path = test_file_name_.c_str(); + const char name[] = "user.test"; + char val = 'a'; + EXPECT_THAT(setxattr(path, name, &val, sizeof(val), /*flags=*/0), + SyscallSucceeds()); + + char buf = '-'; + EXPECT_THAT(getxattr(path, name, &buf, 0), + SyscallSucceedsWithValue(sizeof(val))); + EXPECT_EQ(buf, '-'); +} + +TEST_F(XattrTest, GetxattrSizeTooLarge) { + const char* path = test_file_name_.c_str(); + const char name[] = "user.test"; + char val = 'a'; + EXPECT_THAT(setxattr(path, name, &val, sizeof(val), /*flags=*/0), + SyscallSucceeds()); + + std::vector<char> buf(XATTR_SIZE_MAX + 1); + std::fill(buf.begin(), buf.end(), '-'); + std::vector<char> expected_buf = buf; + expected_buf[0] = 'a'; + EXPECT_THAT(getxattr(path, name, buf.data(), buf.size()), + SyscallSucceedsWithValue(sizeof(val))); + EXPECT_EQ(buf, expected_buf); +} + +TEST_F(XattrTest, GetxattrNullValue) { + const char* path = test_file_name_.c_str(); + const char name[] = "user.test"; + char val = 'a'; + size_t size = sizeof(val); + EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds()); + + EXPECT_THAT(getxattr(path, name, nullptr, size), + SyscallFailsWithErrno(EFAULT)); +} + +TEST_F(XattrTest, GetxattrNullValueAndZeroSize) { + const char* path = test_file_name_.c_str(); + const char name[] = "user.test"; + char val = 'a'; + size_t size = sizeof(val); + // Set value with zero size. + EXPECT_THAT(setxattr(path, name, &val, 0, /*flags=*/0), SyscallSucceeds()); + // Get value with nonzero size. + EXPECT_THAT(getxattr(path, name, nullptr, size), SyscallSucceedsWithValue(0)); + + // Set value with nonzero size. + EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds()); + // Get value with zero size. + EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(size)); +} + +TEST_F(XattrTest, GetxattrNonexistentName) { + const char* path = test_file_name_.c_str(); + const char name[] = "user.test"; + EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA)); +} + +TEST_F(XattrTest, Listxattr) { + const char* path = test_file_name_.c_str(); + const std::string name = "user.test"; + const std::string name2 = "user.test2"; + const std::string name3 = "user.test3"; + EXPECT_THAT(setxattr(path, name.c_str(), nullptr, 0, /*flags=*/0), + SyscallSucceeds()); + EXPECT_THAT(setxattr(path, name2.c_str(), nullptr, 0, /*flags=*/0), + SyscallSucceeds()); + EXPECT_THAT(setxattr(path, name3.c_str(), nullptr, 0, /*flags=*/0), + SyscallSucceeds()); + + std::vector<char> list(name.size() + 1 + name2.size() + 1 + name3.size() + 1); + char* buf = list.data(); + EXPECT_THAT(listxattr(path, buf, XATTR_SIZE_MAX), + SyscallSucceedsWithValue(list.size())); + + absl::flat_hash_set<std::string> got = {}; + for (char* p = buf; p < buf + list.size(); p += strlen(p) + 1) { + got.insert(std::string{p}); + } + + absl::flat_hash_set<std::string> expected = {name, name2, name3}; + EXPECT_EQ(got, expected); +} + +TEST_F(XattrTest, ListxattrNoXattrs) { + const char* path = test_file_name_.c_str(); + + std::vector<char> list, expected; + EXPECT_THAT(listxattr(path, list.data(), sizeof(list)), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(list, expected); + + // Listxattr should succeed if there are no attributes, even if the buffer + // passed in is a nullptr. + EXPECT_THAT(listxattr(path, nullptr, sizeof(list)), + SyscallSucceedsWithValue(0)); +} + +TEST_F(XattrTest, ListxattrNullBuffer) { + const char* path = test_file_name_.c_str(); + const char name[] = "user.test"; + EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds()); + + EXPECT_THAT(listxattr(path, nullptr, sizeof(name)), + SyscallFailsWithErrno(EFAULT)); +} + +TEST_F(XattrTest, ListxattrSizeTooSmall) { + const char* path = test_file_name_.c_str(); + const char name[] = "user.test"; + EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds()); + + char list[sizeof(name) - 1]; + EXPECT_THAT(listxattr(path, list, sizeof(list)), + SyscallFailsWithErrno(ERANGE)); +} + +TEST_F(XattrTest, ListxattrZeroSize) { + const char* path = test_file_name_.c_str(); + const char name[] = "user.test"; + EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds()); + EXPECT_THAT(listxattr(path, nullptr, 0), + SyscallSucceedsWithValue(sizeof(name))); +} + +TEST_F(XattrTest, RemoveXattr) { + const char* path = test_file_name_.c_str(); + const char name[] = "user.test"; + EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds()); + EXPECT_THAT(removexattr(path, name), SyscallSucceeds()); + EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA)); +} + +TEST_F(XattrTest, RemoveXattrNonexistentName) { + const char* path = test_file_name_.c_str(); + const char name[] = "user.test"; + EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(ENODATA)); +} + +TEST_F(XattrTest, LXattrOnSymlink) { + const char name[] = "user.test"; + TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + TempPath link = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(dir.path(), test_file_name_)); + + EXPECT_THAT(lsetxattr(link.path().c_str(), name, nullptr, 0, 0), + SyscallFailsWithErrno(EPERM)); + EXPECT_THAT(lgetxattr(link.path().c_str(), name, nullptr, 0), + SyscallFailsWithErrno(ENODATA)); + EXPECT_THAT(llistxattr(link.path().c_str(), nullptr, 0), + SyscallSucceedsWithValue(0)); + EXPECT_THAT(lremovexattr(link.path().c_str(), name), + SyscallFailsWithErrno(EPERM)); +} + +TEST_F(XattrTest, LXattrOnNonsymlink) { + const char* path = test_file_name_.c_str(); + const char name[] = "user.test"; + int val = 1234; + size_t size = sizeof(val); + EXPECT_THAT(lsetxattr(path, name, &val, size, /*flags=*/0), + SyscallSucceeds()); + + int buf = 0; + EXPECT_THAT(lgetxattr(path, name, &buf, size), + SyscallSucceedsWithValue(size)); + EXPECT_EQ(buf, val); + + char list[sizeof(name)]; + EXPECT_THAT(llistxattr(path, list, sizeof(list)), + SyscallSucceedsWithValue(sizeof(name))); + EXPECT_STREQ(list, name); + + EXPECT_THAT(lremovexattr(path, name), SyscallSucceeds()); +} + +TEST_F(XattrTest, XattrWithFD) { + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_.c_str(), 0)); + const char name[] = "user.test"; + int val = 1234; + size_t size = sizeof(val); + EXPECT_THAT(fsetxattr(fd.get(), name, &val, size, /*flags=*/0), + SyscallSucceeds()); + + int buf = 0; + EXPECT_THAT(fgetxattr(fd.get(), name, &buf, size), + SyscallSucceedsWithValue(size)); + EXPECT_EQ(buf, val); + + char list[sizeof(name)]; + EXPECT_THAT(flistxattr(fd.get(), list, sizeof(list)), + SyscallSucceedsWithValue(sizeof(name))); + EXPECT_STREQ(list, name); + + EXPECT_THAT(fremovexattr(fd.get(), name), SyscallSucceeds()); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/syscall_test_runner.sh b/test/syscalls/syscall_test_runner.sh deleted file mode 100755 index 864bb2de4..000000000 --- a/test/syscalls/syscall_test_runner.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash - -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# syscall_test_runner.sh is a simple wrapper around the go syscall test runner. -# It exists so that we can build the syscall test runner once, and use it for -# all syscall tests, rather than build it for each test run. - -set -euf -x -o pipefail - -echo -- "$@" - -if [[ -n "${TEST_UNDECLARED_OUTPUTS_DIR}" ]]; then - mkdir -p "${TEST_UNDECLARED_OUTPUTS_DIR}" - chmod a+rwx "${TEST_UNDECLARED_OUTPUTS_DIR}" -fi - -# Get location of syscall_test_runner binary. -readonly runner=$(find "${TEST_SRCDIR}" -name syscall_test_runner) - -# Pass the arguments of this script directly to the runner. -exec "${runner}" "$@" diff --git a/test/uds/BUILD b/test/uds/BUILD index a3843e699..51e2c7ce8 100644 --- a/test/uds/BUILD +++ b/test/uds/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package( default_visibility = ["//:sandbox"], @@ -9,7 +9,6 @@ go_library( name = "uds", testonly = 1, srcs = ["uds.go"], - importpath = "gvisor.dev/gvisor/test/uds", deps = [ "//pkg/log", "//pkg/unet", diff --git a/test/util/BUILD b/test/util/BUILD index 5d2a9cc2c..2a17c33ee 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -1,5 +1,4 @@ -load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") -load("//test/syscalls:build_defs.bzl", "select_for_linux") +load("//tools:defs.bzl", "cc_library", "cc_test", "gbenchmark", "gtest", "select_system") package( default_visibility = ["//:sandbox"], @@ -42,7 +41,7 @@ cc_library( ":save_util", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_googletest//:gtest", + gtest, ], ) @@ -56,7 +55,7 @@ cc_library( ":posix_error", ":test_util", "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", + gtest, ], ) @@ -68,7 +67,7 @@ cc_test( ":proc_util", ":test_main", ":test_util", - "@com_google_googletest//:gtest", + gtest, ], ) @@ -88,7 +87,7 @@ cc_library( ":file_descriptor", ":posix_error", "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", + gtest, ], ) @@ -102,7 +101,7 @@ cc_test( ":temp_path", ":test_main", ":test_util", - "@com_google_googletest//:gtest", + gtest, ], ) @@ -135,19 +134,20 @@ cc_library( ":cleanup", ":posix_error", ":test_util", - "@com_google_googletest//:gtest", + gtest, ], ) cc_library( name = "save_util", testonly = 1, - srcs = ["save_util.cc"] + - select_for_linux( - ["save_util_linux.cc"], - ["save_util_other.cc"], - ), + srcs = [ + "save_util.cc", + "save_util_linux.cc", + "save_util_other.cc", + ], hdrs = ["save_util.h"], + defines = select_system(), ) cc_library( @@ -166,6 +166,14 @@ cc_library( ) cc_library( + name = "platform_util", + testonly = 1, + srcs = ["platform_util.cc"], + hdrs = ["platform_util.h"], + deps = [":test_util"], +) + +cc_library( name = "posix_error", testonly = 1, srcs = ["posix_error.cc"], @@ -175,7 +183,7 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:variant", - "@com_google_googletest//:gtest", + gtest, ], ) @@ -186,7 +194,7 @@ cc_test( deps = [ ":posix_error", ":test_main", - "@com_google_googletest//:gtest", + gtest, ], ) @@ -210,7 +218,7 @@ cc_library( ":cleanup", ":posix_error", ":test_util", - "@com_google_googletest//:gtest", + gtest, ], ) @@ -225,27 +233,34 @@ cc_library( ":test_util", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", + gtest, ], ) cc_library( name = "test_util", testonly = 1, - srcs = ["test_util.cc"], + srcs = [ + "test_util.cc", + "test_util_impl.cc", + "test_util_runfiles.cc", + ], hdrs = ["test_util.h"], + defines = select_system(), deps = [ ":fs_util", ":logging", ":posix_error", ":save_util", + "@bazel_tools//tools/cpp/runfiles", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", + gtest, + gbenchmark, ], ) @@ -277,7 +292,7 @@ cc_library( ":posix_error", ":test_util", "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", + gtest, ], ) @@ -288,7 +303,7 @@ cc_test( deps = [ ":test_main", ":test_util", - "@com_google_googletest//:gtest", + gtest, ], ) @@ -308,7 +323,7 @@ cc_library( ":file_descriptor", ":posix_error", ":save_util", - "@com_google_googletest//:gtest", + gtest, ], ) @@ -335,3 +350,9 @@ cc_library( ":save_util", ], ) + +cc_library( + name = "temp_umask", + testonly = 1, + hdrs = ["temp_umask.h"], +) diff --git a/test/util/capability_util.cc b/test/util/capability_util.cc index 5d733887b..a1b994c45 100644 --- a/test/util/capability_util.cc +++ b/test/util/capability_util.cc @@ -36,10 +36,10 @@ PosixErrorOr<bool> CanCreateUserNamespace() { ASSIGN_OR_RETURN_ERRNO( auto child_stack, MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); - int const child_pid = - clone(+[](void*) { return 0; }, - reinterpret_cast<void*>(child_stack.addr() + kPageSize), - CLONE_NEWUSER | SIGCHLD, /* arg = */ nullptr); + int const child_pid = clone( + +[](void*) { return 0; }, + reinterpret_cast<void*>(child_stack.addr() + kPageSize), + CLONE_NEWUSER | SIGCHLD, /* arg = */ nullptr); if (child_pid > 0) { int status; int const ret = waitpid(child_pid, &status, /* options = */ 0); @@ -63,13 +63,13 @@ PosixErrorOr<bool> CanCreateUserNamespace() { // is in a chroot environment (i.e., the caller's root directory does // not match the root directory of the mount namespace in which it // resides)." - std::cerr << "clone(CLONE_NEWUSER) failed with EPERM"; + std::cerr << "clone(CLONE_NEWUSER) failed with EPERM" << std::endl; return false; } else if (errno == EUSERS) { // "(since Linux 3.11) CLONE_NEWUSER was specified in flags, and the call // would cause the limit on the number of nested user namespaces to be // exceeded. See user_namespaces(7)." - std::cerr << "clone(CLONE_NEWUSER) failed with EUSERS"; + std::cerr << "clone(CLONE_NEWUSER) failed with EUSERS" << std::endl; return false; } else { // Unexpected error code; indicate an actual error. diff --git a/test/util/fs_util.cc b/test/util/fs_util.cc index 88b1e7911..5418948fe 100644 --- a/test/util/fs_util.cc +++ b/test/util/fs_util.cc @@ -105,6 +105,15 @@ PosixErrorOr<struct stat> Stat(absl::string_view path) { return stat_buf; } +PosixErrorOr<struct stat> Lstat(absl::string_view path) { + struct stat stat_buf; + int res = lstat(std::string(path).c_str(), &stat_buf); + if (res < 0) { + return PosixError(errno, absl::StrCat("lstat ", path)); + } + return stat_buf; +} + PosixErrorOr<struct stat> Fstat(int fd) { struct stat stat_buf; int res = fstat(fd, &stat_buf); @@ -116,18 +125,18 @@ PosixErrorOr<struct stat> Fstat(int fd) { PosixErrorOr<bool> Exists(absl::string_view path) { struct stat stat_buf; - int res = stat(std::string(path).c_str(), &stat_buf); + int res = lstat(std::string(path).c_str(), &stat_buf); if (res < 0) { if (errno == ENOENT) { return false; } - return PosixError(errno, absl::StrCat("stat ", path)); + return PosixError(errno, absl::StrCat("lstat ", path)); } return true; } PosixErrorOr<bool> IsDirectory(absl::string_view path) { - ASSIGN_OR_RETURN_ERRNO(struct stat stat_buf, Stat(path)); + ASSIGN_OR_RETURN_ERRNO(struct stat stat_buf, Lstat(path)); if (S_ISDIR(stat_buf.st_mode)) { return true; } @@ -443,7 +452,7 @@ PosixErrorOr<std::string> MakeAbsolute(absl::string_view filename, std::string CleanPath(const absl::string_view unclean_path) { std::string path = std::string(unclean_path); - const char *src = path.c_str(); + const char* src = path.c_str(); std::string::iterator dst = path.begin(); // Check for absolute path and determine initial backtrack limit. diff --git a/test/util/fs_util.h b/test/util/fs_util.h index ee1b341d7..8cdac23a1 100644 --- a/test/util/fs_util.h +++ b/test/util/fs_util.h @@ -26,6 +26,17 @@ namespace gvisor { namespace testing { + +// O_LARGEFILE as defined by Linux. glibc tries to be clever by setting it to 0 +// because "it isn't needed", even though Linux can return it via F_GETFL. +#if defined(__x86_64__) +constexpr int kOLargeFile = 00100000; +#elif defined(__aarch64__) +constexpr int kOLargeFile = 00400000; +#else +#error "Unknown architecture" +#endif + // Returns a status or the current working directory. PosixErrorOr<std::string> GetCWD(); @@ -33,9 +44,14 @@ PosixErrorOr<std::string> GetCWD(); // can't be determined. PosixErrorOr<bool> Exists(absl::string_view path); -// Returns a stat structure for the given path or an error. +// Returns a stat structure for the given path or an error. If the path +// represents a symlink, it will be traversed. PosixErrorOr<struct stat> Stat(absl::string_view path); +// Returns a stat structure for the given path or an error. If the path +// represents a symlink, it will not be traversed. +PosixErrorOr<struct stat> Lstat(absl::string_view path); + // Returns a stat struct for the given fd. PosixErrorOr<struct stat> Fstat(int fd); diff --git a/test/util/fs_util_test.cc b/test/util/fs_util_test.cc index 2a200320a..657b6a46e 100644 --- a/test/util/fs_util_test.cc +++ b/test/util/fs_util_test.cc @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "test/util/fs_util.h" + #include <errno.h> + #include <vector> #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "test/util/fs_util.h" #include "test/util/posix_error.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" diff --git a/test/util/mount_util.h b/test/util/mount_util.h index 38ec6c8a1..09e2281eb 100644 --- a/test/util/mount_util.h +++ b/test/util/mount_util.h @@ -17,6 +17,7 @@ #include <errno.h> #include <sys/mount.h> + #include <functional> #include <string> @@ -30,10 +31,10 @@ namespace testing { // Mount mounts the filesystem, and unmounts when the returned reference is // destroyed. -inline PosixErrorOr<Cleanup> Mount(const std::string &source, - const std::string &target, - const std::string &fstype, uint64_t mountflags, - const std::string &data, +inline PosixErrorOr<Cleanup> Mount(const std::string& source, + const std::string& target, + const std::string& fstype, + uint64_t mountflags, const std::string& data, uint64_t umountflags) { if (mount(source.c_str(), target.c_str(), fstype.c_str(), mountflags, data.c_str()) == -1) { diff --git a/test/util/multiprocess_util.h b/test/util/multiprocess_util.h index 61526b4e7..2f3bf4a6f 100644 --- a/test/util/multiprocess_util.h +++ b/test/util/multiprocess_util.h @@ -99,11 +99,13 @@ inline PosixErrorOr<Cleanup> ForkAndExec(const std::string& filename, const ExecveArray& argv, const ExecveArray& envv, pid_t* child, int* execve_errno) { - return ForkAndExec(filename, argv, envv, [] {}, child, execve_errno); + return ForkAndExec( + filename, argv, envv, [] {}, child, execve_errno); } // Equivalent to ForkAndExec, except using dirfd and flags with execveat. -PosixErrorOr<Cleanup> ForkAndExecveat(int32_t dirfd, const std::string& pathname, +PosixErrorOr<Cleanup> ForkAndExecveat(int32_t dirfd, + const std::string& pathname, const ExecveArray& argv, const ExecveArray& envv, int flags, const std::function<void()>& fn, diff --git a/test/util/platform_util.cc b/test/util/platform_util.cc new file mode 100644 index 000000000..c9200d381 --- /dev/null +++ b/test/util/platform_util.cc @@ -0,0 +1,48 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/util/platform_util.h" + +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +PlatformSupport PlatformSupport32Bit() { + if (GvisorPlatform() == Platform::kPtrace || + GvisorPlatform() == Platform::kKVM) { + return PlatformSupport::NotSupported; + } else { + return PlatformSupport::Allowed; + } +} + +PlatformSupport PlatformSupportAlignmentCheck() { + return PlatformSupport::Allowed; +} + +PlatformSupport PlatformSupportMultiProcess() { + return PlatformSupport::Allowed; +} + +PlatformSupport PlatformSupportInt3() { + if (GvisorPlatform() == Platform::kKVM) { + return PlatformSupport::NotSupported; + } else { + return PlatformSupport::Allowed; + } +} + +} // namespace testing +} // namespace gvisor diff --git a/test/util/platform_util.h b/test/util/platform_util.h new file mode 100644 index 000000000..28cc92371 --- /dev/null +++ b/test/util/platform_util.h @@ -0,0 +1,56 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_UTIL_PLATFORM_UTIL_H_ +#define GVISOR_TEST_UTIL_PLATFORM_UTIL_H_ + +namespace gvisor { +namespace testing { + +// PlatformSupport is a generic enumeration of classes of support. +// +// It is up to the individual functions and callers to agree on the precise +// definition for each case. The document here generally refers to 32-bit +// as an example. Many cases will use only NotSupported and Allowed. +enum class PlatformSupport { + // The feature is not supported on the current platform. + // + // In the case of 32-bit, this means that calls will generally be interpreted + // as 64-bit calls, and there is no support for 32-bit binaries, long calls, + // etc. This usually means that the underlying implementation just pretends + // that 32-bit doesn't exist. + NotSupported, + + // Calls will be ignored by the kernel with a fixed error. + Ignored, + + // Calls will result in a SIGSEGV or similar fault. + Segfault, + + // The feature is supported as expected. + // + // In the case of 32-bit, this means that the system call or far call will be + // handled properly. + Allowed, +}; + +PlatformSupport PlatformSupport32Bit(); +PlatformSupport PlatformSupportAlignmentCheck(); +PlatformSupport PlatformSupportMultiProcess(); +PlatformSupport PlatformSupportInt3(); + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_UTIL_PLATFORM_UTL_H_ diff --git a/test/util/posix_error_test.cc b/test/util/posix_error_test.cc index d67270842..bf9465abb 100644 --- a/test/util/posix_error_test.cc +++ b/test/util/posix_error_test.cc @@ -15,6 +15,7 @@ #include "test/util/posix_error.h" #include <errno.h> + #include "gmock/gmock.h" #include "gtest/gtest.h" diff --git a/test/util/pty_util.cc b/test/util/pty_util.cc index c0fd9a095..c01f916aa 100644 --- a/test/util/pty_util.cc +++ b/test/util/pty_util.cc @@ -24,6 +24,14 @@ namespace gvisor { namespace testing { PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master) { + PosixErrorOr<int> n = SlaveID(master); + if (!n.ok()) { + return PosixErrorOr<FileDescriptor>(n.error()); + } + return Open(absl::StrCat("/dev/pts/", n.ValueOrDie()), O_RDWR | O_NONBLOCK); +} + +PosixErrorOr<int> SlaveID(const FileDescriptor& master) { // Get pty index. int n; int ret = ioctl(master.get(), TIOCGPTN, &n); @@ -38,7 +46,7 @@ PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master) { return PosixError(errno, "ioctl(TIOSPTLCK) failed"); } - return Open(absl::StrCat("/dev/pts/", n), O_RDWR | O_NONBLOCK); + return n; } } // namespace testing diff --git a/test/util/pty_util.h b/test/util/pty_util.h index 367b14f15..0722da379 100644 --- a/test/util/pty_util.h +++ b/test/util/pty_util.h @@ -24,6 +24,9 @@ namespace testing { // Opens the slave end of the passed master as R/W and nonblocking. PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master); +// Get the number of the slave end of the master. +PosixErrorOr<int> SlaveID(const FileDescriptor& master); + } // namespace testing } // namespace gvisor diff --git a/test/util/rlimit_util.cc b/test/util/rlimit_util.cc index 684253f78..d7bfc1606 100644 --- a/test/util/rlimit_util.cc +++ b/test/util/rlimit_util.cc @@ -15,6 +15,7 @@ #include "test/util/rlimit_util.h" #include <sys/resource.h> + #include <cerrno> #include "test/util/cleanup.h" diff --git a/test/util/save_util_linux.cc b/test/util/save_util_linux.cc index 7a0f14342..d0aea8e6a 100644 --- a/test/util/save_util_linux.cc +++ b/test/util/save_util_linux.cc @@ -12,22 +12,38 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifdef __linux__ + #include <errno.h> #include <sys/syscall.h> #include <unistd.h> #include "test/util/save_util.h" +#if defined(__x86_64__) || defined(__i386__) +#define SYS_TRIGGER_SAVE SYS_create_module +#elif defined(__aarch64__) +#define SYS_TRIGGER_SAVE SYS_finit_module +#else +#error "Unknown architecture" +#endif + namespace gvisor { namespace testing { void MaybeSave() { if (internal::ShouldSave()) { int orig_errno = errno; - syscall(SYS_create_module, nullptr, 0); + // We use it to trigger saving the sentry state + // when this syscall is called. + // Notice: this needs to be a valid syscall + // that is not used in any of the syscall tests. + syscall(SYS_TRIGGER_SAVE, nullptr, 0); errno = orig_errno; } } } // namespace testing } // namespace gvisor + +#endif diff --git a/test/util/save_util_other.cc b/test/util/save_util_other.cc index 1aca663b7..931af2c29 100644 --- a/test/util/save_util_other.cc +++ b/test/util/save_util_other.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifndef __linux__ + namespace gvisor { namespace testing { @@ -21,3 +23,5 @@ void MaybeSave() { } // namespace testing } // namespace gvisor + +#endif diff --git a/test/util/signal_util.cc b/test/util/signal_util.cc index 26738864f..5ee95ee80 100644 --- a/test/util/signal_util.cc +++ b/test/util/signal_util.cc @@ -15,6 +15,7 @@ #include "test/util/signal_util.h" #include <signal.h> + #include <ostream> #include "gtest/gtest.h" diff --git a/test/util/signal_util.h b/test/util/signal_util.h index 7fd2af015..e7b66aa51 100644 --- a/test/util/signal_util.h +++ b/test/util/signal_util.h @@ -18,6 +18,7 @@ #include <signal.h> #include <sys/syscall.h> #include <unistd.h> + #include <ostream> #include "gmock/gmock.h" @@ -84,6 +85,20 @@ inline void FixupFault(ucontext_t* ctx) { // The encoding is 0x48 0xab 0x00. ctx->uc_mcontext.gregs[REG_RIP] += 3; } +#elif __aarch64__ +inline void Fault() { + // Zero and dereference x0. + asm("mov xzr, x0\r\n" + "str xzr, [x0]\r\n" + : + : + : "x0"); +} + +inline void FixupFault(ucontext_t* ctx) { + // Skip the bad instruction above. + ctx->uc_mcontext.pc += 4; +} #endif } // namespace testing diff --git a/test/util/temp_path.cc b/test/util/temp_path.cc index 35aacb172..e1bdee7fd 100644 --- a/test/util/temp_path.cc +++ b/test/util/temp_path.cc @@ -56,7 +56,7 @@ void TryDeleteRecursively(std::string const& path) { if (undeleted_dirs || undeleted_files || !status.ok()) { std::cerr << path << ": failed to delete " << undeleted_dirs << " directories and " << undeleted_files - << " files: " << status; + << " files: " << status << std::endl; } } } @@ -77,6 +77,7 @@ std::string NewTempAbsPath() { std::string NewTempRelPath() { return NextTempBasename(); } std::string GetAbsoluteTestTmpdir() { + // Note that TEST_TMPDIR is guaranteed to be set. char* env_tmpdir = getenv("TEST_TMPDIR"); std::string tmp_dir = env_tmpdir != nullptr ? std::string(env_tmpdir) : "/tmp"; diff --git a/test/util/temp_path.h b/test/util/temp_path.h index 92d669503..9e5ac11f4 100644 --- a/test/util/temp_path.h +++ b/test/util/temp_path.h @@ -16,6 +16,7 @@ #define GVISOR_TEST_UTIL_TEMP_PATH_H_ #include <sys/stat.h> + #include <string> #include <utility> diff --git a/test/syscalls/linux/temp_umask.h b/test/util/temp_umask.h index 81a25440c..e7de84a54 100644 --- a/test/syscalls/linux/temp_umask.h +++ b/test/util/temp_umask.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef GVISOR_TEST_SYSCALLS_TEMP_UMASK_H_ -#define GVISOR_TEST_SYSCALLS_TEMP_UMASK_H_ +#ifndef GVISOR_TEST_UTIL_TEMP_UMASK_H_ +#define GVISOR_TEST_UTIL_TEMP_UMASK_H_ #include <sys/stat.h> #include <sys/types.h> @@ -36,4 +36,4 @@ class TempUmask { } // namespace testing } // namespace gvisor -#endif // GVISOR_TEST_SYSCALLS_TEMP_UMASK_H_ +#endif // GVISOR_TEST_UTIL_TEMP_UMASK_H_ diff --git a/test/util/test_main.cc b/test/util/test_main.cc index 5c7ee0064..1f389e58f 100644 --- a/test/util/test_main.cc +++ b/test/util/test_main.cc @@ -16,5 +16,5 @@ int main(int argc, char** argv) { gvisor::testing::TestInit(&argc, &argv); - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/util/test_util.cc b/test/util/test_util.cc index ba0dcf7d0..d0c1d6426 100644 --- a/test/util/test_util.cc +++ b/test/util/test_util.cc @@ -40,24 +40,38 @@ namespace gvisor { namespace testing { -#define TEST_ON_GVISOR "TEST_ON_GVISOR" +constexpr char kGvisorNetwork[] = "GVISOR_NETWORK"; +constexpr char kGvisorVfs[] = "GVISOR_VFS"; +constexpr char kFuseEnabled[] = "FUSE_ENABLED"; bool IsRunningOnGvisor() { return GvisorPlatform() != Platform::kNative; } -Platform GvisorPlatform() { +const std::string GvisorPlatform() { // Set by runner.go. - char* env = getenv(TEST_ON_GVISOR); + const char* env = getenv(kTestOnGvisor); if (!env) { return Platform::kNative; } - if (strcmp(env, "ptrace") == 0) { - return Platform::kPtrace; - } - if (strcmp(env, "kvm") == 0) { - return Platform::kKVM; + return std::string(env); +} + +bool IsRunningWithHostinet() { + const char* env = getenv(kGvisorNetwork); + return env && strcmp(env, "host") == 0; +} + +bool IsRunningWithVFS1() { + const char* env = getenv(kGvisorVfs); + if (env == nullptr) { + // If not set, it's running on Linux. + return false; } - std::cerr << "unknown platform " << env; - abort(); + return strcmp(env, "VFS1") == 0; +} + +bool IsFUSEEnabled() { + const char* env = getenv(kFuseEnabled); + return env && strcmp(env, "TRUE") == 0; } // Inline cpuid instruction. Preserve %ebx/%rbx register. In PIC compilations @@ -70,7 +84,6 @@ Platform GvisorPlatform() { "xchg %%rdi, %%rbx\n" \ : "=a"(a), "=D"(b), "=c"(c), "=d"(d) \ : "a"(a_inp), "2"(c_inp)) -#endif // defined(__x86_64__) CPUVendor GetCPUVendor() { uint32_t eax, ebx, ecx, edx; @@ -87,6 +100,7 @@ CPUVendor GetCPUVendor() { } return CPUVendor::kUnknownVendor; } +#endif // defined(__x86_64__) bool operator==(const KernelVersion& first, const KernelVersion& second) { return first.major == second.major && first.minor == second.minor && @@ -116,9 +130,6 @@ PosixErrorOr<KernelVersion> GetKernelVersion() { return ParseKernelVersion(buf.release); } -void SetupGvisorDeathTest() { -} - std::string CPUSetToString(const cpu_set_t& set, size_t cpus) { std::string str = "cpuset["; for (unsigned int n = 0; n < cpus; n++) { @@ -224,15 +235,5 @@ bool Equivalent(uint64_t current, uint64_t target, double tolerance) { return abs_diff <= static_cast<uint64_t>(tolerance * target); } -void TestInit(int* argc, char*** argv) { - ::testing::InitGoogleTest(argc, *argv); - ::absl::ParseCommandLine(*argc, *argv); - - // Always mask SIGPIPE as it's common and tests aren't expected to handle it. - struct sigaction sa = {}; - sa.sa_handler = SIG_IGN; - TEST_CHECK(sigaction(SIGPIPE, &sa, nullptr) == 0); -} - } // namespace testing } // namespace gvisor diff --git a/test/util/test_util.h b/test/util/test_util.h index b9d2dc2ba..373c54f32 100644 --- a/test/util/test_util.h +++ b/test/util/test_util.h @@ -26,16 +26,13 @@ // IsRunningOnGvisor returns true if the test is known to be running on gVisor. // GvisorPlatform can be used to get more detail: // -// switch (GvisorPlatform()) { -// case Platform::kNative: -// case Platform::kGvisor: -// EXPECT_THAT(mmap(...), SyscallSucceeds()); -// break; -// case Platform::kPtrace: -// EXPECT_THAT(mmap(...), SyscallFailsWithErrno(ENOSYS)); -// break; +// if (GvisorPlatform() == Platform::kPtrace) { +// ... // } // +// SetupGvisorDeathTest ensures that signal handling does not interfere with +/// tests that rely on fatal signals. +// // Matchers // ======== // @@ -198,6 +195,8 @@ namespace gvisor { namespace testing { +constexpr char kTestOnGvisor[] = "TEST_ON_GVISOR"; + // TestInit must be called prior to RUN_ALL_TESTS. // // This parses all arguments and adjusts argc and argv appropriately. @@ -213,15 +212,24 @@ void TestInit(int* argc, char*** argv); if (expr) GTEST_SKIP() << #expr; \ } while (0) -enum class Platform { - kNative, - kKVM, - kPtrace, -}; +// Platform contains platform names. +namespace Platform { +constexpr char kNative[] = "native"; +constexpr char kPtrace[] = "ptrace"; +constexpr char kKVM[] = "kvm"; +constexpr char kFuchsia[] = "fuchsia"; +} // namespace Platform + bool IsRunningOnGvisor(); -Platform GvisorPlatform(); +const std::string GvisorPlatform(); +bool IsRunningWithHostinet(); +// TODO(gvisor.dev/issue/1624): Delete once VFS1 is gone. +bool IsRunningWithVFS1(); +bool IsFUSEEnabled(); +#ifdef __linux__ void SetupGvisorDeathTest(); +#endif struct KernelVersion { int major; @@ -560,6 +568,25 @@ ssize_t ApplyFileIoSyscall(F const& f, size_t const count) { } // namespace internal +inline PosixErrorOr<std::string> ReadAllFd(int fd) { + std::string all; + all.reserve(128 * 1024); // arbitrary. + + std::vector<char> buffer(16 * 1024); + for (;;) { + auto const bytes = RetryEINTR(read)(fd, buffer.data(), buffer.size()); + if (bytes < 0) { + return PosixError(errno, "file read"); + } + if (bytes == 0) { + return std::move(all); + } + if (bytes > 0) { + all.append(buffer.data(), bytes); + } + } +} + inline ssize_t ReadFd(int fd, void* buf, size_t count) { return internal::ApplyFileIoSyscall( [&](size_t completed) { @@ -762,7 +789,14 @@ MATCHER_P2(EquivalentWithin, target, tolerance, return Equivalent(arg, target, tolerance); } +// Returns the absolute path to the a data dependency. 'path' is the runfile +// location relative to workspace root. +#ifdef __linux__ +std::string RunfilePath(std::string path); +#endif + void TestInit(int* argc, char*** argv); +int RunAllTests(void); } // namespace testing } // namespace gvisor diff --git a/test/util/test_util_impl.cc b/test/util/test_util_impl.cc new file mode 100644 index 000000000..7e1ad9e66 --- /dev/null +++ b/test/util/test_util_impl.cc @@ -0,0 +1,52 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <signal.h> + +#include "gtest/gtest.h" +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" + +extern bool FLAGS_benchmark_list_tests; +extern std::string FLAGS_benchmark_filter; + +namespace gvisor { +namespace testing { + +void SetupGvisorDeathTest() {} + +void TestInit(int* argc, char*** argv) { + ::testing::InitGoogleTest(argc, *argv); + benchmark::Initialize(argc, *argv); + ::absl::ParseCommandLine(*argc, *argv); + + // Always mask SIGPIPE as it's common and tests aren't expected to handle it. + struct sigaction sa = {}; + sa.sa_handler = SIG_IGN; + TEST_CHECK(sigaction(SIGPIPE, &sa, nullptr) == 0); +} + +int RunAllTests() { + if (FLAGS_benchmark_list_tests || FLAGS_benchmark_filter != ".") { + benchmark::RunSpecifiedBenchmarks(); + return 0; + } else { + return RUN_ALL_TESTS(); + } +} + +} // namespace testing +} // namespace gvisor diff --git a/test/util/test_util_runfiles.cc b/test/util/test_util_runfiles.cc new file mode 100644 index 000000000..694d21692 --- /dev/null +++ b/test/util/test_util_runfiles.cc @@ -0,0 +1,50 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef __fuchsia__ + +#include <iostream> +#include <string> + +#include "test/util/fs_util.h" +#include "test/util/test_util.h" +#include "tools/cpp/runfiles/runfiles.h" + +namespace gvisor { +namespace testing { + +std::string RunfilePath(std::string path) { + static const bazel::tools::cpp::runfiles::Runfiles* const runfiles = [] { + std::string error; + auto* runfiles = + bazel::tools::cpp::runfiles::Runfiles::CreateForTest(&error); + if (runfiles == nullptr) { + std::cerr << "Unable to find runfiles: " << error << std::endl; + } + return runfiles; + }(); + + if (!runfiles) { + // Can't find runfiles? This probably won't work, but __main__/path is our + // best guess. + return JoinPath("__main__", path); + } + + return runfiles->Rlocation(JoinPath("__main__", path)); +} + +} // namespace testing +} // namespace gvisor + +#endif // __fuchsia__ diff --git a/test/util/test_util_test.cc b/test/util/test_util_test.cc index b7300d9e5..f42100374 100644 --- a/test/util/test_util_test.cc +++ b/test/util/test_util_test.cc @@ -15,6 +15,7 @@ #include "test/util/test_util.h" #include <errno.h> + #include <vector> #include "gmock/gmock.h" |